Skip to content

Commit b93de9e

Browse files
author
Tianyu Gao
committed
fix download bug
1 parent f327a5b commit b93de9e

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

opennre/pretrain.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,18 @@ def download_bert_base_uncased():
3535
check_root()
3636
if not os.path.exists(os.path.join(root_path, 'pretrain/bert-base-uncased')):
3737
os.mkdir(os.path.join(root_path, 'pretrain/bert-base-uncased'))
38-
os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/bert_config.json')
38+
os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/config.json')
3939
os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/pytorch_model.bin')
4040
os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/vocab.txt')
4141

4242
def download_pretrain(model_name):
4343
ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
4444
if not os.path.exists(ckpt):
45-
os.system('wget -P ' + ckpt + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar')
45+
os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar')
4646

4747
def get_model(model_name):
4848
check_root()
49+
ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
4950
if model_name == 'wiki80_cnn_softmax':
5051
download_pretrain(model_name)
5152
download_glove()
@@ -72,7 +73,7 @@ def get_model(model_name):
7273
download_wiki80()
7374
rel2id = json.load(open(os.path.join(root_path, 'benchmark/wiki80/wiki80_rel2id.json')))
7475
sentence_encoder = encoder.BERTEncoder(
75-
max_length=80, pretrain_path='pretrain/bert-base-uncased')
76+
max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/bert-base-uncased'))
7677
m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
7778
m.load_state_dict(torch.load(ckpt)['state_dict'])
7879
return m

requirements.txt

-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ idna==2.8
88
jmespath==0.9.4
99
joblib==0.14.0
1010
numpy==1.17.3
11-
pkg-resources==0.0.0
1211
python-dateutil==2.8.0
1312
regex==2019.8.19
1413
requests==2.22.0

0 commit comments

Comments
 (0)