Skip to content

Commit 3bf68b6

Browse files
brandonwillardtwiecki
authored andcommitted
Add missing values as unobserved random variables and estimate them during MCMC
1 parent 9636408 commit 3bf68b6

File tree

5 files changed

+238
-138
lines changed

5 files changed

+238
-138
lines changed

pymc3/aesaraf.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from typing import (
1517
Callable,
1618
Dict,
@@ -169,19 +171,17 @@ def change_rv_size(
169171
def extract_rv_and_value_vars(
170172
var: TensorVariable,
171173
) -> Tuple[TensorVariable, TensorVariable]:
172-
"""Extract a random variable and its corresponding value variable from a generic
173-
`TensorVariable`.
174+
"""Return a random variable and it's observations or value variable, or ``None``.
174175
175176
Parameters
176177
==========
177178
var
178-
A variable corresponding to a `RandomVariable`.
179+
A variable corresponding to a ``RandomVariable``.
179180
180181
Returns
181182
=======
182-
The first value in the tuple is the `RandomVariable`, and the second is the
183-
measure-space variable that corresponds with the latter (i.e. the "value"
184-
variable).
183+
The first value in the tuple is the ``RandomVariable``, and the second is the
184+
measure/log-likelihood value variable that corresponds with the latter.
185185
186186
"""
187187
if not var.owner:
@@ -195,7 +195,7 @@ def extract_rv_and_value_vars(
195195

196196

197197
def extract_obs_data(x: TensorVariable) -> np.ndarray:
198-
"""Extract data observed symbolic variables.
198+
"""Extract data from observed symbolic variables.
199199
200200
Raises
201201
------
@@ -331,17 +331,24 @@ def transform_replacements(var, replacements):
331331
rv_var, rv_value_var = extract_rv_and_value_vars(var)
332332

333333
if rv_value_var is None:
334+
warnings.warn(
335+
f"No value variable found for {rv_var}; "
336+
"the random variable will not be replaced."
337+
)
334338
return []
335339

336340
transform = getattr(rv_value_var.tag, "transform", None)
337341

338342
if transform is None or not apply_transforms:
339343
replacements[var] = rv_value_var
340-
return []
344+
# In case the value variable is itself a graph, we walk it for
345+
# potential replacements
346+
return [rv_value_var]
341347

342348
trans_rv_value = transform.backward(rv_var, rv_value_var)
343349
replacements[var] = trans_rv_value
344350

351+
# Walk the transformed variable and make replacements
345352
return [trans_rv_value]
346353

347354
return replace_rvs_in_graphs(graphs, transform_replacements, initial_replacements, **kwargs)

0 commit comments

Comments
 (0)