@@ -35,17 +35,18 @@ def download_bert_base_uncased():
35
35
check_root ()
36
36
if not os .path .exists (os .path .join (root_path , 'pretrain/bert-base-uncased' )):
37
37
os .mkdir (os .path .join (root_path , 'pretrain/bert-base-uncased' ))
38
- os .system ('wget -P ' + os .path .join (root_path , 'pretrain/bert-base-uncased' ) + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/bert_config .json' )
38
+ os .system ('wget -P ' + os .path .join (root_path , 'pretrain/bert-base-uncased' ) + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/config .json' )
39
39
os .system ('wget -P ' + os .path .join (root_path , 'pretrain/bert-base-uncased' ) + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/pytorch_model.bin' )
40
40
os .system ('wget -P ' + os .path .join (root_path , 'pretrain/bert-base-uncased' ) + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/vocab.txt' )
41
41
42
42
def download_pretrain (model_name ):
43
43
ckpt = os .path .join (root_path , 'pretrain/nre/' + model_name + '.pth.tar' )
44
44
if not os .path .exists (ckpt ):
45
- os .system ('wget -P ' + ckpt + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar' )
45
+ os .system ('wget -P ' + os . path . join ( root_path , 'pretrain/nre' ) + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar' )
46
46
47
47
def get_model (model_name ):
48
48
check_root ()
49
+ ckpt = os .path .join (root_path , 'pretrain/nre/' + model_name + '.pth.tar' )
49
50
if model_name == 'wiki80_cnn_softmax' :
50
51
download_pretrain (model_name )
51
52
download_glove ()
@@ -72,7 +73,7 @@ def get_model(model_name):
72
73
download_wiki80 ()
73
74
rel2id = json .load (open (os .path .join (root_path , 'benchmark/wiki80/wiki80_rel2id.json' )))
74
75
sentence_encoder = encoder .BERTEncoder (
75
- max_length = 80 , pretrain_path = 'pretrain/bert-base-uncased' )
76
+ max_length = 80 , pretrain_path = os . path . join ( root_path , 'pretrain/bert-base-uncased' ) )
76
77
m = model .SoftmaxNN (sentence_encoder , len (rel2id ), rel2id )
77
78
m .load_state_dict (torch .load (ckpt )['state_dict' ])
78
79
return m
0 commit comments