From 0edbe87985ab270c246e967507e0ec45d7acb659 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 4 Apr 2025 14:49:54 +0100 Subject: [PATCH 1/2] Initial commit to make codebase robust to use of non-standard meta_dict names Signed-off-by: Ben Murray --- monai/apps/deepgrow/transforms.py | 21 +++++++++++++++---- monai/apps/detection/transforms/dictionary.py | 5 ++++- monai/data/meta_obj.py | 8 +++++++ monai/transforms/intensity/dictionary.py | 4 +++- monai/transforms/io/dictionary.py | 1 + monai/transforms/meta_utility/dictionary.py | 7 ++++++- monai/transforms/post/dictionary.py | 3 +++ monai/transforms/utils.py | 3 +++ 8 files changed, 45 insertions(+), 7 deletions(-) diff --git a/monai/apps/deepgrow/transforms.py b/monai/apps/deepgrow/transforms.py index 721c0db489..20ff15e5e1 100644 --- a/monai/apps/deepgrow/transforms.py +++ b/monai/apps/deepgrow/transforms.py @@ -19,6 +19,7 @@ import torch from monai.config import IndexSelection, KeysCollection, NdarrayOrTensor +from monai.data.meta_obj import get_meta_dict_name from monai.networks.layers import GaussianFilter from monai.transforms import Resize, SpatialCrop from monai.transforms.transform import MapTransform, Randomizable, Transform @@ -546,7 +547,12 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num): def __call__(self, data): d = dict(data) - meta_dict_key = self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}" + meta_dict_key = self.meta_keys + if not meta_dict_key: + candidate_meta_key = f"{self.ref_image}_{self.meta_key_postfix}" + meta_dict = d.get(candidate_meta_key, None) + if meta_dict is None: + meta_dict_key = get_meta_dict_name(self.ref_image, d) if meta_dict_key not in d: raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!") if "spatial_shape" not in d[meta_dict_key]: @@ -742,7 +748,10 @@ def __init__( def __call__(self, data: Any) -> dict: d = dict(data) guidance = d[self.guidance] - meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] + # meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: dict = d.get(self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}", None) + if meta_dict is None: + meta_dict = d[get_meta_dict_name(self.ref_image, d)] current_shape = d[self.ref_image].shape[1:] cropped_shape = meta_dict[self.cropped_shape_key][1:] factor = np.divide(current_shape, cropped_shape) @@ -852,7 +861,8 @@ def __init__( def __call__(self, data: Any) -> dict: d = dict(data) - meta_dict: dict = d[f"{self.ref_image}_{self.meta_key_postfix}"] + meta_dict: dict = d.get(f"{self.ref_image}_{self.meta_key_postfix}", + d[get_meta_dict_name(self.ref_image, d)]) for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys): image = d[key] @@ -969,5 +979,8 @@ def __call__(self, data): for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix): img_slice, idx = self._apply(d[key], guidance) d[key] = img_slice - d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx + # d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx + if meta_key not in d: + meta_key = get_meta_dict_name(key, d) + d[meta_key] = idx return d diff --git a/monai/apps/detection/transforms/dictionary.py b/monai/apps/detection/transforms/dictionary.py index 52b1a7d15d..23c632c57a 100644 --- a/monai/apps/detection/transforms/dictionary.py +++ b/monai/apps/detection/transforms/dictionary.py @@ -41,6 +41,7 @@ from monai.config import KeysCollection, SequenceStr from monai.config.type_definitions import DtypeLike, NdarrayOrTensor from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image +from monai.data.meta_obj import get_meta_dict_name from monai.data.meta_tensor import MetaTensor, get_track_meta from monai.data.utils import orientation_ras_lps from monai.transforms import Flip, RandFlip, RandZoom, Rotate90, SpatialCrop, Zoom @@ -308,7 +309,9 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[Ndarray elif meta_key in d: meta_dict = d[meta_key] else: - raise ValueError(f"{meta_key} is not found. Please check whether it is the correct the image meta key.") + meta_key = get_meta_dict_name(self.box_ref_image_keys, d) + if meta_key not in d: + raise ValueError(f"{self.image_meta_key} is not found. Please check whether it is the correct the image meta key.") if "affine" not in meta_dict: raise ValueError( f"'affine' is not found in {meta_key}. \ diff --git a/monai/data/meta_obj.py b/monai/data/meta_obj.py index 15e6e8be15..ee59e27bde 100644 --- a/monai/data/meta_obj.py +++ b/monai/data/meta_obj.py @@ -242,3 +242,11 @@ def is_batch(self) -> bool: def is_batch(self, val: bool) -> None: """Set whether object is part of batch or not.""" self._is_batch = val + + +def get_meta_dict_name(key, dictionary): + for kv, kd in dictionary.items(): + if isinstance(kd, dict): + if kd.get("tensor_name", None) == key: + return kv + return None diff --git a/monai/transforms/intensity/dictionary.py b/monai/transforms/intensity/dictionary.py index f2b1a2fd40..a9d982b07a 100644 --- a/monai/transforms/intensity/dictionary.py +++ b/monai/transforms/intensity/dictionary.py @@ -24,7 +24,7 @@ from monai.config import DtypeLike, KeysCollection from monai.config.type_definitions import NdarrayOrTensor -from monai.data.meta_obj import get_track_meta +from monai.data.meta_obj import get_meta_dict_name, get_track_meta from monai.transforms.intensity.array import ( AdjustContrast, ClipIntensityPercentiles, @@ -358,6 +358,8 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]: d, self.factor_key, self.meta_keys, self.meta_key_postfix ): meta_key = meta_key or f"{key}_{meta_key_postfix}" + if meta_key not in d: + meta_key = get_meta_dict_name(key, d) factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None offset = None if factor is None else self.shifter.offset * factor d[key] = self.shifter(d[key], offset=offset) diff --git a/monai/transforms/io/dictionary.py b/monai/transforms/io/dictionary.py index be1e78db8a..52d105d439 100644 --- a/monai/transforms/io/dictionary.py +++ b/monai/transforms/io/dictionary.py @@ -169,6 +169,7 @@ def __call__(self, data, reader: ImageReader | None = None): f"loader must return a tuple or list (because image_only=False was used), got {type(data)}." ) d[key] = data[0] + data[1]['tensor_name'] = key if not isinstance(data[1], dict): raise ValueError(f"metadata must be a dict, got {type(data[1])}.") meta_key = meta_key or f"{key}_{meta_key_postfix}" diff --git a/monai/transforms/meta_utility/dictionary.py b/monai/transforms/meta_utility/dictionary.py index ed752bb2d7..a053e8f07c 100644 --- a/monai/transforms/meta_utility/dictionary.py +++ b/monai/transforms/meta_utility/dictionary.py @@ -23,6 +23,7 @@ import torch from monai.config.type_definitions import KeysCollection, NdarrayOrTensor +from monai.data.meta_obj import get_meta_dict_name from monai.data.meta_tensor import MetaTensor from monai.transforms.inverse import InvertibleTransform from monai.transforms.transform import MapTransform @@ -77,7 +78,10 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Nd _ = self.get_most_recent_transform(d, key) # do the inverse im = d[key] - meta = d.pop(PostFix.meta(key), None) + if PostFix.meta(key) in d: + meta = d.pop(PostFix.meta(key), None) + else: + meta = d.pop(get_meta_dict_name(key, d)) transforms = d.pop(PostFix.transforms(key), None) im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore d[key] = im @@ -101,6 +105,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N for key in self.key_iterator(d): self.push_transform(d, key) im = d[key] + meta = d.pop(PostFix.meta(key), None) transforms = d.pop(PostFix.transforms(key), None) im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore diff --git a/monai/transforms/post/dictionary.py b/monai/transforms/post/dictionary.py index 7e1e074f71..779e6f235f 100644 --- a/monai/transforms/post/dictionary.py +++ b/monai/transforms/post/dictionary.py @@ -28,6 +28,7 @@ from monai import config from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike from monai.data.csv_saver import CSVSaver +from monai.data.meta_obj import get_meta_dict_name from monai.data.meta_tensor import MetaTensor from monai.transforms.inverse import InvertibleTransform from monai.transforms.post.array import ( @@ -797,6 +798,8 @@ def __call__(self, data): if meta_key is None and meta_key_postfix is not None: meta_key = f"{key}_{meta_key_postfix}" meta_data = d[meta_key] if meta_key is not None else None + if meta_data is None: + meta_data = d.get(get_meta_dict_name(key, d), None) self.saver.save(data=d[key], meta_data=meta_data) if self.flush: self.saver.finalize() diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 1ff0abc27c..16c80298d1 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -27,6 +27,7 @@ import monai from monai.config import DtypeLike, IndexSelection from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor +from monai.data.meta_obj import get_meta_dict_name from monai.data.utils import to_affine_nd from monai.networks.layers import GaussianFilter from monai.networks.utils import meshgrid_ij @@ -2144,6 +2145,8 @@ def sync_meta_info(key, data_dict, t: bool = True): # update meta dicts meta_dict_key = PostFix.meta(key) + if meta_dict_key not in d: + meta_dict_key = get_meta_dict_name(key, d) if meta_dict_key not in d: d[meta_dict_key] = monai.data.MetaTensor.get_default_meta() if not isinstance(d[key], monai.data.MetaTensor): From a81140cbf57f4bcd3d78658bdbcd46bd6a556cf5 Mon Sep 17 00:00:00 2001 From: Ben Murray Date: Fri, 4 Apr 2025 16:03:46 +0100 Subject: [PATCH 2/2] Bugfix for sync_meta_info Signed-off-by: Ben Murray --- monai/transforms/utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/monai/transforms/utils.py b/monai/transforms/utils.py index 16c80298d1..091fa318f0 100644 --- a/monai/transforms/utils.py +++ b/monai/transforms/utils.py @@ -2144,9 +2144,11 @@ def sync_meta_info(key, data_dict, t: bool = True): d = dict(data_dict) # update meta dicts - meta_dict_key = PostFix.meta(key) - if meta_dict_key not in d: + default_meta_dict_key = PostFix.meta(key) + if default_meta_dict_key not in d: meta_dict_key = get_meta_dict_name(key, d) + else: + meta_dict_key = default_meta_dict_key if meta_dict_key not in d: d[meta_dict_key] = monai.data.MetaTensor.get_default_meta() if not isinstance(d[key], monai.data.MetaTensor):