Skip to content

Entrainment Plotting API

bmtool.bmplot.entrainment.plot_spike_power_correlation(spike_df, lfp_data, firing_quantile, fs, pop_names, filter_method='wavelet', bandwidth=2.0, lowcut=None, highcut=None, freq_range=(10, 100), freq_step=5, type_name='raw', time_windows=None, error_type='ci')

Calculate and plot correlation between population spike rates and LFP power across frequencies. Supports both single-signal and trial-based analysis with error bars.

Parameters:

Name Type Description Default
spike_rate DataArray

Population spike rates with dimensions (time, population[, type])

required
lfp_data DataArray

LFP data

required
fs float

Sampling frequency

required
pop_names list

List of population names to analyze

required
filter_method str

Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')

'wavelet'
bandwidth float

Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)

2.0
lowcut float

Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'

None
highcut float

Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'

None
freq_range tuple

Min and max frequency to analyze (default: (10, 100))

(10, 100)
freq_step float

Step size for frequency analysis (default: 5)

5
type_name str

Which type of spike rate to use if 'type' dimension exists (default: 'raw')

'raw'
time_windows list

List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal

None
error_type str

Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation

'ci'
Source code in bmtool/bmplot/entrainment.py
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
def plot_spike_power_correlation(
    spike_df: pd.DataFrame,
    lfp_data: xr.DataArray,
    firing_quantile: float,
    fs: float,
    pop_names: list,
    filter_method: str = "wavelet",
    bandwidth: float = 2.0,
    lowcut: float = None,
    highcut: float = None,
    freq_range: tuple = (10, 100),
    freq_step: float = 5,
    type_name: str = "raw",
    time_windows: list = None,
    error_type: str = "ci",  # New parameter: "ci" for confidence interval, "sem" for standard error, "std" for standard deviation
):
    """
    Calculate and plot correlation between population spike rates and LFP power across frequencies.
    Supports both single-signal and trial-based analysis with error bars.

    Parameters
    ----------
    spike_rate : xr.DataArray
        Population spike rates with dimensions (time, population[, type])
    lfp_data : xr.DataArray
        LFP data
    fs : float
        Sampling frequency
    pop_names : list
        List of population names to analyze
    filter_method : str, optional
        Filtering method to use, either 'wavelet' or 'butter' (default: 'wavelet')
    bandwidth : float, optional
        Bandwidth parameter for wavelet filter when method='wavelet' (default: 2.0)
    lowcut : float, optional
        Lower frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
    highcut : float, optional
        Upper frequency bound (Hz) for butterworth bandpass filter, required if filter_method='butter'
    freq_range : tuple, optional
        Min and max frequency to analyze (default: (10, 100))
    freq_step : float, optional
        Step size for frequency analysis (default: 5)
    type_name : str, optional
        Which type of spike rate to use if 'type' dimension exists (default: 'raw')
    time_windows : list, optional
        List of (start, end) time tuples for trial-based analysis. If None, analyze entire signal
    error_type : str, optional
        Type of error bars to plot: "ci" for 95% confidence interval, "sem" for standard error, "std" for standard deviation
    """

    if not (0 <= firing_quantile < 1):
        raise ValueError("firing_quantile must be between 0 and 1")

    if error_type not in ["ci", "sem", "std"]:
        raise ValueError(
            "error_type must be 'ci' for confidence interval, 'sem' for standard error, or 'std' for standard deviation"
        )

    # Setup
    is_trial_based = time_windows is not None

    # Convert spike_df to spike rate with trial-based filtering of high firing cells
    if is_trial_based:
        # Initialize storage for trial-based spike rates
        trial_rates = []

        for start_time, end_time in time_windows:
            # Get spikes for this trial
            trial_spikes = spike_df[
                (spike_df["timestamps"] >= start_time) & (spike_df["timestamps"] <= end_time)
            ].copy()

            # Filter for high firing cells within this trial
            trial_spikes = bmspikes.find_highest_firing_cells(
                trial_spikes, upper_quantile=firing_quantile
            )
            # Calculate rate for this trial's filtered spikes
            trial_rate = bmspikes.get_population_spike_rate(
                trial_spikes, fs=fs, t_start=start_time, t_stop=end_time
            )
            trial_rates.append(trial_rate)

        # Combine all trial rates
        spike_rate = xr.concat(trial_rates, dim="trial")
    else:
        # For non-trial analysis, proceed as before
        spike_df = bmspikes.find_highest_firing_cells(spike_df, upper_quantile=firing_quantile)
        spike_rate = bmspikes.get_population_spike_rate(spike_df)

    # Setup frequencies for analysis
    frequencies = np.arange(freq_range[0], freq_range[1] + 1, freq_step)

    # Pre-calculate LFP power for all frequencies
    power_by_freq = {}
    for freq in frequencies:
        power_by_freq[freq] = get_lfp_power(
            lfp_data, freq, fs, filter_method, lowcut=lowcut, highcut=highcut, bandwidth=bandwidth
        )

    # Calculate correlations
    results = {}
    for pop in pop_names:
        pop_spike_rate = spike_rate.sel(population=pop, type=type_name)
        results[pop] = {}

        for freq in frequencies:
            lfp_power = power_by_freq[freq]

            if not is_trial_based:
                # Single signal analysis
                if len(pop_spike_rate) != len(lfp_power):
                    print(f"Warning: Length mismatch for {pop} at {freq} Hz")
                    continue

                corr, p_val = stats.spearmanr(pop_spike_rate, lfp_power)
                results[pop][freq] = {
                    "correlation": corr,
                    "p_value": p_val,
                }
            else:
                # Trial-based analysis using pre-filtered trial rates
                trial_correlations = []

                for trial_idx in range(len(time_windows)):
                    # Get time window first
                    start_time, end_time = time_windows[trial_idx]

                    # Get the pre-filtered spike rate for this trial
                    trial_spike_rate = pop_spike_rate.sel(trial=trial_idx)

                    # Get corresponding LFP power for this trial window
                    trial_lfp_power = lfp_power.sel(time=slice(start_time, end_time))

                    # Ensure both signals have same time points
                    common_times = np.intersect1d(trial_spike_rate.time, trial_lfp_power.time)

                    if len(common_times) > 0:
                        trial_sr = trial_spike_rate.sel(time=common_times).values
                        trial_lfp = trial_lfp_power.sel(time=common_times).values

                        if (
                            len(trial_sr) > 1 and len(trial_lfp) > 1
                        ):  # Need at least 2 points for correlation
                            corr, _ = stats.spearmanr(trial_sr, trial_lfp)
                            if not np.isnan(corr):
                                trial_correlations.append(corr)

                # Calculate trial statistics
                if len(trial_correlations) > 0:
                    trial_correlations = np.array(trial_correlations)
                    mean_corr = np.mean(trial_correlations)

                    if len(trial_correlations) > 1:
                        if error_type == "ci":
                            # Calculate 95% confidence interval using t-distribution
                            df = len(trial_correlations) - 1
                            sem = stats.sem(trial_correlations)
                            t_critical = stats.t.ppf(0.975, df)  # 95% CI, two-tailed
                            error_val = t_critical * sem
                            error_lower = mean_corr - error_val
                            error_upper = mean_corr + error_val
                        elif error_type == "sem":
                            # Calculate standard error of the mean
                            sem = stats.sem(trial_correlations)
                            error_lower = mean_corr - sem
                            error_upper = mean_corr + sem
                        elif error_type == "std":
                            # Calculate standard deviation
                            std = np.std(trial_correlations, ddof=1)
                            error_lower = mean_corr - std
                            error_upper = mean_corr + std
                    else:
                        error_lower = error_upper = mean_corr

                    results[pop][freq] = {
                        "correlation": mean_corr,
                        "error_lower": error_lower,
                        "error_upper": error_upper,
                        "n_trials": len(trial_correlations),
                        "trial_correlations": trial_correlations,
                    }
                else:
                    # No valid trials
                    results[pop][freq] = {
                        "correlation": np.nan,
                        "error_lower": np.nan,
                        "error_upper": np.nan,
                        "n_trials": 0,
                        "trial_correlations": np.array([]),
                    }

    # Plotting
    sns.set_style("whitegrid")
    plt.figure(figsize=(12, 8))

    for i, pop in enumerate(pop_names):
        # Extract data for plotting
        plot_freqs = []
        plot_corrs = []
        plot_ci_lower = []
        plot_ci_upper = []

        for freq in frequencies:
            if freq in results[pop] and not np.isnan(results[pop][freq]["correlation"]):
                plot_freqs.append(freq)
                plot_corrs.append(results[pop][freq]["correlation"])

                if is_trial_based:
                    plot_ci_lower.append(results[pop][freq]["error_lower"])
                    plot_ci_upper.append(results[pop][freq]["error_upper"])

        if len(plot_freqs) == 0:
            continue

        # Convert to arrays
        plot_freqs = np.array(plot_freqs)
        plot_corrs = np.array(plot_corrs)

        # Get color for this population
        colors = plt.get_cmap("tab10")
        color = colors(i)

        # Plot main line
        plt.plot(
            plot_freqs, plot_corrs, marker="o", label=pop, linewidth=2, markersize=6, color=color
        )

        # Plot error bands for trial-based analysis
        if is_trial_based and len(plot_ci_lower) > 0:
            plot_ci_lower = np.array(plot_ci_lower)
            plot_ci_upper = np.array(plot_ci_upper)
            plt.fill_between(plot_freqs, plot_ci_lower, plot_ci_upper, alpha=0.2, color=color)

    # Formatting
    plt.xlabel("Frequency (Hz)", fontsize=12)
    plt.ylabel("Spike Rate-Power Correlation", fontsize=12)

    # Calculate percentage for title
    firing_percentage = round(float((1 - firing_quantile) * 100), 1)
    if is_trial_based:
        title = f"Trial-averaged Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells (95% CI)"
    else:
        title = f"Spike Rate-LFP Power Correlation\nTop {firing_percentage}% Firing Cells"

    plt.title(title, fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.axhline(y=0, color="gray", linestyle="-", alpha=0.5)

    # Legend
    # Create legend elements for each population
    from matplotlib.lines import Line2D

    colors = plt.get_cmap("tab10")
    legend_elements = [
        Line2D([0], [0], color=colors(i), marker="o", linestyle="-", label=pop)
        for i, pop in enumerate(pop_names)
    ]

    # Add error band legend element for trial-based analysis
    if is_trial_based:
        # Map error type to legend label
        error_labels = {"ci": "95% CI", "sem": "±SEM", "std": "±1 SD"}
        error_label = error_labels[error_type]

        legend_elements.append(
            Line2D([0], [0], color="gray", alpha=0.3, linewidth=10, label=error_label)
        )

    plt.legend(handles=legend_elements, fontsize=10, loc="best")

    # Axis formatting
    if len(frequencies) > 10:
        plt.xticks(frequencies[::2])
    else:
        plt.xticks(frequencies)
    plt.xlim(frequencies[0], frequencies[-1])

    y_min, y_max = plt.ylim()
    plt.ylim(min(y_min, -0.1), max(y_max, 0.1))

    plt.tight_layout()
    plt.show()