diff --git a/mergekit/merge.py b/mergekit/merge.py index 60189f44..291f8486 100644 --- a/mergekit/merge.py +++ b/mergekit/merge.py @@ -136,6 +136,19 @@ def run_merge( "Chat template specified but no tokenizer found. Chat template will not be saved." ) + # Copy feature_extractor if it is a whisper model + if options.copy_feature_extractor and arch_info.definition.expected_model_type == "whisper": + try: + _copy_feature_extractor( + merge_config, out_path, trust_remote_code=options.trust_remote_code + ) + except Exception as e: + logging.error( + "Failed to copy feature_extractor. The merge was still successful, just copy it from somewhere else.", + exc_info=e, + ) + + if tokenizer: logging.info("Saving tokenizer") _set_chat_template(tokenizer, merge_config) @@ -229,6 +242,36 @@ def _copy_tokenizer( tokenizer.save_pretrained(out_path, safe_serialization=True) +def _copy_feature_extractor( + merge_config: MergeConfiguration, out_path: str, trust_remote_code: bool = False +): + donor_model = merge_config.base_model or (merge_config.referenced_models()[0]) + + if (os.path.exists( + os.path.join(donor_model.model.path, "preprocessor_config.json") + ) + ): + logging.info(f"Copying feature_extractor from {donor_model}") + + for file_name in [ + "preprocessor_config.json", + ]: + if os.path.exists(os.path.join(donor_model.model.path, file_name)): + shutil.copy( + os.path.join(donor_model.model.path, file_name), + os.path.join(out_path, file_name), + ) + return + + # fallback: try actually loading the feature_extractor and saving it + logging.info(f"Reserializing feature_extractor from {donor_model}") + feature_extractor = transformers.AutoFeatureExtractor.from_pretrained( + donor_model.model.path, + revision=donor_model.model.revision, + trust_remote_code=trust_remote_code, + ) + _set_chat_template(feature_extractor, merge_config) + feature_extractor.save_pretrained(out_path, safe_serialization=True) def _model_out_config( config: MergeConfiguration, arch_info: ArchitectureInfo,