Skip to content

Commit 82ccb3b

Browse files
ExpectationMaxmichaelosthege
authored andcommitted
Keep broadcasting information in make_shared_replacements
It seems like broadcasting information gets lost when applying `pm.make_shared_replacements`, leading to problems with the metropolis sampler. Potentially related issues below: - #1083 - #1304 - #1983 This fix was previously suggested in the following issue: - #3337 It could be that further adaptations are necessary as indicated in the issue. Strangely, this does not seem to lead to problems when using NUTS.
1 parent 53b642e commit 82ccb3b

File tree

3 files changed

+33
-1
lines changed

3 files changed

+33
-1
lines changed

RELEASE-NOTES.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
### Maintenance
1313
- The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see [#4509](https://github.com/pymc-devs/pymc3/pull/4509)).
1414
- Remove float128 dtype support (see [#4514](https://github.com/pymc-devs/pymc3/pull/4514)).
15+
- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)).
16+
+ ...
1517

1618
## PyMC3 3.11.1 (12 February 2021)
1719

pymc3/aesaraf.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,12 @@ def make_shared_replacements(vars, model):
238238
Dict of variable -> new shared variable
239239
"""
240240
othervars = set(model.vars) - set(vars)
241-
return {var: aesara.shared(var.tag.test_value, var.name + "_shared") for var in othervars}
241+
return {
242+
var: aesara.shared(
243+
var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable
244+
)
245+
for var in othervars
246+
}
242247

243248

244249
def join_nonshared_inputs(xs, vars, shared, make_shared=False):

pymc3/tests/test_aesaraf.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,38 @@
2121

2222
from aesara.tensor.type import TensorType
2323

24+
import pymc3 as pm
25+
2426
from pymc3.aesaraf import _conversion_map, take_along_axis
2527
from pymc3.vartypes import int_types
2628

2729
FLOATX = str(aesara.config.floatX)
2830
INTX = str(_conversion_map[FLOATX])
2931

3032

33+
class TestBroadcasting:
34+
def test_make_shared_replacements(self):
35+
"""Check if pm.make_shared_replacements preserves broadcasting."""
36+
37+
with pm.Model() as test_model:
38+
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
39+
test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1))
40+
41+
# Replace test1 with a shared variable, keep test 2 the same
42+
replacement = pm.make_shared_replacements([test_model.test2], test_model)
43+
assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable
44+
45+
def test_metropolis_sampling(self):
46+
"""Check if the Metropolis sampler can handle broadcasting."""
47+
with pm.Model() as test_model:
48+
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
49+
test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10))
50+
51+
step = pm.Metropolis()
52+
# This should fail immediately if broadcasting does not work.
53+
pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False)
54+
55+
3156
def _make_along_axis_idx(arr_shape, indices, axis):
3257
# compute dimensions to iterate over
3358
if str(indices.dtype) not in int_types:

0 commit comments

Comments
 (0)