Skip to content

Added custom inputs and basic cluster analysis #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed automind/__init__.py
Empty file.
494 changes: 494 additions & 0 deletions automind/sim/b2_inputs.py

Large diffs are not rendered by default.

134 changes: 117 additions & 17 deletions automind/sim/b2_models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
### Brian2 models for network construction.
import brian2 as b2
import numpy as np
from automind.sim import b2_inputs


def adaptive_exp_net(all_param_dict):
Expand Down Expand Up @@ -64,14 +65,14 @@ def adaptive_exp_net(all_param_dict):
)

### TO DO: also randomly initialize w to either randint(?)*b or randn*(v-v_rest)*a

poisson_input_E = b2.PoissonInput(
''' poisson_input_E = b2.PoissonInput(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here? This func didn't need to be changed right (or is used at all)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here? This func didn't need to be changed right (or is used at all)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here? This func didn't need to be changed right (or is used at all)? The quotes are typos / accidents?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here? This func didn't need to be changed right (or is used at all)? The quotes are typos / accidents?

Copy link
Author

@nairb1234 nairb1234 May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sorry I probably wanted to check if the custom inputs were working and silenced this to confirm, the Poisson input should be kept.

target=E_pop,
target_var="ge",
N=param_dict_neuron_E["N_poisson"],
rate=param_dict_neuron_E["poisson_rate"],
weight=param_dict_neuron_E["Q_poisson"],
)
)'''


if has_inh:
# make adlif if delta_t is 0, otherwise adex
Expand Down Expand Up @@ -268,12 +269,21 @@ def make_clustered_network(
return membership, shared_membership, conn_in, conn_out


def adaptive_exp_net_clustered(all_param_dict):
"""Adaptive exponential integrate-and-fire network with clustered connections."""
#Modified function incorporating inputs to specific clusters
def adaptive_exp_net_clustered_cog(all_param_dict, mode='default', custom_input=None, stim_cluster=None, custom_cluster_input=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

potentially think about absorbing the other args into either params_network, params_settings, or a new subdict params_input in all_param_dict since it's more consistent with existing interface, and will make inference on these easier. But I'm not sure if there is a lot more design considerations.

this is more of a comment for myself than you @nairb1234, unless you think this is easily doable in the short-term. Otherwise, I will open an issue for myself on this.

'''
Adaptive exponential integrate-and-fire network with clustered connections.

3 modes
- Default mode - no input
- Single mode - Single input -> User can define any input sequence. DM_simple is used when no inputs are provided
- Cluster mode - Each cluster gets different input, can be defined by user. DM_simple with different mean is used for each cluster when no inputs are provided.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a bit more details on the other args, like what it represents and var type (e.g., str? array? etc.)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added

- Can also select number of clusters to stimulate
'''

# separate parameter dictionaries
param_dict_net = all_param_dict["params_net"]
param_dict_settings = all_param_dict["params_settings"]

# set random seeds
b2.seed(param_dict_settings["random_seed"])
np.random.seed(param_dict_settings["random_seed"])
Expand All @@ -287,25 +297,27 @@ def adaptive_exp_net_clustered(all_param_dict):

#### NETWORK CONSTRUCTION ############
######################################

### get cell counts
N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"]
N_exc = int(N_pop * exc_prop)
N_inh = N_pop - N_exc
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great


### define neuron equation
adex_coba_eq = """dv/dt = (-g_L * (v - v_rest) + g_L * delta_T * exp((v - v_thresh)/delta_T) - w + I)/C : volt (unless refractory)"""
adlif_coba_eq = (
"""dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)"""
)

adlif_coba_eq = """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)"""

network_eqs = """
dw/dt = (-w + a * (v - v_rest))/tau_w : amp
dge/dt = -ge / tau_ge : siemens
dgi/dt = -gi / tau_gi : siemens
Ie = ge * (E_ge - v): amp
Ii = gi * (E_gi - v): amp
I = I_bias + Ie + Ii : amp
I_ext: amp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

great!

I = I_bias + Ie + Ii + I_ext: amp
"""

### get cell counts
N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"]
N_exc, N_inh = int(N_pop * exc_prop), int(N_pop * (1 - exc_prop))

### make neuron populations, set initial values and connect poisson inputs ###
# make adlif if delta_t is 0, otherwise adex
neuron_eq = (
Expand Down Expand Up @@ -422,6 +434,7 @@ def adaptive_exp_net_clustered(all_param_dict):
p_out,
param_dict_net["order_clusters"],
)
param_dict_net["membership"] = membership
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is necessary for figuring out the cluster membership post hoc? I kinda want to avoid doing it like this, because introducing a new param into the dict is a bit adhoc unless we define it with a default beforehand (which we can do, in the default_configs), and also that this saves a relatively big array into the dict. I don't know if that's necessarily a problem since the dict doesn't get saved or anything, just feels clunky. Let me know what your thoughts are here @nairb1234, also fine if there's not a better way of doing it.


# scale synaptic weight
Q_ge_out = param_dict_neuron_E["Q_ge"]
Expand Down Expand Up @@ -491,6 +504,93 @@ def adaptive_exp_net_clustered(all_param_dict):
)
syn_i2i.connect("i!=j", p=param_dict_net["p_i2i"])

### Handle different input modes ###
if mode == 'default': #No input
stim_time_values = b2_inputs.DM_simple(all_param_dict,0,0) #Just change this to an input (pass in an stim array)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a possibility to skip this altogether or simply define I_ext as a scalar 0 if in this default case? My concern is whether this increases run-time or memory use.

Copy link
Author

@nairb1234 nairb1234 May 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Skipped that now and just put a single pass under the default condition and worked fine.

dt = param_dict_settings["dt"]
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt)

# Define network operation to update I_ext
@b2.network_operation(dt=dt)
def update_test_input(t):
E_pop.I_ext = stim_timed_array(t)

elif mode == 'single':
# Check if custom input is provided in the parameter dictionary
custom_input = all_param_dict.get("custom_input", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I think having a structure like all_param_dict['params_input']['custom_input'] might be better, so we put all the input related stuff in params_input

if custom_input is not None:
stim_time_values = custom_input
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah so a bit more explanation here on custom_input would be good. I guess it's an array at the same time resolution? But does it have to be the same length as the sim_time? Also, relevant to my comment above, I guess we have to think more carefully about putting it into params_dict.

else:
# Use default DM_simple if no custom input is provided
stim_time_values = b2_inputs.DM_simple(all_param_dict)

dt = param_dict_settings["dt"]
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt)

# Define network operation to update I_ext
@b2.network_operation(dt=dt)
def update_test_input(t):
E_pop.I_ext = stim_timed_array(t)

elif mode == 'cluster':
# Determine if network has clusters
has_clusters = (
"n_clusters" in param_dict_net.keys()
and param_dict_net["n_clusters"] >= 2
and param_dict_net["R_pe2e"] != 1
)

if has_clusters:
n_clusters_original = int(param_dict_net["n_clusters"])
stimulated_clusters_count = n_clusters_original

#Check if user defined number of clusters to stimulate
if stim_cluster is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe name this n_clusters_to_stim or something

stimulated_clusters_count = stim_cluster
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same for stimulated_clusters_count, a bit confusing as a variable name

if stim_cluster > n_clusters_original:
stimulated_clusters_count = n_clusters_original
print(f"No. of clusters picked ({stim_cluster}) exceeds actual no. of clusters. Stimulating all {n_clusters_original} clusters instead.")
#Select number of clusters
selected_clusters = np.random.choice(
n_clusters_original,
stimulated_clusters_count,
replace=False
)
cluster_lists = [[c] for c in selected_clusters]

# Generate cluster-specific inputs for selected clusters - see b2_inputs
if custom_cluster_input is not None:
stim_list = custom_cluster_input
_, weight_list = b2_inputs.cluster_specific_stim(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just to make sure I understand correctly: stim_list has the list of stimulation values over time, weight_list has the stim amplitude factor for each cluster, and cluster_list has which clusters to stimulate?

Could you explain a bit more here? There's a lot going on, and the overwriting of variable values after ifs also makes it a bit hard to follow

all_param_dict,
n_clusters=stimulated_clusters_count,
)
else:
stim_list, weight_list = b2_inputs.cluster_specific_stim(
all_param_dict,
n_clusters=stimulated_clusters_count,
)

# Create input configurations
input_configs = b2_inputs.get_input_configs(
cluster_lists,
stim_list,
weight_list,
)
input_op = b2_inputs.create_input_operation(E_pop, input_configs, membership)
param_dict_net['input'] = stim_list
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here, I would put input in a different subdict for inputs, not params_net

else:
# Fallback to test mode if network doesn't have clusters initially
print("Network does not have clusters. All neurons will receive the same DM_simple input ")
stim_time_values = b2_inputs.test_stim(all_param_dict)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens here?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh i see, test_stim is the sequence. Okay see my comment in that b2_inputs re naming

dt = param_dict_settings["dt"]
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt)

@b2.network_operation(dt=dt)
def update_cluster_fallback_input(t):
E_pop.I_ext = stim_timed_array(t)
param_dict_net['input'] = stim_timed_array

### define monitors ###
rate_monitors, spike_monitors, trace_monitors = [], [], []
rec_defs = param_dict_settings["record_defs"]
Expand All @@ -510,12 +610,12 @@ def adaptive_exp_net_clustered(all_param_dict):
# and later drop randomly before saving, otherwise
# recording only from first n neurons, which heavily overlap
# with those stimulated, and the first few clusters
rec_idx = np.arange(N_exc)
rec_idx = np.arange(N_exc)
else:
rec_idx = (
np.arange(rec_defs[pop_name]["spikes"])
np.arange(rec_defs[pop_name]["spikes"])
if type(rec_defs[pop_name]["spikes"]) is int
else rec_defs[pop_name]["spikes"]
else rec_defs[pop_name]["spikes"] #Change param_settings.record_Defs to 2000
)
spike_monitors.append(
b2.SpikeMonitor(pop[rec_idx], name=pop_name + "_spikes")
Expand Down
32 changes: 31 additions & 1 deletion automind/utils/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

def _filter_spikes_random(spike_trains, n_to_save):
"""Filter a subset of spike trains randomly for saving."""
np.random.seed(42)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I forgot if we talked about this but should this be reproduceably random, i.e., using the model random_seed?

record_subset = np.sort(
np.random.choice(len(spike_trains), n_to_save, replace=False)
)
Expand All @@ -33,13 +34,14 @@ def collect_spikes(net_collect, params_dict):
sm.name.split("_")[0]
]["spikes"]
n_to_save = pop_save_def if type(pop_save_def) == int else len(pop_save_def)
#n_to_save = len(spike_trains)
if n_to_save == len(spike_trains):
# recorded and to-be saved is the same length, go on a per usual
spike_dict[sm.name] = b2_interface._deunitize_spiketimes(spike_trains)
else:
# recorded more than necessary, subselect for saving
spike_dict[sm.name] = b2_interface._deunitize_spiketimes(
_filter_spikes_random(spike_trains, n_to_save)
_filter_spikes_random(spike_trains, n_to_save) # THIS IS WHERE THE NEURONS ARE RANDOMLY DROPPED BEFORE SAVING (AND PLOTTING??)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's not being randomly dropped anymore?

)
return spike_dict

Expand Down Expand Up @@ -674,3 +676,31 @@ def load_df_posteriors(path_dict):
path_dict["root_path"] + path_dict["params_dict_analysis_updated"]
)
return df_posterior_sims, posterior, params_dict_default

def sort_neurons(membership, sorting_method="cluster_identity"):
"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice this should be quite useful

Sort neurons based on the specified method.

Parameters:
membership (list/array of 2D arrays): Membership arrays for each simulation.
sorting_method (str): "cluster_identity" or "n_clusters".

Returns:
sorted_indices (list of arrays): Sorted indices for each simulation.
"""
sorted_indices = []
#Sort by whether neurons are in one cluster or two clusters

if sorting_method == "cluster_identity":
# Sort by the first cluster identity
sorted_idx = np.argsort(membership[:, 0])
sorted_indices.append(sorted_idx)
elif sorting_method == "n_clusters":
#Neurons in one cluster have the same values in both columns
single = np.where(membership[:,0] == membership[:,1])
double = np.where(membership[:,0] != membership[:,1])
sorted_indices.append(single)
sorted_indices.append(double)
else:
raise ValueError("Invalid sorting_method. Use 'cluster_identity' or 'n_clusters'.")
return sorted_indices
124 changes: 123 additions & 1 deletion automind/utils/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,6 @@ def _plot_raster_pretty(
ax.set_ylabel("Raster", fontsize=fontsize)
return ax


def _plot_rates_pretty(
rates,
XL,
Expand Down Expand Up @@ -584,3 +583,126 @@ def plot_corr_pv(pvals, ax, alpha_level=0.05, fmt="w*", ms=0.5):
for j in range(pvals.shape[0]):
if pvals[i, j] < alpha_level:
ax.plot(j, i, fmt, ms=ms, alpha=1)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so this is not being used anymore because you extracted the sorting function right? Then I think it can be removed


#Just use the default plotting function with the sorted spikes
'''
def plot_raster(
spikes,
membership,
XL,
plotting_method="cluster_identity",
every_other=1,
ax=None,
fontsize=14,
plot_inh=False,
E_colors=None,
I_color="gray",
single_cluster_style="|",
double_cluster_style="x",
mew=0.5,
ms=1,
**plot_kwargs,
):
"""
Plot raster plot with neurons sorted by cluster identity or number of clusters.

Parameters:
spikes (dict): Dictionary containing 'exc_spikes' and 'inh_spikes'.
membership: Array/list of 2D membership arrays (from params_net['membership']).
XL (list): X-axis limits.
plotting_method (str): "cluster_identity" or "n_clusters".
every_other (int): Plot every nth spike.
ax (matplotlib axis): Axis to plot on.
fontsize (int): Font size for labels.
plot_inh (bool): Whether to plot inhibitory spikes.
E_colors (list): Colors for excitatory clusters.
I_color (str): Color for inhibitory spikes.
single_cluster_style (str): Marker style for single-cluster neurons.
double_cluster_style (str): Marker style for two-cluster neurons.
mew (float): Marker edge width.
ms (float): Marker size.
"""
if ax is None:
ax = plt.axes()

exc_spikes = spikes["exc_spikes"]
inh_spikes = spikes.get("inh_spikes", {})

if plotting_method == "cluster_identity":
# Sort by cluster identity
sorted_indices = data_utils.sort_neurons(membership, sorting_method='cluster_identity')
sorted_indices_list = sorted_indices[0].tolist() # Convert to list of Python integers
sorted_exc_spikes = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices_list)}
#exc_spikes_to_plot = sorted_exc_spikes.values()
elif plotting_method == "n_clusters":
# Sort by number of clusters
sorted_indices = data_utils.sort_neurons(membership, sorting_method='n_clusters')
sorted_exc_spikes_single = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[0][0])}
sorted_exc_spikes_double = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[1][0])}
#exc_spikes_to_plot.append(sorted_exc_spikes_single)
#exc_spikes_to_plot.append(sorted_exc_spikes_double)
else:
raise ValueError("Invalid plotting_method. Use 'cluster_identity' or 'n_clusters'.")

# Plot excitatory spikes, single cluster in blue and double cluster in red respectively
[
(
ax.plot(
v[::every_other],
i_v * np.ones_like(v[::every_other]),
single_cluster_style,
color='blue',
alpha=1,
ms=ms,
mew=mew,
)
if len(v) > 0
else None
)
for i_v, (t,v) in enumerate(sorted_exc_spikes_single.items())
]
[
(
ax.plot(
v[::every_other],
(i_v+ len(sorted_indices[0][0])) * np.ones_like(v[::every_other]),
single_cluster_style,
color='red',
alpha=1,
ms=ms,
mew=mew,
)
if len(v) > 0
else None
)
for i_v, (t,v) in enumerate(sorted_exc_spikes_double.items())
]

# Plot inhibitory spikes
if plot_inh:
[
(
ax.plot(
v[::every_other],
(i_v + len(sorted_indices[0][0]) + len(sorted_indices[1][0])) * np.ones_like(v[::every_other]),
"|",
color=I_color,
alpha=1,
ms=ms,
mew=mew,
)
if len(v) > 0
else None
)
for i_v, v in enumerate(inh_spikes.values())
]

ax.set_xticks([])
ax.set_yticks([])
ax.spines.left.set_visible(False)
ax.spines.bottom.set_visible(False)
ax.set_xlim(XL)
ax.set_ylabel("Raster", fontsize=fontsize)
return ax
'''
Loading