Skip to content

Commit 623c707

Browse files
1
1 parent e7f4068 commit 623c707

7 files changed

+165
-19
lines changed

app/exp_config.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,8 @@
33
TEST_PERCENT = 0.1
44
MAX_SEQ_LENGTH = 512
55

6-
SYSTEM_CONFIG = {'per_device_train_batch_size': 16,
7-
'gradient_accumulation_steps': 8}
6+
SYSTEM_CONFIG = {'per_device_train_batch_size': 8,
7+
'gradient_accumulation_steps': 16}
88

99
# logging_steps: logging when training roberta
1010
# the frequency of evaluating on the val set for roberta
@@ -14,4 +14,4 @@
1414

1515
TRAIN_DEBUG_CONFIG = {'epoch': 1,
1616
'logging_steps': 1,
17-
'eval_steps': 2}
17+
'eval_steps': 2}

app/run_pos_dep_constit_ner.sh

+14
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# bash run_pos_dep_constit_ner.sh
2+
set -e
3+
4+
# pre-train model
5+
bash train_on_cn_novel_dep.sh '0' 1 0 0 15
6+
bash train_on_cn_novel_pos.sh '0' 1 0 0 15
7+
bash train_on_cn_novel_constit.sh '0' 1 0 0 15
8+
bash train_on_cn_novel_ner.sh '0' 1 0 0 15
9+
10+
# non-pre-train model
11+
bash train_on_cn_novel_dep.sh '0' 1 0 1 15
12+
bash train_on_cn_novel_pos.sh '0' 1 0 1 15
13+
bash train_on_cn_novel_constit.sh '0' 1 0 1 15
14+
bash train_on_cn_novel_ner.sh '0' 1 0 1 15

app/train_cn_roberta.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,9 @@ def args_parse():
6161
'None',
6262
'likelihood_rank',
6363
'pos',
64-
'dep'
64+
'dep',
65+
'constit', # phrase structure tree, constituency tree,
66+
'ner'
6567
],
6668
required=True)
6769
parser.add_argument('--is_change_apply_to_test', type=int, default=0)
@@ -95,6 +97,9 @@ def main():
9597
if x in semantic_change:
9698
is_change_apply_to_test = True
9799

100+
if 'pos' in semantic_change or 'dep' in semantic_change:
101+
assert char_freq_ranges == [0]
102+
98103
# read char frequencies
99104
char_freq_rank = {}
100105
with open(char_freq_txt_path, 'r') as f:
@@ -116,6 +121,23 @@ def main():
116121
classifier_name += '_no_pretrain'
117122

118123
semantic_change_str = '_'.join(semantic_change)
124+
125+
# set save path
126+
# save path
127+
save_name = f'{dataset_name}_{classifier_name}_{semantic_change_str}' \
128+
f'_{is_change_apply_to_train}_{is_change_apply_to_test}'
129+
if is_debug:
130+
save_name = save_name + '_debug.csv'
131+
else:
132+
save_name = save_name + '.csv'
133+
save_path = os.path.join(save_dir, save_name)
134+
135+
if os.path.isfile(save_path):
136+
print(f"=" * 78)
137+
print(f"{save_path} exist. Skip Training!!!")
138+
print(f"=" * 78)
139+
return
140+
119141
semantic_modifier = SemanticModifier(semantic_change, char_freq_rank=char_freq_rank)
120142

121143
tokenizer = AutoTokenizer.from_pretrained(hugginface_model_id)
@@ -240,10 +262,6 @@ def main():
240262
shutil.rmtree(tmp_ckpts_dir)
241263
print(f"Remove temp dir {tmp_ckpts_dir} SUCCESS!!!!")
242264

243-
# save path
244-
save_path = os.path.join(save_dir,
245-
f'{dataset_name}_{classifier_name}_{semantic_change_str}_{is_change_apply_to_train}'
246-
f'_{is_change_apply_to_test}.csv')
247265
exp_recorder.save_to_disk(save_path)
248266

249267
# # Save model

app/train_on_cn_novel_constit.sh

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# bash train_on_cn_novel_constit.sh '0' 1 1 0 1
2+
3+
char_freq_ranges=${1:-0}
4+
is_change_apply_to_test=${2:-1}
5+
is_debug=${3:-0}
6+
re_init_weights=${4:-0}
7+
repeat=${5:-1}
8+
is_change_apply_to_train=${6:-1}
9+
10+
data_dir=../data/5billion
11+
save_dir=../result/
12+
dataset_name=cn_novel_5billion
13+
classifier_name=cn_roberta
14+
char_freq_txt_path=../data/sort_char.txt
15+
semantic_change='constit'
16+
17+
python3.6 train_cn_roberta.py --classifier_name $classifier_name \
18+
--dataset_name $dataset_name \
19+
--data_dir $data_dir \
20+
--save_dir $save_dir \
21+
--char_freq_txt_path $char_freq_txt_path \
22+
--is_debug $is_debug \
23+
--repeat $repeat \
24+
--char_freq_ranges $char_freq_ranges \
25+
--semantic_change $semantic_change \
26+
--is_change_apply_to_test $is_change_apply_to_test \
27+
--re_init_weights $re_init_weights \
28+
--is_change_apply_to_train $is_change_apply_to_train

app/train_on_cn_novel_ner.sh

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# bash train_on_cn_novel_ner.sh '0' 1 1 0 1
2+
3+
char_freq_ranges=${1:-0}
4+
is_change_apply_to_test=${2:-1}
5+
is_debug=${3:-0}
6+
re_init_weights=${4:-0}
7+
repeat=${5:-1}
8+
is_change_apply_to_train=${6:-1}
9+
10+
data_dir=../data/5billion
11+
save_dir=../result/
12+
dataset_name=cn_novel_5billion
13+
classifier_name=cn_roberta
14+
char_freq_txt_path=../data/sort_char.txt
15+
semantic_change='ner'
16+
17+
python3.6 train_cn_roberta.py --classifier_name $classifier_name \
18+
--dataset_name $dataset_name \
19+
--data_dir $data_dir \
20+
--save_dir $save_dir \
21+
--char_freq_txt_path $char_freq_txt_path \
22+
--is_debug $is_debug \
23+
--repeat $repeat \
24+
--char_freq_ranges $char_freq_ranges \
25+
--semantic_change $semantic_change \
26+
--is_change_apply_to_test $is_change_apply_to_test \
27+
--re_init_weights $re_init_weights \
28+
--is_change_apply_to_train $is_change_apply_to_train

core/semantic_modifier.py

+69-11
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,32 @@
1+
import os
12
import ipdb
23
import copy
34
import random
45
import numpy as np
6+
import glob
57
import spacy
68
from tqdm import tqdm
9+
import pickle
10+
import hashlib
11+
import benepar
712

813

914
class SemanticModifier:
1015
def __init__(self, semantic_change, char_freq_rank=None):
1116
self.semantic_change = semantic_change
1217
self.char_freq_rank = char_freq_rank
1318
self.max_freq = max(self.char_freq_rank.values())
14-
if 'pos' or 'dep' in semantic_change:
19+
if 'pos' in semantic_change or \
20+
'dep' in semantic_change or \
21+
'constit' in semantic_change or \
22+
'ner' in semantic_change:
1523
# self.spacy_parser = spacy.load("zh_core_web_sm")
1624
self.spacy_parser = spacy.load("zh_core_web_trf")
25+
# if 'constit' in semantic_change:
26+
benepar_model = 'benepar_zh2'
27+
self.spacy_parser.add_pipe("benepar", config={"model": benepar_model})
28+
print(f'Spacy add benepar pipe done! Model: {benepar_model}')
29+
self.spacy_results = {}
1730
else:
1831
self.spacy_parser = None
1932

@@ -89,21 +102,66 @@ def change_texts(self, texts, char_freq_range):
89102
appear_set.add(x)
90103
split_text = new_split_text
91104

92-
if 'pos' in self.semantic_change or 'dep' in self.semantic_change:
105+
if 'pos' in self.semantic_change or \
106+
'dep' in self.semantic_change or \
107+
'constit' in self.semantic_change or \
108+
'ner' in self.semantic_change:
93109
# from spacy.lang.zh.examples import sentences
94110
# example_sentence = sentences[0]
95-
parse_res = self.spacy_parser(text.replace(' ', ''))
96-
new_text = []
97-
for token in parse_res:
98-
if 'pos' in self.semantic_change:
99-
new_text.append(token.pos_)
100-
elif 'dep' in self.semantic_change:
101-
new_text.append(token.dep_)
111+
to_parse_text = text.replace(' ', '')
112+
model_name = self.spacy_parser.meta['name'] + '_' + self.spacy_parser.meta['lang']
113+
text_md5 = hashlib.md5(f'{model_name}_{to_parse_text}'.encode()).hexdigest()
114+
115+
if text_md5 in self.spacy_results:
116+
split_text = self.spacy_results[text_md5]
117+
else:
118+
text_pickle_path = f'../spacy_temp/{text_md5}.pkl'
119+
if os.path.isfile(text_pickle_path):
120+
parse_res = pickle.load(open(text_pickle_path, 'rb'))
121+
else:
122+
parse_res = self.spacy_parser(to_parse_text)
123+
pickle.dump(parse_res, open(text_pickle_path, 'wb'))
124+
new_text = []
125+
# Reference: https://spacy.io/usage/linguistic-features#dependency-parse
126+
127+
if 'pos' in self.semantic_change or 'dep' in self.semantic_change:
128+
for token in parse_res:
129+
if 'pos' in self.semantic_change:
130+
new_text.append(token.pos_)
131+
elif 'dep' in self.semantic_change:
132+
new_text.append(token.dep_)
133+
new_text.append(str(token.idx))
134+
new_text.append(str(token.head.idx))
135+
# new_text.append(token.head.text)
136+
else:
137+
raise Exception
138+
new_text = ' '.join(new_text)
139+
elif 'constit' in self.semantic_change:
140+
new_text = []
141+
tokens = [str(x) for x in parse_res]
142+
tokens = ''.join(tokens)
143+
tokens_set = set(list(tokens))
144+
for sen in parse_res.sents:
145+
parse_string = sen._.parse_string
146+
new_text.append(parse_string)
147+
new_text = '<s>'.join(new_text)
148+
for token in tokens_set:
149+
new_text = new_text.replace(str(token), '')
150+
new_text = new_text.replace(' ', '')
151+
elif 'ner' in self.semantic_change:
152+
new_text = []
153+
for ent in parse_res.ents:
154+
new_text.append(ent.label_)
155+
new_text.append(str(ent.start_char))
156+
new_text.append(str(ent.end_char))
157+
new_text = ' '.join(new_text)
102158
else:
103159
raise Exception
104-
split_text = new_text
160+
split_text = new_text
161+
self.spacy_results[text_md5] = split_text
105162
# 这个pos/dep tag的数量和原本中文的数量是对不上的,因为会对中文做分词,所以会短一点
106-
processed_texts.append(' '.join(split_text))
163+
164+
processed_texts.append(split_text)
107165

108166
assert len(processed_texts) == len(texts)
109167

spacy_temp/.gitkeep

Whitespace-only changes.

0 commit comments

Comments
 (0)