@@ -499,6 +499,15 @@ def parse_args():
499
499
" more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
500
500
),
501
501
)
502
+ parser .add_argument (
503
+ "--image_interpolation_mode" ,
504
+ type = str ,
505
+ default = "lanczos" ,
506
+ choices = [
507
+ f .lower () for f in dir (transforms .InterpolationMode ) if not f .startswith ("__" ) and not f .endswith ("__" )
508
+ ],
509
+ help = "The image interpolation method to use for resizing images." ,
510
+ )
502
511
503
512
args = parser .parse_args ()
504
513
env_local_rank = int (os .environ .get ("LOCAL_RANK" , - 1 ))
@@ -787,10 +796,17 @@ def tokenize_captions(examples, is_train=True):
787
796
)
788
797
return inputs .input_ids
789
798
790
- # Preprocessing the datasets.
799
+ # Get the specified interpolation method from the args
800
+ interpolation = getattr (transforms .InterpolationMode , args .image_interpolation_mode .upper (), None )
801
+
802
+ # Raise an error if the interpolation method is invalid
803
+ if interpolation is None :
804
+ raise ValueError (f"Unsupported interpolation mode { args .image_interpolation_mode } ." )
805
+
806
+ # Data preprocessing transformations
791
807
train_transforms = transforms .Compose (
792
808
[
793
- transforms .Resize (args .resolution , interpolation = transforms . InterpolationMode . BILINEAR ),
809
+ transforms .Resize (args .resolution , interpolation = interpolation ), # Use dynamic interpolation method
794
810
transforms .CenterCrop (args .resolution ) if args .center_crop else transforms .RandomCrop (args .resolution ),
795
811
transforms .RandomHorizontalFlip () if args .random_flip else transforms .Lambda (lambda x : x ),
796
812
transforms .ToTensor (),
0 commit comments