|
1 |
| -from collections.abc import Sequence |
| 1 | +from collections.abc import Callable, Sequence |
2 | 2 | from typing import Any, cast
|
3 | 3 |
|
4 | 4 | import numpy as np
|
| 5 | +from numpy import broadcast_shapes, empty |
5 | 6 |
|
6 | 7 | from pytensor import config
|
7 | 8 | from pytensor.compile.builders import OpFromGraph
|
|
22 | 23 | from pytensor.tensor.utils import (
|
23 | 24 | _parse_gufunc_signature,
|
24 | 25 | broadcast_static_dim_lengths,
|
| 26 | + faster_broadcast_to, |
| 27 | + faster_ndindex, |
25 | 28 | import_func_from_string,
|
26 | 29 | safe_signature,
|
27 | 30 | )
|
28 | 31 | from pytensor.tensor.variable import TensorVariable
|
29 | 32 |
|
30 | 33 |
|
| 34 | +def _vectorize_node_perform( |
| 35 | + core_node: Apply, |
| 36 | + batch_bcast_patterns: Sequence[tuple[bool, ...]], |
| 37 | + batch_ndim: int, |
| 38 | + impl: str | None, |
| 39 | +) -> Callable: |
| 40 | + """Creates a vectorized `perform` function for a given core node. |
| 41 | +
|
| 42 | + Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op. |
| 43 | + """ |
| 44 | + |
| 45 | + storage_map = {var: [None] for var in core_node.inputs + core_node.outputs} |
| 46 | + core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl) |
| 47 | + single_in = len(core_node.inputs) == 1 |
| 48 | + core_input_storage = [storage_map[inp] for inp in core_node.inputs] |
| 49 | + core_output_storage = [storage_map[out] for out in core_node.outputs] |
| 50 | + core_storage = core_input_storage + core_output_storage |
| 51 | + |
| 52 | + def vectorized_perform( |
| 53 | + *args, |
| 54 | + batch_bcast_patterns=batch_bcast_patterns, |
| 55 | + batch_ndim=batch_ndim, |
| 56 | + single_in=single_in, |
| 57 | + core_thunk=core_thunk, |
| 58 | + core_input_storage=core_input_storage, |
| 59 | + core_output_storage=core_output_storage, |
| 60 | + core_storage=core_storage, |
| 61 | + ): |
| 62 | + if single_in: |
| 63 | + batch_shape = args[0].shape[:batch_ndim] |
| 64 | + else: |
| 65 | + _check_runtime_broadcast_core(args, batch_bcast_patterns, batch_ndim) |
| 66 | + batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args)) |
| 67 | + args = list(args) |
| 68 | + for i, arg in enumerate(args): |
| 69 | + if arg.shape[:batch_ndim] != batch_shape: |
| 70 | + args[i] = faster_broadcast_to( |
| 71 | + arg, batch_shape + arg.shape[batch_ndim:] |
| 72 | + ) |
| 73 | + |
| 74 | + ndindex_iterator = faster_ndindex(batch_shape) |
| 75 | + # Call once to get the output shapes |
| 76 | + try: |
| 77 | + # TODO: Pass core shape as input like BlockwiseWithCoreShape does? |
| 78 | + index0 = next(ndindex_iterator) |
| 79 | + except StopIteration: |
| 80 | + raise NotImplementedError("vectorize with zero size not implemented") |
| 81 | + else: |
| 82 | + for core_input, arg in zip(core_input_storage, args): |
| 83 | + core_input[0] = np.asarray(arg[index0]) |
| 84 | + core_thunk() |
| 85 | + outputs = tuple( |
| 86 | + empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype) |
| 87 | + for core_output in core_output_storage |
| 88 | + ) |
| 89 | + for output, core_output in zip(outputs, core_output_storage): |
| 90 | + output[index0] = core_output[0] |
| 91 | + |
| 92 | + for index in ndindex_iterator: |
| 93 | + for core_input, arg in zip(core_input_storage, args): |
| 94 | + core_input[0] = np.asarray(arg[index]) |
| 95 | + core_thunk() |
| 96 | + for output, core_output in zip(outputs, core_output_storage): |
| 97 | + output[index] = core_output[0] |
| 98 | + |
| 99 | + # Clear storage |
| 100 | + for core_val in core_storage: |
| 101 | + core_val[0] = None |
| 102 | + return outputs |
| 103 | + |
| 104 | + return vectorized_perform |
| 105 | + |
| 106 | + |
| 107 | +def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ndim): |
| 108 | + # strict=None because we are in a hot loop |
| 109 | + # We zip together the dimension lengths of each input and their broadcast patterns |
| 110 | + for dim_lengths_and_bcast in zip( |
| 111 | + *[ |
| 112 | + zip(input.shape[:batch_ndim], batch_bcast_pattern) |
| 113 | + for input, batch_bcast_pattern in zip( |
| 114 | + numerical_inputs, batch_bcast_patterns |
| 115 | + ) |
| 116 | + ], |
| 117 | + ): |
| 118 | + # If for any dimension where an entry has dim_length != 1, |
| 119 | + # and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting. |
| 120 | + if ( |
| 121 | + any(d != 1 for d, _ in dim_lengths_and_bcast) |
| 122 | + and (1, False) in dim_lengths_and_bcast |
| 123 | + ): |
| 124 | + raise ValueError( |
| 125 | + "Runtime broadcasting not allowed. " |
| 126 | + "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" |
| 127 | + "If broadcasting was intended, use `specify_broadcastable` on the relevant input." |
| 128 | + ) |
| 129 | + |
| 130 | + |
31 | 131 | class Blockwise(Op):
|
32 | 132 | """Generalizes a core `Op` to work with batched dimensions.
|
33 | 133 |
|
@@ -308,91 +408,74 @@ def L_op(self, inputs, outs, ograds):
|
308 | 408 |
|
309 | 409 | return rval
|
310 | 410 |
|
311 |
| - def _create_node_gufunc(self, node) -> None: |
| 411 | + def _create_node_gufunc(self, node: Apply, impl) -> Callable: |
312 | 412 | """Define (or retrieve) the node gufunc used in `perform`.
|
313 | 413 |
|
314 | 414 | If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
|
315 | 415 | Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
|
316 | 416 |
|
317 | 417 | The gufunc is stored in the tag of the node.
|
318 | 418 | """
|
319 |
| - gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None) |
320 |
| - |
321 |
| - if gufunc_spec is not None: |
322 |
| - gufunc = import_func_from_string(gufunc_spec[0]) |
323 |
| - if gufunc is None: |
| 419 | + batch_ndim = self.batch_ndim(node) |
| 420 | + batch_bcast_patterns = [ |
| 421 | + inp.type.broadcastable[:batch_ndim] for inp in node.inputs |
| 422 | + ] |
| 423 | + if ( |
| 424 | + gufunc_spec := self.gufunc_spec |
| 425 | + or getattr(self.core_op, "gufunc_spec", None) |
| 426 | + ) is not None: |
| 427 | + core_func = import_func_from_string(gufunc_spec[0]) |
| 428 | + if core_func is None: |
324 | 429 | raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
|
325 | 430 |
|
326 |
| - else: |
327 |
| - # Wrap core_op perform method in numpy vectorize |
328 |
| - n_outs = len(self.outputs_sig) |
329 |
| - core_node = self._create_dummy_core_node(node.inputs) |
330 |
| - inner_outputs_storage = [[None] for _ in range(n_outs)] |
331 |
| - |
332 |
| - def core_func( |
333 |
| - *inner_inputs, |
334 |
| - core_node=core_node, |
335 |
| - inner_outputs_storage=inner_outputs_storage, |
336 |
| - ): |
337 |
| - self.core_op.perform( |
338 |
| - core_node, |
339 |
| - [np.asarray(inp) for inp in inner_inputs], |
340 |
| - inner_outputs_storage, |
341 |
| - ) |
342 |
| - |
343 |
| - if n_outs == 1: |
344 |
| - return inner_outputs_storage[0][0] |
345 |
| - else: |
346 |
| - return tuple(r[0] for r in inner_outputs_storage) |
| 431 | + if len(node.outputs) == 1: |
| 432 | + |
| 433 | + def gufunc( |
| 434 | + *inputs, |
| 435 | + batch_bcast_patterns=batch_bcast_patterns, |
| 436 | + batch_ndim=batch_ndim, |
| 437 | + ): |
| 438 | + _check_runtime_broadcast_core( |
| 439 | + inputs, batch_bcast_patterns, batch_ndim |
| 440 | + ) |
| 441 | + return (core_func(*inputs),) |
| 442 | + else: |
347 | 443 |
|
348 |
| - gufunc = np.vectorize(core_func, signature=self.signature) |
| 444 | + def gufunc( |
| 445 | + *inputs, |
| 446 | + batch_bcast_patterns=batch_bcast_patterns, |
| 447 | + batch_ndim=batch_ndim, |
| 448 | + ): |
| 449 | + _check_runtime_broadcast_core( |
| 450 | + inputs, batch_bcast_patterns, batch_ndim |
| 451 | + ) |
| 452 | + return core_func(*inputs) |
| 453 | + else: |
| 454 | + core_node = self._create_dummy_core_node(node.inputs) # type: ignore |
| 455 | + gufunc = _vectorize_node_perform( |
| 456 | + core_node, |
| 457 | + batch_bcast_patterns=batch_bcast_patterns, |
| 458 | + batch_ndim=self.batch_ndim(node), |
| 459 | + impl=impl, |
| 460 | + ) |
349 | 461 |
|
350 |
| - node.tag.gufunc = gufunc |
| 462 | + return gufunc |
351 | 463 |
|
352 | 464 | def _check_runtime_broadcast(self, node, inputs):
|
353 | 465 | batch_ndim = self.batch_ndim(node)
|
| 466 | + batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs] |
| 467 | + _check_runtime_broadcast_core(inputs, batch_bcast, batch_ndim) |
354 | 468 |
|
355 |
| - # strict=False because we are in a hot loop |
356 |
| - for dims_and_bcast in zip( |
357 |
| - *[ |
358 |
| - zip( |
359 |
| - input.shape[:batch_ndim], |
360 |
| - sinput.type.broadcastable[:batch_ndim], |
361 |
| - strict=False, |
362 |
| - ) |
363 |
| - for input, sinput in zip(inputs, node.inputs, strict=False) |
364 |
| - ], |
365 |
| - strict=False, |
366 |
| - ): |
367 |
| - if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast: |
368 |
| - raise ValueError( |
369 |
| - "Runtime broadcasting not allowed. " |
370 |
| - "At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n" |
371 |
| - "If broadcasting was intended, use `specify_broadcastable` on the relevant input." |
372 |
| - ) |
| 469 | + def prepare_node(self, node, storage_map, compute_map, impl=None): |
| 470 | + node.tag.gufunc = self._create_node_gufunc(node, impl=impl) |
373 | 471 |
|
374 | 472 | def perform(self, node, inputs, output_storage):
|
375 |
| - gufunc = getattr(node.tag, "gufunc", None) |
376 |
| - |
377 |
| - if gufunc is None: |
378 |
| - # Cache it once per node |
379 |
| - self._create_node_gufunc(node) |
| 473 | + try: |
380 | 474 | gufunc = node.tag.gufunc
|
381 |
| - |
382 |
| - self._check_runtime_broadcast(node, inputs) |
383 |
| - |
384 |
| - res = gufunc(*inputs) |
385 |
| - if not isinstance(res, tuple): |
386 |
| - res = (res,) |
387 |
| - |
388 |
| - # strict=False because we are in a hot loop |
389 |
| - for node_out, out_storage, r in zip( |
390 |
| - node.outputs, output_storage, res, strict=False |
391 |
| - ): |
392 |
| - out_dtype = getattr(node_out, "dtype", None) |
393 |
| - if out_dtype and out_dtype != r.dtype: |
394 |
| - r = np.asarray(r, dtype=out_dtype) |
395 |
| - out_storage[0] = r |
| 475 | + except AttributeError: |
| 476 | + gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None) |
| 477 | + for out_storage, result in zip(output_storage, gufunc(*inputs)): |
| 478 | + out_storage[0] = result |
396 | 479 |
|
397 | 480 | def __str__(self):
|
398 | 481 | if self.name is None:
|
|
0 commit comments