|
7 | 7 | import json
|
8 | 8 | import numpy as np
|
9 | 9 |
|
10 |
| -default_root_path = os.path.join(os.getenv('HOME'), '.opennre') |
| 10 | +default_root_path = os.path.join(os.getenv('openNRE'), '.') |
11 | 11 |
|
12 | 12 | def check_root(root_path=default_root_path):
|
13 | 13 | if not os.path.exists(root_path):
|
14 | 14 | os.mkdir(root_path)
|
15 | 15 | os.mkdir(os.path.join(root_path, 'benchmark'))
|
16 | 16 | os.mkdir(os.path.join(root_path, 'pretrain'))
|
17 |
| - os.mkdir(os.path.join(root_path, 'pretrain/nre')) |
| 17 | + os.mkdir(os.path.join(root_path, 'ckpt')) |
18 | 18 |
|
19 | 19 | def download_wiki80(root_path=default_root_path):
|
20 | 20 | check_root()
|
@@ -49,14 +49,18 @@ def download_bert_base_uncased(root_path=default_root_path):
|
49 | 49 | os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/vocab.txt')
|
50 | 50 |
|
51 | 51 | def download_pretrain(model_name, root_path=default_root_path):
|
52 |
| - ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') |
| 52 | + ckpt = os.path.join(root_path, 'ckpt/' + model_name + '.pth.tar') |
53 | 53 | if not os.path.exists(ckpt):
|
54 |
| - os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar') |
| 54 | + print("*"*20) |
| 55 | + print("下载ckpt") |
| 56 | + os.system('wget -P ' + os.path.join(root_path, 'ckpt/') + ' http://193.112.16.83:8080/opennre/ckpt/' + model_name + '.pth.tar') |
55 | 57 |
|
56 | 58 | def get_model(model_name, root_path=default_root_path):
|
57 | 59 | check_root()
|
58 |
| - ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar') |
| 60 | + ckpt = os.path.join(root_path, 'ckpt/' + model_name + '.pth.tar') |
| 61 | + |
59 | 62 | if model_name == 'wiki80_cnn_softmax':
|
| 63 | + print("*"*20+"taorui") |
60 | 64 | download_pretrain(model_name)
|
61 | 65 | download_glove()
|
62 | 66 | download_wiki80()
|
@@ -86,5 +90,38 @@ def get_model(model_name, root_path=default_root_path):
|
86 | 90 | m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
|
87 | 91 | m.load_state_dict(torch.load(ckpt)['state_dict'])
|
88 | 92 | return m
|
| 93 | + elif model_name == 'test_chinese_bert_softmax': |
| 94 | + download_pretrain(model_name) |
| 95 | + download_bert_base_uncased() |
| 96 | + download_wiki80() |
| 97 | + rel2id = json.load(open(os.path.join(root_path, 'benchmark/test_chinese/test_chinese_rel2id.json'))) |
| 98 | + sentence_encoder = encoder.BERTEncoder( |
| 99 | + max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch')) |
| 100 | + m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) |
| 101 | + m.load_state_dict(torch.load(ckpt)['state_dict']) |
| 102 | + return m |
| 103 | + |
| 104 | + elif model_name == 'people_chinese_bert_softmax': |
| 105 | + download_pretrain(model_name) |
| 106 | + download_bert_base_uncased() |
| 107 | + download_wiki80() |
| 108 | + rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation/people-relation_rel2id.json'))) |
| 109 | + sentence_encoder = encoder.BERTEncoder( |
| 110 | + max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch')) |
| 111 | + m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) |
| 112 | + m.load_state_dict(torch.load(ckpt)['state_dict']) |
| 113 | + return m |
| 114 | + |
| 115 | + elif model_name == 'people_delunknown_chinese_bert_softmax': |
| 116 | + download_pretrain(model_name) |
| 117 | + download_bert_base_uncased() |
| 118 | + download_wiki80() |
| 119 | + rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation-delunknow/people-relation_rel2id.json'))) |
| 120 | + sentence_encoder = encoder.BERTEncoder( |
| 121 | + max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch')) |
| 122 | + m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id) |
| 123 | + m.load_state_dict(torch.load(ckpt)['state_dict']) |
| 124 | + return m |
| 125 | + |
89 | 126 | else:
|
90 | 127 | raise NotImplementedError
|
0 commit comments