Skip to content

Entrainment Plotting API

bmtool.bmplot.entrainment.plot_spike_power_correlation(spike_df, lfp_data, fs, pop_names, filter_method='wavelet', bandwidth=2.0, lowcut=None, highcut=None, freq_range=(10, 100), freq_step=5, type_name='raw', figsize=(12, 8))

Calculate and plot spike rate-LFP power correlation across frequencies for full signal.

Analyzes the relationship between population spike rates and LFP power across a range of frequencies, using Spearman correlation for the entire signal duration.

Parameters:

Name Type Description Default
spike_df DataFrame

DataFrame containing spike data with columns 'timestamps', 'node_ids', and 'pop_name'.

required
lfp_data DataArray

LFP data with time dimension.

required
fs float

Sampling frequency in Hz.

required
pop_names List[str]

List of population names to analyze.

required
filter_method str

Filtering method: 'wavelet' or 'butter' (default: 'wavelet').

'wavelet'
bandwidth float

Bandwidth parameter for wavelet filter (default: 2.0).

2.0
lowcut float

Lower frequency bound (Hz) for butterworth filter. Required if filter_method='butter'.

None
highcut float

Upper frequency bound (Hz) for butterworth filter. Required if filter_method='butter'.

None
freq_range Tuple[float, float]

Min and max frequency to analyze in Hz (default: (10, 100)).

(10, 100)
freq_step float

Step size for frequency analysis in Hz (default: 5).

5
type_name str

Which type of spike rate to use (default: 'raw').

'raw'
figsize Tuple[float, float]

Figure size (width, height) in inches (default: (12, 8)).

(12, 8)

Returns:

Type Description
Figure

Figure containing the correlation plot.

Notes
  • Uses Spearman correlation (rank-based, robust to outliers).
  • Pre-computes LFP power at all frequencies for efficiency.

Examples:

>>> fig = plot_spike_power_correlation(
...     spike_df=spike_df,
...     lfp_data=lfp,
...     fs=400,
...     pop_names=['PV', 'SST'],
...     freq_range=(10, 100),
...     freq_step=5
... )
Source code in bmtool/bmplot/entrainment.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
def plot_spike_power_correlation(
    spike_df: pd.DataFrame,
    lfp_data: xr.DataArray,
    fs: float,
    pop_names: List[str],
    filter_method: str = "wavelet",
    bandwidth: float = 2.0,
    lowcut: Optional[float] = None,
    highcut: Optional[float] = None,
    freq_range: Tuple[float, float] = (10, 100),
    freq_step: float = 5,
    type_name: str = "raw",
    figsize: Tuple[float, float] = (12, 8),
) -> Figure:
    """
    Calculate and plot spike rate-LFP power correlation across frequencies for full signal.

    Analyzes the relationship between population spike rates and LFP power across a range
    of frequencies, using Spearman correlation for the entire signal duration.

    Parameters
    ----------
    spike_df : pd.DataFrame
        DataFrame containing spike data with columns 'timestamps', 'node_ids', and 'pop_name'.
    lfp_data : xr.DataArray
        LFP data with time dimension.
    fs : float
        Sampling frequency in Hz.
    pop_names : List[str]
        List of population names to analyze.
    filter_method : str, optional
        Filtering method: 'wavelet' or 'butter' (default: 'wavelet').
    bandwidth : float, optional
        Bandwidth parameter for wavelet filter (default: 2.0).
    lowcut : float, optional
        Lower frequency bound (Hz) for butterworth filter. Required if filter_method='butter'.
    highcut : float, optional
        Upper frequency bound (Hz) for butterworth filter. Required if filter_method='butter'.
    freq_range : Tuple[float, float], optional
        Min and max frequency to analyze in Hz (default: (10, 100)).
    freq_step : float, optional
        Step size for frequency analysis in Hz (default: 5).
    type_name : str, optional
        Which type of spike rate to use (default: 'raw').
    figsize : Tuple[float, float], optional
        Figure size (width, height) in inches (default: (12, 8)).

    Returns
    -------
    matplotlib.figure.Figure
        Figure containing the correlation plot.

    Notes
    -----
    - Uses Spearman correlation (rank-based, robust to outliers).
    - Pre-computes LFP power at all frequencies for efficiency.

    Examples
    --------
    >>> fig = plot_spike_power_correlation(
    ...     spike_df=spike_df,
    ...     lfp_data=lfp,
    ...     fs=400,
    ...     pop_names=['PV', 'SST'],
    ...     freq_range=(10, 100),
    ...     freq_step=5
    ... )
    """
    # Compute spike rate for all spikes
    spike_rate = bmspikes.get_population_spike_rate(spike_df, fs=fs)

    # Setup frequencies for analysis
    frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)

    # Pre-calculate LFP power for all frequencies
    power_by_freq = {}
    for freq in frequencies:
        power_by_freq[freq] = get_lfp_power(
            lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
        )

    # Calculate correlations for each population and frequency
    results = {}
    for pop in pop_names:
        results[pop] = {}
        pop_spike_rate = spike_rate.sel(population=pop, type=type_name)

        for freq in frequencies:
            lfp_power = power_by_freq[freq]

            if len(pop_spike_rate) != len(lfp_power):
                print(f"Warning: Length mismatch for {pop} at {freq} Hz")
                print(f"{len(pop_spike_rate)} {len(lfp_power)}")
                continue

            corr, p_val = stats.spearmanr(pop_spike_rate.values, lfp_power.values)
            results[pop][freq] = {"correlation": corr, "p_value": p_val}

    # Create plot
    sns.set_style("whitegrid")
    fig = plt.figure(figsize=figsize)

    colors = plt.get_cmap("tab10")
    for i, pop in enumerate(pop_names):
        plot_freqs = []
        plot_corrs = []

        for freq in frequencies:
            if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
                plot_freqs.append(freq)
                plot_corrs.append(results[pop][freq]["correlation"])

        if len(plot_freqs) == 0:
            continue

        plot_freqs = np.array(plot_freqs)
        plot_corrs = np.array(plot_corrs)
        color = colors(i)

        plt.plot(
            plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
        )

    # Formatting
    plt.xlabel("Frequency (Hz)", fontsize=12)
    plt.ylabel("Spike Rate-Power Correlation", fontsize=12)

    plt.title(
        "Spike Rate-LFP Power Correlation",
        fontsize=14,
    )
    plt.grid(True, alpha=0.3)
    plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)

    # Setup legend
    from matplotlib.lines import Line2D

    legend_elements = [
        Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
        for i, pop in enumerate(pop_names)
    ]
    plt.legend(handles=legend_elements, fontsize=10, loc="best")

    # Axis formatting
    if len(frequencies) > 10:
        plt.xticks(frequencies[::2])
    else:
        plt.xticks(frequencies)
    plt.xlim(frequencies[0], frequencies[-1])

    y_min, y_max = plt.ylim()
    plt.ylim(min(y_min, -0.1), max(y_max, 0.1))

    plt.tight_layout()
    return fig