diff --git a/src/multivelo/dynamical_chrom_func.py b/src/multivelo/dynamical_chrom_func.py index 8680759..5c66d8a 100644 --- a/src/multivelo/dynamical_chrom_func.py +++ b/src/multivelo/dynamical_chrom_func.py @@ -16,7 +16,8 @@ import scvelo as scv import pandas as pd import seaborn as sns -from numba import jit +from numba import njit +import numba from numba.typed import List from tqdm.auto import tqdm from joblib import Parallel, delayed @@ -27,7 +28,7 @@ src_path = os.path.join(current_path, "..") sys.path.append(src_path) - +# a funciton to check for invalid values of different parameters def check_params(alpha_c, alpha, beta, @@ -127,69 +128,15 @@ def check_params(alpha_c, return new_alpha_c, new_alpha, new_beta, new_gamma -# @jit(nopython=True, fastmath=True, debug=True) -# def check_params(alpha_c, -# alpha, -# beta, -# gamma, -# c0=None, -# u0=None, -# s0=None): - -# # check if any of our parameters are infinite -# if c0 is not None and math.isinf(c0): -# logg.error("c0 is infinite.", v=1) -# if u0 is not None and math.isinf(u0): -# logg.error("u0 is infinite.", v=1) -# if s0 is not None and math.isinf(s0): -# logg.error("s0 is infinite.", v=1) -# if math.isinf(alpha_c): -# logg.error("alpha_c is infinite.", v=1) -# if math.isinf(alpha): -# logg.error("alpha is infinite.", v=1) -# if math.isinf(beta): -# logg.error("beta is infinite.", v=1) -# if math.isinf(gamma): -# logg.error("gamma is infinite.", v=1) - -# # check if any of our parameters are nan -# if c0 is not None and math.isnan(c0): -# logg.error("c0 is infinite.", v=1) -# if u0 is not None and math.isnan(u0): -# logg.error("u0 is infinite.", v=1) -# if s0 is not None and math.isnan(s0): -# logg.error("s0 is infinite.", v=1) -# if math.isnan(alpha_c): -# logg.error("alpha_c is infinite.", v=1) -# if math.isnan(alpha): -# logg.error("alpha is infinite.", v=1) -# if math.isnan(beta): -# logg.error("beta is infinite.", v=1) -# if math.isnan(gamma): -# logg.error("gamma is infinite.", v=1) - -# # check if any of our rate parameters are 0 -# if alpha_c < 1e-7: -# logg.error("alpha_c is zero.", v=1) -# if alpha < 1e-7: -# logg.error("alpha is zero.", v=1) -# if beta < 1e-7: -# logg.error("beta is zero.", v=1) -# if gamma < 1e-7: -# logg.error("gamma is zero.", v=1) - -# if beta == alpha_c: -# logg.error("alpha_c and beta are equal, leading to divide by zero", -# v=1) -# if beta == gamma: -# logg.error("gamma and beta are equal, leading to divide by zero", -# v=1) -# if alpha_c == gamma: -# logg.error("gamma and alpha_c are equal, leading to divide by zero", -# v=1) - -# @jit(nopython=True, fastmath=True, debug=True) +@njit( + locals={ + "res": numba.types.float64[:, ::1], + "eat": numba.types.float64[::1], + "ebt": numba.types.float64[::1], + "egt": numba.types.float64[::1], + }, + fastmath=True) def predict_exp(tau, c0, u0, @@ -204,14 +151,6 @@ def predict_exp(tau, backward=False, rna_only=False): - # check_params(alpha_c, - # alpha, - # beta, - # gamma, - # c0, - # u0, - # s0) - if len(tau) == 0: return np.empty((0, 3)) if backward: @@ -250,7 +189,23 @@ def predict_exp(tau, return res -# @jit(nopython=True, fastmath=True, debug=True) +@njit(locals={ + "exp_sw1": numba.types.float64[:, ::1], + "exp_sw2": numba.types.float64[:, ::1], + "exp_sw3": numba.types.float64[:, ::1], + "exp1": numba.types.float64[:, ::1], + "exp2": numba.types.float64[:, ::1], + "exp3": numba.types.float64[:, ::1], + "exp4": numba.types.float64[:, ::1], + "tau_sw1": numba.types.float64[::1], + "tau_sw2": numba.types.float64[::1], + "tau_sw3": numba.types.float64[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1] + }, + fastmath=True) def generate_exp(tau_list, t_sw_array, alpha_c, @@ -429,7 +384,23 @@ def generate_exp(tau_list, return (exp1, exp2, exp3, exp4), (exp_sw1, exp_sw2, exp_sw3) -# @jit(nopython=True, fastmath=True, debug=True) +@njit(locals={ + "exp_sw1": numba.types.float64[:, ::1], + "exp_sw2": numba.types.float64[:, ::1], + "exp_sw3": numba.types.float64[:, ::1], + "exp1": numba.types.float64[:, ::1], + "exp2": numba.types.float64[:, ::1], + "exp3": numba.types.float64[:, ::1], + "exp4": numba.types.float64[:, ::1], + "tau_sw1": numba.types.float64[::1], + "tau_sw2": numba.types.float64[::1], + "tau_sw3": numba.types.float64[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1] + }, + fastmath=True) def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma, scale_cc=1, model=1): if beta == alpha_c: @@ -535,7 +506,10 @@ def generate_exp_backward(tau_list, t_sw_array, alpha_c, alpha, beta, gamma, return (exp1, exp2, exp3), (exp_sw1, exp_sw2) -# @jit(nopython=True, fastmath=True, debug=True) +@njit(locals={ + "res": numba.types.float64[:, ::1], + }, + fastmath=True) def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True): res = np.empty((1, 3)) if not chrom_open: @@ -553,7 +527,13 @@ def ss_exp(alpha_c, alpha, beta, gamma, pred_r=True, chrom_open=True): return res -# @jit(nopython=True, fastmath=True, debug=True) +@njit(locals={ + "ss1": numba.types.float64[:, ::1], + "ss2": numba.types.float64[:, ::1], + "ss3": numba.types.float64[:, ::1], + "ss4": numba.types.float64[:, ::1] + }, + fastmath=True) def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0): if model == 0: ss1 = ss_exp(alpha_c, alpha, beta, gamma, pred_r=False) @@ -574,7 +554,7 @@ def compute_ss_exp(alpha_c, alpha, beta, gamma, model=0): return np.vstack((ss1, ss2, ss3, ss4)) -# @jit(nopython=True, fastmath=True, debug=True) +@njit(fastmath=True) def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1, pred_r=True, chrom_open=True, rna_only=False): if rna_only: @@ -594,7 +574,30 @@ def velocity_equations(c, u, s, alpha_c, alpha, beta, gamma, scale_cc=1, return alpha_c - alpha_c * c, np.zeros(len(u)), np.zeros(len(u)) -# @jit(nopython=True, fastmath=True, debug=True) +@njit(locals={ + "state0": numba.types.boolean[::1], + "state1": numba.types.boolean[::1], + "state2": numba.types.boolean[::1], + "state3": numba.types.boolean[::1], + "tau1": numba.types.float64[::1], + "tau2": numba.types.float64[::1], + "tau3": numba.types.float64[::1], + "tau4": numba.types.float64[::1], + "exp_list": numba.types.Tuple((numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1])), + "exp_sw_list": numba.types.Tuple((numba.types.float64[:, ::1], + numba.types.float64[:, ::1], + numba.types.float64[:, ::1])), + "c": numba.types.float64[::1], + "u": numba.types.float64[::1], + "s": numba.types.float64[::1], + "vc_vec": numba.types.float64[::1], + "vu_vec": numba.types.float64[::1], + "vs_vec": numba.types.float64[::1] + }, + fastmath=True) def compute_velocity(t, t_sw_array, state, @@ -1331,15 +1334,6 @@ def predict_exp_ten(self, backward=False, rna_only=False): - #TODO: Check params?? - # check_params(alpha_c, - # alpha, - # beta, - # gamma, - # c0, - # u0, - # s0) - if scale_cc is None: scale_cc = torch.tensor(1.0, requires_grad=True, device=self.device,