-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathplotting.py
80 lines (66 loc) · 2.2 KB
/
plotting.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import numpy as np
import plotly
import plotly.graph_objs as go
from plotly.tools import set_credentials_file
from sklearn.manifold import TSNE
def gen_plotly_specs(datas, hover_text, cat):
data = []
for cur_cat in np.unique(cat):
idx = np.where(cat == cur_cat)[0]
cur_pts = datas[idx]
# creating scatter plot for a topic
trace = go.Scatter3d(
x=cur_pts[:, 0],
y=cur_pts[:, 1],
z=cur_pts[:, 2],
mode='markers',
marker=dict(
size=10,
line=dict(width=0.0),
opacity=0.8
),
name=cur_cat,
text=hover_text[idx]
)
data.append(trace)
return data
def tsne_plotly(data, cat, labels, source, username, api_key, seed=0,
max_points_per_category=250, max_label_length=64):
print("Plotting data...")
set_credentials_file(username=username, api_key=api_key)
model = TSNE(n_components=3, random_state=seed, verbose=1)
reduced = model.fit_transform(data)
# subsample points before tsne / plotting
if False:
new_data = []
new_cats = []
for n in np.unique(cat):
idx = np.where(cat == n)[0]
idx = idx[:max_points_per_category]
new_data.append(data[idx])
new_cats.append(np.reshape(cat[idx], [cat[idx].size, 1]))
data = np.vstack(new_data)
cat = np.vstack(new_cats)
cat = np.reshape(cat, (cat.size,))
labels = np.asarray([lbl[:max_label_length] for lbl in labels])
plot_params = [
[reduced, source, labels[cat], 'topics-scatter.html'],
[reduced, labels[cat], source, 'source-scatter.html']]
# general figure layouts these are default values
layout = go.Layout(
margin=dict(
l=0,
r=0,
b=0,
t=0
)
)
# generating figures
figures = []
for data, hover_text, cats, fname in plot_params:
fig = gen_plotly_specs(datas=data, hover_text=hover_text, cat=cats)
plotly.offline.plot({
"data": fig,
"layout": layout
}, filename=fname)
figures.append([fig, fname])