Skip to content

Commit 5302125

Browse files
1
1 parent d563077 commit 5302125

15 files changed

+431
-334
lines changed

app/ngram_interpret.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,9 @@ def compute_per_sentence_attr_score(target_df, class_pos, ngram, save_path, lang
115115

116116

117117
# python ngram_interpret.py --path '../result/interpret/interpret_cn_novel_5billion_cn_roberta_debug_0_text_len_128_debug_N_10000_use_all_zero_bs_token_attr.csv'
118+
# python ngram_interpret.py --path '../result/interpret/interpret_cn_novel_5billion_cn_roberta_debug_0_text_len_128_debug_N_10000_use_all_zero_bs_token_attr.csv'
119+
# python ngram_interpret.py --path '../result/interpret/interpret_cn_novel_5billion_cn_roberta_debug_0_text_len_128_debug_N_10000_use_all_zero_bs_token_attr.csv'
120+
118121
# python ngram_interpret.py --path '../result/interpret/interpret_en_grover_en_roberta_debug_0_text_len_256_debug_N_800_use_pad_bs_token_attr.csv'
119122
# python ngram_interpret.py --path '../result/interpret/interpret_en_writing_prompt_en_roberta_debug_0_text_len_128_debug_N_800_use_pad_bs_token_attr.csv'
120123
# python ngram_interpret.py --path '../result/interpret/interpret_en_grover_en_roberta_debug_0_text_len_256_debug_N_10000_use_all_zero_bs_token_attr.csv'
@@ -128,7 +131,7 @@ def main():
128131
language = basename.split('_')[1]
129132

130133
df = pd.read_csv(path)
131-
ngrams = [1, 2, 3, 4, 5, 6]
134+
ngrams = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]
132135
# ngrams = [1, 2, 3, 4, 5, 6]
133136

134137
# token_freq_dict = collections.Counter(df['token'].values)

app/ngram_interpret.sh

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# bash ngram_interpret.sh
2+
3+
set -e
4+
5+
python ngram_interpret.py --path ../result/interpret/interpret_cn_novel_5billion_cn_roberta_debug_0_text_len_128_debug_N_10000_use_all_zero_bs_token_attr.csv
6+
python ngram_interpret.py --path ../result/interpret/interpret_cn_novel_5billion_cn_roberta_debug_0_text_len_128_debug_N_10000_use_pad_bs_token_attr.csv
7+
python ngram_interpret.py --path ../result/interpret/interpret_en_grover_en_roberta_debug_0_text_len_256_debug_N_10000_use_all_zero_bs_token_attr.csv
8+
python ngram_interpret.py --path ../result/interpret/interpret_en_grover_en_roberta_debug_0_text_len_256_debug_N_10000_use_pad_bs_token_attr.csv
9+
python ngram_interpret.py --path ../result/interpret/interpret_en_writing_prompt_en_roberta_debug_0_text_len_128_debug_N_10000_use_all_zero_bs_token_attr.csv
10+
python ngram_interpret.py --path ../result/interpret/interpret_en_writing_prompt_en_roberta_debug_0_text_len_128_debug_N_10000_use_pad_bs_token_attr.csv
11+

app/run_all.sh

+14-14
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,22 @@
55
#repeat=${5:-1}
66

77
# apply to test (pretrain)
8-
bash train_on_cn_novel_origin.sh '16 256 1024 0' 1 0 0 15
9-
bash train_on_cn_novel_reorder_shuffle.sh '16 256 1024 0' 1 0 0 15
10-
bash train_on_cn_novel_char_deduplicate.sh '16 256 1024 0' 1 0 0 15
11-
bash train_on_cn_novel_reorder_freq_high2low.sh '16 256 1024 0' 1 0 0 15
12-
bash train_on_cn_novel_reorder_freq_low2high.sh '16 256 1024 0' 1 0 0 15
13-
bash train_on_cn_novel_reorder_shuffle+deduplicate.sh '16 256 1024 0' 1 0 0 15
14-
bash train_on_cn_novel_likelihood_rank.sh '16 256 1024 0' 1 0 0 15
8+
bash train_on_cn_novel_origin.sh '16 32 64 128 256 512 0' 1 0 0 15
9+
bash train_on_cn_novel_reorder_shuffle.sh '16 32 64 128 256 512 0' 1 0 0 15
10+
bash train_on_cn_novel_char_deduplicate.sh '16 32 64 128 256 512 0' 1 0 0 15
11+
#bash train_on_cn_novel_reorder_freq_high2low.sh '16 32 64 128 256 512 0' 1 0 0 15
12+
#bash train_on_cn_novel_reorder_freq_low2high.sh '16 32 64 128 256 512 0' 1 0 0 15
13+
bash train_on_cn_novel_reorder_shuffle+deduplicate.sh '16 32 64 128 256 512 0' 1 0 0 15
14+
bash train_on_cn_novel_likelihood_rank.sh '16 32 64 128 256 512 0' 1 0 0 15
1515

1616
# apply to test (not pretrain)
17-
bash train_on_cn_novel_origin.sh '16 256 1024 0' 1 0 1 15 # non pre-train
18-
bash train_on_cn_novel_reorder_shuffle.sh '16 256 1024 0' 1 0 1 15
19-
bash train_on_cn_novel_char_deduplicate.sh '16 256 1024 0' 1 0 1 15
20-
bash train_on_cn_novel_reorder_freq_high2low.sh '16 256 1024 0' 1 0 1 15
21-
bash train_on_cn_novel_reorder_freq_low2high.sh '16 256 1024 0' 1 0 1 15
22-
bash train_on_cn_novel_reorder_shuffle+deduplicate.sh '16 256 1024 0' 1 0 1 15
23-
bash train_on_cn_novel_likelihood_rank.sh '16 256 1024 0' 1 0 1 15
17+
bash train_on_cn_novel_origin.sh '16 32 64 128 256 512 0' 1 0 1 15 # non pre-train
18+
bash train_on_cn_novel_reorder_shuffle.sh '16 32 64 128 256 512 0' 1 0 1 15
19+
bash train_on_cn_novel_char_deduplicate.sh '16 32 64 128 256 512 0' 1 0 1 15
20+
#bash train_on_cn_novel_reorder_freq_high2low.sh '16 32 64 128 256 512 0' 1 0 1 15
21+
#bash train_on_cn_novel_reorder_freq_low2high.sh '16 32 64 128 256 512 0' 1 0 1 15
22+
bash train_on_cn_novel_reorder_shuffle+deduplicate.sh '16 32 64 128 256 512 0' 1 0 1 15
23+
2424

2525
#char_freq_ranges=${1:-0}
2626
#is_debug=${2:-0}

app/run_en_all_no_spacy.sh

+10-10
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,18 @@
1010

1111

1212
# Grover
13-
# bash train_on_en.sh 'en_grover' '400 800 1600 0' 1 0 0 15 1 'None'
14-
# bash train_on_en.sh 'en_grover' '400 800 1600 0' 1 0 1 15 1 'None' # NON PRETRAIN
15-
# bash train_on_en.sh 'en_grover' '16 256 1024 0' 1 0 0 15 1 'reorder_shuffle char_deduplicate'
16-
# bash train_on_en.sh 'en_grover' '16 256 1024 0' 1 0 0 15 1 'reorder_freq_low2high'
17-
# bash train_on_en.sh 'en_grover' '16 256 1024 0' 1 0 0 15 1 'reorder_freq_high2low'
13+
# bash train_on_en.sh 'en_grover' '400 800 1600 3200 6400 12800 0' 1 0 0 15 1 'None'
14+
# bash train_on_en.sh 'en_grover' '400 800 1600 3200 6400 12800 0' 1 0 0 15 1 'char_deduplicate'
15+
# bash train_on_en.sh 'en_grover' '400 800 1600 3200 6400 12800 0' 1 0 0 15 1 'reorder_shuffle'
16+
# bash train_on_en.sh 'en_grover' '400 800 1600 3200 6400 12800 0' 1 0 0 15 1 'reorder_shuffle char_deduplicate'
17+
# bash train_on_en.sh 'en_grover' '400 800 1600 3200 6400 12800 0' 1 0 0 15 1 'likelihood_rank'
18+
1819

1920
# en_writing_prompt
20-
# bash train_on_en.sh 'en_writing_prompt' '20 40 80 160 0' 1 0 0 15 1 'None'
21-
# bash train_on_en.sh 'en_writing_prompt' '20 40 80 160 0' 1 0 1 15 1 'None' # NON PRETRAIN
22-
# bash train_on_en.sh 'en_writing_prompt' '16 256 1024 0' 1 0 0 15 1 'reorder_shuffle char_deduplicate'
23-
# bash train_on_en.sh 'en_writing_prompt' '16 256 1024 0' 1 0 0 15 1 'reorder_freq_low2high' # a8c02b65
24-
# bash train_on_en.sh 'en_writing_prompt' '16 256 1024 0' 1 0 0 15 1 'reorder_freq_high2low'
21+
# bash train_on_en.sh 'en_writing_prompt' '10 20 40 80 160 320 0' 1 0 0 15 1 'None'
22+
# bash train_on_en.sh 'en_writing_prompt' '10 20 40 80 160 320 0' 1 0 0 15 1 'char_deduplicate'
23+
# bash train_on_en.sh 'en_writing_prompt' '10 20 40 80 160 320 0' 1 0 0 15 1 'reorder_shuffle'
24+
# bash train_on_en.sh 'en_writing_prompt' '10 20 40 80 160 320 0' 1 0 0 15 1 'reorder_shuffle char_deduplicate'
2525

2626
dataset_name=${1:-0}
2727
is_debug=${2:-0}

app/run_story_interpret_all.sh

+20
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# bash run_story_interpret_all.sh
2+
3+
set -e
4+
#bash save_model_for_interpret.sh en_grover en_roberta 0
5+
#bash save_model_for_interpret.sh en_writing_prompt en_roberta 0
6+
# bash run_story_interpret.sh 500 1 128 cn_novel_5billion interpret_cn_novel_5billion_cn_roberta_debug_0 bert 100 1
7+
# bash run_story_interpret.sh 800 1 256 en_grover interpret_en_grover_en_roberta_debug_0 roberta 40 1
8+
# bash run_story_interpret.sh 800 1 128 en_writing_prompt interpret_en_writing_prompt_en_roberta_debug_0 roberta 100 1
9+
10+
#bash run_story_interpret.sh 10000 1 128 cn_novel_5billion interpret_cn_novel_5billion_cn_roberta_debug_0 bert 100 0
11+
#bash run_story_interpret.sh 10000 1 256 en_grover interpret_en_grover_en_roberta_debug_0 roberta 40 0
12+
#bash run_story_interpret.sh 10000 1 128 en_writing_prompt interpret_en_writing_prompt_en_roberta_debug_0 roberta 100 0
13+
14+
#bash run_story_interpret.sh 10000 1 256 en_grover interpret_en_grover_en_roberta_debug_0 roberta 500 0
15+
#bash run_story_interpret.sh 10000 1 128 cn_novel_5billion interpret_cn_novel_5billion_cn_roberta_debug_0 bert 500 0
16+
#bash run_story_interpret.sh 10000 1 128 en_writing_prompt interpret_en_writing_prompt_en_roberta_debug_0 roberta 500 0
17+
18+
bash run_story_interpret.sh 10000 1 256 en_grover interpret_en_grover_en_roberta_debug_0 roberta 500 1
19+
bash run_story_interpret.sh 10000 1 128 cn_novel_5billion interpret_cn_novel_5billion_cn_roberta_debug_0 bert 500 1
20+
bash run_story_interpret.sh 10000 1 128 en_writing_prompt interpret_en_writing_prompt_en_roberta_debug_0 roberta 500 1

core/interpreter.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -110,12 +110,14 @@ def interpret_encoded_inputs(self,
110110
target=all_1_labels,
111111
baselines=all_pad_embedding,
112112
n_steps=self.n_steps,
113-
return_convergence_delta=True)
113+
return_convergence_delta=True,
114+
internal_batch_size=32)
114115
else:
115116
label1_attributions_ig, label1_delta = self.ig.attribute(inputs=input_embedding,
116117
target=all_1_labels,
117118
n_steps=self.n_steps,
118-
return_convergence_delta=True)
119+
return_convergence_delta=True,
120+
internal_batch_size=32)
119121
label1_attributions_ig = label1_attributions_ig.detach().cpu()
120122
label1_delta = label1_delta.detach().cpu()
121123
# print("label1_delta: ", label1_delta)

latex/LREC_corruption_example.tgn

+1-1
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)