-
Notifications
You must be signed in to change notification settings - Fork 4
Added custom inputs and basic cluster analysis #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
### Brian2 models for network construction. | ||
import brian2 as b2 | ||
import numpy as np | ||
from automind.sim import b2_inputs | ||
|
||
|
||
def adaptive_exp_net(all_param_dict): | ||
|
@@ -64,14 +65,14 @@ def adaptive_exp_net(all_param_dict): | |
) | ||
|
||
### TO DO: also randomly initialize w to either randint(?)*b or randn*(v-v_rest)*a | ||
|
||
poisson_input_E = b2.PoissonInput( | ||
''' poisson_input_E = b2.PoissonInput( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happened here? This func didn't need to be changed right (or is used at all)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happened here? This func didn't need to be changed right (or is used at all)? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happened here? This func didn't need to be changed right (or is used at all)? The quotes are typos / accidents? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happened here? This func didn't need to be changed right (or is used at all)? The quotes are typos / accidents? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes sorry I probably wanted to check if the custom inputs were working and silenced this to confirm, the Poisson input should be kept. |
||
target=E_pop, | ||
target_var="ge", | ||
N=param_dict_neuron_E["N_poisson"], | ||
rate=param_dict_neuron_E["poisson_rate"], | ||
weight=param_dict_neuron_E["Q_poisson"], | ||
) | ||
)''' | ||
|
||
|
||
if has_inh: | ||
# make adlif if delta_t is 0, otherwise adex | ||
|
@@ -268,12 +269,21 @@ def make_clustered_network( | |
return membership, shared_membership, conn_in, conn_out | ||
|
||
|
||
def adaptive_exp_net_clustered(all_param_dict): | ||
"""Adaptive exponential integrate-and-fire network with clustered connections.""" | ||
#Modified function incorporating inputs to specific clusters | ||
def adaptive_exp_net_clustered_cog(all_param_dict, mode='default', custom_input=None, stim_cluster=None, custom_cluster_input=None): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. potentially think about absorbing the other args into either this is more of a comment for myself than you @nairb1234, unless you think this is easily doable in the short-term. Otherwise, I will open an issue for myself on this. |
||
''' | ||
Adaptive exponential integrate-and-fire network with clustered connections. | ||
|
||
3 modes | ||
- Default mode - no input | ||
- Single mode - Single input -> User can define any input sequence. DM_simple is used when no inputs are provided | ||
- Cluster mode - Each cluster gets different input, can be defined by user. DM_simple with different mean is used for each cluster when no inputs are provided. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe a bit more details on the other args, like what it represents and var type (e.g., str? array? etc.) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added |
||
- Can also select number of clusters to stimulate | ||
''' | ||
|
||
# separate parameter dictionaries | ||
param_dict_net = all_param_dict["params_net"] | ||
param_dict_settings = all_param_dict["params_settings"] | ||
|
||
# set random seeds | ||
b2.seed(param_dict_settings["random_seed"]) | ||
np.random.seed(param_dict_settings["random_seed"]) | ||
|
@@ -287,25 +297,27 @@ def adaptive_exp_net_clustered(all_param_dict): | |
|
||
#### NETWORK CONSTRUCTION ############ | ||
###################################### | ||
|
||
### get cell counts | ||
N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] | ||
N_exc = int(N_pop * exc_prop) | ||
N_inh = N_pop - N_exc | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great |
||
|
||
### define neuron equation | ||
adex_coba_eq = """dv/dt = (-g_L * (v - v_rest) + g_L * delta_T * exp((v - v_thresh)/delta_T) - w + I)/C : volt (unless refractory)""" | ||
adlif_coba_eq = ( | ||
"""dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" | ||
) | ||
|
||
adlif_coba_eq = """dv/dt = (-g_L * (v - v_rest) - w + I)/C : volt (unless refractory)""" | ||
|
||
network_eqs = """ | ||
dw/dt = (-w + a * (v - v_rest))/tau_w : amp | ||
dge/dt = -ge / tau_ge : siemens | ||
dgi/dt = -gi / tau_gi : siemens | ||
Ie = ge * (E_ge - v): amp | ||
Ii = gi * (E_gi - v): amp | ||
I = I_bias + Ie + Ii : amp | ||
I_ext: amp | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. great! |
||
I = I_bias + Ie + Ii + I_ext: amp | ||
""" | ||
|
||
### get cell counts | ||
N_pop, exc_prop = param_dict_net["N_pop"], param_dict_net["exc_prop"] | ||
N_exc, N_inh = int(N_pop * exc_prop), int(N_pop * (1 - exc_prop)) | ||
|
||
### make neuron populations, set initial values and connect poisson inputs ### | ||
# make adlif if delta_t is 0, otherwise adex | ||
neuron_eq = ( | ||
|
@@ -422,6 +434,7 @@ def adaptive_exp_net_clustered(all_param_dict): | |
p_out, | ||
param_dict_net["order_clusters"], | ||
) | ||
param_dict_net["membership"] = membership | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is necessary for figuring out the cluster membership post hoc? I kinda want to avoid doing it like this, because introducing a new param into the dict is a bit adhoc unless we define it with a default beforehand (which we can do, in the |
||
|
||
# scale synaptic weight | ||
Q_ge_out = param_dict_neuron_E["Q_ge"] | ||
|
@@ -491,6 +504,93 @@ def adaptive_exp_net_clustered(all_param_dict): | |
) | ||
syn_i2i.connect("i!=j", p=param_dict_net["p_i2i"]) | ||
|
||
### Handle different input modes ### | ||
if mode == 'default': #No input | ||
stim_time_values = b2_inputs.DM_simple(all_param_dict,0,0) #Just change this to an input (pass in an stim array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. is there a possibility to skip this altogether or simply define I_ext as a scalar There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Skipped that now and just put a single |
||
dt = param_dict_settings["dt"] | ||
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
||
# Define network operation to update I_ext | ||
@b2.network_operation(dt=dt) | ||
def update_test_input(t): | ||
E_pop.I_ext = stim_timed_array(t) | ||
|
||
elif mode == 'single': | ||
# Check if custom input is provided in the parameter dictionary | ||
custom_input = all_param_dict.get("custom_input", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah I think having a structure like |
||
if custom_input is not None: | ||
stim_time_values = custom_input | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah so a bit more explanation here on |
||
else: | ||
# Use default DM_simple if no custom input is provided | ||
stim_time_values = b2_inputs.DM_simple(all_param_dict) | ||
|
||
dt = param_dict_settings["dt"] | ||
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
||
# Define network operation to update I_ext | ||
@b2.network_operation(dt=dt) | ||
def update_test_input(t): | ||
E_pop.I_ext = stim_timed_array(t) | ||
|
||
elif mode == 'cluster': | ||
# Determine if network has clusters | ||
has_clusters = ( | ||
"n_clusters" in param_dict_net.keys() | ||
and param_dict_net["n_clusters"] >= 2 | ||
and param_dict_net["R_pe2e"] != 1 | ||
) | ||
|
||
if has_clusters: | ||
n_clusters_original = int(param_dict_net["n_clusters"]) | ||
stimulated_clusters_count = n_clusters_original | ||
|
||
#Check if user defined number of clusters to stimulate | ||
if stim_cluster is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe name this
nairb1234 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
stimulated_clusters_count = stim_cluster | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same for |
||
if stim_cluster > n_clusters_original: | ||
stimulated_clusters_count = n_clusters_original | ||
print(f"No. of clusters picked ({stim_cluster}) exceeds actual no. of clusters. Stimulating all {n_clusters_original} clusters instead.") | ||
#Select number of clusters | ||
selected_clusters = np.random.choice( | ||
n_clusters_original, | ||
stimulated_clusters_count, | ||
replace=False | ||
) | ||
cluster_lists = [[c] for c in selected_clusters] | ||
|
||
# Generate cluster-specific inputs for selected clusters - see b2_inputs | ||
if custom_cluster_input is not None: | ||
stim_list = custom_cluster_input | ||
_, weight_list = b2_inputs.cluster_specific_stim( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. just to make sure I understand correctly: Could you explain a bit more here? There's a lot going on, and the overwriting of variable values after ifs also makes it a bit hard to follow |
||
all_param_dict, | ||
n_clusters=stimulated_clusters_count, | ||
) | ||
else: | ||
stim_list, weight_list = b2_inputs.cluster_specific_stim( | ||
all_param_dict, | ||
n_clusters=stimulated_clusters_count, | ||
) | ||
|
||
# Create input configurations | ||
input_configs = b2_inputs.get_input_configs( | ||
cluster_lists, | ||
stim_list, | ||
weight_list, | ||
) | ||
input_op = b2_inputs.create_input_operation(E_pop, input_configs, membership) | ||
param_dict_net['input'] = stim_list | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same here, I would put |
||
else: | ||
# Fallback to test mode if network doesn't have clusters initially | ||
print("Network does not have clusters. All neurons will receive the same DM_simple input ") | ||
stim_time_values = b2_inputs.test_stim(all_param_dict) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh i see, |
||
dt = param_dict_settings["dt"] | ||
stim_timed_array = b2.TimedArray(stim_time_values * b2.amp, dt=dt) | ||
|
||
@b2.network_operation(dt=dt) | ||
def update_cluster_fallback_input(t): | ||
E_pop.I_ext = stim_timed_array(t) | ||
param_dict_net['input'] = stim_timed_array | ||
|
||
### define monitors ### | ||
rate_monitors, spike_monitors, trace_monitors = [], [], [] | ||
rec_defs = param_dict_settings["record_defs"] | ||
|
@@ -510,12 +610,12 @@ def adaptive_exp_net_clustered(all_param_dict): | |
# and later drop randomly before saving, otherwise | ||
# recording only from first n neurons, which heavily overlap | ||
# with those stimulated, and the first few clusters | ||
rec_idx = np.arange(N_exc) | ||
rec_idx = np.arange(N_exc) | ||
else: | ||
rec_idx = ( | ||
np.arange(rec_defs[pop_name]["spikes"]) | ||
np.arange(rec_defs[pop_name]["spikes"]) | ||
if type(rec_defs[pop_name]["spikes"]) is int | ||
else rec_defs[pop_name]["spikes"] | ||
else rec_defs[pop_name]["spikes"] #Change param_settings.record_Defs to 2000 | ||
) | ||
spike_monitors.append( | ||
b2.SpikeMonitor(pop[rec_idx], name=pop_name + "_spikes") | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
|
||
def _filter_spikes_random(spike_trains, n_to_save): | ||
"""Filter a subset of spike trains randomly for saving.""" | ||
np.random.seed(42) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I forgot if we talked about this but should this be reproduceably random, i.e., using the model |
||
record_subset = np.sort( | ||
np.random.choice(len(spike_trains), n_to_save, replace=False) | ||
) | ||
|
@@ -33,13 +34,14 @@ def collect_spikes(net_collect, params_dict): | |
sm.name.split("_")[0] | ||
]["spikes"] | ||
n_to_save = pop_save_def if type(pop_save_def) == int else len(pop_save_def) | ||
#n_to_save = len(spike_trains) | ||
if n_to_save == len(spike_trains): | ||
# recorded and to-be saved is the same length, go on a per usual | ||
spike_dict[sm.name] = b2_interface._deunitize_spiketimes(spike_trains) | ||
else: | ||
# recorded more than necessary, subselect for saving | ||
spike_dict[sm.name] = b2_interface._deunitize_spiketimes( | ||
_filter_spikes_random(spike_trains, n_to_save) | ||
_filter_spikes_random(spike_trains, n_to_save) # THIS IS WHERE THE NEURONS ARE RANDOMLY DROPPED BEFORE SAVING (AND PLOTTING??) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it's not being randomly dropped anymore? |
||
) | ||
return spike_dict | ||
|
||
|
@@ -674,3 +676,31 @@ def load_df_posteriors(path_dict): | |
path_dict["root_path"] + path_dict["params_dict_analysis_updated"] | ||
) | ||
return df_posterior_sims, posterior, params_dict_default | ||
|
||
def sort_neurons(membership, sorting_method="cluster_identity"): | ||
""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice this should be quite useful |
||
Sort neurons based on the specified method. | ||
|
||
Parameters: | ||
membership (list/array of 2D arrays): Membership arrays for each simulation. | ||
sorting_method (str): "cluster_identity" or "n_clusters". | ||
|
||
Returns: | ||
sorted_indices (list of arrays): Sorted indices for each simulation. | ||
""" | ||
sorted_indices = [] | ||
#Sort by whether neurons are in one cluster or two clusters | ||
|
||
if sorting_method == "cluster_identity": | ||
# Sort by the first cluster identity | ||
sorted_idx = np.argsort(membership[:, 0]) | ||
sorted_indices.append(sorted_idx) | ||
elif sorting_method == "n_clusters": | ||
#Neurons in one cluster have the same values in both columns | ||
single = np.where(membership[:,0] == membership[:,1]) | ||
double = np.where(membership[:,0] != membership[:,1]) | ||
sorted_indices.append(single) | ||
sorted_indices.append(double) | ||
else: | ||
raise ValueError("Invalid sorting_method. Use 'cluster_identity' or 'n_clusters'.") | ||
return sorted_indices |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -327,7 +327,6 @@ def _plot_raster_pretty( | |
ax.set_ylabel("Raster", fontsize=fontsize) | ||
return ax | ||
|
||
|
||
def _plot_rates_pretty( | ||
rates, | ||
XL, | ||
|
@@ -584,3 +583,126 @@ def plot_corr_pv(pvals, ax, alpha_level=0.05, fmt="w*", ms=0.5): | |
for j in range(pvals.shape[0]): | ||
if pvals[i, j] < alpha_level: | ||
ax.plot(j, i, fmt, ms=ms, alpha=1) | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so this is not being used anymore because you extracted the sorting function right? Then I think it can be removed |
||
|
||
#Just use the default plotting function with the sorted spikes | ||
''' | ||
def plot_raster( | ||
spikes, | ||
membership, | ||
XL, | ||
plotting_method="cluster_identity", | ||
every_other=1, | ||
ax=None, | ||
fontsize=14, | ||
plot_inh=False, | ||
E_colors=None, | ||
I_color="gray", | ||
single_cluster_style="|", | ||
double_cluster_style="x", | ||
mew=0.5, | ||
ms=1, | ||
**plot_kwargs, | ||
): | ||
""" | ||
Plot raster plot with neurons sorted by cluster identity or number of clusters. | ||
|
||
Parameters: | ||
spikes (dict): Dictionary containing 'exc_spikes' and 'inh_spikes'. | ||
membership: Array/list of 2D membership arrays (from params_net['membership']). | ||
XL (list): X-axis limits. | ||
plotting_method (str): "cluster_identity" or "n_clusters". | ||
every_other (int): Plot every nth spike. | ||
ax (matplotlib axis): Axis to plot on. | ||
fontsize (int): Font size for labels. | ||
plot_inh (bool): Whether to plot inhibitory spikes. | ||
E_colors (list): Colors for excitatory clusters. | ||
I_color (str): Color for inhibitory spikes. | ||
single_cluster_style (str): Marker style for single-cluster neurons. | ||
double_cluster_style (str): Marker style for two-cluster neurons. | ||
mew (float): Marker edge width. | ||
ms (float): Marker size. | ||
""" | ||
if ax is None: | ||
ax = plt.axes() | ||
|
||
exc_spikes = spikes["exc_spikes"] | ||
inh_spikes = spikes.get("inh_spikes", {}) | ||
|
||
if plotting_method == "cluster_identity": | ||
# Sort by cluster identity | ||
sorted_indices = data_utils.sort_neurons(membership, sorting_method='cluster_identity') | ||
sorted_indices_list = sorted_indices[0].tolist() # Convert to list of Python integers | ||
sorted_exc_spikes = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices_list)} | ||
#exc_spikes_to_plot = sorted_exc_spikes.values() | ||
elif plotting_method == "n_clusters": | ||
# Sort by number of clusters | ||
sorted_indices = data_utils.sort_neurons(membership, sorting_method='n_clusters') | ||
sorted_exc_spikes_single = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[0][0])} | ||
sorted_exc_spikes_double = {i: exc_spikes[idx] for i, idx in enumerate(sorted_indices[1][0])} | ||
#exc_spikes_to_plot.append(sorted_exc_spikes_single) | ||
#exc_spikes_to_plot.append(sorted_exc_spikes_double) | ||
else: | ||
raise ValueError("Invalid plotting_method. Use 'cluster_identity' or 'n_clusters'.") | ||
|
||
# Plot excitatory spikes, single cluster in blue and double cluster in red respectively | ||
[ | ||
( | ||
ax.plot( | ||
v[::every_other], | ||
i_v * np.ones_like(v[::every_other]), | ||
single_cluster_style, | ||
color='blue', | ||
alpha=1, | ||
ms=ms, | ||
mew=mew, | ||
) | ||
if len(v) > 0 | ||
else None | ||
) | ||
for i_v, (t,v) in enumerate(sorted_exc_spikes_single.items()) | ||
] | ||
[ | ||
( | ||
ax.plot( | ||
v[::every_other], | ||
(i_v+ len(sorted_indices[0][0])) * np.ones_like(v[::every_other]), | ||
single_cluster_style, | ||
color='red', | ||
alpha=1, | ||
ms=ms, | ||
mew=mew, | ||
) | ||
if len(v) > 0 | ||
else None | ||
) | ||
for i_v, (t,v) in enumerate(sorted_exc_spikes_double.items()) | ||
] | ||
|
||
# Plot inhibitory spikes | ||
if plot_inh: | ||
[ | ||
( | ||
ax.plot( | ||
v[::every_other], | ||
(i_v + len(sorted_indices[0][0]) + len(sorted_indices[1][0])) * np.ones_like(v[::every_other]), | ||
"|", | ||
color=I_color, | ||
alpha=1, | ||
ms=ms, | ||
mew=mew, | ||
) | ||
if len(v) > 0 | ||
else None | ||
) | ||
for i_v, v in enumerate(inh_spikes.values()) | ||
] | ||
|
||
ax.set_xticks([]) | ||
ax.set_yticks([]) | ||
ax.spines.left.set_visible(False) | ||
ax.spines.bottom.set_visible(False) | ||
ax.set_xlim(XL) | ||
ax.set_ylabel("Raster", fontsize=fontsize) | ||
return ax | ||
''' |
Uh oh!
There was an error while loading. Please reload this page.