Skip to content

Commit

Permalink
Rewriting somewhat
Browse files Browse the repository at this point in the history
  • Loading branch information
aarmey committed Nov 25, 2024
1 parent ca12a46 commit 5c767c0
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 60 deletions.
51 changes: 11 additions & 40 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,19 @@
import autograd.numpy as anp


def project_tensor(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
"""
Projects a 3D tensor of C x C x LR with a projection matrix of C x CES
along both C dimensions to form a resulting tensor of CES x CES x LR.
"""
return np.einsum("ab,cd,bdg->acg", proj_matrix, proj_matrix, tensor)

B = np.zeros((proj_matrix.shape[1], proj_matrix.shape[1], tensor.shape[2]))
for i in range(tensor.shape[2]):
B[:, :, i] = proj_matrix.T @ tensor[:, :, i] @ proj_matrix

return B


def project_data(
def solve_projections(
X_list: list,
factors: list[np.ndarray],
weights: np.ndarray = None,
full_tensor: np.ndarray = None,
) -> tuple[list[np.ndarray], np.ndarray]:
) -> list[np.ndarray]:
"""
Takes a list of 3D tensors of C x C x LR, a means matrix, factors of
A: obs x rank
Expand All @@ -38,47 +32,24 @@ def project_data(
A, B, C, D = factors

projections: list[np.ndarray] = []
projected_X = np.empty((A.shape[0], B.shape[0], C.shape[0], D.shape[0]))

if full_tensor is None:
weights = np.ones(A.shape[1]) if weights is None else weights
full_tensor = tl.cp_tensor.cp_to_tensor((weights, [A, B, C, D]))
full_tensor = tl.cp_tensor.cp_to_tensor((weights, [A, B, C, D]))

for i, mat in enumerate(X_list):
lhs = full_tensor[i, :, :, :]
cells = mat.shape[0]
ces = lhs.shape[0]

manifold = Stiefel(cells, ces)
manifold = Stiefel(mat.shape[0], full_tensor.shape[1])
a_mat = anp.asarray(mat)
a_lhs = anp.asarray(lhs)
a_lhs = anp.asarray(full_tensor[i, :, :, :])

@pymanopt.function.autograd(manifold)
def objective_function(proj):
a_mat_recon = anp.zeros_like(a_mat)
for j in range(a_lhs.shape[2]):
slice = anp.dot(anp.dot(proj, a_lhs[:, :, j]), proj.T)
tensor = np.zeros((*slice.shape, a_lhs.shape[2]))
# Create a mask of zeros with 1 at index j along last axis
mask = np.zeros(a_lhs.shape[2])
mask[j] = 1

# Broadcast the mask and multiply with the expanded matrix
a_mat_recon = anp.add(
a_mat_recon, np.expand_dims(slice, axis=-1) * mask
)

dif = a_mat - a_mat_recon
return anp.sum(anp.abs(dif))
a_mat_recon = anp.einsum("ab,cd,bdg->acg", proj, proj, a_lhs)
return anp.sum(anp.square(a_mat - a_mat_recon))

problem = Problem(manifold=manifold, cost=objective_function)

# Solve the problem
solver = ConjugateGradient(verbosity=2)
solver = ConjugateGradient(verbosity=1)
proj = solver.run(problem).point

projections.append(proj)

projected_X[i, :, :, :] = project_tensor(mat, proj)

return projections, projected_X
return projections
28 changes: 8 additions & 20 deletions cellcommunicationpf2/tests/test_OP.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,32 @@
import numpy as np
from ..cc_pf2 import project_data, project_tensor
from ..cc_pf2 import project_data, solve_projections
import pytest


@pytest.mark.skip(reason="The project method hasn't been completed yet")
def test_project_data():
"""
Tests that the dimensions are correct and that the method is able to run without errors.
"""

# Define dimensions
num_tensors = 3
cells = 20
LR = 10
obs = 5
rank = 5

# Generate random X_list
X_list = [np.random.rand(cells, cells, LR) for _ in range(num_tensors)]
X_mat = np.random.rand(cells, cells, LR)

# Generate random factors: A (obs x rank), B (C x rank), C (C x rank), D (LR x rank)
A = np.random.rand(obs, rank)
B = np.random.rand(rank, rank)
C = np.random.rand(rank, rank)
D = np.random.rand(LR, rank)
factors = [A, B, C, D]
# Projection matrix
proj_matrix = np.linalg.qr(np.random.rand(cells, rank))[0]

# Call the project_data method
projections, projected_X = project_data(X_list, factors)

# Assertions
assert len(projections) == num_tensors
for proj in projections:
assert proj.shape == (cells, rank)
projected_X = project_data(X_mat, proj_matrix)

assert projected_X.shape == (obs, rank, rank, LR)


@pytest.mark.skip(reason="The project method hasn't been completed yet")
@pytest.mark.xfail(reason="The project method hasn't been completed yet")
def test_project_data_output():
"""
Tests that the project data method is actually able to solve for the correct optimal projection matrix.
Expand All @@ -61,19 +50,18 @@ def test_project_data_output():
for i in range(num_tensors):
Q = projections[i]
A = projected_X[i, :, :, :]
B = project_tensor(A, Q.T)
B = project_data(A, Q.T)
recreated_tensors.append(B)

# Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method
projections_recreated, _ = project_data(
projections_recreated = solve_projections(
recreated_tensors,
[
np.zeros((obs, rank)),
np.zeros((rank, rank)),
np.zeros((rank, rank)),
np.zeros((LR, rank)),
],
full_tensor=projected_X,
)

# Assert that the projections are the same
Expand Down

0 comments on commit 5c767c0

Please sign in to comment.