diff --git a/src/MEArec/tools.py b/src/MEArec/tools.py index 42176a5..e59ce21 100755 --- a/src/MEArec/tools.py +++ b/src/MEArec/tools.py @@ -3098,7 +3098,7 @@ def plot_templates( else: colors = plt.rcParams["axes.prop_cycle"].by_key()["color"] - gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax) + gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax.get_subplotspec()) for i_n, n in enumerate(template_ids): r = i_n // ncols @@ -3318,7 +3318,7 @@ def plot_waveforms( nrows = 1 ncols = n_units - gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax) + gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax.get_subplotspec()) for i, wf in enumerate(waveforms): r = i // ncols @@ -3336,7 +3336,7 @@ def plot_waveforms( nrows = 1 ncols = len(spiketrain_id) - gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax) + gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax.get_subplotspec()) # find ylim min_wf = 0 @@ -3488,7 +3488,7 @@ def plot_amplitudes( nrows = 1 ncols = n_units - gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax) + gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax.get_subplotspec()) for i_n, n in enumerate(spiketrain_id): r = i_n // ncols @@ -3626,7 +3626,7 @@ def plot_pca_map( nrows = len(pc_dims) * len(elec_dims) ncols = nrows - gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax) + gs = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec=ax.get_subplotspec()) for p1 in pc_dims: for i1, ch1 in enumerate(elec_dims):