11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import warnings
15
+
14
16
from typing import (
15
17
Callable ,
16
18
Dict ,
@@ -169,19 +171,17 @@ def change_rv_size(
169
171
def extract_rv_and_value_vars (
170
172
var : TensorVariable ,
171
173
) -> Tuple [TensorVariable , TensorVariable ]:
172
- """Extract a random variable and its corresponding value variable from a generic
173
- `TensorVariable`.
174
+ """Return a random variable and it's observations or value variable, or ``None``.
174
175
175
176
Parameters
176
177
==========
177
178
var
178
- A variable corresponding to a `RandomVariable`.
179
+ A variable corresponding to a `` RandomVariable` `.
179
180
180
181
Returns
181
182
=======
182
- The first value in the tuple is the `RandomVariable`, and the second is the
183
- measure-space variable that corresponds with the latter (i.e. the "value"
184
- variable).
183
+ The first value in the tuple is the ``RandomVariable``, and the second is the
184
+ measure/log-likelihood value variable that corresponds with the latter.
185
185
186
186
"""
187
187
if not var .owner :
@@ -195,7 +195,7 @@ def extract_rv_and_value_vars(
195
195
196
196
197
197
def extract_obs_data (x : TensorVariable ) -> np .ndarray :
198
- """Extract data observed symbolic variables.
198
+ """Extract data from observed symbolic variables.
199
199
200
200
Raises
201
201
------
@@ -331,17 +331,24 @@ def transform_replacements(var, replacements):
331
331
rv_var , rv_value_var = extract_rv_and_value_vars (var )
332
332
333
333
if rv_value_var is None :
334
+ warnings .warn (
335
+ f"No value variable found for { rv_var } ; "
336
+ "the random variable will not be replaced."
337
+ )
334
338
return []
335
339
336
340
transform = getattr (rv_value_var .tag , "transform" , None )
337
341
338
342
if transform is None or not apply_transforms :
339
343
replacements [var ] = rv_value_var
340
- return []
344
+ # In case the value variable is itself a graph, we walk it for
345
+ # potential replacements
346
+ return [rv_value_var ]
341
347
342
348
trans_rv_value = transform .backward (rv_var , rv_value_var )
343
349
replacements [var ] = trans_rv_value
344
350
351
+ # Walk the transformed variable and make replacements
345
352
return [trans_rv_value ]
346
353
347
354
return replace_rvs_in_graphs (graphs , transform_replacements , initial_replacements , ** kwargs )
0 commit comments