Skip to content

Commit 564fbd1

Browse files
author
Chris Swierczewski
committed
Add document-topic plotting function
1 parent beb1661 commit 564fbd1

File tree

1 file changed

+45
-11
lines changed

1 file changed

+45
-11
lines changed

lda_topic_modeling/generate_example_data.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import scipy as sp
66
import scipy.stats
77

8+
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
9+
810
def generate_griffiths_data(num_documents=5000, average_document_length=150,
911
num_topics=5, alpha=None, eta=None, seed=0):
1012
"""Returns example documents from Griffiths-Steyvers [1].
@@ -46,7 +48,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
4648
theta : Numpy NDArray
4749
A matrix of size `num_documents` x `num_topics` equal to the topic
4850
mixtures used to generate the output `documents`.
49-
51+
5052
References
5153
----------
5254
[1] Thomas L Griffiths and Mark Steyvers. "Finding Scientific Topics."
@@ -56,7 +58,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
5658
"""
5759
vocabulary_size = 25
5860
image_dim = np.int(np.sqrt(vocabulary_size))
59-
61+
6062
# perform checks on input
6163
assert num_topics in [5,10], 'Example data only available for 5 or 10 topics'
6264
if alpha:
@@ -75,7 +77,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
7577
dirichlet_eta = sp.stats.dirichlet(eta)
7678

7779
# initialize a known topic-word distribution (beta) using eta. these are
78-
# the "row" and "column" topics, respectively. when num_topics = 5 only
80+
# the "row" and "column" topics, respectively. when num_topics = 5 only
7981
# create the col topics. when num_topics = 10 add the row topics as well
8082
#
8183
beta = np.zeros((num_topics,image_dim,image_dim), dtype=np.float)
@@ -111,22 +113,22 @@ def plot_lda(data, nrows, ncols, with_colorbar=True, cmap=cm.viridis):
111113
fig, ax = plt.subplots(nrows, ncols, figsize=(ncols,nrows))
112114
vmin = 0
113115
vmax = data.max()
114-
116+
115117
V = len(data[0])
116118
n = int(np.sqrt(V))
117119
for i in range(nrows):
118120
for j in range(ncols):
119121
index = i*ncols + j
120-
122+
121123
if nrows > 1:
122124
im = ax[i,j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax)
123125
else:
124126
im = ax[j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax)
125-
127+
126128
for axi in ax.ravel():
127129
axi.set_xticks([])
128130
axi.set_yticks([])
129-
131+
130132
if with_colorbar:
131133
fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', fraction=0.2)
132134
return fig
@@ -136,18 +138,50 @@ def match_estimated_topics(topics_known, topics_estimated):
136138
K, V = topics_known.shape
137139
permutation = -1*np.ones(K, dtype=np.int)
138140
unmatched_estimated_topics = []
139-
141+
140142
for estimated_topic_index, t in enumerate(topics_estimated):
141143
matched_known_topic_index = np.argmin([np.linalg.norm(known_topic - t) for known_topic in topics_known])
142144
if permutation[matched_known_topic_index] == -1:
143145
permutation[matched_known_topic_index] = estimated_topic_index
144146
else:
145147
unmatched_estimated_topics.append(estimated_topic_index)
146-
148+
147149
for estimated_topic_index in unmatched_estimated_topics:
148150
for i in range(K):
149151
if permutation[i] == -1:
150152
permutation[i] = estimated_topic_index
151153
break
152-
153-
return permutation, (topics_estimated[permutation,:]).copy()
154+
155+
return permutation, (topics_estimated[permutation,:]).copy()
156+
157+
def _document_with_topic(fig, gsi, index, document, topic_mixture=None,
158+
vmin=0, vmax=32):
159+
ax_doc = fig.add_subplot(gsi[:5,:])
160+
ax_doc.matshow(document.reshape(5,5), cmap='gray_r',
161+
vmin=vmin, vmax=vmax)
162+
ax_doc.set_xticks([])
163+
ax_doc.set_yticks([])
164+
165+
if topic_mixture is not None:
166+
ax_topic = plt.subplot(gsi[-1,:])
167+
ax_topic.matshow(topic_mixture.reshape(1,-1), cmap='Reds',
168+
vmin=0, vmax=1)
169+
ax_topic.set_xticks([])
170+
ax_topic.set_yticks([])
171+
172+
def plot_lda_topics(documents, nrows, ncols, with_colorbar=True,
173+
topic_mixtures=None, cmap='Viridis', dpi=160):
174+
fig = plt.figure()
175+
gs = GridSpec(nrows, ncols)
176+
177+
vmin, vmax = (0, documents.max())
178+
179+
for i in range(nrows):
180+
for j in range(ncols):
181+
index = i*ncols + j
182+
gsi = GridSpecFromSubplotSpec(6, 5, subplot_spec=gs[i,j])
183+
_document_with_topic(fig, gsi, index, documents[index],
184+
topic_mixture=topic_mixtures[index],
185+
vmin=vmin, vmax=vmax)
186+
187+
return fig

0 commit comments

Comments
 (0)