Skip to content

Commit

Permalink
Fixed a missing input argument in post_analysis.
Browse files Browse the repository at this point in the history
Adjusted fontsize in the cluster plot.
  • Loading branch information
Yichen Gu committed Jan 8, 2024
1 parent b9bd498 commit 11c4fa9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 9 deletions.
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[metadata]
name = velovae
version = 0.1.3
version = 0.1.3-beta
author = Yichen Gu
author_email = [email protected]
description = Bayesian Inference of RNA Velocity
Expand Down
9 changes: 5 additions & 4 deletions velovae/analysis/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,10 +562,13 @@ def post_analysis(adata,
pd.set_option("display.precision", 3)

print("--- Plotting Results ---")
save_format = kwargs["save_format"] if "save_format" in kwargs else "png"

if 'cluster' in plot_type or "all" in plot_type:
plot_cluster(adata.obsm[f"X_{embed}"],
adata.obs[cluster_key].to_numpy(),
embed=embed,
palette=palette,
save=(None if figure_path is None else
f"{figure_path}/{test_id}_umap.png"))

Expand Down Expand Up @@ -602,8 +605,6 @@ def post_analysis(adata,
if len(genes) == 0:
return

format = kwargs["format"] if "format" in kwargs else "png"

if palette is None:
palette = get_colors(len(cell_types_raw))

Expand All @@ -629,7 +630,7 @@ def post_analysis(adata,
Labels_phase_demo,
path=figure_path,
figname=test_id,
format=format)
format=save_format)

if 'gene' in plot_type or 'all' in plot_type:
T = {}
Expand Down Expand Up @@ -671,7 +672,7 @@ def post_analysis(adata,
palette=palette,
path=figure_path,
figname=test_id,
format=format)
format=save_format)

if 'stream' in plot_type or 'all' in plot_type:
try:
Expand Down
23 changes: 19 additions & 4 deletions velovae/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,14 @@ def plot_phase(u, s,
save_fig(fig, save, (lgd,))


def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels=True, save=None):
def plot_cluster(X_embed,
cell_labels,
color_map=None,
embed='umap',
show_labels=True,
fontsize=None,
palette=None,
save=None):
"""Plot the predicted cell types from the encoder
Args:
Expand All @@ -320,6 +327,10 @@ def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels
Embedding name. Used for labeling axes. Defaults to 'umap'.
show_labels (bool, optional):
Whether to add cell cluster names to the plot. Defaults to True.
fontsize (int, optional):
Font size for cell cluster names. Defaults to None.
palette (list[str], optional):
List of colors for plotting. Defaults to None.
save (str, optional):
Figure name for saving (including path). Defaults to None.
"""
Expand All @@ -329,16 +340,20 @@ def plot_cluster(X_embed, cell_labels, color_map=None, embed='umap', show_labels
y = X_embed[:, 1]
x_range = x.max()-x.min()
y_range = y.max()-y.min()
colors = get_colors(len(cell_types), color_map)
if palette is None:
palette = get_colors(len(cell_types), color_map)

n_char_max = np.max([len(x) for x in cell_types])
for i, typei in enumerate(cell_types):
mask = cell_labels == typei
xbar, ybar = np.mean(x[mask]), np.mean(y[mask])
ax.plot(x[mask], y[mask], '.', color=colors[i % len(colors)])
ax.plot(x[mask], y[mask], '.', color=palette[i % len(palette)])
n_char = len(typei)
if show_labels:
txt = ax.text(xbar - x_range*4e-3*n_char, ybar - y_range*4e-3, typei, fontsize=200//n_char_max, color='k')
if fontsize is None:
fontsize = min(200//n_char_max, 400//len(cell_types))
fontsize = max(fontsize, 8)
txt = ax.text(xbar - x_range*4e-3*n_char, ybar - y_range*4e-3, typei, fontsize=fontsize, color='k')
txt.set_bbox(dict(facecolor='white', alpha=0.5, edgecolor='black'))

ax.set_xlabel(f'{embed} 1')
Expand Down

0 comments on commit 11c4fa9

Please sign in to comment.