-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
branch: models_igr: graph display + record selection #6
base: master
Are you sure you want to change the base?
Changes from all commits
6fe2e9a
d5df424
2ddc41c
e5fe31f
4fb520f
7758ed8
06fc25d
3ed9930
c8fb1e2
597409e
cf742d9
ef7ae7f
79f38d8
34a7336
387e21a
f2d346e
8411e81
2ca9dec
be385e8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,13 +1,18 @@ | ||
from nrn import Segment, Section | ||
from collections import defaultdict | ||
from collections import defaultdict, namedtuple | ||
|
||
import numpy as np | ||
import pandas as pd | ||
from neuron import h | ||
import matplotlib.pyplot as plt | ||
from neuronpp.core.hocwrappers.sec import Sec | ||
|
||
from neuronpp.core.hocwrappers.point_process import PointProcess | ||
|
||
from neuronpp.core.hocwrappers.seg import Seg | ||
|
||
MarkerParams = namedtuple("Simulation_params", | ||
"agent_class agent_stepsize dt input_cell_num output_cell_num output_labels") | ||
|
||
|
||
class Record: | ||
def __init__(self, elements, variables='v'): | ||
|
@@ -36,13 +41,20 @@ def __init__(self, elements, variables='v'): | |
for elem in elements: | ||
for var in variables: | ||
if isinstance(elem, Seg): | ||
name = elem.parent.name | ||
cell_name = elem.parent.parent.name | ||
name = "%s_%s" % (cell_name, elem.name) | ||
elif isinstance(elem, PointProcess): | ||
cell_name = elem.cell.name | ||
name = "%s_%s" % (cell_name, elem.name) | ||
elif isinstance(elem, Sec): | ||
raise TypeError("Record element cannot be of type Sec, however you can specify Seg eg. soma(0.5) and pass as element.") | ||
else: | ||
name = elem.name | ||
try: | ||
s = getattr(elem.hoc, "_ref_%s" % var) | ||
except AttributeError: | ||
raise AttributeError("there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var) | ||
raise AttributeError( | ||
"there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var) | ||
|
||
rec = h.Vector().record(s) | ||
self.recs[var].append((name, rec)) | ||
|
@@ -87,16 +99,19 @@ def _plot_static(self, position=None): | |
for i, (name, rec) in enumerate(section_recs): | ||
rec_np = rec.as_numpy() | ||
if np.max(np.isnan(rec_np)): | ||
raise ValueError("Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name)) | ||
raise ValueError( | ||
"Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name)) | ||
|
||
if position is not "merge": | ||
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1) | ||
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), | ||
index=i + 1) | ||
ax.set_title("Variable: %s" % var_name) | ||
ax.plot(self.t, rec, label=name) | ||
ax.set(xlabel='t (ms)', ylabel=var_name) | ||
ax.legend() | ||
|
||
def _plot_animate(self, steps=10000, y_lim=None, position=None): | ||
def _plot_animate(self, steps=10000, y_lim=None, position=None, true_class=None, pred_class=None, | ||
show_true_predicted=False, marker_params: MarkerParams = None): | ||
""" | ||
Call each time you want to redraw plot. | ||
|
||
|
@@ -109,8 +124,20 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): | |
* position=(3,3) -> if you have 9 neurons and want to display 'v' on 3x3 matrix | ||
* position='merge' -> it will display all figures on the same graph. | ||
* position=None -> Default, each neuron has separated axis (row) on the figure. | ||
:param true_class: list of true class labels in this window | ||
:param pred_class: list of predicted class labels in window | ||
:param show_true_predicted: whther to print true/predicted class' marks on the plot | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is no point of using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a matter of liking. Lots of code is written so that switch some behaviour on/off using a single switch. |
||
:param marker_params: MarkerParams namedtuple contains inner params: | ||
:param agent_stepsize: agent readout time step | ||
:param dt: agent integration time step | ||
:param input_cell_num: number of input cells | ||
:param output_cell_num: number of output cells | ||
:param output_labels: list of true labels for the consecutive plots | ||
:return: | ||
""" | ||
if show_true_predicted and marker_params is None: | ||
raise ValueError( | ||
"Running parameters run_params need to be passed if true/predicted markers are to be shown") | ||
create_fig = False | ||
for var_name, section_recs in self.recs.items(): | ||
if var_name not in self.figs: | ||
|
@@ -119,21 +146,25 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): | |
fig = self.figs[var_name] | ||
if fig is None: | ||
create_fig = True | ||
fig = plt.figure() | ||
fig = plt.figure(figsize=(16.5, 5.5)) | ||
fig.canvas.draw() | ||
self.figs[var_name] = fig | ||
|
||
if show_true_predicted: | ||
if len(marker_params.output_labels) != len(section_recs): | ||
raise ValueError( | ||
"show_predicted is true but the number of true labels given is not equal to actual number of elemens to plot.") | ||
for i, (name, rec) in enumerate(section_recs): | ||
if create_fig: | ||
if position == 'merge': | ||
ax = fig.add_subplot(1, 1, 1) | ||
else: | ||
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1) | ||
ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), | ||
index=i + 1) | ||
|
||
if y_lim: | ||
ax.set_ylim(y_lim[0], y_lim[1]) | ||
line, = ax.plot([], lw=1, label=name) | ||
ax.set_title("Variable: %s" % var_name) | ||
ax.set_ylabel(var_name) | ||
ax.set_xlabel("t (ms)") | ||
ax.legend() | ||
|
@@ -146,17 +177,83 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): | |
|
||
ax.set_xlim(t.min(), t.max()) | ||
if y_lim is None: | ||
ax.set_ylim(r.min()-(np.abs(r.min()*0.05)), r.max()+(np.abs(r.max()*0.05))) | ||
# compute per-plot OY limits if global are not given | ||
current_y_lim = (r.min() - (np.abs(r.min() * 0.05)), r.max() + (np.abs(r.max() * 0.05))) | ||
ax.set_ylim(current_y_lim) | ||
else: | ||
current_y_lim = y_lim | ||
|
||
# update data | ||
line.set_data(t, r) | ||
if show_true_predicted: | ||
# info draw markers for true and predicted classes | ||
self._show_true_predicted_marks(ax=ax, label=marker_params.output_labels[i], true_class=true_class, | ||
pred_class=pred_class, | ||
t=t, y_limits=current_y_lim, marker_params=marker_params) | ||
if create_fig and i == 0: | ||
# draw legend only the first time and only on the uppermost graph | ||
ax.legend() | ||
|
||
# info join plots by removing labels and ticks from subplots that are not on the edge | ||
if create_fig: | ||
igorpodolak marked this conversation as resolved.
Show resolved
Hide resolved
|
||
fig.subplots_adjust(left=0.09, bottom=0.075, right=0.99, top=0.98, wspace=None, hspace=0.00) | ||
fig.canvas.draw() | ||
fig.canvas.flush_events() | ||
|
||
if create_fig: | ||
plt.show(block=False) | ||
|
||
def _show_true_predicted_marks(self, ax, label, true_class, pred_class, t, y_limits, marker_params): | ||
""" | ||
draw triangles for true and predicted classes | ||
:param ax: the canvas | ||
:param true_class: list of true class labels in this window | ||
:param pred_class: list of predicted class labels in window | ||
:param y_limits: this canvas OY limits for y axis. Default is (-80, 50) | ||
:param run_params: a namedtuple containing | ||
:param agent_stepsize: agent readout time step | ||
:param dt: agent integration time step | ||
:param input_cell_num: number of input cells | ||
:param output_cell_num: number of output cells | ||
:param output_labels: list of true labels for the consecutive plots | ||
:return: | ||
""" | ||
if marker_params.output_labels is not None: | ||
true_x, pred_x = self._get_labels_timestamps(label=label, | ||
true_class=true_class, | ||
pred_class=pred_class, t=t, | ||
marker_params=marker_params) | ||
else: | ||
raise ValueError("True_labels parameter need to be given if show_true_prediction is True") | ||
true_y = [y_limits[0] + np.abs(y_limits[0]) * 0.09] * len(true_x) | ||
pred_y = [y_limits[1] - np.abs(y_limits[1] * 0.12)] * len(pred_x) | ||
ax.scatter(true_x, true_y, c="orange", marker="^", alpha=0.95, label="true") | ||
ax.scatter(pred_x, pred_y, c="magenta", marker="v", alpha=0.95, label="predicted") | ||
|
||
@staticmethod | ||
def _get_labels_timestamps(label, true_class, pred_class, t, marker_params): | ||
""" | ||
find and return lists of time steps for true and predicted labels | ||
:param label: the label id (an int) | ||
:param true_class: list of true classes for the whole time region | ||
:param pred_class: list of predicted labels (class ids) for the whole time region | ||
:param t: the region time steps | ||
:param marker_params: | ||
:return: lists of marks for true_x: true classes, pred_x: predicted classes | ||
""" | ||
n = len(true_class) | ||
x = t[::int(2 * marker_params.agent_stepsize / marker_params.dt)][-n:] | ||
true_x = [] | ||
pred_x = [] | ||
# todo change lists into numpy arrays for speed | ||
for k in range(n): | ||
# get the true classes for the current label | ||
if true_class[k] == label: | ||
true_x.append(x[k]) | ||
if pred_class[k] == label: | ||
pred_x.append(x[k]) | ||
return true_x, pred_x | ||
|
||
def to_csv(self, filename): | ||
cols = ['time'] | ||
data = [self.t.as_numpy().tolist()] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this concept of recording with labels should be moved to separated class which inherits from
Record
, eg.RecordWithLabels
. In this case -dt
andtrue_labels
params from the_plot_animate()
metod can be moved into consructor.