Skip to content

Commit bac27a1

Browse files
authored
Update train_text_to_image.py
1 parent bd96a08 commit bac27a1

File tree

1 file changed

+18
-2
lines changed

1 file changed

+18
-2
lines changed

examples/text_to_image/train_text_to_image.py

+18-2
Original file line numberDiff line numberDiff line change
@@ -499,6 +499,15 @@ def parse_args():
499499
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500500
),
501501
)
502+
parser.add_argument(
503+
"--image_interpolation_mode",
504+
type=str,
505+
default="lanczos",
506+
choices=[
507+
f.lower() for f in dir(transforms.InterpolationMode) if not f.startswith("__") and not f.endswith("__")
508+
],
509+
help="The image interpolation method to use for resizing images.",
510+
)
502511

503512
args = parser.parse_args()
504513
env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
@@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True):
787796
)
788797
return inputs.input_ids
789798

790-
# Preprocessing the datasets.
799+
# Get the specified interpolation method from the args
800+
interpolation = getattr(transforms.InterpolationMode, args.image_interpolation_mode.upper(), None)
801+
802+
# Raise an error if the interpolation method is invalid
803+
if interpolation is None:
804+
raise ValueError(f"Unsupported interpolation mode {args.image_interpolation_mode}.")
805+
806+
# Data preprocessing transformations
791807
train_transforms = transforms.Compose(
792808
[
793-
transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.BILINEAR),
809+
transforms.Resize(args.resolution, interpolation=interpolation), # Use dynamic interpolation method
794810
transforms.CenterCrop(args.resolution) if args.center_crop else transforms.RandomCrop(args.resolution),
795811
transforms.RandomHorizontalFlip() if args.random_flip else transforms.Lambda(lambda x: x),
796812
transforms.ToTensor(),

0 commit comments

Comments
 (0)