Skip to content

Commit

Permalink
dynamically adjust indexes for saving results
Browse files Browse the repository at this point in the history
  • Loading branch information
sophiamaedler committed Nov 29, 2024
1 parent 2ed029f commit ab1fdd9
Showing 1 changed file with 6 additions and 4 deletions.
10 changes: 6 additions & 4 deletions src/scportrait/pipeline/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,18 +404,20 @@ def inference(self,
#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+batch_size)] = result.numpy()
cell_ids[ix:(ix+batch_size)] = class_id.unsqueeze(1)

ix += batch_size

for i in range(len(dataloader) - 1):
if i % 10 == 0:
self.log(f"processing batch {i}")
x, label, id = next(data_iter)

r = model_fun(x.to(self.config["inference_device"]))
result = r

#save the results for each batch into the memory mapped array at the specified indices
features[ix:(ix+batch_size)] = r.cpu().detach().numpy()
cell_ids[ix:(ix+batch_size)] = label.unsqueeze(1)
features[ix:(ix+r.shape[0])] = r.cpu().detach().numpy()
cell_ids[ix:(ix+r.shape[0])] = label.unsqueeze(1)

ix += r.shape[0]

if hasattr(self.config, "log_transform"):
if self.config["log_transform"]:
Expand Down

0 comments on commit ab1fdd9

Please sign in to comment.