14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147 | def raster(
spikes_df: Optional[pd.DataFrame] = None,
config: Optional[str] = None,
network_name: Optional[str] = None,
groupby: str = "pop_name",
sortby: Optional[str] = None,
ax: Optional[Axes] = None,
tstart: Optional[float] = None,
tstop: Optional[float] = None,
color_map: Optional[Dict[str, str]] = None,
dot_size: float = 0.3,
) -> Axes:
"""
Plots a raster plot of neural spikes, with different colors for each population.
Parameters:
----------
spikes_df : pd.DataFrame, optional
DataFrame containing spike data with columns 'timestamps', 'node_ids', and optional 'pop_name'.
config : str, optional
Path to the configuration file used to load node data.
network_name : str, optional
Specific network name to select from the configuration; if not provided, uses the first network.
groupby : str, optional
Column name to group spikes by for coloring. Default is 'pop_name'.
sortby : str, optional
Column name to sort node_ids within each group. If provided, nodes within each population will be sorted by this column.
ax : matplotlib.axes.Axes, optional
Axes on which to plot the raster; if None, a new figure and axes are created.
tstart : float, optional
Start time for filtering spikes; only spikes with timestamps greater than `tstart` will be plotted.
tstop : float, optional
Stop time for filtering spikes; only spikes with timestamps less than `tstop` will be plotted.
color_map : dict, optional
Dictionary specifying colors for each population. Keys should be population names, and values should be color values.
dot_size: float, optional
Size of the dot to display on the scatterplot
Returns:
-------
matplotlib.axes.Axes
Axes with the raster plot.
Notes:
-----
- If `config` is provided, the function merges population names from the node data with `spikes_df`.
- Each unique population from groupby in `spikes_df` will be represented by a different color if `color_map` is not specified.
- If `color_map` is provided, it should contain colors for all unique `pop_name` values in `spikes_df`.
"""
# Initialize axes if none provided
sns.set_style("whitegrid")
if ax is None:
_, ax = plt.subplots(1, 1)
# Filter spikes by time range if specified
if tstart is not None:
spikes_df = spikes_df[spikes_df["timestamps"] > tstart]
if tstop is not None:
spikes_df = spikes_df[spikes_df["timestamps"] < tstop]
# Load and merge node population data if config is provided
if config:
nodes = load_nodes_from_config(config)
if network_name:
nodes = nodes.get(network_name, {})
else:
nodes = list(nodes.values())[0] if nodes else {}
print(
"Grabbing first network; specify a network name to ensure correct node population is selected."
)
# Find common columns, but exclude the join key from the list
common_columns = spikes_df.columns.intersection(nodes.columns).tolist()
common_columns = [
col for col in common_columns if col != "node_ids"
] # Remove our join key from the common list
# Drop all intersecting columns except the join key column from df2
spikes_df = spikes_df.drop(columns=common_columns)
# merge nodes and spikes df
spikes_df = spikes_df.merge(
nodes[groupby], left_on="node_ids", right_index=True, how="left"
)
# Get unique population names
unique_pop_names = spikes_df[groupby].unique()
# Generate colors if no color_map is provided
if color_map is None:
cmap = plt.get_cmap("tab10") # Default colormap
color_map = {
pop_name: cmap(i / len(unique_pop_names)) for i, pop_name in enumerate(unique_pop_names)
}
else:
# Ensure color_map contains all population names
missing_colors = [pop for pop in unique_pop_names if pop not in color_map]
if missing_colors:
raise ValueError(f"color_map is missing colors for populations: {missing_colors}")
# Plot each population with its specified or generated color
legend_handles = []
y_offset = 0 # Track y-position offset for stacking populations
for pop_name, group in spikes_df.groupby(groupby):
if sortby:
# Sort by the specified column, putting NaN values at the end
group_sorted = group.sort_values(by=sortby, na_position='last')
# Create a mapping from node_ids to consecutive y-positions based on sorted order
# Use the sorted order to maintain the same sequence for all spikes from same node
unique_nodes_sorted = group_sorted['node_ids'].drop_duplicates()
node_to_y = {node_id: y_offset + i for i, node_id in enumerate(unique_nodes_sorted)}
# Map node_ids to new y-positions for ALL spikes (not just the sorted group)
y_positions = group['node_ids'].map(node_to_y)
# Verify no data was lost
assert len(y_positions) == len(group), f"Data loss detected in population {pop_name}"
assert y_positions.isna().sum() == 0, f"Unmapped node_ids found in population {pop_name}"
else:
y_positions = group['node_ids']
ax.scatter(group["timestamps"], y_positions, color=color_map[pop_name], s=dot_size)
# Dummy scatter for consistent legend appearance
handle = ax.scatter([], [], color=color_map[pop_name], label=pop_name, s=20)
legend_handles.append(handle)
# Update y_offset for next population if sortby is used
if sortby:
y_offset += len(unique_nodes_sorted)
# Label axes
ax.set_xlabel("Time")
ax.set_ylabel("Node ID")
ax.legend(handles=legend_handles, title="Population", loc="upper right", framealpha=0.9)
return ax
|