Skip to content

Make MONAI robust to non-default meta dict names #8411

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 2 commits into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import torch

from monai.config import IndexSelection, KeysCollection, NdarrayOrTensor
from monai.data.meta_obj import get_meta_dict_name
from monai.networks.layers import GaussianFilter
from monai.transforms import Resize, SpatialCrop
from monai.transforms.transform import MapTransform, Randomizable, Transform
Expand Down Expand Up @@ -546,7 +547,12 @@ def _apply(self, pos_clicks, neg_clicks, factor, slice_num):

def __call__(self, data):
d = dict(data)
meta_dict_key = self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"
meta_dict_key = self.meta_keys
if not meta_dict_key:
candidate_meta_key = f"{self.ref_image}_{self.meta_key_postfix}"
meta_dict = d.get(candidate_meta_key, None)
if meta_dict is None:
meta_dict_key = get_meta_dict_name(self.ref_image, d)
if meta_dict_key not in d:
raise RuntimeError(f"Missing meta_dict {meta_dict_key} in data!")
if "spatial_shape" not in d[meta_dict_key]:
Expand Down Expand Up @@ -742,7 +748,10 @@ def __init__(
def __call__(self, data: Any) -> dict:
d = dict(data)
guidance = d[self.guidance]
meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"]
# meta_dict: dict = d[self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}"]
meta_dict: dict = d.get(self.meta_keys or f"{self.ref_image}_{self.meta_key_postfix}", None)
if meta_dict is None:
meta_dict = d[get_meta_dict_name(self.ref_image, d)]
current_shape = d[self.ref_image].shape[1:]
cropped_shape = meta_dict[self.cropped_shape_key][1:]
factor = np.divide(current_shape, cropped_shape)
Expand Down Expand Up @@ -852,7 +861,8 @@ def __init__(

def __call__(self, data: Any) -> dict:
d = dict(data)
meta_dict: dict = d[f"{self.ref_image}_{self.meta_key_postfix}"]
meta_dict: dict = d.get(f"{self.ref_image}_{self.meta_key_postfix}",
d[get_meta_dict_name(self.ref_image, d)])

for key, mode, align_corners, meta_key in self.key_iterator(d, self.mode, self.align_corners, self.meta_keys):
image = d[key]
Expand Down Expand Up @@ -969,5 +979,8 @@ def __call__(self, data):
for key, meta_key, meta_key_postfix in self.key_iterator(d, self.meta_keys, self.meta_key_postfix):
img_slice, idx = self._apply(d[key], guidance)
d[key] = img_slice
d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx
# d[meta_key or f"{key}_{meta_key_postfix}"]["slice_idx"] = idx
if meta_key not in d:
meta_key = get_meta_dict_name(key, d)
d[meta_key] = idx
return d
5 changes: 4 additions & 1 deletion monai/apps/detection/transforms/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from monai.config import KeysCollection, SequenceStr
from monai.config.type_definitions import DtypeLike, NdarrayOrTensor
from monai.data.box_utils import COMPUTE_DTYPE, BoxMode, clip_boxes_to_image
from monai.data.meta_obj import get_meta_dict_name
from monai.data.meta_tensor import MetaTensor, get_track_meta
from monai.data.utils import orientation_ras_lps
from monai.transforms import Flip, RandFlip, RandZoom, Rotate90, SpatialCrop, Zoom
Expand Down Expand Up @@ -308,7 +309,9 @@ def extract_affine(self, data: Mapping[Hashable, torch.Tensor]) -> tuple[Ndarray
elif meta_key in d:
meta_dict = d[meta_key]
else:
raise ValueError(f"{meta_key} is not found. Please check whether it is the correct the image meta key.")
meta_key = get_meta_dict_name(self.box_ref_image_keys, d)
if meta_key not in d:
raise ValueError(f"{self.image_meta_key} is not found. Please check whether it is the correct the image meta key.")
if "affine" not in meta_dict:
raise ValueError(
f"'affine' is not found in {meta_key}. \
Expand Down
8 changes: 8 additions & 0 deletions monai/data/meta_obj.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,3 +242,11 @@ def is_batch(self) -> bool:
def is_batch(self, val: bool) -> None:
"""Set whether object is part of batch or not."""
self._is_batch = val


def get_meta_dict_name(key, dictionary):
for kv, kd in dictionary.items():
if isinstance(kd, dict):
if kd.get("tensor_name", None) == key:
return kv
return None
4 changes: 3 additions & 1 deletion monai/transforms/intensity/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from monai.config import DtypeLike, KeysCollection
from monai.config.type_definitions import NdarrayOrTensor
from monai.data.meta_obj import get_track_meta
from monai.data.meta_obj import get_meta_dict_name, get_track_meta
from monai.transforms.intensity.array import (
AdjustContrast,
ClipIntensityPercentiles,
Expand Down Expand Up @@ -358,6 +358,8 @@ def __call__(self, data) -> dict[Hashable, NdarrayOrTensor]:
d, self.factor_key, self.meta_keys, self.meta_key_postfix
):
meta_key = meta_key or f"{key}_{meta_key_postfix}"
if meta_key not in d:
meta_key = get_meta_dict_name(key, d)
factor: float | None = d[meta_key].get(factor_key) if meta_key in d else None
offset = None if factor is None else self.shifter.offset * factor
d[key] = self.shifter(d[key], offset=offset)
Expand Down
1 change: 1 addition & 0 deletions monai/transforms/io/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def __call__(self, data, reader: ImageReader | None = None):
f"loader must return a tuple or list (because image_only=False was used), got {type(data)}."
)
d[key] = data[0]
data[1]['tensor_name'] = key
if not isinstance(data[1], dict):
raise ValueError(f"metadata must be a dict, got {type(data[1])}.")
meta_key = meta_key or f"{key}_{meta_key_postfix}"
Expand Down
7 changes: 6 additions & 1 deletion monai/transforms/meta_utility/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import torch

from monai.config.type_definitions import KeysCollection, NdarrayOrTensor
from monai.data.meta_obj import get_meta_dict_name
from monai.data.meta_tensor import MetaTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.transform import MapTransform
Expand Down Expand Up @@ -77,7 +78,10 @@ def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, Nd
_ = self.get_most_recent_transform(d, key)
# do the inverse
im = d[key]
meta = d.pop(PostFix.meta(key), None)
if PostFix.meta(key) in d:
meta = d.pop(PostFix.meta(key), None)
else:
meta = d.pop(get_meta_dict_name(key, d))
transforms = d.pop(PostFix.transforms(key), None)
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
d[key] = im
Expand All @@ -101,6 +105,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
for key in self.key_iterator(d):
self.push_transform(d, key)
im = d[key]

meta = d.pop(PostFix.meta(key), None)
transforms = d.pop(PostFix.transforms(key), None)
im = MetaTensor(im, meta=meta, applied_operations=transforms) # type: ignore
Expand Down
3 changes: 3 additions & 0 deletions monai/transforms/post/dictionary.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from monai import config
from monai.config.type_definitions import KeysCollection, NdarrayOrTensor, PathLike
from monai.data.csv_saver import CSVSaver
from monai.data.meta_obj import get_meta_dict_name
from monai.data.meta_tensor import MetaTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.post.array import (
Expand Down Expand Up @@ -797,6 +798,8 @@ def __call__(self, data):
if meta_key is None and meta_key_postfix is not None:
meta_key = f"{key}_{meta_key_postfix}"
meta_data = d[meta_key] if meta_key is not None else None
if meta_data is None:
meta_data = d.get(get_meta_dict_name(key, d), None)
self.saver.save(data=d[key], meta_data=meta_data)
if self.flush:
self.saver.finalize()
Expand Down
7 changes: 6 additions & 1 deletion monai/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import monai
from monai.config import DtypeLike, IndexSelection
from monai.config.type_definitions import NdarrayOrTensor, NdarrayTensor
from monai.data.meta_obj import get_meta_dict_name
from monai.data.utils import to_affine_nd
from monai.networks.layers import GaussianFilter
from monai.networks.utils import meshgrid_ij
Expand Down Expand Up @@ -2143,7 +2144,11 @@ def sync_meta_info(key, data_dict, t: bool = True):
d = dict(data_dict)

# update meta dicts
meta_dict_key = PostFix.meta(key)
default_meta_dict_key = PostFix.meta(key)
if default_meta_dict_key not in d:
meta_dict_key = get_meta_dict_name(key, d)
else:
meta_dict_key = default_meta_dict_key
if meta_dict_key not in d:
d[meta_dict_key] = monai.data.MetaTensor.get_default_meta()
if not isinstance(d[key], monai.data.MetaTensor):
Expand Down
Loading