Skip to content

Commit 93e87c1

Browse files
author
Junpeng Lao
committed
Add test for _repr_latex_
with small bug fix in MvStudentT
1 parent 2def14c commit 93e87c1

File tree

2 files changed

+17
-3
lines changed

2 files changed

+17
-3
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ class MvStudentT(Continuous):
292292
def __init__(self, nu, Sigma, mu=None, *args, **kwargs):
293293
super(MvStudentT, self).__init__(*args, **kwargs)
294294
self.nu = nu = tt.as_tensor_variable(nu)
295-
self.mu = tt.zeros(Sigma.shape[0]) if mu is None else tt.as_tensor_variable(mu)
295+
mu = tt.zeros(Sigma.shape[0]) if mu is None else tt.as_tensor_variable(mu)
296296
self.Sigma = Sigma = tt.as_tensor_variable(Sigma)
297297

298298
self.mean = self.median = self.mode = self.mu = mu

pymc3/tests/test_distributions.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
from ..model import Model, Point, Potential
88
from ..blocking import DictToVarBijection, DictToArrayBijection, ArrayOrdering
99
from ..distributions import (DensityDist, Categorical, Multinomial, VonMises, Dirichlet,
10-
MvStudentT, MvNormal, ZeroInflatedPoisson,
10+
MvStudentT, MvNormal, ZeroInflatedPoisson, GaussianRandomWalk,
1111
ZeroInflatedNegativeBinomial, Constant, Poisson, Bernoulli, Beta,
12-
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto,
12+
BetaBinomial, HalfStudentT, StudentT, Weibull, Pareto, NormalMixture,
1313
InverseGamma, Gamma, Cauchy, HalfCauchy, Lognormal, Laplace,
1414
NegativeBinomial, Geometric, Exponential, ExGaussian, Normal,
1515
Flat, LKJCorr, Wald, ChiSquared, HalfNormal, DiscreteUniform,
@@ -822,3 +822,17 @@ def ref_pdf(value):
822822
)
823823

824824
self.pymc3_matches_scipy(TestedInterpolated, R, {}, ref_pdf)
825+
826+
827+
def test_repr_latex_():
828+
with Model():
829+
x0 = Binomial('Discrete', p=.5, n=10)
830+
x1 = Normal('Continuous', mu=0., sd=1.)
831+
x2 = GaussianRandomWalk('Timeseries', mu=x1, sd=1., shape=2)
832+
x3 = MvStudentT('Multivariate', nu=5, mu=x2, Sigma=np.diag(np.ones(2)), shape=2)
833+
x4 = NormalMixture('Mixture', w=np.array([.5, .5]), mu=x3, sd=x0)
834+
assert x0._repr_latex_()=='$Discrete \\sim \\text{Binomial}(\\mathit{n}=10, \\mathit{p}=0.5)$'
835+
assert x1._repr_latex_()=='$Continuous \\sim \\text{Normal}(\\mathit{mu}=0.0, \\mathit{sd}=1.0)$'
836+
assert x2._repr_latex_()=='$Timeseries \\sim \\text{GaussianRandomWalk}(\\mathit{mu}=Continuous, \\mathit{sd}=1.0)$'
837+
assert x3._repr_latex_()=='$Multivariate \\sim \\text{MvStudentT}(\\mathit{nu}=5, \\mathit{mu}=Timeseries, \\mathit{Sigma}=array)$'
838+
assert x4._repr_latex_()=='$Mixture \\sim \\text{NormalMixture}(\\mathit{w}=array, \\mathit{mu}=Multivariate, \\mathit{sigma}=f(Discrete))$'

0 commit comments

Comments
 (0)