Skip to content

Commit

Permalink
More black formatting
Browse files Browse the repository at this point in the history
  • Loading branch information
pgleeson committed Apr 2, 2024
1 parent 5ad7388 commit b3bf95b
Showing 1 changed file with 73 additions and 48 deletions.
121 changes: 73 additions & 48 deletions NEST_SLI/spike_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,32 +27,39 @@
import os
import re

datapath = '.'
datapath = "."

# get simulation time and numbers of neurons recorded from sim_params.sli
with open(os.path.join(datapath, 'sim_params.sli'), 'r') as f:
with open(os.path.join(datapath, "sim_params.sli"), "r") as f:
sim_params_contents = f.read()
T = float(re.search(r'/t_sim (.+) def', sim_params_contents).group(1))
record_frac = re.search(r'/record_fraction_neurons_spikes (.+) def', sim_params_contents).group(1) == 'true'
T = float(re.search(r"/t_sim (.+) def", sim_params_contents).group(1))
record_frac = (
re.search(
r"/record_fraction_neurons_spikes (.+) def", sim_params_contents
).group(1)
== "true"
)

if record_frac:
frac_rec = float(re.search(r'/frac_rec_spikes (.+) def', sim_params_contents).group(1))
frac_rec = float(
re.search(r"/frac_rec_spikes (.+) def", sim_params_contents).group(1)
)
else:
n_rec = int(re.search(r'/n_rec_spikes (.+) def', sim_params_contents).group(1))
n_rec = int(re.search(r"/n_rec_spikes (.+) def", sim_params_contents).group(1))

T_start = 200. # starting point of analysis (to avoid transients)
T_start = 200.0 # starting point of analysis (to avoid transients)

# load node IDs

node_ids = np.loadtxt(os.path.join(datapath, 'population_nodeIDs.dat'), dtype=int)
print('Global IDs:')
node_ids = np.loadtxt(os.path.join(datapath, "population_nodeIDs.dat"), dtype=int)
print("Global IDs:")
print(node_ids)
print()

# number of populations

num_pops = len(node_ids)
print('Number of populations:')
print("Number of populations:")
print(num_pops)
print()

Expand All @@ -73,29 +80,29 @@

# first node ID of each population once device node IDs are dropped

first_node_ids = [int(1 + np.sum(pop_sizes[:i]))
for i in np.arange(len(pop_sizes))]
first_node_ids = [int(1 + np.sum(pop_sizes[:i])) for i in np.arange(len(pop_sizes))]

# last node ID of each population once device node IDs are dropped

last_node_ids = [int(np.sum(pop_sizes[:i + 1]))
for i in np.arange(len(pop_sizes))]
last_node_ids = [int(np.sum(pop_sizes[: i + 1])) for i in np.arange(len(pop_sizes))]

# convert lists to a nicer format, i.e. [[2/3e, 2/3i], []....]

Pop_sizes = [pop_sizes[i:i + 2] for i in range(0, len(pop_sizes), 2)]
print('Population sizes:')
Pop_sizes = [pop_sizes[i : i + 2] for i in range(0, len(pop_sizes), 2)]
print("Population sizes:")
print(Pop_sizes)
print()

Raw_first_node_ids = [raw_first_node_ids[i:i + 2] for i in range(0, len(raw_first_node_ids), 2)]
First_node_ids = [first_node_ids[i:i + 2] for i in range(0, len(first_node_ids), 2)]
Last_node_ids = [last_node_ids[i:i + 2] for i in range(0, len(last_node_ids), 2)]
Raw_first_node_ids = [
raw_first_node_ids[i : i + 2] for i in range(0, len(raw_first_node_ids), 2)
]
First_node_ids = [first_node_ids[i : i + 2] for i in range(0, len(first_node_ids), 2)]
Last_node_ids = [last_node_ids[i : i + 2] for i in range(0, len(last_node_ids), 2)]

# total number of neurons in the simulation

num_neurons = last_node_ids[len(last_node_ids) - 1]
print('Total number of neurons:')
print("Total number of neurons:")
print(num_neurons)
print()

Expand All @@ -105,25 +112,38 @@
# will contain neuron id resolved spike trains
neuron_spikes = [[] for i in np.arange(num_neurons + 1)]
# container for population-resolved spike data
spike_data = [[[], []], [[], []], [[], []], [[], []], [[], []], [[], []],
[[], []], [[], []]]
spike_data = [
[[], []],
[[], []],
[[], []],
[[], []],
[[], []],
[[], []],
[[], []],
[[], []],
]

counter = 0

for layer in ['0', '1', '2', '3']:
for population in ['0', '1']:
output = os.path.join(datapath,
'population_spikes-{}-{}.gdf'.format(layer,
population))
file_pattern = os.path.join(datapath,
'spikes_{}_{}*'.format(layer, population))
for layer in ["0", "1", "2", "3"]:
for population in ["0", "1"]:
output = os.path.join(
datapath, "population_spikes-{}-{}.gdf".format(layer, population)
)
file_pattern = os.path.join(datapath, "spikes_{}_{}*".format(layer, population))
files = glob.glob(file_pattern)
print('Merge ' + str(
len(files)) + ' spike files from L' + layer + 'P' + population)
print(
"Merge "
+ str(len(files))
+ " spike files from L"
+ layer
+ "P"
+ population
)
if files:
merged_file = open(output, 'w')
merged_file = open(output, "w")
for file in files:
data = open(file, 'r')
data = open(file, "r")
nest_version = next(data)
backend_version = next(data)
column_header = next(data)
Expand All @@ -135,18 +155,18 @@
first_node_id = First_node_ids[int(layer)][int(population)]
a[0] = a[0] - raw_first_node_id + first_node_id

if (a[1] > T_start): # discard data in the start-up phase
if a[1] > T_start: # discard data in the start-up phase
spike_data[counter][0].append(num_neurons - a[0])
spike_data[counter][1].append(a[1] - T_start)
neuron_spikes[a[0]].append(a[1] - T_start)

converted_line = str(a[0]) + '\t' + str(a[1]) + '\n'
converted_line = str(a[0]) + "\t" + str(a[1]) + "\n"
merged_file.write(converted_line)
data.close()
merged_file.close()
counter += 1

clrs = ['0', '0.5', '0', '0.5', '0', '0.5', '0', '0.5']
clrs = ["0", "0.5", "0", "0.5", "0", "0.5", "0", "0.5"]
plt.ion()

# raster plot
Expand All @@ -155,15 +175,20 @@
counter = 1
for j in np.arange(num_pops):
for i in np.arange(first_node_ids[j], first_node_ids[j] + rec_sizes[j]):
plt.plot(neuron_spikes[i],
np.ones_like(neuron_spikes[i]) + sum(rec_sizes) - counter,
'k o', ms=1, mfc=clrs[j], mec=clrs[j])
plt.plot(
neuron_spikes[i],
np.ones_like(neuron_spikes[i]) + sum(rec_sizes) - counter,
"k o",
ms=1,
mfc=clrs[j],
mec=clrs[j],
)
counter += 1
plt.xlim(0, T - T_start)
plt.ylim(0, sum(rec_sizes))
plt.xlabel(r'time (ms)')
plt.ylabel(r'neuron id')
plt.savefig(os.path.join(datapath, 'rasterplot.png'))
plt.xlabel(r"time (ms)")
plt.ylabel(r"neuron id")
plt.savefig(os.path.join(datapath, "rasterplot.png"))

# firing rates

Expand All @@ -177,16 +202,16 @@
temp = 0

print()
print('Firing rates:')
print("Firing rates:")
print(rates)

plt.figure(2)
ticks = np.arange(num_pops)
plt.bar(ticks, rates, width=0.9, color='k')
xticklabels = ['L2/3e', 'L2/3i', 'L4e', 'L4i', 'L5e', 'L5i', 'L6e', 'L6i']
plt.bar(ticks, rates, width=0.9, color="k")
xticklabels = ["L2/3e", "L2/3i", "L4e", "L4i", "L5e", "L5i", "L6e", "L6i"]
plt.setp(plt.gca(), xticks=ticks + 0.5, xticklabels=xticklabels)
plt.xlabel(r'subpopulation')
plt.ylabel(r'firing rate (spikes/s)')
plt.savefig(os.path.join(datapath, 'firing_rates.png'))
plt.xlabel(r"subpopulation")
plt.ylabel(r"firing rate (spikes/s)")
plt.savefig(os.path.join(datapath, "firing_rates.png"))

plt.show()

0 comments on commit b3bf95b

Please sign in to comment.