Skip to content

Commit

Permalink
Update segment.py
Browse files Browse the repository at this point in the history
Handle segmentation for multiple bandwidths appropriately
  • Loading branch information
lmanan authored Feb 27, 2024
1 parent f21d015 commit e8c8edb
Showing 1 changed file with 3 additions and 2 deletions.
5 changes: 3 additions & 2 deletions cellulus/segment.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def segment(inference_config: InferenceConfig) -> None:
embeddings_mean = embeddings[
np.newaxis, : dataset_meta_data.num_spatial_dims, ...
].copy()

ds_segmentation[sample, bandwidth_factor, ...] = segmentation
elif inference_config.clustering == "greedy":
if dataset_meta_data.num_spatial_dims == 3:
cluster3d = Cluster3d(
Expand All @@ -151,6 +151,7 @@ def segment(inference_config: InferenceConfig) -> None:
bandwidth=inference_config.bandwidth / (2**bandwidth_factor),
min_object_size=inference_config.min_size,
)
ds_segmentation[sample, bandwidth_factor, ...] = segmentation
elif dataset_meta_data.num_spatial_dims == 2:
cluster2d = Cluster2d(
width=embeddings.shape[-1],
Expand All @@ -165,4 +166,4 @@ def segment(inference_config: InferenceConfig) -> None:
min_object_size=inference_config.min_size,
)

ds_segmentation[sample, bandwidth_factor, ...] = segmentation
ds_segmentation[sample, bandwidth_factor, ...] = segmentation

0 comments on commit e8c8edb

Please sign in to comment.