Skip to content

Commit 5d51953

Browse files
roesta07Rojan ShrestharicardoV94
authored
Fix error in docstring example for MatrixNormal (#7599)
Co-authored-by: Rojan Shrestha <[email protected]> Co-authored-by: Ricardo Vieira <[email protected]>
1 parent 6cdfc30 commit 5d51953

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

pymc/distributions/multivariate.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1791,14 +1791,21 @@ class MatrixNormal(Continuous):
17911791
Examples
17921792
--------
17931793
Define a matrixvariate normal variable for given row and column covariance
1794-
matrices::
1794+
matrices.
17951795
1796-
colcov = np.array([[1.0, 0.5], [0.5, 2]])
1797-
rowcov = np.array([[1, 0, 0], [0, 4, 0], [0, 0, 16]])
1798-
m = rowcov.shape[0]
1799-
n = colcov.shape[0]
1800-
mu = np.zeros((m, n))
1801-
vals = pm.MatrixNormal("vals", mu=mu, colcov=colcov, rowcov=rowcov)
1796+
.. code:: python
1797+
1798+
import pymc as pm
1799+
import numpy as np
1800+
import pytensor.tensor as pt
1801+
1802+
with pm.Model() as model:
1803+
colcov = np.array([[1.0, 0.5], [0.5, 2]])
1804+
rowcov = np.array([[1, 0, 0], [0, 4, 0], [0, 0, 16]])
1805+
m = rowcov.shape[0]
1806+
n = colcov.shape[0]
1807+
mu = np.zeros((m, n))
1808+
vals = pm.MatrixNormal("vals", mu=mu, colcov=colcov, rowcov=rowcov)
18021809
18031810
Above, the ith row in vals has a variance that is scaled by 4^i.
18041811
Alternatively, row or column cholesky matrices could be substituted for
@@ -1827,16 +1834,13 @@ class MatrixNormal(Continuous):
18271834
with pm.Model() as model:
18281835
# Setup right cholesky matrix
18291836
sd_dist = pm.HalfCauchy.dist(beta=2.5, shape=3)
1830-
colchol_packed = pm.LKJCholeskyCov('colcholpacked', n=3, eta=2,
1831-
sd_dist=sd_dist)
1832-
colchol = pm.expand_packed_triangular(3, colchol_packed)
1833-
1837+
colchol,_,_ = pm.LKJCholeskyCov('colchol', n=3, eta=2,sd_dist=sd_dist)
18341838
# Setup left covariance matrix
18351839
scale = pm.LogNormal('scale', mu=np.log(true_scale), sigma=0.5)
18361840
rowcov = pt.diag([scale**(2*i) for i in range(m)])
18371841
18381842
vals = pm.MatrixNormal('vals', mu=mu, colchol=colchol, rowcov=rowcov,
1839-
observed=data)
1843+
observed=data)
18401844
"""
18411845

18421846
rv_op = matrixnormal

0 commit comments

Comments
 (0)