Skip to content

Commit 783c5ba

Browse files
Deprecation of testval→initval and related start kwarg
1 parent 814f0da commit 783c5ba

File tree

3 files changed

+86
-14
lines changed

3 files changed

+86
-14
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
## PyMC 3.11.5 (TBD)
44
### Backports
55
+ The `pm.logp(rv, x)` syntax is now available and recommended to make your model code `v4`-ready. Note that this backport is just an alias and much less capable than what's available with `pymc >=4` (see [#5083](https://github.com/pymc-devs/pymc/pulls/5083)).
6+
+ The `pm.Distribution(testval=...)` kwarg was deprecated and will be replaced by `pm.Distribution(initval=...)`in `pymc >=4` (see [#5226](https://github.com/pymc-devs/pymc/pulls/5226)).
7+
+ The `pm.sample(start=...)` kwarg was deprecated and will be replaced by `pm.sample(initvals=...)`in `pymc >=4` (see [#5226](https://github.com/pymc-devs/pymc/pulls/5226)).
68

79
## PyMC3 3.11.4 (20 August 2021)
810

pymc3/distributions/distribution.py

Lines changed: 61 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424

2525
import dill
2626

27+
from deprecat.sphinx import deprecat
28+
2729
if TYPE_CHECKING:
2830
from typing import Optional, Callable
2931

@@ -67,6 +69,7 @@
6769
) # type: contextvars.ContextVar[Optional[Callable]]
6870

6971
PLATFORM = sys.platform
72+
UNSET = object()
7073

7174

7275
class _Unpickling:
@@ -130,15 +133,54 @@ def dist(cls, *args, **kwargs):
130133
dist.__init__(*args, **kwargs)
131134
return dist
132135

136+
@deprecat(
137+
deprecated_args={
138+
"testval": dict(version="3.11.5", reason="replaced by `initval` in PyMC 4.0.0"),
139+
}
140+
)
133141
def __init__(
134-
self, shape, dtype, testval=None, defaults=(), transform=None, broadcastable=None, dims=None
142+
self,
143+
shape,
144+
dtype,
145+
initval=None,
146+
defaults=(),
147+
transform=None,
148+
broadcastable=None,
149+
dims=None,
150+
*,
151+
testval=UNSET,
135152
):
153+
"""Creates a PyMC distribution object.
154+
155+
Parameters
156+
----------
157+
shape : tuple
158+
Output shape of the RV.
159+
Forwarded to the Theano TensorType of this RV.
160+
dtype
161+
Forwarded to the Theano TensorType of this RV.
162+
initval : np.ndarray
163+
Initial value for this RV.
164+
In PyMC 4.0.0 this will no longer assign test values to the tensors.
165+
defaults : tuple
166+
transform : pm.Transform
167+
broadcastable : tuple
168+
Forwarded to the Theano TensorType of this RV.
169+
dims : tuple
170+
Ignored.
171+
testval : np.ndarray
172+
The old way of specifying initial values assigning test-values.
173+
"""
174+
# Handle deprecated kwargs
175+
if testval is not UNSET:
176+
initval = testval
177+
136178
self.shape = np.atleast_1d(shape)
137179
if False in (np.floor(self.shape) == self.shape):
138180
raise TypeError("Expected int elements in shape")
139181
self.dtype = dtype
140182
self.type = TensorType(self.dtype, self.shape, broadcastable)
141-
self.testval = testval
183+
self.testval = initval
142184
self.defaults = defaults
143185
self.transform = transform
144186

@@ -288,7 +330,7 @@ def __init__(
288330
**kwargs,
289331
):
290332
super().__init__(
291-
shape=shape, dtype=dtype, testval=testval, defaults=defaults, *args, **kwargs
333+
shape=shape, dtype=dtype, initval=testval, defaults=defaults, *args, **kwargs
292334
)
293335
self.parent_dist = parent_dist
294336

@@ -353,16 +395,22 @@ class DensityDist(Distribution):
353395
354396
"""
355397

398+
@deprecat(
399+
deprecated_args={
400+
"testval": dict(version="3.11.5", reason="replaced by `initval` in PyMC 4.0.0"),
401+
}
402+
)
356403
def __init__(
357404
self,
358405
logp,
359406
shape=(),
360407
dtype=None,
361-
testval=0,
408+
initval=0,
362409
random=None,
363410
wrap_random_with_dist_shape=True,
364411
check_shape_in_random=True,
365412
*args,
413+
testval=UNSET,
366414
**kwargs,
367415
):
368416
"""
@@ -379,8 +427,8 @@ def __init__(
379427
a value here.
380428
dtype: None, str (Optional)
381429
The dtype of the distribution.
382-
testval: number or array (Optional)
383-
The ``testval`` of the RV's tensor that follow the ``DensityDist``
430+
initval: number or array (Optional)
431+
The ``initval`` of the RV's tensor that follow the ``DensityDist``
384432
distribution.
385433
random: None or callable (Optional)
386434
If ``None``, no random method is attached to the ``DensityDist``
@@ -403,6 +451,8 @@ def __init__(
403451
If ``True``, the shape of the random samples generate in the
404452
``random`` method is checked with the expected return shape. This
405453
test is only performed if ``wrap_random_with_dist_shape is False``.
454+
testval : np.ndarray
455+
The old way of specifying initial values assigning test-values.
406456
args, kwargs: (Optional)
407457
These are passed to the parent class' ``__init__``.
408458
@@ -525,9 +575,13 @@ def __init__(
525575
assert prior.shape == (10, 100, 3)
526576
527577
"""
578+
# Handle deprecated kwargs
579+
if testval is not UNSET:
580+
initval = testval
581+
528582
if dtype is None:
529583
dtype = theano.config.floatX
530-
super().__init__(shape, dtype, testval, *args, **kwargs)
584+
super().__init__(shape, dtype, initval, *args, **kwargs)
531585
self.logp = logp
532586
if type(self.logp) == types.MethodType:
533587
if PLATFORM != "linux":

pymc3/sampling.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
from collections import defaultdict
2525
from copy import copy, deepcopy
26-
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast
26+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Set, Union, cast
2727

2828
import arviz
2929
import numpy as np
@@ -32,6 +32,7 @@
3232
import xarray
3333

3434
from arviz import InferenceData
35+
from deprecat.sphinx import deprecat
3536
from fastprogress.fastprogress import progress_bar
3637

3738
import pymc3 as pm
@@ -232,12 +233,18 @@ def _print_step_hierarchy(s: Step, level=0) -> None:
232233
_log.info(">" * level + f"{s.__class__.__name__}: [{varnames}]")
233234

234235

236+
@deprecat(
237+
deprecated_args={
238+
"start": dict(version="3.11.5", reason="renamed to `initvals` in PyMC v4.0.0"),
239+
"pickle_backend": dict(version="3.11.5", reason="removed in PyMC v4.0.0"),
240+
}
241+
)
235242
def sample(
236243
draws=1000,
237244
step=None,
238245
init="auto",
239246
n_init=200000,
240-
start=None,
247+
initvals: Optional[Union[PointType, Sequence[Optional[PointType]]]] = None,
241248
trace=None,
242249
chain_idx=0,
243250
chains=None,
@@ -251,6 +258,7 @@ def sample(
251258
callback=None,
252259
jitter_max_retries=10,
253260
*,
261+
start=None,
254262
return_inferencedata=None,
255263
idata_kwargs: dict = None,
256264
mp_ctx=None,
@@ -294,11 +302,10 @@ def sample(
294302
users.
295303
n_init : int
296304
Number of iterations of initializer. Only works for 'ADVI' init methods.
297-
start : dict, or array of dict
298-
Starting point in parameter space (or partial point)
299-
Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
300-
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
301-
overwrite the default.
305+
initvals : optional, dict, array of dict
306+
Dict or list of dicts with initial values to use instead of the defaults.
307+
The keys should be names of transformed random variables.
308+
Initialization methods for NUTS (see ``init`` keyword) can overwrite the default.
302309
trace : backend, list, or MultiTrace
303310
This should be a backend instance, a list of variables to track, or a MultiTrace object
304311
with past values. If a MultiTrace object is given, it must contain samples for the chain
@@ -339,6 +346,11 @@ def sample(
339346
Maximum number of repeated attempts (per chain) at creating an initial matrix with uniform jitter
340347
that yields a finite probability. This applies to ``jitter+adapt_diag`` and ``jitter+adapt_full``
341348
init methods.
349+
start : dict, or array of dict
350+
Starting point in parameter space (or partial point)
351+
Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
352+
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
353+
overwrite the default.
342354
return_inferencedata : bool, default=False
343355
Whether to return the trace as an :class:`arviz:arviz.InferenceData` (True) object or a `MultiTrace` (False)
344356
Defaults to `False`, but we'll switch to `True` in an upcoming release.
@@ -422,6 +434,10 @@ def sample(
422434
mean sd hdi_3% hdi_97%
423435
p 0.609 0.047 0.528 0.699
424436
"""
437+
# Handle deprecated/forwards-compatible kwargs
438+
if initvals is not None:
439+
start = initvals
440+
425441
model = modelcontext(model)
426442
start = deepcopy(start)
427443
if start is None:

0 commit comments

Comments
 (0)