-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathupload_to_hub.py
33 lines (27 loc) · 1.07 KB
/
upload_to_hub.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
from argparse import ArgumentParser
from transformers import AutoModel, AutoTokenizer
from sentence_transformers.models import Pooling, Transformer
from sentence_transformers import SentenceTransformer
parser = ArgumentParser()
parser.add_argument("--checkpoint")
parser.add_argument("--model_name", default="mfaq")
parser.add_argument("--organization", default="clips")
parser.add_argument("--exist_ok", action="store_true")
parser.add_argument("--replace_model_card", action="store_true")
args = parser.parse_args()
# model = AutoModel.from_pretrained(args.checkpoint, add_pooling_layer=False)
# tokenizer = AutoTokenizer.from_pretrained(args.checkpoint)
model = Transformer(
args.checkpoint,
max_seq_length=128,
model_args={"add_pooling_layer": False},
tokenizer_name_or_path=args.checkpoint
)
pooling = Pooling(model.auto_model.config.hidden_size, pooling_mode="mean")
st = SentenceTransformer(modules=[model, pooling])
st.save_to_hub(
args.model_name,
organization=args.organization,
exist_ok=args.exist_ok,
replace_model_card=args.replace_model_card
)