Skip to content

Commit

Permalink
Finished initialization method first draft
Browse files Browse the repository at this point in the history
  • Loading branch information
Nathaniel-github committed Dec 5, 2024
1 parent daeda01 commit 18368bb
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 67 deletions.
32 changes: 32 additions & 0 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,38 @@
from pymanopt import Problem
from pymanopt.manifolds import Stiefel
from pymanopt.optimizers import TrustRegions
from typing import Optional
from sklearn.utils.extmath import randomized_svd


def flatten_tensor_list(tensor_list: list):
"""
Flatten a list of 3D tensors from A x B x B x C to a matrix of (A*B*B) x C
"""

# Reshape each tensor to a 2D matrix
# This will stack rows of each B x B tensor into a single row
reshaped_tensors = [tensor.reshape(-1, tensor.shape[-1]) for tensor in tensor_list]

# Vertically stack these matrices
flattened_matrix = np.vstack(reshaped_tensors)

return flattened_matrix


def init(
X_list: list,
rank: int,
random_state: Optional[int] = None,
) -> list[np.ndarray]:
"""
Initializes the factors for the CP decomposition of a list of 3D tensors
"""
data_matrix = flatten_tensor_list(X_list)

_, _, C = randomized_svd(data_matrix, rank, random_state=random_state)
factors = [np.ones((X_list[0].shape[0], rank)), np.eye(rank), np.eye(rank), C.T]
return factors


def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
Expand Down
67 changes: 0 additions & 67 deletions cellcommunicationpf2/tests/test_OP.py

This file was deleted.

0 comments on commit 18368bb

Please sign in to comment.