From 11c4fa996e1ba941fc56ad8a26c7707eb8979ef5 Mon Sep 17 00:00:00 2001 From: Yichen Gu Date: Mon, 8 Jan 2024 12:36:00 -0500 Subject: [PATCH] Fixed a missing input argument in post_analysis. Adjusted fontsize in the cluster plot. --- setup.cfg | 2 +- velovae/analysis/evaluation.py | 9 +++++---- velovae/plotting.py | 23 +++++++++++++++++++---- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/setup.cfg b/setup.cfg index 6a8e4ab..26c6dcd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,6 +1,6 @@ [metadata] name = velovae -version = 0.1.3 +version = 0.1.3-beta author = Yichen Gu author_email = gyichen@umich.edu description = Bayesian Inference of RNA Velocity diff --git a/velovae/analysis/evaluation.py b/velovae/analysis/evaluation.py index 3500514..8258ff1 100644 --- a/velovae/analysis/evaluation.py +++ b/velovae/analysis/evaluation.py @@ -562,10 +562,13 @@ def post_analysis(adata, pd.set_option("display.precision", 3) print("--- Plotting Results ---") + save_format = kwargs["save_format"] if "save_format" in kwargs else "png" + if 'cluster' in plot_type or "all" in plot_type: plot_cluster(adata.obsm[f"X_{embed}"], adata.obs[cluster_key].to_numpy(), embed=embed, + palette=palette, save=(None if figure_path is None else f"{figure_path}/{test_id}_umap.png")) @@ -602,8 +605,6 @@ def post_analysis(adata, if len(genes) == 0: return - format = kwargs["format"] if "format" in kwargs else "png" - if palette is None: palette = get_colors(len(cell_types_raw)) @@ -629,7 +630,7 @@ def post_analysis(adata, Labels_phase_demo, path=figure_path, figname=test_id, - format=format) + format=save_format) if 'gene' in plot_type or 'all' in plot_type: T = {} @@ -671,7 +672,7 @@ def post_analysis(adata, palette=palette, path=figure_path, figname=test_id, - format=format) + format=save_format) if 'stream' in plot_type or 'all' in plot_type: try: diff --git a/velovae/plotting.py b/velovae/plotting.py index 062f6d3..5b9e9b3 100644 --- a/velovae/plotting.py +++ b/velovae/plotting.py @@ -306,7 +306,14 @@ def plot_phase(u, s, save_fig(fig, save, (lgd,)) -def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels=True, save=None): +def plot_cluster(X_embed, + cell_labels, + color_map=None, + embed='umap', + show_labels=True, + fontsize=None, + palette=None, + save=None): """Plot the predicted cell types from the encoder Args: @@ -320,6 +327,10 @@ def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels Embedding name. Used for labeling axes. Defaults to 'umap'. show_labels (bool, optional): Whether to add cell cluster names to the plot. Defaults to True. + fontsize (int, optional): + Font size for cell cluster names. Defaults to None. + palette (list[str], optional): + List of colors for plotting. Defaults to None. save (str, optional): Figure name for saving (including path). Defaults to None. """ @@ -329,16 +340,20 @@ def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels y = X_embed[:, 1] x_range = x.max()-x.min() y_range = y.max()-y.min() - colors = get_colors(len(cell_types), color_map) + if palette is None: + palette = get_colors(len(cell_types), color_map) n_char_max = np.max([len(x) for x in cell_types]) for i, typei in enumerate(cell_types): mask = cell_labels == typei xbar, ybar = np.mean(x[mask]), np.mean(y[mask]) - ax.plot(x[mask], y[mask], '.', color=colors[i % len(colors)]) + ax.plot(x[mask], y[mask], '.', color=palette[i % len(palette)]) n_char = len(typei) if show_labels: - txt = ax.text(xbar - x_range*4e-3*n_char, ybar - y_range*4e-3, typei, fontsize=200//n_char_max, color='k') + if fontsize is None: + fontsize = min(200//n_char_max, 400//len(cell_types)) + fontsize = max(fontsize, 8) + txt = ax.text(xbar - x_range*4e-3*n_char, ybar - y_range*4e-3, typei, fontsize=fontsize, color='k') txt.set_bbox(dict(facecolor='white', alpha=0.5, edgecolor='black')) ax.set_xlabel(f'{embed} 1')