Skip to content

Commit

Permalink
Fix small shape error with obs factor matrix
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel-github committed Dec 5, 2024
1 parent 994863d commit f18af58
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def init(
data_matrix = flatten_tensor_list(X_list)

_, _, C = randomized_svd(data_matrix, rank, random_state=random_state)
factors = [np.ones((X_list[0].shape[0], rank)), np.eye(rank), np.eye(rank), C.T]
factors = [np.ones((len(X_list), rank)), np.eye(rank), np.eye(rank), C.T]
return factors


Expand Down
5 changes: 3 additions & 2 deletions cellcommunicationpf2/tests/test_ccpf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,18 @@ def test_init():
"""

# Define dimensions
obs = 3
cells = 20
LR = 10
rank = 5

# Generate random X_list
X_list = [np.random.rand(cells, cells, LR) for _ in range(3)]
X_list = [np.random.rand(cells, cells, LR) for _ in range(obs)]

# Call the init method
factors = init(X_list, rank)

assert factors[0].shape == (cells, rank)
assert factors[0].shape == (obs, rank)
assert factors[1].shape == (rank, rank)
assert factors[2].shape == (rank, rank)
assert factors[3].shape == (LR, rank)
Expand Down

0 comments on commit f18af58

Please sign in to comment.