Skip to content

Commit

Permalink
Add in a repeated solve of up to 5 iterations
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel-github committed Dec 5, 2024
1 parent daeda01 commit 51d6792
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:

def solve_projections(
X_list: list,
full_tensor: np.ndarray,
projected_X: np.ndarray,
) -> list[np.ndarray]:
"""
Takes a list of 3D tensors of C x C x LR, a means matrix, factors of
Expand All @@ -30,9 +30,10 @@ def solve_projections(
projections: list[np.ndarray] = []

for i, mat in enumerate(X_list):
manifold = Stiefel(mat.shape[0], full_tensor.shape[1])
max_iter = 5
manifold = Stiefel(mat.shape[0], projected_X.shape[1])
a_mat = anp.asarray(mat)
a_lhs = anp.asarray(full_tensor[i, :, :, :])
a_lhs = anp.asarray(projected_X[i, :, :, :])

@pymanopt.function.autograd(manifold)
def objective_function(proj):
Expand All @@ -49,7 +50,20 @@ def objective_function(proj):

U, _, Vt = np.linalg.svd(proj, full_matrices=False)
proj = U @ Vt


# Check if the projection matrix is correct
A = X_list[i]
B = project_data(A, proj)

# While this projection matrix is not solved rerun the solver till max_iter
while not np.allclose(B, projected_X[i, :, :, :]) and max_iter > 0:
proj = solver.run(problem).point
U, _, Vt = np.linalg.svd(proj, full_matrices=False)
proj = U @ Vt
A = X_list[i]
B = project_data(A, proj)
max_iter -= 1

projections.append(proj)

return projections

0 comments on commit 51d6792

Please sign in to comment.