Skip to content

Commit

Permalink
Fixed broken methods
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel-github committed Nov 26, 2024
1 parent fa48701 commit 3bc6c49
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
Projects a 3D tensor of C x C x LR with a projection matrix of C x CES
along both C dimensions to form a resulting tensor of CES x CES x LR.
"""
return np.einsum("ab,cd,bdg->acg", proj_matrix, proj_matrix, tensor)
return np.einsum("ab,cd,acg->bdg", proj_matrix, proj_matrix, tensor)


def solve_projections(
Expand All @@ -37,7 +37,7 @@ def solve_projections(

@pymanopt.function.autograd(manifold)
def objective_function(proj):
a_mat_recon = anp.einsum("ab,cd,bdg->acg", proj, proj, a_lhs)
a_mat_recon = anp.einsum("ab,cd,acg->bdg", proj.T, proj.T, a_lhs)
return anp.sum(anp.square(a_mat - a_mat_recon))

problem = Problem(manifold=manifold, cost=objective_function)
Expand Down
6 changes: 3 additions & 3 deletions cellcommunicationpf2/tests/test_OP.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import pytest


@pytest.mark.xfail(reason="The project method hasn't been completed yet")
def test_project_data():
"""
Tests that the dimensions are correct and that the method is able to run without errors.
Expand All @@ -22,9 +21,10 @@ def test_project_data():
proj_matrix = np.linalg.qr(np.random.rand(cells, rank))[0]

# Call the project_data method
print(proj_matrix.shape)
projected_X = project_data(X_mat, proj_matrix)

assert projected_X.shape == (obs, rank, rank, LR)
assert projected_X.shape == (rank, rank, LR)


@pytest.mark.xfail(reason="The project method hasn't been completed yet")
Expand Down Expand Up @@ -64,6 +64,6 @@ def test_project_data_output():
for i in range(num_tensors):
difference_sum = np.sum(np.abs(projections[i] - projections_recreated[i]))
print(
f"Projection {i} difference sum: {difference_sum}. Sum of projections in absolute: {np.sum(np.abs(projections[i]))}. Sum of projections_recreated in absolute: {np.sum(np.abs(projections_recreated[i]))}"
f"Projection {i} difference sum: {difference_sum}. Sum of projections in absolute: {np.sum(np.abs(projections[i]))}"
)
assert np.allclose(projections[i], projections_recreated[i])

0 comments on commit 3bc6c49

Please sign in to comment.