forked from rujunhan/ConditionalEmbeddings
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtoy_corpus_eval.py
106 lines (85 loc) · 4.16 KB
/
toy_corpus_eval.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Compare the dot products of the embeddings to the PMI matrix entries
import argparse
import os
from pathlib import Path
from argparse import Namespace
import numpy as np
import pandas as pd
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('pdf')
import seaborn as sns
from bias_utils import load_BBB_nonzero
def main(args):
bbb_vecs = load_BBB_nonzero(
input_dir=os.path.join(args.base_dir, f'data/{args.name}/results'), file_stamp=args.file_stamp,
run_id=args.run_id, only_nonzero=False, match_vectors=None)
# Generate PMI matrix
# python makecooccur.py --data ConditionalEmbeddings/data/ToySun/cooccur --info ConditionalEmbeddings/data/ToySun/info --type word --window-size 2 --out ConditionalEmbeddings/data/ToySun/cooccurs --start 1990 --end 2000 --step 10
# python PMI_compute.py -vectors None -wlist_dir ConditionalEmbeddings/data/ToySun/results/PMI -bin_dir ConditionalEmbeddings/data/ToySun/cooccurs/word/2 -word_dict_pkl ConditionalEmbeddings/data/ToySun/info/word-dict.pkl -output_dir ConditionalEmbeddings/data/ToySun/results/PMI
pmi_mat = pd.read_csv(os.path.join(args.base_dir, 'data', args.name, 'results', 'PMI', 'pmi.csv'))
pmi_matrices = {}
for decade in pmi_mat['decade'].unique():
pmi_decade = pmi_mat.loc[pmi_mat['decade'] == decade].copy()
pmi_decade['PMI(w,k)'] = pmi_decade.apply(
lambda row: np.log(row['#wc'] * row['D'] / (row['#w'] * row['#c'])), axis=1)
# Drop the diagonal
#pmi_decade = pmi_decade.loc[pmi_decade['w_idx'] != pmi_decade['c_idx']]
pmi_decade = pmi_decade.pivot_table(index=['w_idx'], columns='c_idx', values='PMI(w,k)')
pmi_matrices[str(decade)] = pmi_decade.copy()
# Create matrix of dot products
dot_matrices = {}
for decade, model in bbb_vecs.items():
# Normalize
#model.init_sims(replace=True)
m = np.dot(model.vectors, model.vectors.T)
dot_matrices[decade] = m
# Single decade
select_decade = '1990'
df = pd.concat(
[pd.DataFrame(dot_matrices[select_decade].reshape(-1, )),
pd.DataFrame(pmi_matrices[select_decade].to_numpy().reshape(-1, ))], axis=1)
df.corr()
# Heat map
a = pd.DataFrame(dot_matrices[select_decade])
a['i'] = np.arange(a.shape[0])
a = pd.melt(a, id_vars=['i'], var_name='j', value_name='value')
a['type'] = 'Dot Product'
b = pmi_matrices[select_decade].copy()
b['i'] = np.arange(b.shape[0])
b = pd.melt(b, id_vars=['i'], var_name='j', value_name='value')
b['type'] = 'PMI'
df = pd.concat([a, b])
# Add word names
vocab = pmi_mat.groupby(['w_idx', 'w']).size().reset_index()[['w_idx', 'w']]
vocab = vocab.to_dict('split')['data']
vocab = {i:w for i, w in vocab}
df['w_i'] = df['i'].apply(lambda i: vocab[i])
df['w_j'] = df['j'].apply(lambda j: vocab[j])
vmin, vmax = df['value'].min(), df['value'].max()
def facet_heatmap(data, color, **kws):
data = data.pivot_table(values='value', index='w_i', columns='w_j')
sns.heatmap(data, cbar=True, vmin=vmin, vmax=vmax, cmap="vlag", center=0)
g = sns.FacetGrid(df, col='type')
g.map_dataframe(facet_heatmap)
g.set_titles(row_template="{row_name}", col_template='{col_name}')
g.fig.suptitle('')
g.set_xlabels('')
g.set_ylabels('')
g.figure.savefig(os.path.join(args.output_dir, f"pmi_dot_eval.png"), dpi=800)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-run_id", type=str, required=True)
parser.add_argument("-name", type=str, required=True)
parser.add_argument("-run_location", type=str, choices=['local', 'sherlock'])
parser.add_argument("-base_dir", type=str, required=False)
parser.add_argument("-output_dir", type=str, required=False)
parser.add_argument("-file_stamp", type=str, required=False)
args = parser.parse_args()
if args.run_location == 'sherlock':
args.base_dir = Path('/oak/stanford/groups/deho/legal_nlp/WEB')
elif args.run_location == 'local':
args.base_dir = Path(__file__).parent
args.file_stamp = args.name
args.output_dir = os.path.join(args.base_dir, 'data', args.name, 'results')
main(args)