From 18368bb9ee6755276033c9239b15fe68fcab55bc Mon Sep 17 00:00:00 2001 From: Nathaniel Thomas Date: Thu, 5 Dec 2024 13:12:29 -0800 Subject: [PATCH] Finished initialization method first draft --- cellcommunicationpf2/cc_pf2.py | 32 +++++++++++++ cellcommunicationpf2/tests/test_OP.py | 67 --------------------------- 2 files changed, 32 insertions(+), 67 deletions(-) delete mode 100644 cellcommunicationpf2/tests/test_OP.py diff --git a/cellcommunicationpf2/cc_pf2.py b/cellcommunicationpf2/cc_pf2.py index 8c2ff2c..dd248f3 100644 --- a/cellcommunicationpf2/cc_pf2.py +++ b/cellcommunicationpf2/cc_pf2.py @@ -4,6 +4,38 @@ from pymanopt import Problem from pymanopt.manifolds import Stiefel from pymanopt.optimizers import TrustRegions +from typing import Optional +from sklearn.utils.extmath import randomized_svd + + +def flatten_tensor_list(tensor_list: list): + """ + Flatten a list of 3D tensors from A x B x B x C to a matrix of (A*B*B) x C + """ + + # Reshape each tensor to a 2D matrix + # This will stack rows of each B x B tensor into a single row + reshaped_tensors = [tensor.reshape(-1, tensor.shape[-1]) for tensor in tensor_list] + + # Vertically stack these matrices + flattened_matrix = np.vstack(reshaped_tensors) + + return flattened_matrix + + +def init( + X_list: list, + rank: int, + random_state: Optional[int] = None, +) -> list[np.ndarray]: + """ + Initializes the factors for the CP decomposition of a list of 3D tensors + """ + data_matrix = flatten_tensor_list(X_list) + + _, _, C = randomized_svd(data_matrix, rank, random_state=random_state) + factors = [np.ones((X_list[0].shape[0], rank)), np.eye(rank), np.eye(rank), C.T] + return factors def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray: diff --git a/cellcommunicationpf2/tests/test_OP.py b/cellcommunicationpf2/tests/test_OP.py deleted file mode 100644 index 79c21e6..0000000 --- a/cellcommunicationpf2/tests/test_OP.py +++ /dev/null @@ -1,67 +0,0 @@ -import numpy as np - -from ..cc_pf2 import project_data, solve_projections - - -def test_project_data(): - """ - Tests that the dimensions are correct and that the method is able to run without errors. - """ - - # Define dimensions - cells = 20 - LR = 10 - rank = 5 - - # Generate random X_list - X_mat = np.random.rand(cells, cells, LR) - - # Projection matrix - proj_matrix = np.linalg.qr(np.random.rand(cells, rank))[0] - - # Call the project_data method - print(proj_matrix.shape) - projected_X = project_data(X_mat, proj_matrix) - - assert projected_X.shape == (rank, rank, LR) - - -def test_project_data_output_proj_matrix(): - """ - Tests that the project data method is actually able to solve for the correct optimal projection matrix. - Asserts that the projection matrices solved are the same. - """ - # Define dimensions - num_tensors = 3 - cells = 20 - variables = 10 - obs = 20 - rank = 5 - # Generate a random projected tensor - projected_X = np.random.rand(obs, rank, rank, variables) - - # Generate a random set of projection matrices - projections = [ - np.linalg.qr(np.random.rand(cells, rank))[0] for _ in range(num_tensors) - ] - - # Recreate the original tensor using the projection matrices and projected tensor - recreated_tensors = [] - for i in range(num_tensors): - Q = projections[i] - A = projected_X[i, :, :, :] - B = project_data(A, Q.T) - recreated_tensors.append(B) - - # Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method - projections_recreated = solve_projections( - recreated_tensors, - projected_X, - ) - - # Assert that the projections are the same - for i in range(num_tensors): - sign_correct = np.sign(projections[i][0, 0] * projections_recreated[i][0, 0]) - np.testing.assert_allclose( - projections[i], projections_recreated[i] * sign_correct, atol=1e-9 - )