Skip to content

Commit d8f5a4d

Browse files
committed
Update dist parameter hints
1 parent 664a447 commit d8f5a4d

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

pymc/distributions/continuous.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def polyagamma_cdf(*args, **kwargs):
8585
normal_lcdf,
8686
zvalue,
8787
)
88-
from pymc.distributions.distribution import Continuous
88+
from pymc.distributions.distribution import DIST_PARAMETER_TYPES, Continuous
8989
from pymc.distributions.shape_utils import rv_size_is_none
9090
from pymc.math import invlogit, logdiffexp, logit
9191
from pymc.util import UNSET
@@ -692,12 +692,12 @@ class TruncatedNormal(BoundedContinuous):
692692
@classmethod
693693
def dist(
694694
cls,
695-
mu: Optional[Union[float, np.ndarray]] = None,
696-
sigma: Optional[Union[float, np.ndarray]] = None,
697-
tau: Optional[Union[float, np.ndarray]] = None,
698-
sd: Optional[Union[float, np.ndarray]] = None,
699-
lower: Optional[Union[float, np.ndarray]] = None,
700-
upper: Optional[Union[float, np.ndarray]] = None,
695+
mu: Optional[DIST_PARAMETER_TYPES] = None,
696+
sigma: Optional[DIST_PARAMETER_TYPES] = None,
697+
tau: Optional[DIST_PARAMETER_TYPES] = None,
698+
sd: Optional[DIST_PARAMETER_TYPES] = None,
699+
lower: Optional[DIST_PARAMETER_TYPES] = None,
700+
upper: Optional[DIST_PARAMETER_TYPES] = None,
701701
transform: str = "auto",
702702
*args,
703703
**kwargs,

pymc/distributions/distribution.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,10 @@
1919

2020
from abc import ABCMeta
2121
from functools import singledispatch
22-
from typing import Callable, Iterable, Optional, Sequence
22+
from typing import Callable, Iterable, Optional, Sequence, Union
2323

2424
import aesara
25+
import numpy as np
2526

2627
from aeppl.logprob import _logcdf, _logprob
2728
from aesara import tensor as at
@@ -57,6 +58,8 @@
5758
"NoDistribution",
5859
]
5960

61+
DIST_PARAMETER_TYPES = Union[np.ndarray, int, float, TensorVariable]
62+
6063
vectorized_ppc = contextvars.ContextVar(
6164
"vectorized_ppc", default=None
6265
) # type: contextvars.ContextVar[Optional[Callable]]

0 commit comments

Comments
 (0)