Skip to content

Commit

Permalink
Merge pull request #1 from meyer-lab/op_for_4d
Browse files Browse the repository at this point in the history
OP for 4D tensor
  • Loading branch information
aarmey authored Nov 26, 2024
2 parents 5ba945f + 47758f4 commit c4039ad
Show file tree
Hide file tree
Showing 9 changed files with 146 additions and 58 deletions.
3 changes: 0 additions & 3 deletions cellcommunicationpf2/cc-pf2.py

This file was deleted.

56 changes: 56 additions & 0 deletions cellcommunicationpf2/cc_pf2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import autograd.numpy as anp
import numpy as np
import pymanopt
from pymanopt import Problem
from pymanopt.manifolds import Stiefel
from pymanopt.optimizers import ConjugateGradient


def project_data(tensor: np.ndarray, proj_matrix: np.ndarray) -> np.ndarray:
"""
Projects a 3D tensor of C x C x LR with a projection matrix of C x CES
along both C dimensions to form a resulting tensor of CES x CES x LR.
"""
return np.einsum("ab,cd,acg->bdg", proj_matrix, proj_matrix, tensor)


def solve_projections(
X_list: list,
full_tensor: np.ndarray,
) -> list[np.ndarray]:
"""
Takes a list of 3D tensors of C x C x LR, a means matrix, factors of
A: obs x rank
B: CES x rank
C: CES x rank
D: LR x rank
and solves for the projection matrices for each tensor as well as
reconstruct the data based on the projection matrices.
"""
projections: list[np.ndarray] = []

for i, mat in enumerate(X_list):
manifold = Stiefel(mat.shape[0], full_tensor.shape[1])
a_mat = anp.asarray(mat)
a_lhs = anp.asarray(full_tensor[i, :, :, :])

@pymanopt.function.autograd(manifold)
def objective_function(proj):
a_mat_recon = anp.einsum("ba,dc,acg->bdg", proj, proj, a_lhs)
return anp.sum(anp.square(a_mat - a_mat_recon))

problem = Problem(manifold=manifold, cost=objective_function)

# Solve the problem
solver = ConjugateGradient(
verbosity=1, min_gradient_norm=1e-9, min_step_size=1e-12
)
proj = solver.run(problem).point

U, _, Vt = np.linalg.svd(proj, full_matrices=False)

proj = U @ Vt

projections.append(proj)

return projections
45 changes: 1 addition & 44 deletions cellcommunicationpf2/figures/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,13 @@

import sys
import time
from string import ascii_letters

import datashader as ds
import datashader.transfer_functions as tf
import matplotlib
import numpy as np
import seaborn as sns
from matplotlib import gridspec
from matplotlib import pyplot as plt
from matplotlib.axes import Axes
from matplotlib.figure import Figure
from string import ascii_letters
from ..tensor import reorder_table

matplotlib.use("AGG")

Expand Down Expand Up @@ -89,41 +84,3 @@ def genFigure():
)

print(f"Figure {sys.argv[1]} is done after {time.time() - start} seconds.\n")


def ds_show(result: tf.Image, ax: plt.Axes):
"""Show datashader results."""
result = tf.dynspread(result, threshold=0.95, max_px=5)
result = tf.set_background(result, "white")
img_rev = result.data[::-1]
mpl_img = np.dstack(
[
img_rev & 0x0000FF,
(img_rev & 0x00FF00) >> 8,
(img_rev & 0xFF0000) >> 16,
]
)

ax.imshow(mpl_img)


def get_canvas(points: np.ndarray) -> ds.Canvas:
"""Compute bounds on a space with appropriate padding"""
min_xy = np.nanmin(points, axis=0)
assert min_xy.size == 2
max_xy = np.nanmax(points, axis=0)

mins = np.round(min_xy - 0.05 * (max_xy - min_xy))
maxs = np.round(max_xy + 0.05 * (max_xy - min_xy))

canvas = ds.Canvas(
plot_width=900,
plot_height=900,
x_range=(mins[0], maxs[0]),
y_range=(mins[1], maxs[1]),
)

return canvas



7 changes: 1 addition & 6 deletions cellcommunicationpf2/figures/figure1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@
Figure 1: XX
"""

import pandas as pd
from .common import getSetup, subplotLabel


def makeFigure():
ax, f = getSetup((12, 12), (3, 3))
subplotLabel(ax)





return f
return f
Empty file.
67 changes: 67 additions & 0 deletions cellcommunicationpf2/tests/test_OP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import numpy as np

from ..cc_pf2 import project_data, solve_projections


def test_project_data():
"""
Tests that the dimensions are correct and that the method is able to run without errors.
"""

# Define dimensions
cells = 20
LR = 10
rank = 5

# Generate random X_list
X_mat = np.random.rand(cells, cells, LR)

# Projection matrix
proj_matrix = np.linalg.qr(np.random.rand(cells, rank))[0]

# Call the project_data method
print(proj_matrix.shape)
projected_X = project_data(X_mat, proj_matrix)

assert projected_X.shape == (rank, rank, LR)


def test_project_data_output_proj_matrix():
"""
Tests that the project data method is actually able to solve for the correct optimal projection matrix.
Asserts that the projection matrices solved are the same.
"""
# Define dimensions
num_tensors = 3
cells = 20
variables = 10
obs = 20
rank = 5
# Generate a random projected tensor
projected_X = np.random.rand(obs, rank, rank, variables)

# Generate a random set of projection matrices
projections = [
np.linalg.qr(np.random.rand(cells, rank))[0] for _ in range(num_tensors)
]

# Recreate the original tensor using the projection matrices and projected tensor
recreated_tensors = []
for i in range(num_tensors):
Q = projections[i]
A = projected_X[i, :, :, :]
B = project_data(A, Q.T)
recreated_tensors.append(B)

# Call the project_data method using the recreated tensors to get the projected_X that gets solved by our method
projections_recreated = solve_projections(
recreated_tensors,
projected_X,
)

# Assert that the projections are the same
for i in range(num_tensors):
sign_correct = np.sign(projections[i][0, 0] * projections_recreated[i][0, 0])
np.testing.assert_allclose(
projections[i], projections_recreated[i] * sign_correct, atol=1e-9
)
8 changes: 5 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@ license = "MIT"
requires-python = ">= 3.11"

dependencies = [
"numpy>=1.26",
"numpy>=1.21,<2.1",
"scipy>=1.12",
"tensorly>=0.8.1",
"tensorly~=0.9.0",
"tqdm>=4.66",
"cupy-cuda12x>=13.0",
"scikit-learn>=1.4.2",
"pacmap>=0.7.2",
"tlviz>=0.1.1",
"pymanopt>=2.2.1",
"autograd>=1.7.0",
]

readme = "README.md"
Expand Down Expand Up @@ -57,4 +59,4 @@ select = [
"I",
# Unused arguments
"ARG",
]
]
9 changes: 8 additions & 1 deletion requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
-e file:.
annoy==1.17.3
# via pacmap
autograd==1.7.0
# via cellcommunicationpf2
certifi==2024.8.30
# via requests
charset-normalizer==3.3.2
Expand Down Expand Up @@ -45,6 +47,7 @@ nodeenv==1.9.1
numba==0.60.0
# via pacmap
numpy==2.0.2
# via autograd
# via cellcommunicationpf2
# via contourpy
# via cupy-cuda12x
Expand All @@ -53,6 +56,7 @@ numpy==2.0.2
# via pacmap
# via pandas
# via patsy
# via pymanopt
# via scikit-learn
# via scipy
# via statsmodels
Expand All @@ -76,6 +80,8 @@ pillow==10.4.0
# via matplotlib
pluggy==1.5.0
# via pytest
pymanopt==2.2.1
# via cellcommunicationpf2
pyparsing==3.1.4
# via matplotlib
pyright==1.1.383
Expand All @@ -94,6 +100,7 @@ scikit-learn==1.5.2
# via pacmap
scipy==1.14.1
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
# via statsmodels
# via tensorly
Expand All @@ -103,7 +110,7 @@ six==1.16.0
# via python-dateutil
statsmodels==0.14.4
# via tlviz
tensorly==0.8.1
tensorly==0.9.0
# via cellcommunicationpf2
threadpoolctl==3.5.0
# via scikit-learn
Expand Down
9 changes: 8 additions & 1 deletion requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
-e file:.
annoy==1.17.3
# via pacmap
autograd==1.7.0
# via cellcommunicationpf2
certifi==2024.8.30
# via requests
charset-normalizer==3.3.2
Expand Down Expand Up @@ -39,6 +41,7 @@ matplotlib==3.9.2
numba==0.60.0
# via pacmap
numpy==2.0.2
# via autograd
# via cellcommunicationpf2
# via contourpy
# via cupy-cuda12x
Expand All @@ -47,6 +50,7 @@ numpy==2.0.2
# via pacmap
# via pandas
# via patsy
# via pymanopt
# via scikit-learn
# via scipy
# via statsmodels
Expand All @@ -67,6 +71,8 @@ patsy==0.5.6
# via statsmodels
pillow==10.4.0
# via matplotlib
pymanopt==2.2.1
# via cellcommunicationpf2
pyparsing==3.1.4
# via matplotlib
python-dateutil==2.9.0.post0
Expand All @@ -81,6 +87,7 @@ scikit-learn==1.5.2
# via pacmap
scipy==1.14.1
# via cellcommunicationpf2
# via pymanopt
# via scikit-learn
# via statsmodels
# via tensorly
Expand All @@ -90,7 +97,7 @@ six==1.16.0
# via python-dateutil
statsmodels==0.14.4
# via tlviz
tensorly==0.8.1
tensorly==0.9.0
# via cellcommunicationpf2
threadpoolctl==3.5.0
# via scikit-learn
Expand Down

0 comments on commit c4039ad

Please sign in to comment.