diff --git a/cellcommunicationpf2/cc_pf2.py b/cellcommunicationpf2/cc_pf2.py index dd248f3..bb17ba6 100644 --- a/cellcommunicationpf2/cc_pf2.py +++ b/cellcommunicationpf2/cc_pf2.py @@ -34,7 +34,7 @@ def init( data_matrix = flatten_tensor_list(X_list) _, _, C = randomized_svd(data_matrix, rank, random_state=random_state) - factors = [np.ones((X_list[0].shape[0], rank)), np.eye(rank), np.eye(rank), C.T] + factors = [np.ones((len(X_list), rank)), np.eye(rank), np.eye(rank), C.T] return factors diff --git a/cellcommunicationpf2/tests/test_ccpf2.py b/cellcommunicationpf2/tests/test_ccpf2.py index a52594e..052af01 100644 --- a/cellcommunicationpf2/tests/test_ccpf2.py +++ b/cellcommunicationpf2/tests/test_ccpf2.py @@ -9,17 +9,18 @@ def test_init(): """ # Define dimensions + obs = 3 cells = 20 LR = 10 rank = 5 # Generate random X_list - X_list = [np.random.rand(cells, cells, LR) for _ in range(3)] + X_list = [np.random.rand(cells, cells, LR) for _ in range(obs)] # Call the init method factors = init(X_list, rank) - assert factors[0].shape == (cells, rank) + assert factors[0].shape == (obs, rank) assert factors[1].shape == (rank, rank) assert factors[2].shape == (rank, rank) assert factors[3].shape == (LR, rank)