12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
14
15
+ from collections .abc import Mapping
15
16
from functools import singledispatch
16
- from typing import Optional
17
+ from typing import Dict , Optional , Union
17
18
18
19
import aesara .tensor as at
19
20
import numpy as np
20
21
21
22
from aesara import config
22
23
from aesara .gradient import disconnected_grad
23
24
from aesara .graph .basic import Constant , clone , graph_inputs , io_toposort
25
+ from aesara .graph .fg import FunctionGraph
24
26
from aesara .graph .op import Op , compute_test_value
25
27
from aesara .graph .type import CType
28
+ from aesara .tensor .random .op import RandomVariable
29
+ from aesara .tensor .random .opt import local_subtensor_rv_lift
26
30
from aesara .tensor .subtensor import (
27
31
AdvancedIncSubtensor ,
28
32
AdvancedIncSubtensor1 ,
@@ -107,7 +111,7 @@ def _get_scaling(total_size, shape, ndim):
107
111
108
112
def logpt (
109
113
var : TensorVariable ,
110
- rv_value : Optional [TensorVariable ] = None ,
114
+ rv_values : Optional [Union [ TensorVariable , Dict [ TensorVariable , TensorVariable ]] ] = None ,
111
115
* ,
112
116
jacobian : bool = True ,
113
117
scaling : bool = True ,
@@ -127,10 +131,10 @@ def logpt(
127
131
==========
128
132
var
129
133
The `RandomVariable` output that determines the log-likelihood graph.
130
- rv_value
131
- The variable that represents the value of `var` in its log-likelihood.
132
- If no `rv_value` is provided, ``var.tag.value_var`` will be checked
133
- and, when available, used.
134
+ rv_values
135
+ A variable, or ``dict`` of variables, that represents the value of
136
+ `var` in its log-likelihood. If no `rv_value` is provided,
137
+ ``var.tag.value_var`` will be checked and, when available, used.
134
138
jacobian
135
139
Whether or not to include the Jacobian term.
136
140
scaling
@@ -143,16 +147,17 @@ def logpt(
143
147
Sum the log-likelihood.
144
148
145
149
"""
150
+ if not isinstance (rv_values , Mapping ):
151
+ rv_values = {var : rv_values } if rv_values is not None else {}
146
152
147
153
rv_var , rv_value_var = extract_rv_and_value_vars (var )
148
154
149
- if rv_value is None :
155
+ rv_value = rv_values . get ( rv_var , rv_value_var )
150
156
151
- if rv_var is not None and rv_value_var is None :
152
- raise ValueError (f"No value variable specified or associated with { rv_var } " )
157
+ if rv_var is not None and rv_value is None :
158
+ raise ValueError (f"No value variable specified or associated with { rv_var } " )
153
159
154
- rv_value = rv_value_var
155
- else :
160
+ if rv_value is not None :
156
161
rv_value = at .as_tensor (rv_value )
157
162
158
163
if rv_var is not None :
@@ -163,12 +168,12 @@ def logpt(
163
168
rv_value_var = rv_value
164
169
165
170
if rv_var is None :
166
-
167
171
if var .owner is not None :
168
172
return _logp (
169
173
var .owner .op ,
170
- rv_value ,
171
- var .owner .inputs ,
174
+ var ,
175
+ rv_values ,
176
+ * var .owner .inputs ,
172
177
jacobian = jacobian ,
173
178
scaling = scaling ,
174
179
transformed = transformed ,
@@ -189,10 +194,13 @@ def logpt(
189
194
# Ultimately, with a graph containing only random variables and
190
195
# "deterministics", we can simply replace all the random variables with
191
196
# their value variables and be done.
197
+ tmp_rv_values = rv_values .copy ()
198
+ tmp_rv_values [rv_var ] = rv_var
199
+
192
200
if not cdf :
193
- logp_var = _logp (rv_node .op , rv_var , * dist_params , ** kwargs )
201
+ logp_var = _logp (rv_node .op , rv_var , tmp_rv_values , * dist_params , ** kwargs )
194
202
else :
195
- logp_var = _logcdf (rv_node .op , rv_var , * dist_params , ** kwargs )
203
+ logp_var = _logcdf (rv_node .op , rv_var , tmp_rv_values , * dist_params , ** kwargs )
196
204
197
205
transform = getattr (rv_value_var .tag , "transform" , None ) if rv_value_var else None
198
206
@@ -204,10 +212,13 @@ def logpt(
204
212
logp_var += transformed_jacobian
205
213
206
214
# Replace random variables with their value variables
215
+ replacements = rv_values .copy ()
216
+ replacements .update ({rv_var : rv_value , rv_value_var : rv_value })
217
+
207
218
(logp_var ,), _ = rvs_to_value_vars (
208
219
(logp_var ,),
209
220
apply_transforms = transformed and not cdf ,
210
- initial_replacements = { rv_var : rv_value , rv_value_var : rv_value } ,
221
+ initial_replacements = replacements ,
211
222
)
212
223
213
224
if sum :
@@ -231,15 +242,24 @@ def logpt(
231
242
232
243
233
244
@singledispatch
234
- def _logp (op : Op , value : TensorVariable , * dist_params , ** kwargs ):
245
+ def _logp (
246
+ op : Op ,
247
+ var : TensorVariable ,
248
+ rvs_to_values : Dict [TensorVariable , TensorVariable ],
249
+ * inputs : TensorVariable ,
250
+ ** kwargs ,
251
+ ):
235
252
"""Create a log-likelihood graph.
236
253
237
254
This function dispatches on the type of `op`, which should be a subclass
238
255
of `RandomVariable`. If you want to implement new log-likelihood graphs
239
256
for a `RandomVariable`, register a new function on this dispatcher.
240
257
258
+ The default assumes that the log-likelihood of a term is a zero.
259
+
241
260
"""
242
- return at .zeros_like (value )
261
+ value_var = rvs_to_values .get (var , var )
262
+ return at .zeros_like (value_var )
243
263
244
264
245
265
def convert_indices (indices , entry ):
@@ -256,39 +276,70 @@ def convert_indices(indices, entry):
256
276
return entry
257
277
258
278
259
- def index_from_subtensor (idx_list , indices ):
279
+ def indices_from_subtensor (idx_list , indices ):
260
280
"""Compute a useable index tuple from the inputs of a ``*Subtensor**`` ``Op``."""
261
- index = tuple (tuple (convert_indices (indices , idx ) for idx in idx_list ) if idx_list else indices )
262
- if len (index ) == 1 :
263
- index = index [0 ]
264
- return index
281
+ return tuple (
282
+ tuple (convert_indices (list (indices ), idx ) for idx in idx_list ) if idx_list else indices
283
+ )
265
284
266
285
267
286
@_logp .register (IncSubtensor )
268
287
@_logp .register (AdvancedIncSubtensor )
269
288
@_logp .register (AdvancedIncSubtensor1 )
270
- def incsubtensor_logp (op , value , inputs , ** kwargs ):
271
- rv_var , rv_values , * indices = inputs
289
+ def incsubtensor_logp (op , var , rvs_to_values , indexed_rv_var , rv_values , * indices , ** kwargs ):
272
290
273
- index = index_from_subtensor (getattr (op , "idx_list" , None ), indices )
291
+ index = indices_from_subtensor (getattr (op , "idx_list" , None ), indices )
274
292
275
293
_ , (new_rv_var ,) = clone (
276
- tuple (v for v in graph_inputs ((rv_var ,)) if not isinstance (v , Constant )),
277
- (rv_var ,),
294
+ tuple (v for v in graph_inputs ((indexed_rv_var ,)) if not isinstance (v , Constant )),
295
+ (indexed_rv_var ,),
278
296
copy_inputs = False ,
279
297
copy_orphans = False ,
280
298
)
281
299
new_values = at .set_subtensor (disconnected_grad (new_rv_var )[index ], rv_values )
282
- logp_var = logpt (rv_var , new_values , ** kwargs )
300
+ logp_var = logpt (indexed_rv_var , new_values , ** kwargs )
283
301
284
302
return logp_var
285
303
286
304
287
305
@_logp .register (Subtensor )
288
306
@_logp .register (AdvancedSubtensor )
289
307
@_logp .register (AdvancedSubtensor1 )
290
- def subtensor_logp (op , value , * inputs , ** kwargs ):
291
- raise NotImplementedError ()
308
+ def subtensor_logp (op , var , rvs_to_values , indexed_rv_var , * indices , ** kwargs ):
309
+
310
+ index = indices_from_subtensor (getattr (op , "idx_list" , None ), indices )
311
+
312
+ rv_value = rvs_to_values .get (var , getattr (var .tag , "value_var" , None ))
313
+
314
+ if indexed_rv_var .owner and isinstance (indexed_rv_var .owner .op , RandomVariable ):
315
+
316
+ # We need to lift the index operation through the random variable so
317
+ # that we have a new random variable consisting of only the relevant
318
+ # subset of variables per the index.
319
+ var_copy = var .owner .clone ().default_output ()
320
+ fgraph = FunctionGraph (
321
+ [i for i in graph_inputs ((indexed_rv_var ,)) if not isinstance (i , Constant )],
322
+ [var_copy ],
323
+ clone = False ,
324
+ )
325
+
326
+ (lifted_var ,) = local_subtensor_rv_lift .transform (fgraph , fgraph .outputs [0 ].owner )
327
+
328
+ new_rvs_to_values = rvs_to_values .copy ()
329
+ new_rvs_to_values [lifted_var ] = rv_value
330
+
331
+ logp_var = logpt (lifted_var , new_rvs_to_values , ** kwargs )
332
+
333
+ for idx_var in index :
334
+ logp_var += logpt (idx_var , rvs_to_values , ** kwargs )
335
+
336
+ # TODO: We could add the constant case (i.e. `indexed_rv_var.owner is None`)
337
+ else :
338
+ raise NotImplementedError (
339
+ f"`Subtensor` log-likelihood not implemented for { indexed_rv_var .owner } "
340
+ )
341
+
342
+ return logp_var
292
343
293
344
294
345
def logcdf (* args , ** kwargs ):
@@ -297,7 +348,7 @@ def logcdf(*args, **kwargs):
297
348
298
349
299
350
@singledispatch
300
- def _logcdf (op , value , * args , ** kwargs ):
351
+ def _logcdf (op , values , * args , ** kwargs ):
301
352
"""Create a log-CDF graph.
302
353
303
354
This function dispatches on the type of `op`, which should be a subclass
0 commit comments