diff --git a/cellcommunicationpf2/cc_pf2.py b/cellcommunicationpf2/cc_pf2.py index 92a9872..c3921d4 100644 --- a/cellcommunicationpf2/cc_pf2.py +++ b/cellcommunicationpf2/cc_pf2.py @@ -1,5 +1,4 @@ import numpy as np -import tensorly as tl from pymanopt.manifolds import Stiefel from pymanopt import Problem from pymanopt.optimizers import ConjugateGradient @@ -37,13 +36,13 @@ def solve_projections( @pymanopt.function.autograd(manifold) def objective_function(proj): - a_mat_recon = anp.einsum("ab,cd,acg->bdg", proj.T, proj.T, a_lhs) + a_mat_recon = anp.einsum("ba,dc,acg->bdg", proj, proj, a_lhs) return anp.sum(anp.square(a_mat - a_mat_recon)) problem = Problem(manifold=manifold, cost=objective_function) # Solve the problem - solver = ConjugateGradient(verbosity=1) + solver = ConjugateGradient(verbosity=1, min_gradient_norm=1e-9, min_step_size=1e-12) proj = solver.run(problem).point U, _, Vt = np.linalg.svd(proj, full_matrices=False) diff --git a/cellcommunicationpf2/tests/test_OP.py b/cellcommunicationpf2/tests/test_OP.py index b67b923..29a0ea6 100644 --- a/cellcommunicationpf2/tests/test_OP.py +++ b/cellcommunicationpf2/tests/test_OP.py @@ -1,6 +1,5 @@ import numpy as np from ..cc_pf2 import project_data, solve_projections -import pytest def test_project_data(): @@ -26,44 +25,6 @@ def test_project_data(): assert projected_X.shape == (rank, rank, LR) -def test_project_data_output_proj_data(): - """ - Tests that the project data method is actually able to solve for the correct optimal projection matrix. - Asserts that the projected data through the solved matrices is the same as the input projectedX. - """ - # Define dimensions - num_tensors = 3 - cells = 20 - LR = 10 - obs = 5 - rank = 5 - # Generate a random projected tensor - projected_X = np.random.rand(obs, rank, rank, LR) - - # Generate a random set of projection matrices - projections = [ - np.linalg.qr(np.random.rand(cells, rank))[0] for _ in range(num_tensors) - ] - - # Recreate the original tensor using the projection matrices and projected tensor - recreated_tensors = [] - for i in range(num_tensors): - Q = projections[i] - A = projected_X[i, :, :, :] - B = project_data(A, Q.T) - recreated_tensors.append(B) - - # Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method - projections_recreated = solve_projections( - recreated_tensors, - projected_X, - ) - - # Assert that the projected tensors are the same - for i in range(num_tensors): - assert np.allclose(project_data(recreated_tensors[i], projections_recreated[i]), projected_X[i]) - - def test_project_data_output_proj_matrix(): """ Tests that the project data method is actually able to solve for the correct optimal projection matrix. @@ -72,11 +33,11 @@ def test_project_data_output_proj_matrix(): # Define dimensions num_tensors = 3 cells = 20 - LR = 10 - obs = 5 + variables = 10 + obs = 20 rank = 5 # Generate a random projected tensor - projected_X = np.random.rand(obs, rank, rank, LR) + projected_X = np.random.rand(obs, rank, rank, variables) # Generate a random set of projection matrices projections = [ @@ -99,5 +60,6 @@ def test_project_data_output_proj_matrix(): # Assert that the projections are the same for i in range(num_tensors): - assert np.allclose(projections[i], projections_recreated[i]) or np.allclose(projections[i], -projections_recreated[i]) + sign_correct = np.sign(projections[i][0, 0] * projections_recreated[i][0, 0]) + np.testing.assert_allclose(projections[i], projections_recreated[i] * sign_correct, atol=1e-9)