Skip to content

Commit

Permalink
updated_func
Browse files Browse the repository at this point in the history
  • Loading branch information
siyuh committed Aug 20, 2024
1 parent f7e4a13 commit 163a601
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion starfysh/starfysh.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,7 @@ def model_ct_exp(
model,
adata,
visium_args,
poe = False,
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu')
):
"""
Expand All @@ -1127,14 +1128,21 @@ def model_ct_exp(
sig_means = torch.Tensor(visium_args.sig_mean_norm.values).to(device)
anchor_idx = torch.Tensor(visium_args.pure_idx).to(device)
x_in = torch.Tensor(adata.to_df().values).to(device)
if poe:
y_in = torch.Tensor(visium_args.get_img_patches()).float().to(device)


model.eval()
model = model.to(device)
pred_exprs = {}

for ct_idx, cell_type in enumerate(adata.uns['cell_types']):
# Get inference outputs for the given cell type
inference_outputs = model.inference(x_in)

if poe:
inference_outputs = model.inference(x_in,y_in)
else:
inference_outputs = model.inference(x_in)
inference_outputs['qz'] = inference_outputs['qz_m_ct'][:, ct_idx, :]

# Get generative outputs
Expand Down

0 comments on commit 163a601

Please sign in to comment.