diff --git a/examples/text_to_image/train_text_to_image.py b/examples/text_to_image/train_text_to_image.py index 8ab136179996..324891dc79ff 100644 --- a/examples/text_to_image/train_text_to_image.py +++ b/examples/text_to_image/train_text_to_image.py @@ -499,6 +499,15 @@ def parse_args(): " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator" ), ) + parser.add_argument( + "--image_interpolation_mode", + type=str, + default="lanczos", + choices=[ + f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__") + ], + help="The image interpolation method to use for resizing images.", + ) args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) @@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True): ) return inputs.input_ids - # Preprocessing the datasets. + # Get the specified interpolation method from the args + interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None) + + # Raise an error if the interpolation method is invalid + if interpolation is None: + raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.") + + # Data preprocessing transformations train_transforms = transforms.Compose( [ - transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution), transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x), transforms.ToTensor(),