diff --git a/flox/visualize.py b/flox/visualize.py index c3cd6c816..bdf33b4fe 100644 --- a/flox/visualize.py +++ b/flox/visualize.py @@ -82,6 +82,7 @@ def visualize_groups_1d(array, labels, axis=-1, colors=None, cmap=None): assert labels.ndim == 1 factorized, unique_labels = pd.factorize(labels) assert np.array(labels).ndim == 1 + assert len(labels) == array.shape[axis] chunks = array.chunks[axis] if colors is None: @@ -96,7 +97,7 @@ def visualize_groups_1d(array, labels, axis=-1, colors=None, cmap=None): plt.figure() i0 = 0 for i in chunks: - lab = labels[i0 : i0 + i] + lab = factorized[i0 : i0 + i] col = [colors[label] for label in lab] + [(1, 1, 1)] draw_mesh( 1, @@ -170,3 +171,22 @@ def visualize_cohorts_2d(by, array, method="cohorts"): ax[1].set_title(f"{len(before_merged)} cohorts") ax[2].set_title(f"{len(merged)} merged cohorts") f.set_size_inches((6, 6)) + + +def visualize_groups_1d_long(array, labels, axis): + labels = np.asarray(labels) + assert labels.ndim == 1 + factorized, unique_labels = pd.factorize(labels) + assert np.array(labels).ndim == 1 + assert len(labels) == array.shape[axis] + chunks = array.chunks[axis] + + idx = np.concatenate([np.concatenate([np.ones((c,)), np.array([np.nan])]) for c in chunks])[:-1] + labels_ = labels[np.nancumsum(idx).astype(int) - 1].astype(float) + labels_[np.isnan(idx)] = np.nan + ncol = 5 * (20 + 1) + extra = ncol - len(idx) % ncol + idx2d = np.pad(labels_, (0, extra), constant_values=np.nan).reshape(-1, ncol) + + plt.figure() + plt.pcolormesh(idx2d, cmap=mpl.cm.Reds)