Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Nov 25, 2024
1 parent 5c767c0 commit fa48701
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 11 deletions.
6 changes: 1 addition & 5 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:

def solve_projections(
X_list: list,
factors: list[np.ndarray],
weights: np.ndarray = None,
full_tensor: np.ndarray,
) -> list[np.ndarray]:
"""
Takes a list of 3D tensors of C x C x LR, a means matrix, factors of
Expand All @@ -29,10 +28,7 @@ def solve_projections(
and solves for the projection matrices for each tensor as well as
reconstruct the data based on the projection matrices.
"""
A, B, C, D = factors

projections: list[np.ndarray] = []
full_tensor = tl.cp_tensor.cp_to_tensor((weights, [A, B, C, D]))

for i, mat in enumerate(X_list):
manifold = Stiefel(mat.shape[0], full_tensor.shape[1])
Expand Down
8 changes: 2 additions & 6 deletions cellcommunicationpf2/tests/test_OP.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest


@pytest.mark.xfail(reason="The project method hasn't been completed yet")
def test_project_data():
"""
Tests that the dimensions are correct and that the method is able to run without errors.
Expand Down Expand Up @@ -56,12 +57,7 @@ def test_project_data_output():
# Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method
projections_recreated = solve_projections(
recreated_tensors,
[
np.zeros((obs, rank)),
np.zeros((rank, rank)),
np.zeros((rank, rank)),
np.zeros((LR, rank)),
],
projected_X,
)

# Assert that the projections are the same
Expand Down

0 comments on commit fa48701

Please sign in to comment.