5
5
import scipy as sp
6
6
import scipy .stats
7
7
8
+ from matplotlib .gridspec import GridSpec , GridSpecFromSubplotSpec
9
+
8
10
def generate_griffiths_data (num_documents = 5000 , average_document_length = 150 ,
9
11
num_topics = 5 , alpha = None , eta = None , seed = 0 ):
10
12
"""Returns example documents from Griffiths-Steyvers [1].
@@ -46,7 +48,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
46
48
theta : Numpy NDArray
47
49
A matrix of size `num_documents` x `num_topics` equal to the topic
48
50
mixtures used to generate the output `documents`.
49
-
51
+
50
52
References
51
53
----------
52
54
[1] Thomas L Griffiths and Mark Steyvers. "Finding Scientific Topics."
@@ -56,7 +58,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
56
58
"""
57
59
vocabulary_size = 25
58
60
image_dim = np .int (np .sqrt (vocabulary_size ))
59
-
61
+
60
62
# perform checks on input
61
63
assert num_topics in [5 ,10 ], 'Example data only available for 5 or 10 topics'
62
64
if alpha :
@@ -75,7 +77,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
75
77
dirichlet_eta = sp .stats .dirichlet (eta )
76
78
77
79
# initialize a known topic-word distribution (beta) using eta. these are
78
- # the "row" and "column" topics, respectively. when num_topics = 5 only
80
+ # the "row" and "column" topics, respectively. when num_topics = 5 only
79
81
# create the col topics. when num_topics = 10 add the row topics as well
80
82
#
81
83
beta = np .zeros ((num_topics ,image_dim ,image_dim ), dtype = np .float )
@@ -111,22 +113,22 @@ def plot_lda(data, nrows, ncols, with_colorbar=True, cmap=cm.viridis):
111
113
fig , ax = plt .subplots (nrows , ncols , figsize = (ncols ,nrows ))
112
114
vmin = 0
113
115
vmax = data .max ()
114
-
116
+
115
117
V = len (data [0 ])
116
118
n = int (np .sqrt (V ))
117
119
for i in range (nrows ):
118
120
for j in range (ncols ):
119
121
index = i * ncols + j
120
-
122
+
121
123
if nrows > 1 :
122
124
im = ax [i ,j ].matshow (data [index ].reshape (n ,n ), cmap = cmap , vmin = vmin , vmax = vmax )
123
125
else :
124
126
im = ax [j ].matshow (data [index ].reshape (n ,n ), cmap = cmap , vmin = vmin , vmax = vmax )
125
-
127
+
126
128
for axi in ax .ravel ():
127
129
axi .set_xticks ([])
128
130
axi .set_yticks ([])
129
-
131
+
130
132
if with_colorbar :
131
133
fig .colorbar (im , ax = ax .ravel ().tolist (), orientation = 'horizontal' , fraction = 0.2 )
132
134
return fig
@@ -136,18 +138,50 @@ def match_estimated_topics(topics_known, topics_estimated):
136
138
K , V = topics_known .shape
137
139
permutation = - 1 * np .ones (K , dtype = np .int )
138
140
unmatched_estimated_topics = []
139
-
141
+
140
142
for estimated_topic_index , t in enumerate (topics_estimated ):
141
143
matched_known_topic_index = np .argmin ([np .linalg .norm (known_topic - t ) for known_topic in topics_known ])
142
144
if permutation [matched_known_topic_index ] == - 1 :
143
145
permutation [matched_known_topic_index ] = estimated_topic_index
144
146
else :
145
147
unmatched_estimated_topics .append (estimated_topic_index )
146
-
148
+
147
149
for estimated_topic_index in unmatched_estimated_topics :
148
150
for i in range (K ):
149
151
if permutation [i ] == - 1 :
150
152
permutation [i ] = estimated_topic_index
151
153
break
152
-
153
- return permutation , (topics_estimated [permutation ,:]).copy ()
154
+
155
+ return permutation , (topics_estimated [permutation ,:]).copy ()
156
+
157
+ def _document_with_topic (fig , gsi , index , document , topic_mixture = None ,
158
+ vmin = 0 , vmax = 32 ):
159
+ ax_doc = fig .add_subplot (gsi [:5 ,:])
160
+ ax_doc .matshow (document .reshape (5 ,5 ), cmap = 'gray_r' ,
161
+ vmin = vmin , vmax = vmax )
162
+ ax_doc .set_xticks ([])
163
+ ax_doc .set_yticks ([])
164
+
165
+ if topic_mixture is not None :
166
+ ax_topic = plt .subplot (gsi [- 1 ,:])
167
+ ax_topic .matshow (topic_mixture .reshape (1 ,- 1 ), cmap = 'Reds' ,
168
+ vmin = 0 , vmax = 1 )
169
+ ax_topic .set_xticks ([])
170
+ ax_topic .set_yticks ([])
171
+
172
+ def plot_lda_topics (documents , nrows , ncols , with_colorbar = True ,
173
+ topic_mixtures = None , cmap = 'Viridis' , dpi = 160 ):
174
+ fig = plt .figure ()
175
+ gs = GridSpec (nrows , ncols )
176
+
177
+ vmin , vmax = (0 , documents .max ())
178
+
179
+ for i in range (nrows ):
180
+ for j in range (ncols ):
181
+ index = i * ncols + j
182
+ gsi = GridSpecFromSubplotSpec (6 , 5 , subplot_spec = gs [i ,j ])
183
+ _document_with_topic (fig , gsi , index , documents [index ],
184
+ topic_mixture = topic_mixtures [index ],
185
+ vmin = vmin , vmax = vmax )
186
+
187
+ return fig
0 commit comments