Skip to content

Commit

Permalink
Merge pull request #2 from meyer-lab/4d_init
Browse files Browse the repository at this point in the history
PF2 Init Method
  • Loading branch information
andrewram4287 authored Dec 6, 2024
2 parents daeda01 + 040ab74 commit 411bd55
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 1 deletion.
32 changes: 32 additions & 0 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,38 @@
from pymanopt import Problem
from pymanopt.manifolds import Stiefel
from pymanopt.optimizers import TrustRegions
from typing import Optional
from sklearn.utils.extmath import randomized_svd


def flatten_tensor_list(tensor_list: list):
"""
Flatten a list of 3D tensors from A x B x B x C to a matrix of (A*B*B) x C
"""

# Reshape each tensor to a 2D matrix
# This will stack rows of each B x B tensor into a single row
reshaped_tensors = [tensor.reshape(-1, tensor.shape[-1]) for tensor in tensor_list]

# Vertically stack these matrices
flattened_matrix = np.vstack(reshaped_tensors)

return flattened_matrix


def init(
X_list: list,
rank: int,
random_state: Optional[int] = None,
) -> list[np.ndarray]:
"""
Initializes the factors for the CP decomposition of a list of 3D tensors
"""
data_matrix = flatten_tensor_list(X_list)

_, _, C = randomized_svd(data_matrix, rank, random_state=random_state)
factors = [np.ones((len(X_list), rank)), np.eye(rank), np.eye(rank), C.T]
return factors


def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,29 @@
import numpy as np

from ..cc_pf2 import project_data, solve_projections
from ..cc_pf2 import project_data, solve_projections, init


def test_init():
"""
Tests that the dimensions are correct and that the method is able to run without errors.
"""

# Define dimensions
obs = 3
cells = 20
LR = 10
rank = 5

# Generate random X_list
X_list = [np.random.rand(cells, cells, LR) for _ in range(obs)]

# Call the init method
factors = init(X_list, rank)

assert factors[0].shape == (obs, rank)
assert factors[1].shape == (rank, rank)
assert factors[2].shape == (rank, rank)
assert factors[3].shape == (LR, rank)


def test_project_data():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ dependencies = [
"scipy>=1.12",
"pymanopt>=2.2.1",
"autograd>=1.7.0",
"scikit-learn>=1.4.2",
]

readme = "README.md"
Expand Down
8 changes: 8 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@ coverage==7.6.1
# via pytest-cov
iniconfig==2.0.0
# via pytest
joblib==1.4.2
# via scikit-learn
nodeenv==1.9.1
# via pyright
numpy==2.0.2
# via autograd
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
# via scipy
packaging==24.1
# via pytest
Expand All @@ -33,8 +36,13 @@ pyright==1.1.383
pytest==8.3.3
# via pytest-cov
pytest-cov==6.0.0
scikit-learn==1.5.2
# via cellcommunicationpf2
scipy==1.14.1
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
threadpoolctl==3.5.0
# via scikit-learn
typing-extensions==4.12.2
# via pyright
8 changes: 8 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,21 @@
-e file:.
autograd==1.7.0
# via cellcommunicationpf2
joblib==1.4.2
# via scikit-learn
numpy==2.0.2
# via autograd
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
# via scipy
pymanopt==2.2.1
# via cellcommunicationpf2
scikit-learn==1.5.2
# via cellcommunicationpf2
scipy==1.14.1
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
threadpoolctl==3.5.0
# via scikit-learn

0 comments on commit 411bd55

Please sign in to comment.