Skip to content

Commit d513837

Browse files
authored
Merge pull request aws#39 from awslabs/lda_topic_modeling
Final Drafts of LDA Topic Modeling Notebooks
2 parents 7ccb202 + 3477489 commit d513837

File tree

13 files changed

+1778
-1723
lines changed

13 files changed

+1778
-1723
lines changed

introduction_to_amazon_algorithms/README.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33
This directory includes introductory examples to Amazon SageMaker Algorithms that we have developed so far. It seeks to provide guidance and examples on basic functionality rather than a detailed scientific review or an implementation on complex, real-world data.
44

55
Example Notebooks include:
6-
- *linear_mnist*: Predicts whether a handwritten digit from the MNIST dataset is a 0 or not using a binary classifier from Amazon SageMaker Linear Learner.
76
- *factorization_machines_mnist*: Predicts whether a handwritten digit from the MNIST dataset is a 0 or not using a binary classifier from Amazon SageMaker Factorization Machines.
8-
- *pca_mnist*: Uses Amazon SageMaker Principal Components Analysis (PCA) to calculate eigendigits from MNIST.
7+
- *lda_topic_modeling*: Topic modeling using Amazon SageMaker Latent Dirichlet Allocation (LDA) on a synthetic dataset.
8+
- *linear_mnist*: Predicts whether a handwritten digit from the MNIST dataset is a 0 or not using a binary classifier from Amazon SageMaker Linear Learner.
99
- *ntm_synthetic*: Uses Amazon SageMaker Neural Topic Model (NTM) to uncover topics in documents from a synthetic data source, where topic distributions are known.
10-
- *xgboost_mnist*: Uses Amazon SageMaker XGBoost to classifiy handwritten digits from the MNIST dataset into one of the ten digits using a multi-class classifier. Both single machine and distributed use-cases are presented.
11-
- *xgboost_abalone*: Predicts the age of abalone ([Abalone dataset](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html)) using regression from Amazon SageMaker XGBoost.
10+
- *pca_mnist*: Uses Amazon SageMaker Principal Components Analysis (PCA) to calculate eigendigits from MNIST.
1211
- *seq2seq*: Seq2Seq algorithm is built on top of [Sockeye](https://github.com/awslabs/sockeye), a sequence-to-sequence framework for Neural Machine Translation based on MXNet. SageMaker Seq2Seq implements state-of-the-art encoder-decoder architectures which can also be used for tasks like Abstractive Summarization in addition to Machine Translation.
12+
- *xgboost_abalone*: Predicts the age of abalone ([Abalone dataset](https://www.csie.ntu.edu.tw/~cjlin/libsvmtools/datasets/regression.html)) using regression from Amazon SageMaker XGBoost.
13+
- *xgboost_mnist*: Uses Amazon SageMaker XGBoost to classifiy handwritten digits from the MNIST dataset into one of the ten digits using a multi-class classifier. Both single machine and distributed use-cases are presented.

lda_topic_modeling/LDA - Rosetta Stone.ipynb renamed to introduction_to_amazon_algorithms/lda_topic_modeling/LDA-Introduction.ipynb

Lines changed: 290 additions & 260 deletions
Large diffs are not rendered by default.

lda_topic_modeling/README.md renamed to introduction_to_amazon_algorithms/lda_topic_modeling/README.md

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,12 @@
11
# Latent Dirichlet Allocation and Topic Modeling
22

3-
Example notebooks on using Amazon SageMaker to train and use LDA models.
3+
An introductory notebook on using Amazon SageMaker to train and use LDA models.
44

55
<p align="center">
6-
<img src="https://github.com/awslabs/im-notebook-templates/blob/lda_topic_modeling/lda_topic_modeling/img/img_documents.png">
7-
<img src="https://github.com/awslabs/im-notebook-templates/blob/lda_topic_modeling/lda_topic_modeling/img/img_topics.png">
6+
<img src="https://github.com/awslabs/amazon-sagemaker-examples/blob/lda_topic_modeling/introduction_to_amazon_algorithms/lda_topic_modeling/img/img_documents.png">
7+
<img src="https://github.com/awslabs/amazon-sagemaker-examples/blob/lda_topic_modeling/introduction_to_amazon_algorithms/lda_topic_modeling/img/img_topics.png">
88
</p>
99

10-
* **LDA - Rosetta Stone** - An end-to-end example of generating training data,
11-
uploading to an S3 bucket, training an LDA model, turning the model into an
12-
endpoint, and inferring topic mixtures using the endpoint.
13-
* **LDA - Science** - A deep dive into the science of LDA using Amazon
14-
SageMaker.
15-
1610
## References
1711

1812
The example used in these notebooks come from the following paper:

lda_topic_modeling/generate_example_data.py renamed to introduction_to_amazon_algorithms/lda_topic_modeling/generate_example_data.py

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,20 @@
1+
# Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. A copy of the License is located at
4+
#
5+
# http://aws.amazon.com/apache2.0/
6+
#
7+
# or in the "license" file accompanying this file. This file is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License.
8+
19
import matplotlib
210
import matplotlib.pyplot as plt
311
import matplotlib.cm as cm
412
import numpy as np
513
import scipy as sp
614
import scipy.stats
715

16+
from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec
17+
818
def generate_griffiths_data(num_documents=5000, average_document_length=150,
919
num_topics=5, alpha=None, eta=None, seed=0):
1020
"""Returns example documents from Griffiths-Steyvers [1].
@@ -46,7 +56,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
4656
theta : Numpy NDArray
4757
A matrix of size `num_documents` x `num_topics` equal to the topic
4858
mixtures used to generate the output `documents`.
49-
59+
5060
References
5161
----------
5262
[1] Thomas L Griffiths and Mark Steyvers. "Finding Scientific Topics."
@@ -56,7 +66,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
5666
"""
5767
vocabulary_size = 25
5868
image_dim = np.int(np.sqrt(vocabulary_size))
59-
69+
6070
# perform checks on input
6171
assert num_topics in [5,10], 'Example data only available for 5 or 10 topics'
6272
if alpha:
@@ -75,7 +85,7 @@ def generate_griffiths_data(num_documents=5000, average_document_length=150,
7585
dirichlet_eta = sp.stats.dirichlet(eta)
7686

7787
# initialize a known topic-word distribution (beta) using eta. these are
78-
# the "row" and "column" topics, respectively. when num_topics = 5 only
88+
# the "row" and "column" topics, respectively. when num_topics = 5 only
7989
# create the col topics. when num_topics = 10 add the row topics as well
8090
#
8191
beta = np.zeros((num_topics,image_dim,image_dim), dtype=np.float)
@@ -111,22 +121,22 @@ def plot_lda(data, nrows, ncols, with_colorbar=True, cmap=cm.viridis):
111121
fig, ax = plt.subplots(nrows, ncols, figsize=(ncols,nrows))
112122
vmin = 0
113123
vmax = data.max()
114-
124+
115125
V = len(data[0])
116126
n = int(np.sqrt(V))
117127
for i in range(nrows):
118128
for j in range(ncols):
119129
index = i*ncols + j
120-
130+
121131
if nrows > 1:
122132
im = ax[i,j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax)
123133
else:
124134
im = ax[j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax)
125-
135+
126136
for axi in ax.ravel():
127137
axi.set_xticks([])
128138
axi.set_yticks([])
129-
139+
130140
if with_colorbar:
131141
fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', fraction=0.2)
132142
return fig
@@ -136,18 +146,50 @@ def match_estimated_topics(topics_known, topics_estimated):
136146
K, V = topics_known.shape
137147
permutation = -1*np.ones(K, dtype=np.int)
138148
unmatched_estimated_topics = []
139-
149+
140150
for estimated_topic_index, t in enumerate(topics_estimated):
141151
matched_known_topic_index = np.argmin([np.linalg.norm(known_topic - t) for known_topic in topics_known])
142152
if permutation[matched_known_topic_index] == -1:
143153
permutation[matched_known_topic_index] = estimated_topic_index
144154
else:
145155
unmatched_estimated_topics.append(estimated_topic_index)
146-
156+
147157
for estimated_topic_index in unmatched_estimated_topics:
148158
for i in range(K):
149159
if permutation[i] == -1:
150160
permutation[i] = estimated_topic_index
151161
break
152-
153-
return permutation, (topics_estimated[permutation,:]).copy()
162+
163+
return permutation, (topics_estimated[permutation,:]).copy()
164+
165+
def _document_with_topic(fig, gsi, index, document, topic_mixture=None,
166+
vmin=0, vmax=32):
167+
ax_doc = fig.add_subplot(gsi[:5,:])
168+
ax_doc.matshow(document.reshape(5,5), cmap='gray_r',
169+
vmin=vmin, vmax=vmax)
170+
ax_doc.set_xticks([])
171+
ax_doc.set_yticks([])
172+
173+
if topic_mixture is not None:
174+
ax_topic = plt.subplot(gsi[-1,:])
175+
ax_topic.matshow(topic_mixture.reshape(1,-1), cmap='Reds',
176+
vmin=0, vmax=1)
177+
ax_topic.set_xticks([])
178+
ax_topic.set_yticks([])
179+
180+
def plot_lda_topics(documents, nrows, ncols, with_colorbar=True,
181+
topic_mixtures=None, cmap='Viridis', dpi=160):
182+
fig = plt.figure()
183+
gs = GridSpec(nrows, ncols)
184+
185+
vmin, vmax = (0, documents.max())
186+
187+
for i in range(nrows):
188+
for j in range(ncols):
189+
index = i*ncols + j
190+
gsi = GridSpecFromSubplotSpec(6, 5, subplot_spec=gs[i,j])
191+
_document_with_topic(fig, gsi, index, documents[index],
192+
topic_mixture=topic_mixtures[index],
193+
vmin=vmin, vmax=vmax)
194+
195+
return fig

0 commit comments

Comments
 (0)