diff --git a/neuronpp/utils/graphs/network_status_graph.py b/neuronpp/utils/graphs/network_status_graph.py index e62b30ab..1382d732 100644 --- a/neuronpp/utils/graphs/network_status_graph.py +++ b/neuronpp/utils/graphs/network_status_graph.py @@ -26,6 +26,7 @@ def __init__(self, cells, weight_name='w', plot_fixed_weight_edges=True): self.correct_position = (0.1, 0.1) + self.population_sizes = dict() self.population_names = self._get_population_names() self.edges = self._get_edges(weight_name) @@ -69,6 +70,10 @@ def update_spikes(self, sim_time): def _get_edges(self, weight_name): result = [] + # compute the number of cells in each layer + for c in self.cells: + pop_name = c.name.split('[')[0] + self.population_sizes[pop_name] = self.population_sizes.get(pop_name, 0) + 1 for c in self.cells: soma = c.filter_secs('soma') if c._spike_detector is None: @@ -81,11 +86,13 @@ def _get_edges(self, weight_name): except ValueError: x_pos = self.population_names.index(pop_name) - y_pos = int(split_name[-1][:-1]) + # shift down the layer by half its size to vertically center graph + y_pos = int(split_name[-1][:-1]) - self.population_sizes[pop_name] // 2 if 'inh' in c.name: self.colors.append('red') - y_pos -= 5 + # todo center vertically by half width of the hid layer + y_pos -= 6 elif 'hid' in c.name: self.colors.append('blue') else: @@ -110,7 +117,8 @@ def _find_target(self, c, x_pos, y_pos, weight_name): except ValueError: x_trg = self.population_names.index(pop_name) - y_trg = int(split_target[-1][:-1]) + # center veritically + y_trg = int(split_target[-1][:-1]) - self.population_sizes[pop_name] // 2 weight = None if self.plot_constant_connections and hasattr(nc.target.hoc, weight_name): @@ -125,7 +133,8 @@ def _find_weights(c, weight_name): for nc in c.ncs: if "SpikeDetector" in nc.name: continue - elif isinstance(nc.source, Seg) and isinstance(nc.target, PointProcess) and hasattr(nc.target.hoc, weight_name): + elif isinstance(nc.source, Seg) and isinstance(nc.target, PointProcess) and hasattr(nc.target.hoc, + weight_name): weight = getattr(nc.target.hoc, weight_name) targets.append(weight) return targets diff --git a/neuronpp/utils/record.py b/neuronpp/utils/record.py index 9d175575..e68d30e6 100644 --- a/neuronpp/utils/record.py +++ b/neuronpp/utils/record.py @@ -1,13 +1,18 @@ -from nrn import Segment, Section -from collections import defaultdict +from collections import defaultdict, namedtuple import numpy as np import pandas as pd from neuron import h import matplotlib.pyplot as plt +from neuronpp.core.hocwrappers.sec import Sec + +from neuronpp.core.hocwrappers.point_process import PointProcess from neuronpp.core.hocwrappers.seg import Seg +MarkerParams = namedtuple("Simulation_params", + "agent_class agent_stepsize dt input_cell_num output_cell_num output_labels") + class Record: def __init__(self, elements, variables='v'): @@ -36,13 +41,20 @@ def __init__(self, elements, variables='v'): for elem in elements: for var in variables: if isinstance(elem, Seg): - name = elem.parent.name + cell_name = elem.parent.parent.name + name = "%s_%s" % (cell_name, elem.name) + elif isinstance(elem, PointProcess): + cell_name = elem.cell.name + name = "%s_%s" % (cell_name, elem.name) + elif isinstance(elem, Sec): + raise TypeError("Record element cannot be of type Sec, however you can specify Seg eg. soma(0.5) and pass as element.") else: name = elem.name try: s = getattr(elem.hoc, "_ref_%s" % var) except AttributeError: - raise AttributeError("there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var) + raise AttributeError( + "there is no attribute of %s. Maybe you forgot to append loc param for sections?" % var) rec = h.Vector().record(s) self.recs[var].append((name, rec)) @@ -87,16 +99,19 @@ def _plot_static(self, position=None): for i, (name, rec) in enumerate(section_recs): rec_np = rec.as_numpy() if np.max(np.isnan(rec_np)): - raise ValueError("Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name)) + raise ValueError( + "Vector recorded for variable: '%s' and segment: '%s' contains nan values." % (var_name, name)) if position is not "merge": - ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1) + ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), + index=i + 1) ax.set_title("Variable: %s" % var_name) ax.plot(self.t, rec, label=name) ax.set(xlabel='t (ms)', ylabel=var_name) ax.legend() - def _plot_animate(self, steps=10000, y_lim=None, position=None): + def _plot_animate(self, steps=10000, y_lim=None, position=None, true_class=None, pred_class=None, + show_true_predicted=False, marker_params: MarkerParams = None): """ Call each time you want to redraw plot. @@ -109,8 +124,20 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): * position=(3,3) -> if you have 9 neurons and want to display 'v' on 3x3 matrix * position='merge' -> it will display all figures on the same graph. * position=None -> Default, each neuron has separated axis (row) on the figure. + :param true_class: list of true class labels in this window + :param pred_class: list of predicted class labels in window + :param show_true_predicted: whther to print true/predicted class' marks on the plot + :param marker_params: MarkerParams namedtuple contains inner params: + :param agent_stepsize: agent readout time step + :param dt: agent integration time step + :param input_cell_num: number of input cells + :param output_cell_num: number of output cells + :param output_labels: list of true labels for the consecutive plots :return: """ + if show_true_predicted and marker_params is None: + raise ValueError( + "Running parameters run_params need to be passed if true/predicted markers are to be shown") create_fig = False for var_name, section_recs in self.recs.items(): if var_name not in self.figs: @@ -119,21 +146,25 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): fig = self.figs[var_name] if fig is None: create_fig = True - fig = plt.figure() + fig = plt.figure(figsize=(16.5, 5.5)) fig.canvas.draw() self.figs[var_name] = fig + if show_true_predicted: + if len(marker_params.output_labels) != len(section_recs): + raise ValueError( + "show_predicted is true but the number of true labels given is not equal to actual number of elemens to plot.") for i, (name, rec) in enumerate(section_recs): if create_fig: if position == 'merge': ax = fig.add_subplot(1, 1, 1) else: - ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), index=i + 1) + ax = self._get_subplot(fig=fig, var_name=var_name, position=position, row_len=len(section_recs), + index=i + 1) if y_lim: ax.set_ylim(y_lim[0], y_lim[1]) line, = ax.plot([], lw=1, label=name) - ax.set_title("Variable: %s" % var_name) ax.set_ylabel(var_name) ax.set_xlabel("t (ms)") ax.legend() @@ -146,17 +177,83 @@ def _plot_animate(self, steps=10000, y_lim=None, position=None): ax.set_xlim(t.min(), t.max()) if y_lim is None: - ax.set_ylim(r.min()-(np.abs(r.min()*0.05)), r.max()+(np.abs(r.max()*0.05))) + # compute per-plot OY limits if global are not given + current_y_lim = (r.min() - (np.abs(r.min() * 0.05)), r.max() + (np.abs(r.max() * 0.05))) + ax.set_ylim(current_y_lim) + else: + current_y_lim = y_lim # update data line.set_data(t, r) + if show_true_predicted: + # info draw markers for true and predicted classes + self._show_true_predicted_marks(ax=ax, label=marker_params.output_labels[i], true_class=true_class, + pred_class=pred_class, + t=t, y_limits=current_y_lim, marker_params=marker_params) + if create_fig and i == 0: + # draw legend only the first time and only on the uppermost graph + ax.legend() + # info join plots by removing labels and ticks from subplots that are not on the edge + if create_fig: + fig.subplots_adjust(left=0.09, bottom=0.075, right=0.99, top=0.98, wspace=None, hspace=0.00) fig.canvas.draw() fig.canvas.flush_events() if create_fig: plt.show(block=False) + def _show_true_predicted_marks(self, ax, label, true_class, pred_class, t, y_limits, marker_params): + """ + draw triangles for true and predicted classes + :param ax: the canvas + :param true_class: list of true class labels in this window + :param pred_class: list of predicted class labels in window + :param y_limits: this canvas OY limits for y axis. Default is (-80, 50) + :param run_params: a namedtuple containing + :param agent_stepsize: agent readout time step + :param dt: agent integration time step + :param input_cell_num: number of input cells + :param output_cell_num: number of output cells + :param output_labels: list of true labels for the consecutive plots + :return: + """ + if marker_params.output_labels is not None: + true_x, pred_x = self._get_labels_timestamps(label=label, + true_class=true_class, + pred_class=pred_class, t=t, + marker_params=marker_params) + else: + raise ValueError("True_labels parameter need to be given if show_true_prediction is True") + true_y = [y_limits[0] + np.abs(y_limits[0]) * 0.09] * len(true_x) + pred_y = [y_limits[1] - np.abs(y_limits[1] * 0.12)] * len(pred_x) + ax.scatter(true_x, true_y, c="orange", marker="^", alpha=0.95, label="true") + ax.scatter(pred_x, pred_y, c="magenta", marker="v", alpha=0.95, label="predicted") + + @staticmethod + def _get_labels_timestamps(label, true_class, pred_class, t, marker_params): + """ + find and return lists of time steps for true and predicted labels + :param label: the label id (an int) + :param true_class: list of true classes for the whole time region + :param pred_class: list of predicted labels (class ids) for the whole time region + :param t: the region time steps + :param marker_params: + :return: lists of marks for true_x: true classes, pred_x: predicted classes + """ + n = len(true_class) + x = t[::int(2 * marker_params.agent_stepsize / marker_params.dt)][-n:] + true_x = [] + pred_x = [] + # todo change lists into numpy arrays for speed + for k in range(n): + # get the true classes for the current label + if true_class[k] == label: + true_x.append(x[k]) + if pred_class[k] == label: + pred_x.append(x[k]) + return true_x, pred_x + def to_csv(self, filename): cols = ['time'] data = [self.t.as_numpy().tolist()]