Skip to content

Stimulus API Reference

Core Classes

StimulusBuilder

Class to manage and generate stimuli for BMTK networks.

This class provides a unified interface for defining node assemblies and generating time-varying Poisson spike trains (SONATA format) for those assemblies.

Attributes: config (dict): BMTK simulation configuration. nodes (dict): Dictionary of pandas DataFrames for each network. assemblies (dict): Named groups of node IDs. net_seed (int): Seed for assembly generation. psg_seed (int): Seed for Poisson spike generation.

Source code in bmtool/stimulus/core.py
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
class StimulusBuilder:
    """Class to manage and generate stimuli for BMTK networks.

    This class provides a unified interface for defining node assemblies and 
    generating time-varying Poisson spike trains (SONATA format) for those assemblies.

    Attributes:
        config (dict): BMTK simulation configuration.
        nodes (dict): Dictionary of pandas DataFrames for each network.
        assemblies (dict): Named groups of node IDs.
        net_seed (int): Seed for assembly generation.
        psg_seed (int): Seed for Poisson spike generation.
    """

    def __init__(self, config=None, net_seed=123, psg_seed=1):
        """Initialize the StimulusBuilder.

        Args:
            config (str/dict, optional): Path to BMTK config file or dictionary.
            net_seed (int): Random seed for assembly assignment (default: 123).
            psg_seed (int): Random seed for Poisson spike generation (default: 1).
        """
        self.config = config
        self.nodes = util.load_nodes_from_config(config)
        self.assemblies = {} 
        self.net_seed = net_seed
        self.psg_seed = psg_seed
        self.rng = np.random.default_rng(net_seed)

    def get_nodes(self, network_name, pop_name=None, node_query=None):
        """Helper to get node IDs from loaded nodes.

        Args:
            network_name (str): Name of the network (e.g., 'thalamus').
            pop_name (str, optional): Filter by 'pop_name' column.
            node_query (str, optional): Placeholder for custom filtering logic.

        Returns:
            pd.DataFrame: Filtered nodes.
        """
        if network_name not in self.nodes:
            raise ValueError(f"Network {network_name} not found in configuration.")

        df = self.nodes[network_name]

        if pop_name:
            df = util.get_pop(df, pop_name, key='pop_name')

        if node_query:
            # Generic query support if needed, for now just custom DF filtering by user
            pass

        return df

    def create_assemblies(self, name, network_name, method='random', seed=None, **kwargs):
        """Create node assemblies (subsets) and store them by name.

        Args:
            name (str): Unique name for this assembly group.
            network_name (str): Name of the network to draw nodes from.
            method (str): Method to group nodes. Options:
                - 'random': Assign nodes to n_assemblies randomly.
                - 'grid': Group nodes into a spatial grid based on x, y position.
                - 'property': Group nodes by a column name (e.g., 'pulse_group_id').
            seed (int, optional): Random seed for reproducibility. Overrides instance net_seed for this call.
            **kwargs: Arguments passed to the respective assembly generator:
                - pop_name (str): Filter nodes before assembly creation.
                - n_assemblies (int): Number of assemblies for 'random'.
                - prob_in_assembly (float): Probability of node inclusion (0-1).
                - property_name (str): Column name for 'property' grouping.
                - grid_id (ndarray): 2D array of assembly IDs for 'grid'.
                - grid_size (list): [[min_x, max_x], [min_y, max_y]] for 'grid'.
        """
        nodes_df = self.get_nodes(network_name, kwargs.get('pop_name'))
        node_ids = nodes_df.index.values

        # Use provided seed or default to instance net_seed
        rng = np.random.default_rng(seed if seed is not None else self.net_seed)

        if method == 'random':
            n_assemblies = kwargs.get('n_assemblies', 1)
            prob = kwargs.get('prob_in_assembly', 1.0)

            # Use utility to get assignments
            assy_indices = assemblies.assign_assembly(
                len(node_ids), n_assemblies, rng=rng, seed=None, prob_in_assembly=prob
            )

            # Map back to node IDs
            assembly_list = assemblies.get_assembly_ids(node_ids, assy_indices)
            self.assemblies[name] = assembly_list

        elif method == 'grid':
            grid_id = kwargs.get('grid_id')
            grid_size = kwargs.get('grid_size')

            nodes_assy, _ = assemblies.get_grid_assembly(nodes_df, grid_id, grid_size)
            self.assemblies[name] = nodes_assy

        elif method == 'property':
            prop_name = kwargs.get('property_name')
            prob = kwargs.get('probability', 1.0)

            assembly_list = assemblies.get_assemblies_by_property(
                nodes_df, prop_name, probability=prob, rng=rng, seed=None
            )
            self.assemblies[name] = assembly_list

        else:
            raise ValueError(f"Unknown assembly method: {method}")

    def _generate_firing_rates(self, n_nodes, mean, stdev, distribution='lognormal'):
        """Helper to generate firing rates based on distribution.

        Args:
            n_nodes (int): Number of rates to generate.
            mean (float): Mean firing rate.
            stdev (float): Standard deviation of firing rates.
            distribution (str): 'lognormal' or 'normal'.

        Returns:
            np.ndarray: Array of firing rates.
        """
        if distribution == 'lognormal':
            sigma2 = np.log((stdev / mean) ** 2 + 1)
            mu = np.log(mean) - sigma2 / 2
            sigma = sigma2 ** 0.5
            rates = self.rng.lognormal(mu, sigma, n_nodes)
        elif distribution == 'normal':
            rates = self.rng.normal(mean, stdev, n_nodes)
            rates = np.maximum(rates, 0.0)  # Clamp to non-negative
        else:
            raise ValueError(f"Unknown distribution: {distribution}. Must be 'lognormal' or 'normal'.")

        return rates

    def generate_background(self, output_path, network_name, population_params,
                           groupby='pop_name', t_start=0.0, t_stop=10.0, 
                           verbose=False, seed=None):
        """Generate background (spontaneous) activity for network nodes grouped by property.

        This function generates baseline spiking activity, grouped by a specified node property.
        Each group can use either a constant firing rate or a distribution-based rate.

        Args:
            output_path (str): Path to save the resulting .h5 file.
            network_name (str): BMTK network name.
            population_params (dict): Parameters for each population/group.
                Keys should match values in the node property specified by groupby.
                Each value is a dict with:
                    - 'mean_firing_rate' (float): Mean firing rate in Hz (required)
                    - 'stdev' (float, optional): Standard deviation. If provided, uses lognormal distribution.
                                                If omitted, uses constant firing rate.
                Example:
                    {
                        'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0},
                        'PV': {'mean_firing_rate': 30.0},  # constant rate
                        'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5}
                    }
            groupby (str): Node property to group by (default: 'pop_name').
                          Will match against keys in population_params.
            t_start, t_stop (float): Time range for activity (seconds).
            verbose (bool): If True, print detailed information (default: False).
            seed (int, optional): Random seed for distribution sampling. Overrides instance psg_seed.

        Examples:
            # Population-specific rates with mixed distributions
            params = {
                'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0},
                'PV': {'mean_firing_rate': 30.0},  # constant rate
                'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5}
            }
            sb.generate_background(
                output_path='background.h5',
                network_name='input',
                population_params=params,
                t_start=0.0, t_stop=15.0
            )

            # Group by custom property (e.g., layer)
            layer_params = {
                'L1': {'mean_firing_rate': 10.0, 'stdev': 1.0},
                'L2/3': {'mean_firing_rate': 15.0, 'stdev': 2.0}
            }
            sb.generate_background(
                output_path='layer_background.h5',
                network_name='input',
                population_params=layer_params,
                groupby='layer'
            )
        """
        if population_params is None or not isinstance(population_params, dict):
            raise ValueError("population_params must be a non-empty dict")

        nodes_df = self.get_nodes(network_name)

        # Verify groupby column exists
        if groupby not in nodes_df.columns:
            raise ValueError(f"Node property '{groupby}' not found in network '{network_name}'")

        # Use provided seed or default to instance psg_seed
        psg_seed = seed if seed is not None else self.psg_seed

        population = network_name  # Default population name in PSG
        psg = PoissonSpikeGenerator(population=population, seed=psg_seed)

        times = (t_start, t_stop)
        total_nodes = 0

        for group_key, params in population_params.items():
            # Find nodes matching this group
            nodes_in_group = nodes_df[nodes_df[groupby] == group_key].index.values

            if len(nodes_in_group) == 0:
                if verbose:
                    print(f"  Warning: No nodes found with {groupby}='{group_key}'")
                continue

            total_nodes += len(nodes_in_group)

            if not isinstance(params, dict) or 'mean_firing_rate' not in params:
                raise ValueError(f"params['{group_key}'] must be a dict with 'mean_firing_rate' key")

            mean_rate = params['mean_firing_rate']
            stdev = params.get('stdev', None)

            # Determine: constant vs distribution-based
            if stdev is not None:
                # Use distribution (lognormal)
                firing_rates = self._generate_firing_rates(len(nodes_in_group), mean_rate, stdev, 'lognormal')
                for node_id, rate in zip(nodes_in_group, firing_rates):
                    psg.add(node_ids=node_id, firing_rate=rate, times=times)
                if verbose:
                    print(f"  {group_key}: {len(nodes_in_group)} nodes, {mean_rate:.1f}±{stdev:.1f} Hz (lognormal)")
            else:
                # Use constant firing rate
                psg.add(node_ids=nodes_in_group.tolist(), firing_rate=mean_rate, times=times)
                if verbose:
                    print(f"  {group_key}: {len(nodes_in_group)} nodes, {mean_rate:.1f} Hz (constant)")

        # Write to file
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        psg.to_sonata(output_path)
        if verbose:
            print(f"Generated background activity: {total_nodes} nodes to {output_path}")

    def generate_stimulus(self, output_path, pattern_type, assembly_name, verbose=False, seed=None, **kwargs):
        """Generate a BMTK Poisson spike file (SONATA) for a specific assembly group.

        Use create_assemblies() first to define your stimulus assemblies, then call this
        function to generate time-varying firing patterns for those assemblies.

        Args:
            output_path (str): Path to save the resulting .h5 file.
            pattern_type (str): Firing rate template ('short', 'long', 'ramp', etc).
            assembly_name (str): Name of the assembly group created via create_assemblies.
            verbose (bool): If True, print detailed information (default: False).
            seed (int, optional): Random seed for Poisson spike generation. Overrides instance psg_seed.
            **kwargs: Arguments passed to the generator function and PoissonSpikeGenerator.
                - population (str): Name of the spike population (for BMTK).
                - firing_rate (3-tuple): (off_rate, burst_rate, silent_rate).
                - on_time (float): Duration of active period.
                - off_time (float): Duration of silent period.
                - t_start (float): Start time of cycles.
                - t_stop (float): End time of cycles.

        Example:
            # First create assemblies
            sb.create_assemblies(name='stim_groups', network_name='thalamus', 
                                method='property', property_name='pulse_group_id')

            # Then generate stimulus
            sb.generate_stimulus(output_path='stim.h5', pattern_type='long', 
                                assembly_name='stim_groups', population='thalamus',
                                firing_rate=(0.0, 50.0, 0.0), t_start=1.0, t_stop=15.0,
                                on_time=1.0, off_time=0.5)
        """
        if assembly_name not in self.assemblies:
            raise ValueError(f"Assembly '{assembly_name}' not defined. Use create_assemblies() first.")

        assembly_list = self.assemblies[assembly_name]
        n_assemblies = len(assembly_list)

        # Get population name for PSG (consumed here)
        population = kwargs.pop('population', 'stimulus')

        # Use provided seed or default to instance psg_seed
        psg_seed = seed if seed is not None else self.psg_seed

        # Dispatch to generator
        generator_func = getattr(generators, f"get_fr_{pattern_type}", None)
        if not generator_func:
             raise ValueError(f"Unknown pattern type: {pattern_type}")

        # Generate traces
        params = generator_func(n_assemblies=n_assemblies, verbose=verbose, **kwargs)

        # Create PSG
        psg = PoissonSpikeGenerator(population=population, seed=psg_seed)

        # Add spikes
        if verbose:
            print(f"Generating spiking for {n_assemblies} assemblies...")
        for ids, param_dict in zip(assembly_list, params):
            psg.add(node_ids=ids, firing_rate=param_dict['firing_rate'], times=param_dict['times'])

        # Write to file
        os.makedirs(os.path.dirname(output_path), exist_ok=True)
        psg.to_sonata(output_path)
        if verbose:
            print(f"Written stimulus to {output_path}")

__init__(config=None, net_seed=123, psg_seed=1)

Initialize the StimulusBuilder.

Args: config (str/dict, optional): Path to BMTK config file or dictionary. net_seed (int): Random seed for assembly assignment (default: 123). psg_seed (int): Random seed for Poisson spike generation (default: 1).

Source code in bmtool/stimulus/core.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
def __init__(self, config=None, net_seed=123, psg_seed=1):
    """Initialize the StimulusBuilder.

    Args:
        config (str/dict, optional): Path to BMTK config file or dictionary.
        net_seed (int): Random seed for assembly assignment (default: 123).
        psg_seed (int): Random seed for Poisson spike generation (default: 1).
    """
    self.config = config
    self.nodes = util.load_nodes_from_config(config)
    self.assemblies = {} 
    self.net_seed = net_seed
    self.psg_seed = psg_seed
    self.rng = np.random.default_rng(net_seed)

get_nodes(network_name, pop_name=None, node_query=None)

Helper to get node IDs from loaded nodes.

Args: network_name (str): Name of the network (e.g., 'thalamus'). pop_name (str, optional): Filter by 'pop_name' column. node_query (str, optional): Placeholder for custom filtering logic.

Returns: pd.DataFrame: Filtered nodes.

Source code in bmtool/stimulus/core.py
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
def get_nodes(self, network_name, pop_name=None, node_query=None):
    """Helper to get node IDs from loaded nodes.

    Args:
        network_name (str): Name of the network (e.g., 'thalamus').
        pop_name (str, optional): Filter by 'pop_name' column.
        node_query (str, optional): Placeholder for custom filtering logic.

    Returns:
        pd.DataFrame: Filtered nodes.
    """
    if network_name not in self.nodes:
        raise ValueError(f"Network {network_name} not found in configuration.")

    df = self.nodes[network_name]

    if pop_name:
        df = util.get_pop(df, pop_name, key='pop_name')

    if node_query:
        # Generic query support if needed, for now just custom DF filtering by user
        pass

    return df

create_assemblies(name, network_name, method='random', seed=None, **kwargs)

Create node assemblies (subsets) and store them by name.

Args: name (str): Unique name for this assembly group. network_name (str): Name of the network to draw nodes from. method (str): Method to group nodes. Options: - 'random': Assign nodes to n_assemblies randomly. - 'grid': Group nodes into a spatial grid based on x, y position. - 'property': Group nodes by a column name (e.g., 'pulse_group_id'). seed (int, optional): Random seed for reproducibility. Overrides instance net_seed for this call. **kwargs: Arguments passed to the respective assembly generator: - pop_name (str): Filter nodes before assembly creation. - n_assemblies (int): Number of assemblies for 'random'. - prob_in_assembly (float): Probability of node inclusion (0-1). - property_name (str): Column name for 'property' grouping. - grid_id (ndarray): 2D array of assembly IDs for 'grid'. - grid_size (list): [[min_x, max_x], [min_y, max_y]] for 'grid'.

Source code in bmtool/stimulus/core.py
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
def create_assemblies(self, name, network_name, method='random', seed=None, **kwargs):
    """Create node assemblies (subsets) and store them by name.

    Args:
        name (str): Unique name for this assembly group.
        network_name (str): Name of the network to draw nodes from.
        method (str): Method to group nodes. Options:
            - 'random': Assign nodes to n_assemblies randomly.
            - 'grid': Group nodes into a spatial grid based on x, y position.
            - 'property': Group nodes by a column name (e.g., 'pulse_group_id').
        seed (int, optional): Random seed for reproducibility. Overrides instance net_seed for this call.
        **kwargs: Arguments passed to the respective assembly generator:
            - pop_name (str): Filter nodes before assembly creation.
            - n_assemblies (int): Number of assemblies for 'random'.
            - prob_in_assembly (float): Probability of node inclusion (0-1).
            - property_name (str): Column name for 'property' grouping.
            - grid_id (ndarray): 2D array of assembly IDs for 'grid'.
            - grid_size (list): [[min_x, max_x], [min_y, max_y]] for 'grid'.
    """
    nodes_df = self.get_nodes(network_name, kwargs.get('pop_name'))
    node_ids = nodes_df.index.values

    # Use provided seed or default to instance net_seed
    rng = np.random.default_rng(seed if seed is not None else self.net_seed)

    if method == 'random':
        n_assemblies = kwargs.get('n_assemblies', 1)
        prob = kwargs.get('prob_in_assembly', 1.0)

        # Use utility to get assignments
        assy_indices = assemblies.assign_assembly(
            len(node_ids), n_assemblies, rng=rng, seed=None, prob_in_assembly=prob
        )

        # Map back to node IDs
        assembly_list = assemblies.get_assembly_ids(node_ids, assy_indices)
        self.assemblies[name] = assembly_list

    elif method == 'grid':
        grid_id = kwargs.get('grid_id')
        grid_size = kwargs.get('grid_size')

        nodes_assy, _ = assemblies.get_grid_assembly(nodes_df, grid_id, grid_size)
        self.assemblies[name] = nodes_assy

    elif method == 'property':
        prop_name = kwargs.get('property_name')
        prob = kwargs.get('probability', 1.0)

        assembly_list = assemblies.get_assemblies_by_property(
            nodes_df, prop_name, probability=prob, rng=rng, seed=None
        )
        self.assemblies[name] = assembly_list

    else:
        raise ValueError(f"Unknown assembly method: {method}")

generate_background(output_path, network_name, population_params, groupby='pop_name', t_start=0.0, t_stop=10.0, verbose=False, seed=None)

Generate background (spontaneous) activity for network nodes grouped by property.

This function generates baseline spiking activity, grouped by a specified node property. Each group can use either a constant firing rate or a distribution-based rate.

Args: output_path (str): Path to save the resulting .h5 file. network_name (str): BMTK network name. population_params (dict): Parameters for each population/group. Keys should match values in the node property specified by groupby. Each value is a dict with: - 'mean_firing_rate' (float): Mean firing rate in Hz (required) - 'stdev' (float, optional): Standard deviation. If provided, uses lognormal distribution. If omitted, uses constant firing rate. Example: { 'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0}, 'PV': {'mean_firing_rate': 30.0}, # constant rate 'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5} } groupby (str): Node property to group by (default: 'pop_name'). Will match against keys in population_params. t_start, t_stop (float): Time range for activity (seconds). verbose (bool): If True, print detailed information (default: False). seed (int, optional): Random seed for distribution sampling. Overrides instance psg_seed.

Examples: # Population-specific rates with mixed distributions params = { 'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0}, 'PV': {'mean_firing_rate': 30.0}, # constant rate 'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5} } sb.generate_background( output_path='background.h5', network_name='input', population_params=params, t_start=0.0, t_stop=15.0 )

# Group by custom property (e.g., layer)
layer_params = {
    'L1': {'mean_firing_rate': 10.0, 'stdev': 1.0},
    'L2/3': {'mean_firing_rate': 15.0, 'stdev': 2.0}
}
sb.generate_background(
    output_path='layer_background.h5',
    network_name='input',
    population_params=layer_params,
    groupby='layer'
)
Source code in bmtool/stimulus/core.py
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
def generate_background(self, output_path, network_name, population_params,
                       groupby='pop_name', t_start=0.0, t_stop=10.0, 
                       verbose=False, seed=None):
    """Generate background (spontaneous) activity for network nodes grouped by property.

    This function generates baseline spiking activity, grouped by a specified node property.
    Each group can use either a constant firing rate or a distribution-based rate.

    Args:
        output_path (str): Path to save the resulting .h5 file.
        network_name (str): BMTK network name.
        population_params (dict): Parameters for each population/group.
            Keys should match values in the node property specified by groupby.
            Each value is a dict with:
                - 'mean_firing_rate' (float): Mean firing rate in Hz (required)
                - 'stdev' (float, optional): Standard deviation. If provided, uses lognormal distribution.
                                            If omitted, uses constant firing rate.
            Example:
                {
                    'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0},
                    'PV': {'mean_firing_rate': 30.0},  # constant rate
                    'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5}
                }
        groupby (str): Node property to group by (default: 'pop_name').
                      Will match against keys in population_params.
        t_start, t_stop (float): Time range for activity (seconds).
        verbose (bool): If True, print detailed information (default: False).
        seed (int, optional): Random seed for distribution sampling. Overrides instance psg_seed.

    Examples:
        # Population-specific rates with mixed distributions
        params = {
            'PN': {'mean_firing_rate': 20.0, 'stdev': 2.0},
            'PV': {'mean_firing_rate': 30.0},  # constant rate
            'SST': {'mean_firing_rate': 15.0, 'stdev': 1.5}
        }
        sb.generate_background(
            output_path='background.h5',
            network_name='input',
            population_params=params,
            t_start=0.0, t_stop=15.0
        )

        # Group by custom property (e.g., layer)
        layer_params = {
            'L1': {'mean_firing_rate': 10.0, 'stdev': 1.0},
            'L2/3': {'mean_firing_rate': 15.0, 'stdev': 2.0}
        }
        sb.generate_background(
            output_path='layer_background.h5',
            network_name='input',
            population_params=layer_params,
            groupby='layer'
        )
    """
    if population_params is None or not isinstance(population_params, dict):
        raise ValueError("population_params must be a non-empty dict")

    nodes_df = self.get_nodes(network_name)

    # Verify groupby column exists
    if groupby not in nodes_df.columns:
        raise ValueError(f"Node property '{groupby}' not found in network '{network_name}'")

    # Use provided seed or default to instance psg_seed
    psg_seed = seed if seed is not None else self.psg_seed

    population = network_name  # Default population name in PSG
    psg = PoissonSpikeGenerator(population=population, seed=psg_seed)

    times = (t_start, t_stop)
    total_nodes = 0

    for group_key, params in population_params.items():
        # Find nodes matching this group
        nodes_in_group = nodes_df[nodes_df[groupby] == group_key].index.values

        if len(nodes_in_group) == 0:
            if verbose:
                print(f"  Warning: No nodes found with {groupby}='{group_key}'")
            continue

        total_nodes += len(nodes_in_group)

        if not isinstance(params, dict) or 'mean_firing_rate' not in params:
            raise ValueError(f"params['{group_key}'] must be a dict with 'mean_firing_rate' key")

        mean_rate = params['mean_firing_rate']
        stdev = params.get('stdev', None)

        # Determine: constant vs distribution-based
        if stdev is not None:
            # Use distribution (lognormal)
            firing_rates = self._generate_firing_rates(len(nodes_in_group), mean_rate, stdev, 'lognormal')
            for node_id, rate in zip(nodes_in_group, firing_rates):
                psg.add(node_ids=node_id, firing_rate=rate, times=times)
            if verbose:
                print(f"  {group_key}: {len(nodes_in_group)} nodes, {mean_rate:.1f}±{stdev:.1f} Hz (lognormal)")
        else:
            # Use constant firing rate
            psg.add(node_ids=nodes_in_group.tolist(), firing_rate=mean_rate, times=times)
            if verbose:
                print(f"  {group_key}: {len(nodes_in_group)} nodes, {mean_rate:.1f} Hz (constant)")

    # Write to file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    psg.to_sonata(output_path)
    if verbose:
        print(f"Generated background activity: {total_nodes} nodes to {output_path}")

generate_stimulus(output_path, pattern_type, assembly_name, verbose=False, seed=None, **kwargs)

Generate a BMTK Poisson spike file (SONATA) for a specific assembly group.

Use create_assemblies() first to define your stimulus assemblies, then call this function to generate time-varying firing patterns for those assemblies.

Args: output_path (str): Path to save the resulting .h5 file. pattern_type (str): Firing rate template ('short', 'long', 'ramp', etc). assembly_name (str): Name of the assembly group created via create_assemblies. verbose (bool): If True, print detailed information (default: False). seed (int, optional): Random seed for Poisson spike generation. Overrides instance psg_seed. **kwargs: Arguments passed to the generator function and PoissonSpikeGenerator. - population (str): Name of the spike population (for BMTK). - firing_rate (3-tuple): (off_rate, burst_rate, silent_rate). - on_time (float): Duration of active period. - off_time (float): Duration of silent period. - t_start (float): Start time of cycles. - t_stop (float): End time of cycles.

Example: # First create assemblies sb.create_assemblies(name='stim_groups', network_name='thalamus', method='property', property_name='pulse_group_id')

# Then generate stimulus
sb.generate_stimulus(output_path='stim.h5', pattern_type='long', 
                    assembly_name='stim_groups', population='thalamus',
                    firing_rate=(0.0, 50.0, 0.0), t_start=1.0, t_stop=15.0,
                    on_time=1.0, off_time=0.5)
Source code in bmtool/stimulus/core.py
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
def generate_stimulus(self, output_path, pattern_type, assembly_name, verbose=False, seed=None, **kwargs):
    """Generate a BMTK Poisson spike file (SONATA) for a specific assembly group.

    Use create_assemblies() first to define your stimulus assemblies, then call this
    function to generate time-varying firing patterns for those assemblies.

    Args:
        output_path (str): Path to save the resulting .h5 file.
        pattern_type (str): Firing rate template ('short', 'long', 'ramp', etc).
        assembly_name (str): Name of the assembly group created via create_assemblies.
        verbose (bool): If True, print detailed information (default: False).
        seed (int, optional): Random seed for Poisson spike generation. Overrides instance psg_seed.
        **kwargs: Arguments passed to the generator function and PoissonSpikeGenerator.
            - population (str): Name of the spike population (for BMTK).
            - firing_rate (3-tuple): (off_rate, burst_rate, silent_rate).
            - on_time (float): Duration of active period.
            - off_time (float): Duration of silent period.
            - t_start (float): Start time of cycles.
            - t_stop (float): End time of cycles.

    Example:
        # First create assemblies
        sb.create_assemblies(name='stim_groups', network_name='thalamus', 
                            method='property', property_name='pulse_group_id')

        # Then generate stimulus
        sb.generate_stimulus(output_path='stim.h5', pattern_type='long', 
                            assembly_name='stim_groups', population='thalamus',
                            firing_rate=(0.0, 50.0, 0.0), t_start=1.0, t_stop=15.0,
                            on_time=1.0, off_time=0.5)
    """
    if assembly_name not in self.assemblies:
        raise ValueError(f"Assembly '{assembly_name}' not defined. Use create_assemblies() first.")

    assembly_list = self.assemblies[assembly_name]
    n_assemblies = len(assembly_list)

    # Get population name for PSG (consumed here)
    population = kwargs.pop('population', 'stimulus')

    # Use provided seed or default to instance psg_seed
    psg_seed = seed if seed is not None else self.psg_seed

    # Dispatch to generator
    generator_func = getattr(generators, f"get_fr_{pattern_type}", None)
    if not generator_func:
         raise ValueError(f"Unknown pattern type: {pattern_type}")

    # Generate traces
    params = generator_func(n_assemblies=n_assemblies, verbose=verbose, **kwargs)

    # Create PSG
    psg = PoissonSpikeGenerator(population=population, seed=psg_seed)

    # Add spikes
    if verbose:
        print(f"Generating spiking for {n_assemblies} assemblies...")
    for ids, param_dict in zip(assembly_list, params):
        psg.add(node_ids=ids, firing_rate=param_dict['firing_rate'], times=param_dict['times'])

    # Write to file
    os.makedirs(os.path.dirname(output_path), exist_ok=True)
    psg.to_sonata(output_path)
    if verbose:
        print(f"Written stimulus to {output_path}")

Firing Rate Generators

Functions for generating time-varying firing rate patterns:

Short Bursts

Short burst is delivered to each assembly sequentially within each cycle.

Args: n_assemblies: Total number of assemblies. firing_rate: 3-tuple of firing rates (off_rate, burst_fr, silent_rate) firing_rate[0] = off_rate (background rate during non-burst on-time) firing_rate[1] = burst_fr (burst firing rate) firing_rate[2] = silent_rate (rate during off-time and before t_start) on_time: Duration of on period (s) off_time: Duration of off period (s) t_start: Start time of the stimulus cycles (s) t_stop: Stop time of the stimulus cycles (s) n_rounds: Number of short bursts each assembly receives per cycle. Can be fractional; some assemblies will receive one more burst per cycle. verbose: If True, print detailed information assembly_index: List of selected assembly indices. If provided, generates traces for all n_assemblies. Non-selected assemblies fire at off_rate during on_time.

Returns: list: Firing rate traces for each assembly

Source code in bmtool/stimulus/generators.py
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
def get_fr_short(n_assemblies, firing_rate=(0., 0., 0.),
                 on_time=1.0, off_time=0.5,
                 t_start=0.0, t_stop=10.0, n_cycles=None, n_rounds=1, verbose=False,
                 assembly_index=None):
    """Short burst is delivered to each assembly sequentially within each cycle.

    Args:
        n_assemblies: Total number of assemblies.
        firing_rate: 3-tuple of firing rates (off_rate, burst_fr, silent_rate)
            firing_rate[0] = off_rate (background rate during non-burst on-time)
            firing_rate[1] = burst_fr (burst firing rate)
            firing_rate[2] = silent_rate (rate during off-time and before t_start)
        on_time: Duration of on period (s)
        off_time: Duration of off period (s)
        t_start: Start time of the stimulus cycles (s)
        t_stop: Stop time of the stimulus cycles (s)
        n_rounds: Number of short bursts each assembly receives per cycle.
            Can be fractional; some assemblies will receive one more burst per cycle.
        verbose: If True, print detailed information
        assembly_index: List of selected assembly indices. If provided, generates traces for all
                       n_assemblies. Non-selected assemblies fire at off_rate during on_time.

    Returns:
        list: Firing rate traces for each assembly
    """
    # Set assembly_index to all if not provided
    if assembly_index is None:
        assembly_index = list(range(n_assemblies))

    # In generalized version, n_assemblies passed IS the total assemblies.
    total_assemblies = n_assemblies

    if verbose:
        print("\nStarting get_fr_short...")
        print(f"Selected assemblies: {assembly_index} out of {total_assemblies}")
        print(f"Firing rates (off, burst, silent): {firing_rate}")
        print(f"Time parameters - start: {t_start}, stop: {t_stop}, on_time: {on_time}, off_time: {off_time}")
        print(f"Rounds per cycle: {n_rounds}")

    # Ensure firing_rate is properly formatted as 3-tuple
    firing_rate = np.asarray(firing_rate).ravel()[:3]
    if len(firing_rate) < 3:
        # Pad with zeros if insufficient values provided
        firing_rate = np.concatenate((firing_rate, np.zeros(3 - firing_rate.size)))

    off_rate, burst_fr, silent_rate = firing_rate
    assembly_index = list(assembly_index)  # Ensure it's a list

    t_cycle = on_time + off_time
    if n_cycles is not None:
        n_cycle = n_cycles
    else:
        n_cycle = int((t_stop - t_start) // t_cycle)

    # Calculate burst timing within each cycle (based on selected assemblies only)
    n_selected = len(assembly_index)
    if n_selected == 0:
        n_bursts_per_cycle = 0
    else:
        n_bursts_per_cycle = int(np.ceil(n_rounds * n_selected))
    n_rounds_int = int(np.ceil(n_rounds))

    if verbose:
        print(f"\nCycle information:")
        print(f"Time per cycle: {t_cycle}")
        print(f"Number of cycles: {n_cycle}")
        print(f"Bursts per cycle: {n_bursts_per_cycle}")
        print(f"Rounds (integer): {n_rounds_int}")

    # Calculate time slots for bursts within on_time
    if n_bursts_per_cycle > 0:
        burst_duration = on_time / n_bursts_per_cycle
        burst_times = np.linspace(0, on_time - burst_duration, n_bursts_per_cycle)
    else:
        burst_times = []

    params = []

    for i in range(total_assemblies):
        if verbose:
            print(f"\nProcessing assembly {i}...")

        # Initialize with silent rate from 0 to t_start
        times = [0.0, t_start]
        rates = [silent_rate, silent_rate]

        is_selected = i in assembly_index

        for cycle in range(n_cycle):
            cycle_start = t_start + cycle * t_cycle
            on_period_start = cycle_start
            on_period_end = cycle_start + on_time
            cycle_end = cycle_start + t_cycle

            if verbose and cycle == 0:
                print(f"  Cycle {cycle}: start={cycle_start}, on_end={on_period_end}, cycle_end={cycle_end}")

            if is_selected:
                # Determine which bursts this assembly gets in this cycle
                # Map assembly i to its position in assembly_index
                try:
                    selected_position = assembly_index.index(i)
                except ValueError:
                    selected_position = -1 # Should not happen if is_selected checked

                assembly_bursts = []
                for round_num in range(n_rounds_int):
                    burst_index = round_num * n_selected + selected_position
                    if burst_index < n_bursts_per_cycle:
                        burst_start_time = on_period_start + burst_times[burst_index]
                        burst_end_time = burst_start_time + burst_duration
                        # Make sure burst doesn't exceed on_period
                        burst_end_time = min(burst_end_time, on_period_end)
                        assembly_bursts.append((burst_start_time, burst_end_time))

                if verbose and cycle == 0:
                    print(f"  Assembly {i} gets {len(assembly_bursts)} bursts in cycle {cycle}")

                # Add timepoints for this cycle
                current_time = on_period_start

                # Handle the on period with bursts
                for burst_start, burst_end in assembly_bursts:
                    # Before burst (background rate)
                    if current_time < burst_start:
                        times.extend([current_time, current_time, burst_start, burst_start])
                        rates.extend([silent_rate, off_rate, off_rate, silent_rate])

                    # During burst (burst rate)
                    times.extend([burst_start, burst_start, burst_end, burst_end])
                    rates.extend([silent_rate, burst_fr, burst_fr, silent_rate])

                    current_time = burst_end

                # After all bursts until end of on period (background rate)
                if current_time < on_period_end:
                    times.extend([current_time, current_time, on_period_end, on_period_end])
                    rates.extend([silent_rate, off_rate, off_rate, silent_rate])
            else:
                # Non-selected assembly: fires at off_rate during on_time
                times.extend([on_period_start, on_period_start, on_period_end, on_period_end])
                rates.extend([silent_rate, off_rate, off_rate, silent_rate])

            # Off period (silent rate)
            times.extend([on_period_end, on_period_end, cycle_end, cycle_end])
            rates.extend([silent_rate, silent_rate, silent_rate, silent_rate])

        # Add final timepoint
        times.append(t_stop)
        rates.append(silent_rate)

        params.append({
            'firing_rate': rates,
            'times': times
        })

    return params

Long Bursts

Long burst where one assembly is active per cycle.

Args: n_assemblies (int): Total number of assemblies. firing_rate (tuple): 3-tuple (off_rate, burst_rate, silent_rate). on_time (float): Duration of active cycle (s). off_time (float): Duration of silent period (s). t_start (float): Start time (s). t_stop (float): Stop time (s). n_cycles (int, optional): Number of cycles. Defaults to floor of duration. verbose (bool): Whether to print debug info. assembly_index (list, optional): List of assemblies to generate traces for.

Returns: list: Firing rate parameters for each assembly.

Source code in bmtool/stimulus/generators.py
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
def get_fr_long(n_assemblies, firing_rate=(0., 0., 0.),
                on_time=1.0, off_time=0.5,
                t_start=0.0, t_stop=10.0, n_cycles=None, verbose=False,
                assembly_index=None):
    """Long burst where one assembly is active per cycle.

    Args:
        n_assemblies (int): Total number of assemblies.
        firing_rate (tuple): 3-tuple (off_rate, burst_rate, silent_rate).
        on_time (float): Duration of active cycle (s).
        off_time (float): Duration of silent period (s).
        t_start (float): Start time (s).
        t_stop (float): Stop time (s).
        n_cycles (int, optional): Number of cycles. Defaults to floor of duration.
        verbose (bool): Whether to print debug info.
        assembly_index (list, optional): List of assemblies to generate traces for.

    Returns:
        list: Firing rate parameters for each assembly.
    """
    if assembly_index is None:
        assembly_index = list(range(n_assemblies))

    total_assemblies = n_assemblies

    if verbose:
        print("\nStarting get_fr_long...")
        print(f"Selected assemblies: {assembly_index} out of {total_assemblies}")
        print(f"Firing rates (off, burst, silent): {firing_rate}")
        print(f"Time parameters - start: {t_start}, stop: {t_stop}, on_time: {on_time}, off_time: {off_time}")

    # Ensure firing_rate is properly formatted as 3-tuple
    firing_rate = np.asarray(firing_rate).ravel()[:3]
    if len(firing_rate) < 3:
        # Pad with zeros if insufficient values provided
        firing_rate = np.concatenate((firing_rate, np.zeros(3 - firing_rate.size)))

    off_rate, burst_fr, silent_rate = firing_rate
    assembly_index = list(assembly_index)  # Ensure it's a list

    t_cycle = on_time + off_time
    if n_cycles is not None:
        n_cycle = n_cycles
    else:
        n_cycle = int((t_stop - t_start) // t_cycle)
    n_selected = len(assembly_index)

    if verbose:
        print(f"\nCycle information:")
        print(f"Time per cycle: {t_cycle}")
        print(f"Number of cycles: {n_cycle}")

    params = []

    for i in range(total_assemblies):
        # Initialize with silent rate from 0 to t_start
        times = [0.0, t_start]
        rates = [silent_rate, silent_rate]

        is_selected = i in assembly_index

        for cycle in range(n_cycle):
            cycle_start = t_start + cycle * t_cycle
            burst_start = cycle_start
            burst_end = cycle_start + on_time
            cycle_end = cycle_start + t_cycle

            # Map to position in selected assemblies only
            if n_selected > 0:
                active_position = cycle % n_selected
                active_assembly = assembly_index[active_position]
            else:
                active_assembly = -1

            if i == active_assembly:
                # This assembly is active - burst_fr during on_time, silent during off_time
                times.extend([
                    burst_start,
                    burst_start,
                    burst_end,
                    burst_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start of burst
                    burst_fr,  # During burst
                    burst_fr,  # During burst
                    silent_rate,  # End of burst
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])
            elif is_selected:
                # This assembly is selected but inactive - off_rate during on_time
                times.extend([
                    burst_start,
                    burst_start,
                    burst_end,
                    burst_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start of burst
                    off_rate,  # During active assembly's burst
                    off_rate,  # During active assembly's burst
                    silent_rate,  # End of burst
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])
            else:
                # Non-selected assembly: always silent or background (use off_rate during on_time)
                times.extend([
                    burst_start,
                    burst_start,
                    burst_end,
                    burst_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start of burst
                    off_rate,  # Background during on_time
                    off_rate,  # Background during on_time
                    silent_rate,  # End of burst
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])

        # Add final timepoint
        times.append(t_stop)
        rates.append(silent_rate)

        params.append({
            'firing_rate': rates,
            'times': times
        })

    return params

Ramp Patterns

Ramping input where one assembly is active per cycle, with linear rate changes.

Args: n_assemblies (int): Total number of assemblies. firing_rate (tuple): 4-tuple (off_rate, ramp_start_fr, ramp_end_fr, silent_rate). on_time (float): Duration of active cycle (s). off_time (float): Duration of silent period (s). ramp_on_time (float, optional): Offset within on_time to start ramp. ramp_off_time (float, optional): Offset within on_time to end ramp. t_start (float): Start time (s). t_stop (float): Stop time (s). n_cycles (int, optional): Number of cycles. verbose (bool): Whether to print debug info. assembly_index (list, optional): List of assemblies to generate traces for.

Returns: list: Firing rate parameters for each assembly.

Source code in bmtool/stimulus/generators.py
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
def get_fr_ramp(n_assemblies, firing_rate=(0., 0., 0., 0.),
                on_time=1.0, off_time=0.5,
                ramp_on_time=None, ramp_off_time=None,
                t_start=0.0, t_stop=10.0, n_cycles=None, verbose=False,
                assembly_index=None):
    """Ramping input where one assembly is active per cycle, with linear rate changes.

    Args:
        n_assemblies (int): Total number of assemblies.
        firing_rate (tuple): 4-tuple (off_rate, ramp_start_fr, ramp_end_fr, silent_rate).
        on_time (float): Duration of active cycle (s).
        off_time (float): Duration of silent period (s).
        ramp_on_time (float, optional): Offset within on_time to start ramp.
        ramp_off_time (float, optional): Offset within on_time to end ramp.
        t_start (float): Start time (s).
        t_stop (float): Stop time (s).
        n_cycles (int, optional): Number of cycles.
        verbose (bool): Whether to print debug info.
        assembly_index (list, optional): List of assemblies to generate traces for.

    Returns:
        list: Firing rate parameters for each assembly.
    """
    if assembly_index is None:
        assembly_index = list(range(n_assemblies))

    total_assemblies = n_assemblies

    if verbose:
        print("\nStarting get_fr_ramp...")
        print(f"Selected assemblies: {assembly_index} out of {total_assemblies}")
        print(f"Firing rates (off, ramp_start, ramp_end, silent): {firing_rate}")
        print(f"Time parameters - start: {t_start}, stop: {t_stop}, on_time: {on_time}, off_time: {off_time}")

    # Ensure firing_rate is properly formatted
    firing_rate = np.asarray(firing_rate).ravel()[:4]
    firing_rate = np.concatenate((np.zeros(4 - firing_rate.size), firing_rate))

    off_rate = firing_rate[0]
    ramp_start_fr = firing_rate[1]
    ramp_end_fr = firing_rate[2]
    silent_rate = firing_rate[3]
    assembly_index = list(assembly_index)  # Ensure it's a list

    # Set ramp timing within on_time
    ramp_off_time = on_time if ramp_off_time is None else min(ramp_off_time, on_time)
    ramp_on_time = 0. if ramp_on_time is None else min(ramp_on_time, ramp_off_time)

    t_cycle = on_time + off_time
    if n_cycles is not None:
        n_cycle = n_cycles
    else:
        n_cycle = int((t_stop - t_start) // t_cycle)
    n_selected = len(assembly_index)

    if verbose:
        print(f"\nCycle information:")
        print(f"Time per cycle: {t_cycle}")
        print(f"Number of cycles: {n_cycle}")
        print(f"Ramp timing: {ramp_on_time} to {ramp_off_time} within on_time of {on_time}")

    params = []

    for i in range(total_assemblies):
        # Initialize with zero rate from 0 to t_start
        times = [0.0, t_start]
        rates = [silent_rate, silent_rate]

        is_selected = i in assembly_index

        for cycle in range(n_cycle):
            cycle_start = t_start + cycle * t_cycle
            on_period_start = cycle_start
            on_period_end = cycle_start + on_time
            cycle_end = cycle_start + t_cycle

            # Map to position in selected assemblies only
            if n_selected > 0:
                active_position = cycle % n_selected
                active_assembly = assembly_index[active_position]
            else:
                active_assembly = -1

            if i == active_assembly:
                # This assembly is active - ramping pattern during on_time, silent during off_time

                # Before ramp starts (constant at ramp_start_fr)
                if ramp_on_time > 0:
                    times.extend([on_period_start, on_period_start, 
                                on_period_start + ramp_on_time, on_period_start + ramp_on_time])
                    rates.extend([silent_rate, ramp_start_fr, ramp_start_fr, silent_rate])

                # During ramp (linear increase from ramp_start_fr to ramp_end_fr)
                times.extend([on_period_start + ramp_on_time, on_period_start + ramp_on_time,
                            on_period_start + ramp_off_time, on_period_start + ramp_off_time])
                rates.extend([silent_rate, ramp_start_fr, ramp_end_fr, silent_rate])

                # After ramp ends (constant at ramp_end_fr)
                if ramp_off_time < on_time:
                    times.extend([on_period_start + ramp_off_time, on_period_start + ramp_off_time,
                                on_period_end, on_period_end])
                    rates.extend([silent_rate, ramp_end_fr, ramp_end_fr, silent_rate])

                # Off period (silent rate)
                times.extend([on_period_end, on_period_end, cycle_end, cycle_end])
                rates.extend([silent_rate, silent_rate, silent_rate, silent_rate])

            elif is_selected:
                # This assembly is selected but inactive - off_rate during on_time
                times.extend([
                    on_period_start,
                    on_period_start,
                    on_period_end,
                    on_period_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start of active assembly's burst
                    off_rate,  # During active assembly's on_time (background rate)
                    off_rate,  # During active assembly's on_time (background rate)
                    silent_rate,  # End of on_time
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])
            else:
                # Non-selected assembly: fires at off_rate during on_time
                times.extend([
                    on_period_start,
                    on_period_start,
                    on_period_end,
                    on_period_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start
                    off_rate,  # Background during on_time
                    off_rate,  # Background during on_time
                    silent_rate,  # End of on_time
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])

        # Add final timepoint
        times.append(t_stop)
        rates.append(silent_rate)

        params.append({
            'firing_rate': rates,
            'times': times
        })

    return params

Join (Gradual Recruitment)

Input is delivered to an increasing portion of one assembly in each cycle.

This function generates multiple parameter sets per assembly (controlled by n_steps), simulating a gradual recruitment ('join') or withdrawal ('quit') of neurons.

Args: n_assemblies (int): Total number of assemblies. firing_rate (tuple): 3-tuple (off_rate, on_rate, silent_rate). on_time (float): Duration of active cycle (s). off_time (float): Duration of silent period (s). quit (bool): If True, neurons start on and quit. If False, join one by one. ramp_on_time (float, optional): Offset within on_time to start recruitment. ramp_off_time (float, optional): Offset within on_time to end recruitment. t_start (float): Start time (s). t_stop (float): Stop time (s). n_steps (int): Number of steps (neuron subgroups) within each assembly. assembly_index (list, optional): List of selected assembly indices.

Returns: list of dict: Firing rate parameters, including 'assembly' and 'step' metadata.

Source code in bmtool/stimulus/generators.py
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
def get_fr_join(n_assemblies, firing_rate=(0., 0., 0.),
                on_time=1.0, off_time=0.5,
                quit=False, ramp_on_time=None, ramp_off_time=None,
                t_start=0.0, t_stop=10.0, n_cycles=None, n_steps=20, verbose=False,
                assembly_index=None):
    """Input is delivered to an increasing portion of one assembly in each cycle.

    This function generates multiple parameter sets per assembly (controlled by n_steps),
    simulating a gradual recruitment ('join') or withdrawal ('quit') of neurons.

    Args:
        n_assemblies (int): Total number of assemblies.
        firing_rate (tuple): 3-tuple (off_rate, on_rate, silent_rate).
        on_time (float): Duration of active cycle (s).
        off_time (float): Duration of silent period (s).
        quit (bool): If True, neurons start on and quit. If False, join one by one.
        ramp_on_time (float, optional): Offset within on_time to start recruitment.
        ramp_off_time (float, optional): Offset within on_time to end recruitment.
        t_start (float): Start time (s).
        t_stop (float): Stop time (s).
        n_steps (int): Number of steps (neuron subgroups) within each assembly.
        assembly_index (list, optional): List of selected assembly indices.

    Returns:
        list of dict: Firing rate parameters, including 'assembly' and 'step' metadata.
    """
    if assembly_index is None:
        assembly_index = list(range(n_assemblies))

    total_assemblies = n_assemblies

    if verbose:
        print("\nStarting get_fr_join...")
        print(f"Selected assemblies: {assembly_index} out of {total_assemblies}")
        print(f"Firing rates (off, on, silent): {firing_rate}")
        print(f"n_steps: {n_steps}, quit mode: {quit}")

    # Ensure firing_rate is properly formatted
    firing_rate = np.asarray(firing_rate).ravel()[:3]
    firing_rate = np.concatenate((np.zeros(3 - firing_rate.size), firing_rate))

    off_rate = firing_rate[0]
    on_rate = firing_rate[1]
    silent_rate = firing_rate[2]
    assembly_index = list(assembly_index)  # Ensure it's a list

    # Set recruitment timing within on_time
    ramp_off_time = on_time if ramp_off_time is None else min(ramp_off_time, on_time)
    ramp_on_time = 0. if ramp_on_time is None else min(ramp_on_time, ramp_off_time)

    t_cycle = on_time + off_time
    if n_cycles is not None:
        n_cycle = n_cycles
    else:
        n_cycle = int((t_stop - t_start) // t_cycle)
    n_selected = len(assembly_index)

    # Calculate step timing offsets (when each step gets recruited)
    t_offset = np.linspace(ramp_on_time, ramp_off_time, n_steps, endpoint=False)
    if quit:
        t_offset = t_offset[::-1]

    if verbose:
        print(f"Cycle information:")
        print(f"Time per cycle: {t_cycle}")
        print(f"Number of cycles: {n_cycle}")
        print(f"Recruitment timing: {ramp_on_time} to {ramp_off_time}")
        print(f"Step times: {t_offset}")

    # Generate one parameter set per assembly
    # Each selected assembly will have n_steps sub-groups that use this same pattern
    # but applied to different neuron groups (handled externally)
    all_params = []

    for assy_idx in range(total_assemblies):
        is_selected = assy_idx in assembly_index

        if is_selected:
            # Selected assembly: generate n_steps parameter sets (one for each neuron sub-group)
            for step_idx, step_time in enumerate(t_offset):
                # Initialize with silent rate from 0 to t_start
                times = [0.0, t_start]
                rates = [silent_rate, silent_rate]

                for cycle in range(n_cycle):
                    cycle_start = t_start + cycle * t_cycle
                    on_period_start = cycle_start
                    on_period_end = cycle_start + on_time
                    cycle_end = cycle_start + t_cycle

                    # Map to position in selected assemblies only
                    if n_selected > 0:
                        active_position = cycle % n_selected
                        active_assembly = assembly_index[active_position]
                    else:
                        active_assembly = -1

                    if assy_idx == active_assembly:
                        # This assembly is active in this cycle
                        # This specific step gets recruited at step_time

                        if quit:
                            # Quit mode: start with on_rate, switch to silent at step_time
                            recruit_time = on_period_start + step_time

                            # Before quit time (on_rate)
                            if step_time > 0:
                                times.extend([on_period_start, on_period_start,
                                            recruit_time, recruit_time])
                                rates.extend([silent_rate, on_rate, on_rate, silent_rate])

                            # After quit time (silent)
                            times.extend([recruit_time, recruit_time,
                                        on_period_end, on_period_end])
                            rates.extend([silent_rate, silent_rate, silent_rate, silent_rate])
                        else:
                            # Join mode: start silent, switch to on_rate at step_time
                            recruit_time = on_period_start + step_time

                            # Before recruit time (silent)
                            if step_time > 0:
                                times.extend([on_period_start, on_period_start,
                                            recruit_time, recruit_time])
                                rates.extend([silent_rate, silent_rate, silent_rate, silent_rate])

                            # After recruit time (on_rate)
                            times.extend([recruit_time, recruit_time,
                                        on_period_end, on_period_end])
                            rates.extend([silent_rate, on_rate, on_rate, silent_rate])

                        # Off period (silent)
                        times.extend([on_period_end, on_period_end, cycle_end, cycle_end])
                        rates.extend([silent_rate, silent_rate, silent_rate, silent_rate])

                    else:
                        # This selected assembly is inactive - fire at off_rate during on_time
                        times.extend([
                            on_period_start,
                            on_period_start,
                            on_period_end,
                            on_period_end,
                            cycle_end,
                            cycle_end
                        ])
                        rates.extend([
                            silent_rate,  # At start of on_time
                            off_rate,  # During active assembly's on_time (background rate)
                            off_rate,  # During active assembly's on_time (background rate)
                            silent_rate,  # End of on_time
                            silent_rate,  # During off time
                            silent_rate   # Until next cycle
                        ])

                # Add final timepoint
                times.append(t_stop)
                rates.append(silent_rate)

                all_params.append({
                    'firing_rate': rates,
                    'times': times,
                    'assembly': assy_idx,
                    'step': step_idx
                })
        else:
            # Non-selected assembly: single parameter set (fires at off_rate during all on_times)
            times = [0.0, t_start]
            rates = [silent_rate, silent_rate]

            for cycle in range(n_cycle):
                cycle_start = t_start + cycle * t_cycle
                on_period_start = cycle_start
                on_period_end = cycle_start + on_time
                cycle_end = cycle_start + t_cycle

                # Always fire at off_rate during on_time, silent during off_time
                times.extend([
                    on_period_start,
                    on_period_start,
                    on_period_end,
                    on_period_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,  # At start of on_time
                    off_rate,  # Background during on_time
                    off_rate,  # Background during on_time
                    silent_rate,  # End of on_time
                    silent_rate,  # During off time
                    silent_rate   # Until next cycle
                ])

            # Add final timepoint
            times.append(t_stop)
            rates.append(silent_rate)

            all_params.append({
                'firing_rate': rates,
                'times': times,
                'assembly': assy_idx,
                'step': None
            })

    return all_params

Fade Transitions

Input fades in and out between a pair of assemblies in each cycle.

Args: n_assemblies: Total number of assemblies off_rate: firing rate of assemblies not involved in current fade cycle firing_rate: 4-tuple of firing rates (fade_out_start, fade_out_end, fade_in_start, fade_in_end) on_time, off_time: on / off time durations ramp_on_time, ramp_off_time: start and end time of ramp in on time duration t_start, t_stop: start and stop time of the stimulus cycles verbose: if True, print detailed information assembly_index: List of selected assembly indices.

Return: list of firing rate parameter dictionaries

Source code in bmtool/stimulus/generators.py
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
def get_fr_fade(n_assemblies, off_rate=10., firing_rate=(0., 0., 0., 0.),
                on_time=1.0, off_time=0.5,
                ramp_on_time=None, ramp_off_time=None,
                t_start=0.0, t_stop=10.0, n_cycles=None, verbose=False,
                assembly_index=None):
    """Input fades in and out between a pair of assemblies in each cycle.

    Args:
        n_assemblies: Total number of assemblies
        off_rate: firing rate of assemblies not involved in current fade cycle
        firing_rate: 4-tuple of firing rates (fade_out_start, fade_out_end, fade_in_start, fade_in_end)
        on_time, off_time: on / off time durations
        ramp_on_time, ramp_off_time: start and end time of ramp in on time duration
        t_start, t_stop: start and stop time of the stimulus cycles
        verbose: if True, print detailed information
        assembly_index: List of selected assembly indices.

    Return: list of firing rate parameter dictionaries
    """
    if assembly_index is None:
        assembly_index = list(range(n_assemblies))

    total_assemblies = n_assemblies

    if verbose:
        print("\nStarting get_fr_fade...")
        print(f"Selected assemblies: {assembly_index} out of {total_assemblies}")
        print(f"Firing rates (fade_out_start, fade_out_end, fade_in_start, fade_in_end): {firing_rate}")
        print(f"Off rate for non-active assemblies: {off_rate}")
        print(f"Time parameters - start: {t_start}, stop: {t_stop}, on_time: {on_time}, off_time: {off_time}")

    # Ensure firing_rate is properly formatted
    firing_rate = np.asarray(firing_rate).ravel()[:4]
    if firing_rate.size < 4:
        firing_rate = np.concatenate((np.zeros(4 - firing_rate.size), firing_rate))
    fade_out_start, fade_out_end, fade_in_start, fade_in_end = firing_rate
    assembly_index = list(assembly_index)  # Ensure it's a list

    # Set ramp timing within on_time
    ramp_off_time = on_time if ramp_off_time is None else min(ramp_off_time, on_time)
    ramp_on_time = 0. if ramp_on_time is None else min(ramp_on_time, ramp_off_time)

    # Calculate cycle parameters
    t_cycle = on_time + off_time
    if n_cycles is not None:
        n_cycle = n_cycles
    else:
        n_cycle = int((t_stop - t_start) // t_cycle)
    silent_rate = 0.0
    n_selected = len(assembly_index)

    if verbose:
        print(f"\nCycle information:")
        print(f"Time per cycle: {t_cycle}")
        print(f"Number of cycles: {n_cycle}")
        print(f"Ramp timing: {ramp_on_time} to {ramp_off_time} within on_time of {on_time}")

    params = []

    for i in range(total_assemblies):
        # Initialize with zero rate from 0 to t_start
        times = [0.0, t_start]
        rates = [silent_rate, silent_rate]

        is_selected = i in assembly_index

        for cycle in range(n_cycle):
            cycle_start = t_start + cycle * t_cycle
            on_period_start = cycle_start
            on_period_end = cycle_start + on_time
            cycle_end = cycle_start + t_cycle

            # Determine which pair is active
            n_pairs = n_selected // 2 if n_selected >= 2 else 0

            if n_pairs > 0:
                pair_idx = cycle % n_pairs
                if pair_idx * 2 + 1 < n_selected:
                    fade_out_assembly = assembly_index[pair_idx * 2]
                    fade_in_assembly = assembly_index[pair_idx * 2 + 1]
                else:
                    fade_out_assembly = -1
                    fade_in_assembly = -1
            else:
                fade_out_assembly = -1
                fade_in_assembly = -1

            if i == fade_out_assembly:
                # Fading out
                if ramp_on_time > 0:
                    times.extend([on_period_start, on_period_start,
                                on_period_start + ramp_on_time, on_period_start + ramp_on_time])
                    rates.extend([silent_rate, fade_out_start, fade_out_start, silent_rate])

                times.extend([on_period_start + ramp_on_time, on_period_start + ramp_on_time,
                            on_period_start + ramp_off_time, on_period_start + ramp_off_time])
                rates.extend([silent_rate, fade_out_start, fade_out_end, silent_rate])

                if ramp_off_time < on_time:
                    times.extend([on_period_start + ramp_off_time, on_period_start + ramp_off_time,
                                on_period_end, on_period_end])
                    rates.extend([silent_rate, fade_out_end, fade_out_end, silent_rate])

                times.extend([on_period_end, cycle_end, cycle_end])
                rates.extend([silent_rate, silent_rate, silent_rate])

            elif i == fade_in_assembly:
                # Fading in
                if ramp_on_time > 0:
                    times.extend([on_period_start, on_period_start,
                                on_period_start + ramp_on_time, on_period_start + ramp_on_time])
                    rates.extend([silent_rate, fade_in_start, fade_in_start, silent_rate])

                times.extend([on_period_start + ramp_on_time, on_period_start + ramp_on_time,
                            on_period_start + ramp_off_time, on_period_start + ramp_off_time])
                rates.extend([silent_rate, fade_in_start, fade_in_end, silent_rate])

                if ramp_off_time < on_time:
                    times.extend([on_period_start + ramp_off_time, on_period_start + ramp_off_time,
                                on_period_end, on_period_end])
                    rates.extend([silent_rate, fade_in_end, fade_in_end, silent_rate])

                times.extend([on_period_end, cycle_end, cycle_end])
                rates.extend([silent_rate, silent_rate, silent_rate])

            else:
                # Inactive assembly: fires at off_rate during active pair's on_time
                times.extend([
                    on_period_start,
                    on_period_start,
                    on_period_end,
                    on_period_end,
                    cycle_end,
                    cycle_end
                ])
                rates.extend([
                    silent_rate,
                    off_rate,
                    off_rate,
                    silent_rate,
                    silent_rate,
                    silent_rate
                ])

        # Add final timepoint
        times.append(t_stop)
        rates.append(silent_rate)

        params.append({
            'firing_rate': rates,
            'times': times
        })

    return params

Loop Patterns

Poisson input is first on for on_time starting at t_start, then off for off_time.

Source code in bmtool/stimulus/generators.py
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
def get_fr_loop(n_assemblies, firing_rate=(0., 0., 0.),
                on_times=(1.0, ), off_time=0.5,
                t_start=0.0, t_stop=10.0, verbose=False):
    """Poisson input is first on for on_time starting at t_start, then off for
    off_time.
    """
    firing_rate = np.asarray(firing_rate).ravel()
    on_times = np.fmax(np.sort(np.asarray(on_times).ravel()), 0)
    if on_times[0]:
        on_times = np.insert(on_times, 0, 0.)
    if firing_rate.size - on_times.size != 1:
        raise ValueError("Length of `firing_rate` should be len(on_times) + 1.")
    t_cycle, n_cycle = get_stim_cycle(on_times[-1], off_time, t_start, t_stop)

    times = [[0] for _ in range(n_assemblies)]
    for j in range(n_cycle):
        ts = t_start + t_cycle * j + on_times
        times[j % n_assemblies].extend(np.insert(ts, [0, -1], ts[[0, -1]]))

    params = []
    fr = []
    fr0 = firing_rate[0]
    for ts in times:
        ts.append(t_stop)
        n = (len(ts) - 2) // (on_times.size + 2)
        if len(fr) != len(ts):
            fr = np.append(np.tile(np.insert(firing_rate, 0, fr0), n), [fr0, fr0])
        params.append(dict(firing_rate=fr, times=ts))
    return params

Assembly Functions

Functions for creating and managing node assemblies:

Assign N units to n_assemblies randomly.

Args: N (int): Total number of units to assign. n_assemblies (int): Number of assemblies to create. rng (Generator, optional): Random number generator. If None and seed is provided, creates one from seed. seed (int, optional): Random seed for reproducibility. Creates RNG if rng is None. prob_in_assembly (float): Probability of a unit being included in its assigned assembly (0-1).

Returns: list of np.ndarray: Indices of units assigned to each assembly.

Source code in bmtool/stimulus/assemblies.py
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
def assign_assembly(N, n_assemblies, rng=None, seed=None, prob_in_assembly=1.0):
    """Assign N units to n_assemblies randomly.

    Args:
        N (int): Total number of units to assign.
        n_assemblies (int): Number of assemblies to create.
        rng (Generator, optional): Random number generator. If None and seed is provided, creates one from seed.
        seed (int, optional): Random seed for reproducibility. Creates RNG if rng is None.
        prob_in_assembly (float): Probability of a unit being included in its assigned assembly (0-1).

    Returns:
        list of np.ndarray: Indices of units assigned to each assembly.
    """
    if rng is None:
        if seed is not None:
            rng = np.random.default_rng(seed)
        else:
            rng = np.random.default_rng()
    n_per_assemb = num_prop(np.ones(n_assemblies), N)
    split_idx = np.cumsum(n_per_assemb)[:-1]  # indices at which to split
    assy_idx = rng.permutation(N)  # random shuffle for assemblies
    assy_idx = np.split(assy_idx, split_idx)  # split into assemblies

    # Reduce each assembly to the specified proportion
    if prob_in_assembly < 1.0:
        assy_idx = [rng.choice(idx, size=int(len(idx) * prob_in_assembly), replace=False) for idx in assy_idx]

    assy_idx = [np.sort(idx) for idx in assy_idx]
    return assy_idx

Get assemblies based on a property column in the nodes dataframe.

Args: nodes_df: DataFrame of nodes property_name: Column name to group by (e.g. 'pulse_group_id') probability: Probability of selecting a node within its group rng (Generator, optional): Random number generator. If None and seed is provided, creates one from seed. seed (int, optional): Random seed for reproducibility. Creates RNG if rng is None.

Returns: list of node ID arrays, one for each unique property value (sorted)

Source code in bmtool/stimulus/assemblies.py
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
def get_assemblies_by_property(nodes_df, property_name, probability=1.0, rng=None, seed=None):
    """Get assemblies based on a property column in the nodes dataframe.

    Args:
        nodes_df: DataFrame of nodes
        property_name: Column name to group by (e.g. 'pulse_group_id')
        probability: Probability of selecting a node within its group
        rng (Generator, optional): Random number generator. If None and seed is provided, creates one from seed.
        seed (int, optional): Random seed for reproducibility. Creates RNG if rng is None.

    Returns:
        list of node ID arrays, one for each unique property value (sorted)
    """
    if rng is None:
        if seed is not None:
            rng = np.random.default_rng(seed)
        else:
            rng = np.random.default_rng()

    if property_name not in nodes_df.columns:
        raise ValueError(f"Property {property_name} not found in nodes dataframe")

    groups = nodes_df[property_name].unique()
    # Sort groups to ensure consistent ordering (e.g. 0, 1, 2...)
    try:
        groups = np.sort(groups)
    except:
        pass # If mixed types or unsortable, leave as is

    assemblies = []

    # We filter out NaN usually? build_input.py does `if pd.isna(group): continue`
    for group in groups:
        if pd.isna(group): 
            continue

        idx = nodes_df[nodes_df[property_name] == group].index.to_list()

        if probability < 1.0:
            size = int(len(idx) * probability)
            selected_idx = rng.choice(idx, size=size, replace=False)
            assemblies.append(np.sort(selected_idx))
        else:
            assemblies.append(np.sort(idx))

    return assemblies

Divide nodes into assemblies based on lateral location (x, y).

nodes_df: DataFrame with pos_x, pos_y columns. grid_id: assembly ids arranged in 2d-array corresponding to grid locations. grid_size: the bounds of the grid area in (x, y) coordinates (um). [[min_x, max_x], [min_y, max_y]] linked_nodes_list: Optional list of other node lists that map 1:1 to nodes_df

Returns: tuple: (list of assemblies for nodes_df, [list of assemblies for linked nodes], grid_id)

Source code in bmtool/stimulus/assemblies.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
def get_grid_assembly(nodes_df, grid_id, grid_size, linked_nodes_list=None):
    """Divide nodes into assemblies based on lateral location (x, y).

    nodes_df: DataFrame with pos_x, pos_y columns.
    grid_id: assembly ids arranged in 2d-array corresponding to grid locations.
    grid_size: the bounds of the grid area in (x, y) coordinates (um). [[min_x, max_x], [min_y, max_y]]
    linked_nodes_list: Optional list of other node lists that map 1:1 to nodes_df

    Returns:
        tuple: (list of assemblies for nodes_df, [list of assemblies for linked nodes], grid_id)
    """
    grid_id = np.asarray(grid_id)
    grid_size = np.asarray(grid_size)

    # Store original helper column to avoid modifying input df permanently
    df = nodes_df.copy()

    bins = []
    for i in range(2):
        bins.append(np.linspace(*grid_size[i], grid_id.shape[i] + 1)[1:])
        bins[i][-1] += 1.  # Ensure last bin captures edge

    # Assign assembly ID based on position
    df['assy_id'] = grid_id[np.digitize(df['pos_x'], bins[0]),
                            np.digitize(df['pos_y'], bins[1])]

    # Create boolean masks for each assembly
    sorted_grid_ids = np.sort(grid_id, axis=None)
    assy_idx = [df['assy_id'].values == i for i in sorted_grid_ids]

    # Get IDs for the main nodes
    nodes_assy = get_assembly_ids(df.index, assy_idx=assy_idx)

    # Get IDs for linked nodes
    linked_assy = []
    if linked_nodes_list:
        for nodes in linked_nodes_list:
            linked_assy.append(get_assembly_ids(nodes, assy_idx=assy_idx))

    if linked_nodes_list:
        return nodes_assy, linked_assy, grid_id
    else:
        return nodes_assy, grid_id