|
| 1 | +import matplotlib |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import matplotlib.cm as cm |
| 4 | +import numpy as np |
| 5 | +import scipy as sp |
| 6 | +import scipy.stats |
| 7 | + |
| 8 | +from matplotlib.gridspec import GridSpec, GridSpecFromSubplotSpec |
| 9 | + |
| 10 | +def generate_griffiths_data(num_documents=5000, average_document_length=150, |
| 11 | + num_topics=5, alpha=None, eta=None, seed=0): |
| 12 | + """Returns example documents from Griffiths-Steyvers [1]. |
| 13 | +
|
| 14 | + Given an `alpha` and `eta, the Dirichlet priors for the topic and topic-word |
| 15 | + distributions respectively, this function generates sample document word |
| 16 | + counts according to the Latent Dirichlet Allocation (LDA) model. |
| 17 | +
|
| 18 | + Parameters |
| 19 | + ---------- |
| 20 | + num_documents : int |
| 21 | + (Default: 1000) The number of example documents to create using LDA. |
| 22 | + average_document_length : int |
| 23 | + (Default: 100) The average number of words in each document. The |
| 24 | + document length is sampled from a Poisson distribution with this mean. |
| 25 | + num_topics : int |
| 26 | + (Default: 10) Can be set to either 5 or 10. The number of known topics. |
| 27 | + alpha : Numpy NDArray |
| 28 | + (Default: None) An array of length `num_topics` representing a given |
| 29 | + Dirichlet topic prior. If `None` is provided then a uniform |
| 30 | + distribution will be used. |
| 31 | + eta : Numpy NDArray |
| 32 | + (Default: None) An array of length `num_topics` representing a given |
| 33 | + Dirichlet topic-word prior. |
| 34 | + seed : int |
| 35 | + (Defualt: 0) The random number generator seed. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + alpha : Numpy NDArray |
| 40 | + A vector of length `num_topics` equal to the Dirichlet prior used to |
| 41 | + generate documents. |
| 42 | + beta : Numpy NDArray |
| 43 | + A matrix of size `num_topics` x 25 equal to the topic-word probability |
| 44 | + matrix used to generate documents. |
| 45 | + documents : Numpy NDArray |
| 46 | + A matrix of size `num_documents` x 25 equal to the documents generated |
| 47 | + by the LDA model defined by `alpha` and `beta. |
| 48 | + theta : Numpy NDArray |
| 49 | + A matrix of size `num_documents` x `num_topics` equal to the topic |
| 50 | + mixtures used to generate the output `documents`. |
| 51 | +
|
| 52 | + References |
| 53 | + ---------- |
| 54 | + [1] Thomas L Griffiths and Mark Steyvers. "Finding Scientific Topics." |
| 55 | + Proceedings of the National Academy of Sciences, 101(suppl 1):5228–5235, |
| 56 | + 2004. |
| 57 | +
|
| 58 | + """ |
| 59 | + vocabulary_size = 25 |
| 60 | + image_dim = np.int(np.sqrt(vocabulary_size)) |
| 61 | + |
| 62 | + # perform checks on input |
| 63 | + assert num_topics in [5,10], 'Example data only available for 5 or 10 topics' |
| 64 | + if alpha: |
| 65 | + assert len(alpha) == num_topics, 'len(alpha) must be equal to num_topics' |
| 66 | + |
| 67 | + # initialize Dirichlet alpha and eta distributions if not provided. here, |
| 68 | + # the eta distribution is only across `image_dim` elements since each |
| 69 | + # topic-word distribution will only have `image_dim` non-zero entries |
| 70 | + # |
| 71 | + np.random.seed(seed=seed) |
| 72 | + if alpha is None: |
| 73 | + alpha = np.ones(num_topics, dtype=np.float) / num_topics |
| 74 | + if eta is None: |
| 75 | + eta = [100]*image_dim # make it close to a uniform distribution |
| 76 | + dirichlet_alpha = sp.stats.dirichlet(alpha) |
| 77 | + dirichlet_eta = sp.stats.dirichlet(eta) |
| 78 | + |
| 79 | + # initialize a known topic-word distribution (beta) using eta. these are |
| 80 | + # the "row" and "column" topics, respectively. when num_topics = 5 only |
| 81 | + # create the col topics. when num_topics = 10 add the row topics as well |
| 82 | + # |
| 83 | + beta = np.zeros((num_topics,image_dim,image_dim), dtype=np.float) |
| 84 | + for i in range(image_dim): |
| 85 | + beta[i,:,i] = dirichlet_eta.rvs(size=1) |
| 86 | + if num_topics == 10: |
| 87 | + for i in range(image_dim): |
| 88 | + beta[i+image_dim,i,:] = dirichlet_eta.rvs(size=1) |
| 89 | + beta.resize(num_topics, vocabulary_size) |
| 90 | + |
| 91 | + # generate documents using the LDA model / provess |
| 92 | + # |
| 93 | + document_lengths = sp.stats.poisson(average_document_length).rvs(size=num_documents) |
| 94 | + documents = np.zeros((num_documents,vocabulary_size), dtype=np.float) |
| 95 | + thetas = dirichlet_alpha.rvs(size=num_documents) # precompute topic distributions for performance |
| 96 | + for m in range(num_documents): |
| 97 | + document_length = document_lengths[m] |
| 98 | + theta = thetas[m] |
| 99 | + topic = sp.stats.multinomial.rvs(1, theta, size=document_length) # precompute topics for performance |
| 100 | + |
| 101 | + # generate word counts within document |
| 102 | + for n in range(document_length): |
| 103 | + word_topic = topic[n] |
| 104 | + topic_index = np.argmax(word_topic) |
| 105 | + topic_word_distribution = beta[topic_index] |
| 106 | + word = sp.stats.multinomial.rvs(1, topic_word_distribution, size=1).reshape(vocabulary_size) |
| 107 | + documents[m] += word |
| 108 | + |
| 109 | + return alpha, beta, documents, thetas |
| 110 | + |
| 111 | +def plot_lda(data, nrows, ncols, with_colorbar=True, cmap=cm.viridis): |
| 112 | + """Helper function for plotting arrays of image""" |
| 113 | + fig, ax = plt.subplots(nrows, ncols, figsize=(ncols,nrows)) |
| 114 | + vmin = 0 |
| 115 | + vmax = data.max() |
| 116 | + |
| 117 | + V = len(data[0]) |
| 118 | + n = int(np.sqrt(V)) |
| 119 | + for i in range(nrows): |
| 120 | + for j in range(ncols): |
| 121 | + index = i*ncols + j |
| 122 | + |
| 123 | + if nrows > 1: |
| 124 | + im = ax[i,j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax) |
| 125 | + else: |
| 126 | + im = ax[j].matshow(data[index].reshape(n,n), cmap=cmap, vmin=vmin, vmax=vmax) |
| 127 | + |
| 128 | + for axi in ax.ravel(): |
| 129 | + axi.set_xticks([]) |
| 130 | + axi.set_yticks([]) |
| 131 | + |
| 132 | + if with_colorbar: |
| 133 | + fig.colorbar(im, ax=ax.ravel().tolist(), orientation='horizontal', fraction=0.2) |
| 134 | + return fig |
| 135 | + |
| 136 | +def match_estimated_topics(topics_known, topics_estimated): |
| 137 | + """A dumb but fast way to match known topics to estimated topics""" |
| 138 | + K, V = topics_known.shape |
| 139 | + permutation = -1*np.ones(K, dtype=np.int) |
| 140 | + unmatched_estimated_topics = [] |
| 141 | + |
| 142 | + for estimated_topic_index, t in enumerate(topics_estimated): |
| 143 | + matched_known_topic_index = np.argmin([np.linalg.norm(known_topic - t) for known_topic in topics_known]) |
| 144 | + if permutation[matched_known_topic_index] == -1: |
| 145 | + permutation[matched_known_topic_index] = estimated_topic_index |
| 146 | + else: |
| 147 | + unmatched_estimated_topics.append(estimated_topic_index) |
| 148 | + |
| 149 | + for estimated_topic_index in unmatched_estimated_topics: |
| 150 | + for i in range(K): |
| 151 | + if permutation[i] == -1: |
| 152 | + permutation[i] = estimated_topic_index |
| 153 | + break |
| 154 | + |
| 155 | + return permutation, (topics_estimated[permutation,:]).copy() |
| 156 | + |
| 157 | +def _document_with_topic(fig, gsi, index, document, topic_mixture=None, |
| 158 | + vmin=0, vmax=32): |
| 159 | + ax_doc = fig.add_subplot(gsi[:5,:]) |
| 160 | + ax_doc.matshow(document.reshape(5,5), cmap='gray_r', |
| 161 | + vmin=vmin, vmax=vmax) |
| 162 | + ax_doc.set_xticks([]) |
| 163 | + ax_doc.set_yticks([]) |
| 164 | + |
| 165 | + if topic_mixture is not None: |
| 166 | + ax_topic = plt.subplot(gsi[-1,:]) |
| 167 | + ax_topic.matshow(topic_mixture.reshape(1,-1), cmap='Reds', |
| 168 | + vmin=0, vmax=1) |
| 169 | + ax_topic.set_xticks([]) |
| 170 | + ax_topic.set_yticks([]) |
| 171 | + |
| 172 | +def plot_lda_topics(documents, nrows, ncols, with_colorbar=True, |
| 173 | + topic_mixtures=None, cmap='Viridis', dpi=160): |
| 174 | + fig = plt.figure() |
| 175 | + gs = GridSpec(nrows, ncols) |
| 176 | + |
| 177 | + vmin, vmax = (0, documents.max()) |
| 178 | + |
| 179 | + for i in range(nrows): |
| 180 | + for j in range(ncols): |
| 181 | + index = i*ncols + j |
| 182 | + gsi = GridSpecFromSubplotSpec(6, 5, subplot_spec=gs[i,j]) |
| 183 | + _document_with_topic(fig, gsi, index, documents[index], |
| 184 | + topic_mixture=topic_mixtures[index], |
| 185 | + vmin=vmin, vmax=vmax) |
| 186 | + |
| 187 | + return fig |
0 commit comments