Skip to content

Spike Plotting API

bmtool.bmplot.spikes.raster(spikes_df=None, config=None, network_name=None, groupby='pop_name', ax=None, tstart=None, tstop=None, color_map=None, dot_size=0.3)

Plots a raster plot of neural spikes, with different colors for each population.

Parameters:

spikes_df : pd.DataFrame, optional DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'. config : str, optional Path to the configuration file used to load node data. network_name : str, optional Specific network name to select from the configuration; if not provided, uses the first network. ax : matplotlib.axes.Axes, optional Axes on which to plot the raster; if None, a new figure and axes are created. tstart : float, optional Start time for filtering spikes; only spikes with timestamps greater than tstart will be plotted. tstop : float, optional Stop time for filtering spikes; only spikes with timestamps less than tstop will be plotted. color_map : dict, optional Dictionary specifying colors for each population. Keys should be population names, and values should be color values. dot_size: float, optional Size of the dot to display on the scatterplot

Returns:

matplotlib.axes.Axes Axes with the raster plot.

Notes:
  • If config is provided, the function merges population names from the node data with spikes_df.
  • Each unique population from groupby in spikes_df will be represented by a different color if color_map is not specified.
  • If color_map is provided, it should contain colors for all unique pop_name values in spikes_df.
Source code in bmtool/bmplot/spikes.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
def raster(
    spikes_df: Optional[pd.DataFrame] = None,
    config: Optional[str] = None,
    network_name: Optional[str] = None,
    groupby: Optional[str] = "pop_name",
    ax: Optional[Axes] = None,
    tstart: Optional[float] = None,
    tstop: Optional[float] = None,
    color_map: Optional[Dict[str, str]] = None,
    dot_size: Optional[float] = 0.3,
) -> Axes:
    """
    Plots a raster plot of neural spikes, with different colors for each population.

    Parameters:
    ----------
    spikes_df : pd.DataFrame, optional
        DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
    config : str, optional
        Path to the configuration file used to load node data.
    network_name : str, optional
        Specific network name to select from the configuration; if not provided, uses the first network.
    ax : matplotlib.axes.Axes, optional
        Axes on which to plot the raster; if None, a new figure and axes are created.
    tstart : float, optional
        Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
    tstop : float, optional
        Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
    color_map : dict, optional
        Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
    dot_size: float, optional
        Size of the dot to display on the scatterplot

    Returns:
    -------
    matplotlib.axes.Axes
        Axes with the raster plot.

    Notes:
    -----
    - If `config` is provided, the function merges population names from the node data with `spikes_df`.
    - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
    - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
    """
    # Initialize axes if none provided
    if ax is None:
        _, ax = plt.subplots(1, 1)

    # Filter spikes by time range if specified
    if tstart is not None:
        spikes_df = spikes_df[spikes_df["timestamps"] > tstart]
    if tstop is not None:
        spikes_df = spikes_df[spikes_df["timestamps"] < tstop]

    # Load and merge node population data if config is provided
    if config:
        nodes = load_nodes_from_config(config)
        if network_name:
            nodes = nodes.get(network_name, {})
        else:
            nodes = list(nodes.values())[0] if nodes else {}
            print(
                "Grabbing first network; specify a network name to ensure correct node population is selected."
            )

        # Find common columns, but exclude the join key from the list
        common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
        common_columns = [
            col for col in common_columns if col != "node_ids"
        ]  # Remove our join key from the common list

        # Drop all intersecting columns except the join key column from df2
        spikes_df = spikes_df.drop(columns=common_columns)
        # merge nodes and spikes df
        spikes_df = spikes_df.merge(
            nodes[groupby], left_on="node_ids", right_index=True, how="left"
        )

    # Get unique population names
    unique_pop_names = spikes_df[groupby].unique()

    # Generate colors if no color_map is provided
    if color_map is None:
        cmap = plt.get_cmap("tab10")  # Default colormap
        color_map = {
            pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)
        }
    else:
        # Ensure color_map contains all population names
        missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
        if missing_colors:
            raise ValueError(f"color_map is missing colors for populations: {missing_colors}")

    # Plot each population with its specified or generated color
    legend_handles = []
    for pop_name, group in spikes_df.groupby(groupby):
        ax.scatter(group["timestamps"], group["node_ids"], color=color_map[pop_name], s=dot_size)
        # Dummy scatter for consistent legend appearance
        handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
        legend_handles.append(handle)

    # Label axes
    ax.set_xlabel("Time")
    ax.set_ylabel("Node ID")
    ax.legend(handles=legend_handles, title="Population", loc="upper right", framealpha=0.9)

    return ax