Skip to content

Commit 8e47175

Browse files
author
rui.tao
committed
增加中文人物关系抽取模型
1 parent af8b250 commit 8e47175

6 files changed

+155
-8
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import sys, json
2+
import torch
3+
import os
4+
import numpy as np
5+
import opennre
6+
from opennre import encoder, model, framework
7+
import argparse
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--mask_entity', action='store_true', help='Mask entity mentions')
11+
args = parser.parse_args()
12+
13+
# Some basic settings
14+
root_path = '.'
15+
sys.path.append(root_path)
16+
if not os.path.exists('ckpt'):
17+
os.mkdir('ckpt')
18+
ckpt = 'ckpt/people_chinese_bert_softmax.pth.tar'
19+
20+
# Check data
21+
rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation/people-relation_rel2id.json')))
22+
23+
# Define the sentence encoder
24+
sentence_encoder = opennre.encoder.BERTEncoder(
25+
max_length=80,
26+
pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'),
27+
mask_entity=args.mask_entity
28+
)
29+
30+
# Define the model
31+
model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
32+
33+
# Define the whole training framework
34+
framework = opennre.framework.SentenceRE(
35+
train_path=os.path.join(root_path, 'benchmark/people-relation/people-relation_train.txt'),
36+
val_path=os.path.join(root_path, 'benchmark/people-relation/people-relation_val.txt'),
37+
test_path=os.path.join(root_path, 'benchmark/people-relation/people-relation_val.txt'),
38+
model=model,
39+
ckpt=ckpt,
40+
batch_size=64, # Modify the batch size w.r.t. your device
41+
max_epoch=3,
42+
lr=2e-5,
43+
opt='adamw'
44+
)
45+
46+
# Train the model
47+
framework.train_model()
48+
49+
# Test the model
50+
framework.load_state_dict(torch.load(ckpt)['state_dict'])
51+
result = framework.eval_model(framework.test_loader)
52+
53+
# Print the result
54+
print('Accuracy on test set: {}'.format(result['acc']))
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import sys, json
2+
import torch
3+
import os
4+
import numpy as np
5+
import opennre
6+
from opennre import encoder, model, framework
7+
import argparse
8+
9+
parser = argparse.ArgumentParser()
10+
parser.add_argument('--mask_entity', action='store_true', help='Mask entity mentions')
11+
args = parser.parse_args()
12+
13+
# Some basic settings
14+
root_path = '.'
15+
sys.path.append(root_path)
16+
if not os.path.exists('ckpt'):
17+
os.mkdir('ckpt')
18+
ckpt = 'ckpt/test_chinese_bert_softmax.pth.tar'
19+
20+
# Check data
21+
rel2id = json.load(open(os.path.join(root_path, 'benchmark/test_chinese/test_chinese_rel2id.json')))
22+
23+
# Define the sentence encoder
24+
sentence_encoder = opennre.encoder.BERTEncoder(
25+
max_length=80,
26+
pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'),
27+
mask_entity=args.mask_entity
28+
)
29+
30+
# Define the model
31+
model = opennre.model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
32+
33+
# Define the whole training framework
34+
framework = opennre.framework.SentenceRE(
35+
train_path=os.path.join(root_path, 'benchmark/test_chinese/test_chinese_train.txt'),
36+
val_path=os.path.join(root_path, 'benchmark/test_chinese/test_chinese_val.txt'),
37+
test_path=os.path.join(root_path, 'benchmark/test_chinese/test_chinese_val.txt'),
38+
model=model,
39+
ckpt=ckpt,
40+
batch_size=64, # Modify the batch size w.r.t. your device
41+
max_epoch=7,
42+
lr=2e-5,
43+
opt='adamw'
44+
)
45+
46+
# Train the model
47+
framework.train_model()
48+
49+
# Test the model
50+
framework.load_state_dict(torch.load(ckpt)['state_dict'])
51+
result = framework.eval_model(framework.test_loader)
52+
53+
# Print the result
54+
print('Accuracy on test set: {}'.format(result['acc']))

example/train_wiki80_bert_softmax.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
sys.path.append(root_path)
1616
if not os.path.exists('ckpt'):
1717
os.mkdir('ckpt')
18-
ckpt = 'ckpt/wiki80_bert_softmax.pth.tar'
18+
ckpt = 'ckpt/wiki80_bert_softmax_7epoch.pth.tar'
1919

2020
# Check data
2121
opennre.download_wiki80(root_path=root_path)
@@ -40,7 +40,7 @@
4040
model=model,
4141
ckpt=ckpt,
4242
batch_size=64, # Modify the batch size w.r.t. your device
43-
max_epoch=10,
43+
max_epoch=7,
4444
lr=2e-5,
4545
opt='adamw'
4646
)

example/train_wiki80_cnn_softmax.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
model=model,
4444
ckpt=ckpt,
4545
batch_size=32,
46-
max_epoch=100,
46+
max_epoch=2,
4747
lr=0.1,
4848
weight_decay=1e-5,
4949
opt='sgd'

opennre/framework/data_loader.py

+2
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ def __init__(self, path, rel2id, tokenizer, entpair_as_bag=False, bag_size=None,
140140
self.bag_name = []
141141
self.facts = {}
142142
for idx, item in enumerate(self.data):
143+
143144
fact = (item['h']['id'], item['t']['id'], item['relation'])
145+
144146
if item['relation'] != 'NA':
145147
self.facts[fact] = 1
146148
if entpair_as_bag:

opennre/pretrain.py

+42-5
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import json
88
import numpy as np
99

10-
default_root_path = os.path.join(os.getenv('HOME'), '.opennre')
10+
default_root_path = os.path.join(os.getenv('openNRE'), '.')
1111

1212
def check_root(root_path=default_root_path):
1313
if not os.path.exists(root_path):
1414
os.mkdir(root_path)
1515
os.mkdir(os.path.join(root_path, 'benchmark'))
1616
os.mkdir(os.path.join(root_path, 'pretrain'))
17-
os.mkdir(os.path.join(root_path, 'pretrain/nre'))
17+
os.mkdir(os.path.join(root_path, 'ckpt'))
1818

1919
def download_wiki80(root_path=default_root_path):
2020
check_root()
@@ -49,14 +49,18 @@ def download_bert_base_uncased(root_path=default_root_path):
4949
os.system('wget -P ' + os.path.join(root_path, 'pretrain/bert-base-uncased') + ' http://193.112.16.83:8080/opennre/pretrain/bert-base-uncased/vocab.txt')
5050

5151
def download_pretrain(model_name, root_path=default_root_path):
52-
ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
52+
ckpt = os.path.join(root_path, 'ckpt/' + model_name + '.pth.tar')
5353
if not os.path.exists(ckpt):
54-
os.system('wget -P ' + os.path.join(root_path, 'pretrain/nre') + ' http://193.112.16.83:8080/opennre/pretrain/nre/' + model_name + '.pth.tar')
54+
print("*"*20)
55+
print("下载ckpt")
56+
os.system('wget -P ' + os.path.join(root_path, 'ckpt/') + ' http://193.112.16.83:8080/opennre/ckpt/' + model_name + '.pth.tar')
5557

5658
def get_model(model_name, root_path=default_root_path):
5759
check_root()
58-
ckpt = os.path.join(root_path, 'pretrain/nre/' + model_name + '.pth.tar')
60+
ckpt = os.path.join(root_path, 'ckpt/' + model_name + '.pth.tar')
61+
5962
if model_name == 'wiki80_cnn_softmax':
63+
print("*"*20+"taorui")
6064
download_pretrain(model_name)
6165
download_glove()
6266
download_wiki80()
@@ -86,5 +90,38 @@ def get_model(model_name, root_path=default_root_path):
8690
m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
8791
m.load_state_dict(torch.load(ckpt)['state_dict'])
8892
return m
93+
elif model_name == 'test_chinese_bert_softmax':
94+
download_pretrain(model_name)
95+
download_bert_base_uncased()
96+
download_wiki80()
97+
rel2id = json.load(open(os.path.join(root_path, 'benchmark/test_chinese/test_chinese_rel2id.json')))
98+
sentence_encoder = encoder.BERTEncoder(
99+
max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'))
100+
m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
101+
m.load_state_dict(torch.load(ckpt)['state_dict'])
102+
return m
103+
104+
elif model_name == 'people_chinese_bert_softmax':
105+
download_pretrain(model_name)
106+
download_bert_base_uncased()
107+
download_wiki80()
108+
rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation/people-relation_rel2id.json')))
109+
sentence_encoder = encoder.BERTEncoder(
110+
max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'))
111+
m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
112+
m.load_state_dict(torch.load(ckpt)['state_dict'])
113+
return m
114+
115+
elif model_name == 'people_delunknown_chinese_bert_softmax':
116+
download_pretrain(model_name)
117+
download_bert_base_uncased()
118+
download_wiki80()
119+
rel2id = json.load(open(os.path.join(root_path, 'benchmark/people-relation-delunknow/people-relation_rel2id.json')))
120+
sentence_encoder = encoder.BERTEncoder(
121+
max_length=80, pretrain_path=os.path.join(root_path, 'pretrain/chinese_wwm_pytorch'))
122+
m = model.SoftmaxNN(sentence_encoder, len(rel2id), rel2id)
123+
m.load_state_dict(torch.load(ckpt)['state_dict'])
124+
return m
125+
89126
else:
90127
raise NotImplementedError

0 commit comments

Comments
 (0)