Skip to content

Commit 62d9b59

Browse files
brandonwillardtwiecki
authored andcommitted
Add a starting point option to DictToArrayBijection.rmap
1 parent 3bf68b6 commit 62d9b59

File tree

3 files changed

+30
-38
lines changed

3 files changed

+30
-38
lines changed

pymc3/blocking.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,17 @@
1919
"""
2020
import collections
2121

22-
from typing import Dict, List, Optional, Union
22+
from functools import partial
23+
from typing import Callable, Dict, Optional, TypeVar
2324

2425
import numpy as np
2526

2627
__all__ = ["DictToArrayBijection"]
2728

29+
30+
T = TypeVar("T")
31+
PointType = Dict[str, np.ndarray]
32+
2833
# `point_map_info` is a tuple of tuples containing `(name, shape, dtype)` for
2934
# each of the raveled variables.
3035
RaveledVars = collections.namedtuple("RaveledVars", "data, point_map_info")
@@ -38,7 +43,7 @@ class DictToArrayBijection:
3843
"""
3944

4045
@staticmethod
41-
def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
46+
def map(var_dict: PointType) -> RaveledVars:
4247
"""Map a dictionary of names and variables to a concatenated 1D array space."""
4348
vars_info = tuple((v, k, v.shape, v.dtype) for k, v in var_dict.items())
4449
raveled_vars = [v[0].ravel() for v in vars_info]
@@ -50,42 +55,41 @@ def map(var_dict: Dict[str, np.ndarray]) -> RaveledVars:
5055

5156
@staticmethod
5257
def rmap(
53-
array: RaveledVars, as_list: Optional[bool] = False
54-
) -> Union[Dict[str, np.ndarray], List[np.ndarray]]:
58+
array: RaveledVars,
59+
start_point: Optional[PointType] = None,
60+
) -> PointType:
5561
"""Map 1D concatenated array to a dictionary of variables in their original spaces.
5662
5763
Parameters
5864
==========
5965
array
6066
The array to map.
61-
as_list
62-
When ``True``, return a list of the original variables instead of a
63-
``dict`` keyed each variable's name.
67+
start_point
68+
An optional dictionary of initial values.
69+
6470
"""
65-
if as_list:
66-
res = []
71+
if start_point:
72+
res = dict(start_point)
6773
else:
6874
res = {}
6975

7076
if not isinstance(array, RaveledVars):
71-
raise TypeError("`apt` must be a `RaveledVars` type")
77+
raise TypeError("`array` must be a `RaveledVars` type")
7278

7379
last_idx = 0
7480
for name, shape, dtype in array.point_map_info:
7581
arr_len = np.prod(shape, dtype=int)
7682
var = array.data[last_idx : last_idx + arr_len].reshape(shape).astype(dtype)
77-
if as_list:
78-
res.append(var)
79-
else:
80-
res[name] = var
83+
res[name] = var
8184
last_idx += arr_len
8285

8386
return res
8487

8588
@classmethod
86-
def mapf(cls, f):
87-
"""
88-
function f: DictSpace -> T to ArraySpace -> T
89+
def mapf(cls, f: Callable[[PointType], T], start_point: Optional[PointType] = None) -> T:
90+
"""Create a callable that first maps back to ``dict`` inputs and then applies a function.
91+
92+
function f: DictSpace -> T to ArraySpace -> T
8993
9094
Parameters
9195
----------
@@ -95,7 +99,7 @@ def mapf(cls, f):
9599
-------
96100
f: array -> T
97101
"""
98-
return Compose(f, cls.rmap)
102+
return Compose(f, partial(cls.rmap, start_point=start_point))
99103

100104

101105
class Compose:

pymc3/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def __call__(self, grad_vars, grad_out=None, extra_vars=None):
467467
raise ValueError("Extra values are not set.")
468468

469469
if isinstance(grad_vars, RaveledVars):
470-
grad_vars = DictToArrayBijection.rmap(grad_vars, as_list=True)
470+
grad_vars = list(DictToArrayBijection.rmap(grad_vars).values())
471471

472472
cost, *grads = self._aesara_function(*grad_vars)
473473

pymc3/step_methods/arraystep.py

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -141,12 +141,12 @@ def __init__(self, vars, fs, allvars=False, blocked=True):
141141

142142
def step(self, point: Dict[str, np.ndarray]):
143143

144-
inputs = [DictToArrayBijection.mapf(x) for x in self.fs]
144+
partial_funcs_and_point = [DictToArrayBijection.mapf(x, start_point=point) for x in self.fs]
145145
if self.allvars:
146-
inputs.append(point)
146+
partial_funcs_and_point.append(point)
147147

148-
apoint = DictToArrayBijection.map(point)
149-
step_res = self.astep(apoint, *inputs)
148+
apoint = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
149+
step_res = self.astep(apoint, *partial_funcs_and_point)
150150

151151
if self.generates_stats:
152152
apoint_new, stats = step_res
@@ -157,7 +157,7 @@ def step(self, point: Dict[str, np.ndarray]):
157157
# We assume that the mapping has stayed the same
158158
apoint_new = RaveledVars(apoint_new, apoint.point_map_info)
159159

160-
point_new = DictToArrayBijection.rmap(apoint_new)
160+
point_new = DictToArrayBijection.rmap(apoint_new, start_point=point)
161161

162162
if self.generates_stats:
163163
return point_new, stats
@@ -190,14 +190,10 @@ def __init__(self, vars, shared, blocked=True):
190190

191191
def step(self, point):
192192

193-
# Remove shared variables from the sample point
194-
point_no_shared = point.copy()
195193
for name, shared_var in self.shared.items():
196194
shared_var.set_value(point[name])
197-
if name in point_no_shared:
198-
del point_no_shared[name]
199195

200-
q = DictToArrayBijection.map(point_no_shared)
196+
q = DictToArrayBijection.map({v.name: point[v.name] for v in self.vars})
201197

202198
step_res = self.astep(q)
203199

@@ -210,15 +206,7 @@ def step(self, point):
210206
# We assume that the mapping has stayed the same
211207
apoint = RaveledVars(apoint, q.point_map_info)
212208

213-
# We need to re-add the shared variables to the new sample point
214-
a_point = DictToArrayBijection.rmap(apoint)
215-
new_point = {}
216-
for name in point.keys():
217-
shared_value = self.shared.get(name, None)
218-
if shared_value is not None:
219-
new_point[name] = shared_value.get_value()
220-
else:
221-
new_point[name] = a_point[name]
209+
new_point = DictToArrayBijection.rmap(apoint, start_point=point)
222210

223211
if self.generates_stats:
224212
return new_point, stats

0 commit comments

Comments
 (0)