From 1f918f51019295e7dafce3377920e0511bb19060 Mon Sep 17 00:00:00 2001 From: Yichen Gu Date: Mon, 10 Oct 2022 16:26:06 -0400 Subject: [PATCH] Added normalization functions (for future use) --- velovae/model/model_util.py | 57 +++++++++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) diff --git a/velovae/model/model_util.py b/velovae/model/model_util.py index 09e4ea8..243b6e2 100644 --- a/velovae/model/model_util.py +++ b/velovae/model/model_util.py @@ -123,6 +123,62 @@ def pred_su(tau, u0, s0, alpha, beta, gamma): Generalizing RNA velocity to transient cell states through dynamical modeling. Nature biotechnology, 38(12), 1408-1414. """ +def scale_by_gene(U,S,train_idx=None,mode='scale_u'): + #mode + # 'auto' means to scale the one with a smaller range + # 'scale_u' means to match std(u) with std(s) + # 'scale_s' means to match std(s) with std(u) + G = U.shape[1] + scaling_u = np.ones((G)) + scaling_s = np.ones((G)) + std_u, std_s = np.ones((G)),np.ones((G)) + for i in range(G): + if(train_idx is None): + si, ui = S[:,i], U[:,i] + else: + si, ui = S[train_idx,i], U[train_idx,i] + sfilt, ufilt = si[(si>0) & (ui>0)], ui[(si>0) & (ui>0)] #Use only nonzero data points + if(len(sfilt)>3 and len(ufilt)>3): + std_u[i] = np.std(ufilt) + std_s[i] = np.std(sfilt) + mask_u, mask_s = (std_u==0), (std_s==0) + std_u = std_u + (mask_u & (~mask_s))*std_s + (mask_u & mask_s)*1 + std_s = std_s + ((~mask_u) & mask_s)*std_u + (mask_u & mask_s)*1 + if(mode=='auto'): + scaling_u = np.max(np.stack([scaling_u,(std_u/std_s)]),0) + scaling_s = np.max(np.stack([scaling_s,(std_s/std_u)]),0) + elif(mode=='scale_u'): + scaling_u = std_u/std_s + elif(mode=='scale_s'): + scaling_s = std_s/std_u + return U/scaling_u, S/scaling_s, scaling_u, scaling_s + +def scale_by_cell(U,S,train_idx=None,separate_us_scale=True): + N = U.shape[0] + nu, ns = U.sum(1, keepdims=True), S.sum(1, keepdims=True) + if(separate_us_scale): + norm_count = (np.median(nu), np.median(ns)) if train_idx is None else (np.median(nu[train_idx]), np.median(ns[train_idx])) + lu = nu/norm_count[0] + ls = ns/norm_count[1] + else: + norm_count = np.median(nu+ns) if train_idx is None else np.median(nu[train_idx]+ns[train_idx]) + lu = (nu+ns)/norm_count + ls = lu + return U/lu, S/ls, lu, ls + +def get_cell_scale(U,S,train_idx=None,separate_us_scale=True): + N = U.shape[0] + nu, ns = U.sum(1, keepdims=True), S.sum(1, keepdims=True) + if(separate_us_scale): + norm_count = (np.median(nu), np.median(ns)) if train_idx is None else (np.median(nu[train_idx]), np.median(ns[train_idx])) + lu = nu/norm_count[0] + ls = ns/norm_count[1] + else: + norm_count = np.median(nu+ns) if train_idx is None else np.median(nu[train_idx]+ns[train_idx]) + lu = (nu+ns)/norm_count + ls = lu + return lu, ls + def linreg(u, s): q = np.sum(s*s) r = np.sum(u*s) @@ -131,6 +187,7 @@ def linreg(u, s): k = 1.0+np.random.rand() return k + def init_gene(s,u,percent,fit_scaling=False,Ntype=None): #Adopted from scvelo