Skip to content

Spike Plotting API

bmtool.bmplot.spikes.raster(spikes_df=None, config=None, network_name=None, groupby='pop_name', sortby=None, ax=None, tstart=None, tstop=None, color_map=None, dot_size=0.3)

Plots a raster plot of neural spikes, with different colors for each population.

Parameters:

spikes_df : pd.DataFrame, optional DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'. config : str, optional Path to the configuration file used to load node data. network_name : str, optional Specific network name to select from the configuration; if not provided, uses the first network. groupby : str, optional Column name to group spikes by for coloring. Default is 'pop_name'. sortby : str, optional Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column. ax : matplotlib.axes.Axes, optional Axes on which to plot the raster; if None, a new figure and axes are created. tstart : float, optional Start time for filtering spikes; only spikes with timestamps greater than tstart will be plotted. tstop : float, optional Stop time for filtering spikes; only spikes with timestamps less than tstop will be plotted. color_map : dict, optional Dictionary specifying colors for each population. Keys should be population names, and values should be color values. dot_size: float, optional Size of the dot to display on the scatterplot

Returns:

matplotlib.axes.Axes Axes with the raster plot.

Notes:
  • If config is provided, the function merges population names from the node data with spikes_df.
  • Each unique population from groupby in spikes_df will be represented by a different color if color_map is not specified.
  • If color_map is provided, it should contain colors for all unique pop_name values in spikes_df.
Source code in bmtool/bmplot/spikes.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
def raster(
    spikes_df: Optional[pd.DataFrame] = None,
    config: Optional[str] = None,
    network_name: Optional[str] = None,
    groupby: str = "pop_name",
    sortby: Optional[str] = None,
    ax: Optional[Axes] = None,
    tstart: Optional[float] = None,
    tstop: Optional[float] = None,
    color_map: Optional[Dict[str, str]] = None,
    dot_size: float = 0.3,
) -> Axes:
    """
    Plots a raster plot of neural spikes, with different colors for each population.

    Parameters:
    ----------
    spikes_df : pd.DataFrame, optional
        DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
    config : str, optional
        Path to the configuration file used to load node data.
    network_name : str, optional
        Specific network name to select from the configuration; if not provided, uses the first network.
    groupby : str, optional
        Column name to group spikes by for coloring. Default is 'pop_name'.
    sortby : str, optional
        Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column.
    ax : matplotlib.axes.Axes, optional
        Axes on which to plot the raster; if None, a new figure and axes are created.
    tstart : float, optional
        Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
    tstop : float, optional
        Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
    color_map : dict, optional
        Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
    dot_size: float, optional
        Size of the dot to display on the scatterplot

    Returns:
    -------
    matplotlib.axes.Axes
        Axes with the raster plot.

    Notes:
    -----
    - If `config` is provided, the function merges population names from the node data with `spikes_df`.
    - Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
    - If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
    """
    # Initialize axes if none provided
    sns.set_style("whitegrid")
    if ax is None:
        _, ax = plt.subplots(1, 1)

    # Filter spikes by time range if specified
    if tstart is not None:
        spikes_df = spikes_df[spikes_df["timestamps"] > tstart]
    if tstop is not None:
        spikes_df = spikes_df[spikes_df["timestamps"] < tstop]

    # Load and merge node population data if config is provided
    if config:
        nodes = load_nodes_from_config(config)
        if network_name:
            nodes = nodes.get(network_name, {})
        else:
            nodes = list(nodes.values())[0] if nodes else {}
            print(
                "Grabbing first network; specify a network name to ensure correct node population is selected."
            )

        # Find common columns, but exclude the join key from the list
        common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
        common_columns = [
            col for col in common_columns if col != "node_ids"
        ]  # Remove our join key from the common list

        # Drop all intersecting columns except the join key column from df2
        spikes_df = spikes_df.drop(columns=common_columns)
        # merge nodes and spikes df
        spikes_df = spikes_df.merge(
            nodes[groupby], left_on="node_ids", right_index=True, how="left"
        )

    # Get unique population names
    unique_pop_names = spikes_df[groupby].unique()

    # Generate colors if no color_map is provided
    if color_map is None:
        cmap = plt.get_cmap("tab10")  # Default colormap
        color_map = {
            pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)
        }
    else:
        # Ensure color_map contains all population names
        missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
        if missing_colors:
            raise ValueError(f"color_map is missing colors for populations: {missing_colors}")

    # Plot each population with its specified or generated color
    legend_handles = []
    y_offset = 0  # Track y-position offset for stacking populations

    for pop_name, group in spikes_df.groupby(groupby):
        if sortby:
            # Sort by the specified column, putting NaN values at the end
            group_sorted = group.sort_values(by=sortby, na_position='last')
            # Create a mapping from node_ids to consecutive y-positions based on sorted order
            # Use the sorted order to maintain the same sequence for all spikes from same node
            unique_nodes_sorted = group_sorted['node_ids'].drop_duplicates()
            node_to_y = {node_id: y_offset + i for i, node_id in enumerate(unique_nodes_sorted)}
            # Map node_ids to new y-positions for ALL spikes (not just the sorted group)
            y_positions = group['node_ids'].map(node_to_y)
            # Verify no data was lost
            assert len(y_positions) == len(group), f"Data loss detected in population {pop_name}"
            assert y_positions.isna().sum() == 0, f"Unmapped node_ids found in population {pop_name}"
        else:
            y_positions = group['node_ids']

        ax.scatter(group["timestamps"], y_positions, color=color_map[pop_name], s=dot_size)
        # Dummy scatter for consistent legend appearance
        handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
        legend_handles.append(handle)

        # Update y_offset for next population if sortby is used
        if sortby:
            y_offset += len(unique_nodes_sorted)

    # Label axes
    ax.set_xlabel("Time")
    ax.set_ylabel("Node ID")
    ax.legend(handles=legend_handles, title="Population", loc="upper right", framealpha=0.9)

    return ax