From ca12a46033d0f35af0ed33327592eaca70177748 Mon Sep 17 00:00:00 2001 From: Aaron Meyer Date: Mon, 25 Nov 2024 13:11:49 -0800 Subject: [PATCH] Remove means vector --- cellcommunicationpf2/cc_pf2.py | 1 - cellcommunicationpf2/tests/test_OP.py | 6 +----- 2 files changed, 1 insertion(+), 6 deletions(-) diff --git a/cellcommunicationpf2/cc_pf2.py b/cellcommunicationpf2/cc_pf2.py index b4bda9d..d50d2f2 100644 --- a/cellcommunicationpf2/cc_pf2.py +++ b/cellcommunicationpf2/cc_pf2.py @@ -22,7 +22,6 @@ def project_tensor(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray: def project_data( X_list: list, - means: np.ndarray, factors: list[np.ndarray], weights: np.ndarray = None, full_tensor: np.ndarray = None, diff --git a/cellcommunicationpf2/tests/test_OP.py b/cellcommunicationpf2/tests/test_OP.py index 7edd1d7..a5c1c4e 100644 --- a/cellcommunicationpf2/tests/test_OP.py +++ b/cellcommunicationpf2/tests/test_OP.py @@ -19,9 +19,6 @@ def test_project_data(): # Generate random X_list X_list = [np.random.rand(cells, cells, LR) for _ in range(num_tensors)] - # Generate random means matrix - means = np.random.rand(cells) - # Generate random factors: A (obs x rank), B (C x rank), C (C x rank), D (LR x rank) A = np.random.rand(obs, rank) B = np.random.rand(rank, rank) @@ -30,7 +27,7 @@ def test_project_data(): factors = [A, B, C, D] # Call the project_data method - projections, projected_X = project_data(X_list, means, factors) + projections, projected_X = project_data(X_list, factors) # Assertions assert len(projections) == num_tensors @@ -70,7 +67,6 @@ def test_project_data_output(): # Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method projections_recreated, _ = project_data( recreated_tensors, - np.zeros(cells), [ np.zeros((obs, rank)), np.zeros((rank, rank)),