Skip to content

LFP/ECP Plotting API

bmtool.bmplot.lfp.plot_spectrogram(sxx_xarray, remove_aperiodic=None, log_power=False, plt_range=None, clr_freq_range=None, pad=0.03, ax=None, vmin=None, vmax=None)

Plot a power spectrogram with optional aperiodic removal and frequency-based coloring.

Parameters:

Name Type Description Default
sxx_xarray array - like

Spectrogram data as an xarray DataArray with PSD values.

required
remove_aperiodic optional

FOOOF model object for aperiodic subtraction. If None, raw spectrum is displayed.

None
log_power bool or str

If True or 'dB', convert power to log scale. Default is False.

False
plt_range tuple of float

Frequency range to display as (f_min, f_max). If None, displays full range.

None
clr_freq_range tuple of float

Frequency range to use for determining color limits. If None, uses full range.

None
pad float

Padding for colorbar. Default is 0.03.

0.03
ax Axes

Axes to plot on. If None, creates a new figure and axes.

None
vmin float

Minimum value for colorbar scaling. If None, computed from data.

None
vmax float

Maximum value for colorbar scaling. If None, computed from data.

None

Returns:

Type Description
Figure

The figure object containing the spectrogram.

Examples:

>>> fig = plot_spectrogram(
...     sxx_xarray, log_power='dB',
...     plt_range=(10, 100), clr_freq_range=(20, 50)
... )
Source code in bmtool/bmplot/lfp.py
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
def plot_spectrogram(
    sxx_xarray: Any,
    remove_aperiodic: Optional[Any] = None,
    log_power: bool = False,
    plt_range: Optional[Tuple[float, float]] = None,
    clr_freq_range: Optional[Tuple[float, float]] = None,
    pad: float = 0.03,
    ax: Optional[plt.Axes] = None,
    vmin: Optional[float] = None,
    vmax: Optional[float] = None,
) -> Figure:
    """
    Plot a power spectrogram with optional aperiodic removal and frequency-based coloring.

    Parameters
    ----------
    sxx_xarray : array-like
        Spectrogram data as an xarray DataArray with PSD values.
    remove_aperiodic : optional
        FOOOF model object for aperiodic subtraction. If None, raw spectrum is displayed.
    log_power : bool or str, optional
        If True or 'dB', convert power to log scale. Default is False.
    plt_range : tuple of float, optional
        Frequency range to display as (f_min, f_max). If None, displays full range.
    clr_freq_range : tuple of float, optional
        Frequency range to use for determining color limits. If None, uses full range.
    pad : float, optional
        Padding for colorbar. Default is 0.03.
    ax : matplotlib.axes.Axes, optional
        Axes to plot on. If None, creates a new figure and axes.
    vmin : float, optional
        Minimum value for colorbar scaling. If None, computed from data.
    vmax : float, optional
        Maximum value for colorbar scaling. If None, computed from data.

    Returns
    -------
    matplotlib.figure.Figure
        The figure object containing the spectrogram.

    Examples
    --------
    >>> fig = plot_spectrogram(
    ...     sxx_xarray, log_power='dB',
    ...     plt_range=(10, 100), clr_freq_range=(20, 50)
    ... )
    """
    sxx = sxx_xarray.PSD.values.copy()
    t = sxx_xarray.time.values.copy()
    f = sxx_xarray.frequency.values.copy()

    cbar_label = "PSD" if remove_aperiodic is None else "PSD Residual"
    if log_power:
        with np.errstate(divide="ignore"):
            sxx = np.log10(sxx)
        cbar_label += " dB" if log_power == "dB" else " log(power)"

    if remove_aperiodic is not None:
        f1_idx = 0 if f[0] else 1
        ap_fit = gen_aperiodic(f[f1_idx:], remove_aperiodic.aperiodic_params)
        sxx[f1_idx:, :] -= (ap_fit if log_power else 10**ap_fit)[:, None]
        sxx[:f1_idx, :] = 0.0

    if log_power == "dB":
        sxx *= 10

    if ax is None:
        _, ax = plt.subplots(1, 1)
    plt_range = np.array(f[-1]) if plt_range is None else np.array(plt_range)
    if plt_range.size == 1:
        plt_range = [f[0 if f[0] else 1] if log_power else 0.0, plt_range.item()]
    f_idx = (f >= plt_range[0]) & (f <= plt_range[1])

    # Determine vmin and vmax: explicit parameters take precedence, then clr_freq_range, then None
    if vmin is None:
        if clr_freq_range is not None:
            c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
            vmin = sxx[c_idx, :].min()

    if vmax is None:
        if clr_freq_range is not None:
            c_idx = (f >= clr_freq_range[0]) & (f <= clr_freq_range[1])
            vmax = sxx[c_idx, :].max()

    f = f[f_idx]
    pcm = ax.pcolormesh(t, f, sxx[f_idx, :], shading="gouraud", vmin=vmin, vmax=vmax, rasterized=True,cmap='viridis')
    if "cone_of_influence_frequency" in sxx_xarray:
        coif = sxx_xarray.cone_of_influence_frequency
        ax.plot(t, coif)
        ax.fill_between(t, coif, step="mid", alpha=0.2)
    ax.set_xlim(t[0], t[-1])
    # ax.set_xlim(t[0],0.2)
    ax.set_ylim(f[0], f[-1])
    plt.colorbar(mappable=pcm, ax=ax, label=cbar_label, pad=pad)
    ax.set_xlabel("Time (sec)")
    ax.set_ylabel("Frequency (Hz)")
    return ax.figure

bmtool.bmplot.lfp.plot_population_spike_rates_with_lfp(spikes_df, lfp, freq_of_interest, freq_labels, freq_colors, time_range, pop_names, pop_color, pop_groups=None, FR_type='smoothed', stimulus_time=None)

Plot population spike rates with LFP power overlays, with optional trial averaging.

Parameters:

Name Type Description Default
spikes_df DataFrame

DataFrame with spike data.

required
lfp array - like

LFP data (xarray or similar format).

required
freq_of_interest list of float

List of frequencies for LFP power analysis (required).

required
freq_labels list of str

Labels for the frequencies (required).

required
freq_colors list of str

Colors for the frequency plots (required).

required
time_range tuple of float or list of tuple

If tuple (start, end): plots continuous data in that time range. If list of tuples: trial times for averaging. E.g., [(1000,2000), (2500,3500)]. For trial averaging, mean is computed across trials (required).

required
pop_names list of str

List of population names (required).

required
pop_color dict

Dictionary mapping population names to colors (required).

required
pop_groups list of list of str

List of population groups to plot on the same subplot. E.g., [['PV', 'SST'], ['ET', 'IT']] plots PV and SST on one plot, ET and IT on another. If None, each population gets its own subplot (default).

None
FR_type str

Type of firing rate to plot ('raw', 'smoothed', etc.). Default is 'smoothed'.

'smoothed'
stimulus_time float

Time of stimulus onset. For trial averaging: relative to the start of the trial window (e.g., stimulus_time=200 means 200ms after trial start). For continuous plots: absolute time value (e.g., stimulus_time=2500 means stimulus at 2500ms). When provided, the x-axis will be relative to stimulus time (0 = stimulus onset). Default is None.

None

Returns:

Type Description
Figure or None

Figure object containing the plot, or None if no data to plot.

Examples:

>>> # Continuous plot
>>> fig = plot_population_spike_rates_with_lfp(
...     spikes_df, lfp, [40, 80], ['Beta', 'Gamma'],
...     ['blue', 'red'], (0, 10), ['PV', 'SST'],
...     {'PV': 'blue', 'SST': 'red'},
...     pop_groups=[['PV', 'SST']]
... )
>>> # Trial-averaged plot
>>> fig = plot_population_spike_rates_with_lfp(
...     spikes_df, lfp, [40, 80], ['Beta', 'Gamma'],
...     ['blue', 'red'], [(1000,2000), (2500,3500)], ['PV', 'SST'],
...     {'PV': 'blue', 'SST': 'red'},
...     pop_groups=[['PV', 'SST']]
... )
Source code in bmtool/bmplot/lfp.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
def plot_population_spike_rates_with_lfp(
    spikes_df: pd.DataFrame,
    lfp: Any,
    freq_of_interest: List[float],
    freq_labels: List[str],
    freq_colors: List[str],
    time_range: Any,
    pop_names: List[str],
    pop_color: Dict[str, str],
    pop_groups: Optional[List[List[str]]] = None,
    FR_type: str = 'smoothed',
    stimulus_time: Optional[float] = None,
) -> Optional[Figure]:
    """
    Plot population spike rates with LFP power overlays, with optional trial averaging.

    Parameters
    ----------
    spikes_df : pd.DataFrame
        DataFrame with spike data.
    lfp : array-like
        LFP data (xarray or similar format).
    freq_of_interest : list of float
        List of frequencies for LFP power analysis (required).
    freq_labels : list of str
        Labels for the frequencies (required).
    freq_colors : list of str
        Colors for the frequency plots (required).
    time_range : tuple of float or list of tuple
        If tuple (start, end): plots continuous data in that time range.
        If list of tuples: trial times for averaging. E.g., [(1000,2000), (2500,3500)].
        For trial averaging, mean is computed across trials (required).
    pop_names : list of str
        List of population names (required).
    pop_color : dict
        Dictionary mapping population names to colors (required).
    pop_groups : list of list of str, optional
        List of population groups to plot on the same subplot. 
        E.g., [['PV', 'SST'], ['ET', 'IT']] plots PV and SST on one plot, ET and IT on another.
        If None, each population gets its own subplot (default).
    FR_type : str, optional
        Type of firing rate to plot ('raw', 'smoothed', etc.). Default is 'smoothed'.
    stimulus_time : float, optional
        Time of stimulus onset. 
        For trial averaging: relative to the start of the trial window (e.g., stimulus_time=200 means 200ms after trial start).
        For continuous plots: absolute time value (e.g., stimulus_time=2500 means stimulus at 2500ms).
        When provided, the x-axis will be relative to stimulus time (0 = stimulus onset). Default is None.

    Returns
    -------
    matplotlib.figure.Figure or None
        Figure object containing the plot, or None if no data to plot.

    Examples
    --------
    >>> # Continuous plot
    >>> fig = plot_population_spike_rates_with_lfp(
    ...     spikes_df, lfp, [40, 80], ['Beta', 'Gamma'],
    ...     ['blue', 'red'], (0, 10), ['PV', 'SST'],
    ...     {'PV': 'blue', 'SST': 'red'},
    ...     pop_groups=[['PV', 'SST']]
    ... )

    >>> # Trial-averaged plot
    >>> fig = plot_population_spike_rates_with_lfp(
    ...     spikes_df, lfp, [40, 80], ['Beta', 'Gamma'],
    ...     ['blue', 'red'], [(1000,2000), (2500,3500)], ['PV', 'SST'],
    ...     {'PV': 'blue', 'SST': 'red'},
    ...     pop_groups=[['PV', 'SST']]
    ... )
    """
    # Compute spike rates
    spike_rate = get_population_spike_rate(spikes_df, fs=400, network_name='cortex')

    # Compute power for each frequency of interest
    powers = [
        get_lfp_power(lfp, freq_of_interest=freq, fs=lfp.fs, filter_method="wavelet", bandwidth=1.0)
        for freq in freq_of_interest
    ]

    # Determine if we're doing trial averaging
    is_trial_avg = isinstance(time_range, list) and len(time_range) > 0 and isinstance(time_range[0], tuple)

    # Extract and align trials if needed
    spike_rate_trials: Optional[List] = None
    power_trials: Optional[List] = None
    target_length: Optional[int] = None
    trial_start: float = 0.0
    trial_duration: float = 0.0
    if is_trial_avg:
        spike_rate_trials, power_trials, trial_times = _extract_trials(
            spike_rate, powers, time_range
        )
        # trial_times from _extract_trials is normalized (0 to 1)
        # Convert to actual milliseconds based on first trial duration
        trial_start = float(time_range[0][0])
        trial_end = float(time_range[0][1])
        trial_duration = trial_end - trial_start

        # Convert normalized times to milliseconds
        plot_time = trial_times * trial_duration

        # Adjust for stimulus if provided (stimulus_time is relative to trial start)
        if stimulus_time is not None:
            plot_time = plot_time - stimulus_time

        target_length = len(trial_times)
    else:
        # For continuous plots, time_range is a tuple (start, end)
        # We'll just pass the time_range for now; actual shifting happens during plotting
        plot_time = time_range

    # Determine plot groups
    if pop_groups is None:
        # Default: each population gets its own subplot
        plot_groups = [[pop] for pop in pop_names]
    else:
        plot_groups = pop_groups

    # Plotting
    num_subplots = len(plot_groups)
    fig, axes = plt.subplots(num_subplots, 1, figsize=(12, 3.5 * num_subplots))
    if num_subplots == 1:
        axes = [axes]

    for ax_idx, group in enumerate(plot_groups):
        ax = axes[ax_idx]

        # Filter valid populations in this group
        valid_pops = [pop for pop in group if pop in spike_rate.population.values]

        if not valid_pops:
            continue

        # Plot spike rates for each population in the group
        fr_handles = []
        if is_trial_avg and spike_rate_trials is not None:
            # Plot trial-averaged firing rates with SEM shading
            for pop in valid_pops:
                fr_mean, fr_sem = _compute_trial_average(spike_rate_trials, pop, FR_type, target_length=target_length)
                line, = ax.plot(plot_time, fr_mean,
                               color=pop_color[pop], 
                               label=f'{pop} FR',
                               linewidth=2)
                ax.fill_between(plot_time, fr_mean - fr_sem, fr_mean + fr_sem,
                               color=pop_color[pop], alpha=0.2)
                fr_handles.append(line)
        else:
            # Plot continuous firing rates
            plot_time_values = spike_rate.time.values
            if stimulus_time is not None and not is_trial_avg:
                # Shift time axis relative to stimulus
                plot_time_values = plot_time_values - stimulus_time

            for pop in valid_pops:
                line, = ax.plot(plot_time_values, 
                               spike_rate.sel(type=FR_type, population=pop).values,
                               color=pop_color[pop], 
                               label=f'{pop} FR',
                               linewidth=2)
                fr_handles.append(line)

        # Set labels and title
        group_title = ' + '.join(valid_pops)
        avg_text = ' (Trial Avg)' if is_trial_avg else ''
        ax.set_title(group_title + avg_text, fontsize=12)
        ax.set_ylabel('Spike Rate (Hz)', fontsize=11)
        ax.tick_params(axis='y')

        # Twin axis for LFP power
        ax2 = ax.twinx()
        lfp_handles = []
        if is_trial_avg and power_trials is not None:
            # Plot trial-averaged LFP power with SEM shading
            for power_trial, label, color in zip(power_trials, freq_labels, freq_colors):
                power_mean, power_sem = _compute_trial_average_power(power_trial, target_length=target_length)
                line, = ax2.plot(plot_time, power_mean,
                                color=color, label=label, linestyle='--', linewidth=2)
                ax2.fill_between(plot_time, power_mean - power_sem, power_mean + power_sem,
                                color=color, alpha=0.1)
                lfp_handles.append(line)
        else:
            # Plot continuous LFP power
            for power, label, color in zip(powers, freq_labels, freq_colors):
                plot_time_lfp = power['time'].values
                if stimulus_time is not None and not is_trial_avg:
                    # Shift time axis relative to stimulus
                    plot_time_lfp = plot_time_lfp - stimulus_time

                line, = ax2.plot(plot_time_lfp, power.values.squeeze(), 
                                color=color, label=label, linestyle='--', linewidth=2)
                lfp_handles.append(line)

        ax2.set_ylabel('LFP Power', fontsize=11)
        ax2.tick_params(axis='y')

        # Combined legend
        all_handles = fr_handles + lfp_handles
        all_labels = [h.get_label() for h in all_handles]
        ax.legend(all_handles, all_labels, loc='upper right', fontsize=10)

        if is_trial_avg:
            ax.set_xlim(plot_time[0], plot_time[-1])
            if stimulus_time is not None:
                ax.set_xlabel('Time relative to stimulus (ms)', fontsize=11)
            else:
                ax.set_xlabel('Time from trial start (ms)', fontsize=11)
        else:
            # For continuous plots
            if stimulus_time is not None:
                # Shift xlim by stimulus time
                xlim = (plot_time[0] - stimulus_time, plot_time[1] - stimulus_time)
                ax.set_xlim(xlim)
                ax.set_xlabel('Time relative to stimulus (ms)', fontsize=11)
            else:
                ax.set_xlim(plot_time)
                ax.set_xlabel('Time (ms)', fontsize=11)

    plt.tight_layout()
    return fig

bmtool.bmplot.lfp.plot_spike_rate_coherence(spike_rates, fooof_params=None, plt_range=None, plt_log=False, plt_db=True, figsize=(10, 3), ax=None)

Plot coherence between spike rate populations.

Computes coherence exactly like Analyze_PSD_ziao notebook: calculates coherence between population pairs and applies FOOOF fitting to coherence spectra.

Parameters:

Name Type Description Default
spike_rates DataArray

Spike rate data with dimensions (population, time) and 'fs' attribute

required
fooof_params dict

Parameters for FOOOF fitting. If None, uses default parameters

None
plt_range tuple

Frequency range to display (default: [2, 100])

None
plt_log bool

Use log scale for frequency axis, default: False

False
plt_db bool

Plot power in dB, default: True

True
figsize tuple

Figure size, default: (10, 3)

(10, 3)
ax Axes

Axes to plot on. If None, creates new figure

None

Returns:

Type Description
Figure

Figure object containing the coherence plots

Examples:

>>> fig = plot_spike_rate_coherence(spike_rates=spike_rate_data)
Source code in bmtool/bmplot/lfp.py
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
def plot_spike_rate_coherence(
    spike_rates: Any,
    fooof_params: Optional[Dict] = None,
    plt_range: Optional[Tuple[float, float]] = None,
    plt_log: bool = False,
    plt_db: bool = True,
    figsize: Tuple[int, int] = (10, 3),
    ax: Optional[plt.Axes] = None,
) -> Figure:
    """
    Plot coherence between spike rate populations.

    Computes coherence exactly like Analyze_PSD_ziao notebook: calculates coherence
    between population pairs and applies FOOOF fitting to coherence spectra.

    Parameters
    ----------
    spike_rates : xr.DataArray
        Spike rate data with dimensions (population, time) and 'fs' attribute
    fooof_params : dict, optional
        Parameters for FOOOF fitting. If None, uses default parameters
    plt_range : tuple, optional
        Frequency range to display (default: [2, 100])
    plt_log : bool
        Use log scale for frequency axis, default: False
    plt_db : bool
        Plot power in dB, default: True
    figsize : tuple
        Figure size, default: (10, 3)
    ax : matplotlib.axes.Axes, optional
        Axes to plot on. If None, creates new figure

    Returns
    -------
    matplotlib.figure.Figure
        Figure object containing the coherence plots

    Examples
    --------
    >>> fig = plot_spike_rate_coherence(spike_rates=spike_rate_data)
    """
    from scipy import signal

    from ..analysis.lfp import fit_fooof

    # Extract fs from spike_rates attributes
    if not hasattr(spike_rates, 'attrs') or 'fs' not in spike_rates.attrs:
        raise ValueError("spike_rates must have 'fs' attribute")
    fs = spike_rates.attrs['fs']

    # Set default parameters
    if fooof_params is None:
        fooof_params = dict(aperiodic_mode='knee', freq_range=(1, 100), 
                           peak_width_limits=100., max_n_peaks=1, dB_threshold=0.05)

    if plt_range is None:
        plt_range = [2., 100.]

    # Get population pairs like in Analyze_PSD_ziao
    pop_names = spike_rates.population.values
    n_pops = len(pop_names)
    grp_pairs = [[i, j] for i in range(n_pops) for j in range(i+1, n_pops)]
    npairs = len(grp_pairs)

    if npairs == 0:
        raise ValueError("Need at least 2 populations for coherence analysis")

    # Create figure with max 3 plots per row
    if ax is None:
        ncols = min(npairs, 3)
        nrows = (npairs + ncols - 1) // ncols
        fig, axes = plt.subplots(nrows, ncols, figsize=(figsize[0], figsize[1]))
    else:
        fig = ax.get_figure()
        axes = [ax]

    if npairs == 1:
        axes = [axes] if not isinstance(axes, (list, np.ndarray)) else axes

    # Calculate coherence for each pair
    for i, grp_pair in enumerate(grp_pairs):
        if isinstance(axes, np.ndarray):
            ax = axes.flat[i]
        else:
            ax = axes[i]

        pop1_name = pop_names[grp_pair[0]]
        pop2_name = pop_names[grp_pair[1]]

        # Get spike rate data for the pair
        signal1 = spike_rates.sel(type='smoothed', population=pop1_name).values
        signal2 = spike_rates.sel(type='smoothed', population=pop2_name).values

        # Check if both populations have non-zero std
        if np.std(signal1) == 0 or np.std(signal2) == 0:
            ax.text(0.5, 0.5, 'No variation in data', ha='center', va='center',
                   transform=ax.transAxes)
            ax.set_title(f'Coherence {pop1_name}-{pop2_name}')
            continue

        # Calculate coherence over entire time series
        f, cxy = signal.coherence(signal1, signal2, fs=fs)

        # Filter valid coherence values (positive and not NaN)
        idx = (cxy > 0) & np.isfinite(cxy)
        if not np.any(idx):
            ax.text(0.5, 0.5, 'No valid coherence', ha='center', va='center',
                   transform=ax.transAxes)
            ax.set_title(f'Coherence {pop1_name}-{pop2_name}')
            continue

        # Apply FOOOF to coherence (exactly like Analyze_PSD_ziao)
        f_filtered = f[idx]
        cxy_filtered = cxy[idx]

        plt.sca(ax)
        fooof_results, fm = fit_fooof(f_filtered, cxy_filtered, **fooof_params, 
                                   report=False, plot=True)

        # Formatting like Analyze_PSD_ziao
        ax.set_xlabel('Frequency (Hz)')
        ax.set_ylabel('Coherence')
        ax.set_title(f'Coherence {pop1_name}-{pop2_name}')
        ax.set_xlim(plt_range)
        ax.grid(True, alpha=0.3)

    plt.tight_layout()
    return fig