Skip to content

Commit 5fa5c9b

Browse files
committed
Speedup python implementation of Blockwise
1 parent 51cda52 commit 5fa5c9b

File tree

4 files changed

+212
-89
lines changed

4 files changed

+212
-89
lines changed

pytensor/graph/op.py

Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -502,7 +502,7 @@ def make_py_thunk(
502502
self,
503503
node: Apply,
504504
storage_map: StorageMapType,
505-
compute_map: ComputeMapType,
505+
compute_map: ComputeMapType | None,
506506
no_recycling: list[Variable],
507507
debug: bool = False,
508508
) -> ThunkType:
@@ -513,25 +513,38 @@ def make_py_thunk(
513513
"""
514514
node_input_storage = [storage_map[r] for r in node.inputs]
515515
node_output_storage = [storage_map[r] for r in node.outputs]
516-
node_compute_map = [compute_map[r] for r in node.outputs]
517516

518517
if debug and hasattr(self, "debug_perform"):
519518
p = node.op.debug_perform
520519
else:
521520
p = node.op.perform
522521

523-
@is_thunk_type
524-
def rval(
525-
p=p,
526-
i=node_input_storage,
527-
o=node_output_storage,
528-
n=node,
529-
cm=node_compute_map,
530-
):
531-
r = p(n, [x[0] for x in i], o)
532-
for entry in cm:
533-
entry[0] = True
534-
return r
522+
if compute_map is None:
523+
524+
@is_thunk_type
525+
def rval(
526+
p=p,
527+
i=node_input_storage,
528+
o=node_output_storage,
529+
n=node,
530+
):
531+
return p(n, [x[0] for x in i], o)
532+
533+
else:
534+
node_compute_map = [compute_map[r] for r in node.outputs]
535+
536+
@is_thunk_type
537+
def rval(
538+
p=p,
539+
i=node_input_storage,
540+
o=node_output_storage,
541+
n=node,
542+
cm=node_compute_map,
543+
):
544+
r = p(n, [x[0] for x in i], o)
545+
for entry in cm:
546+
entry[0] = True
547+
return r
535548

536549
rval.inputs = node_input_storage
537550
rval.outputs = node_output_storage

pytensor/link/c/op.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def make_c_thunk(
3939
self,
4040
node: Apply,
4141
storage_map: StorageMapType,
42-
compute_map: ComputeMapType,
42+
compute_map: ComputeMapType | None,
4343
no_recycling: Collection[Variable],
4444
) -> CThunkWrapperType:
4545
"""Create a thunk for a C implementation.
@@ -86,11 +86,17 @@ def is_f16(t):
8686
)
8787
thunk, node_input_filters, node_output_filters = outputs
8888

89-
@is_cthunk_wrapper_type
90-
def rval():
91-
thunk()
92-
for o in node.outputs:
93-
compute_map[o][0] = True
89+
if compute_map is None:
90+
rval = is_cthunk_wrapper_type(thunk)
91+
92+
else:
93+
cm_entries = [compute_map[o] for o in node.outputs]
94+
95+
@is_cthunk_wrapper_type
96+
def rval(thunk=thunk, cm_entries=cm_entries):
97+
thunk()
98+
for entry in cm_entries:
99+
entry[0] = True
94100

95101
rval.thunk = thunk
96102
rval.cthunk = thunk.cthunk

pytensor/tensor/blockwise.py

Lines changed: 151 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1-
from collections.abc import Sequence
1+
from collections.abc import Callable, Sequence
22
from typing import Any, cast
33

44
import numpy as np
5+
from numpy import broadcast_shapes, empty
56

67
from pytensor import config
78
from pytensor.compile.builders import OpFromGraph
@@ -22,12 +23,111 @@
2223
from pytensor.tensor.utils import (
2324
_parse_gufunc_signature,
2425
broadcast_static_dim_lengths,
26+
faster_broadcast_to,
27+
faster_ndindex,
2528
import_func_from_string,
2629
safe_signature,
2730
)
2831
from pytensor.tensor.variable import TensorVariable
2932

3033

34+
def _vectorize_node_perform(
35+
core_node: Apply,
36+
batch_bcast_patterns: Sequence[tuple[bool, ...]],
37+
batch_ndim: int,
38+
impl: str | None,
39+
) -> Callable:
40+
"""Creates a vectorized `perform` function for a given core node.
41+
42+
Similar behavior of np.vectorize, but specialized for PyTensor Blockwise Op.
43+
"""
44+
45+
storage_map = {var: [None] for var in core_node.inputs + core_node.outputs}
46+
core_thunk = core_node.op.make_thunk(core_node, storage_map, None, [], impl=impl)
47+
single_in = len(core_node.inputs) == 1
48+
core_input_storage = [storage_map[inp] for inp in core_node.inputs]
49+
core_output_storage = [storage_map[out] for out in core_node.outputs]
50+
core_storage = core_input_storage + core_output_storage
51+
52+
def vectorized_perform(
53+
*args,
54+
batch_bcast_patterns=batch_bcast_patterns,
55+
batch_ndim=batch_ndim,
56+
single_in=single_in,
57+
core_thunk=core_thunk,
58+
core_input_storage=core_input_storage,
59+
core_output_storage=core_output_storage,
60+
core_storage=core_storage,
61+
):
62+
if single_in:
63+
batch_shape = args[0].shape[:batch_ndim]
64+
else:
65+
_check_runtime_broadcast_core(args, batch_bcast_patterns, batch_ndim)
66+
batch_shape = broadcast_shapes(*(arg.shape[:batch_ndim] for arg in args))
67+
args = list(args)
68+
for i, arg in enumerate(args):
69+
if arg.shape[:batch_ndim] != batch_shape:
70+
args[i] = faster_broadcast_to(
71+
arg, batch_shape + arg.shape[batch_ndim:]
72+
)
73+
74+
ndindex_iterator = faster_ndindex(batch_shape)
75+
# Call once to get the output shapes
76+
try:
77+
# TODO: Pass core shape as input like BlockwiseWithCoreShape does?
78+
index0 = next(ndindex_iterator)
79+
except StopIteration:
80+
raise NotImplementedError("vectorize with zero size not implemented")
81+
else:
82+
for core_input, arg in zip(core_input_storage, args):
83+
core_input[0] = np.asarray(arg[index0])
84+
core_thunk()
85+
outputs = tuple(
86+
empty(batch_shape + core_output[0].shape, dtype=core_output[0].dtype)
87+
for core_output in core_output_storage
88+
)
89+
for output, core_output in zip(outputs, core_output_storage):
90+
output[index0] = core_output[0]
91+
92+
for index in ndindex_iterator:
93+
for core_input, arg in zip(core_input_storage, args):
94+
core_input[0] = np.asarray(arg[index])
95+
core_thunk()
96+
for output, core_output in zip(outputs, core_output_storage):
97+
output[index] = core_output[0]
98+
99+
# Clear storage
100+
for core_val in core_storage:
101+
core_val[0] = None
102+
return outputs
103+
104+
return vectorized_perform
105+
106+
107+
def _check_runtime_broadcast_core(numerical_inputs, batch_bcast_patterns, batch_ndim):
108+
# strict=None because we are in a hot loop
109+
# We zip together the dimension lengths of each input and their broadcast patterns
110+
for dim_lengths_and_bcast in zip(
111+
*[
112+
zip(input.shape[:batch_ndim], batch_bcast_pattern)
113+
for input, batch_bcast_pattern in zip(
114+
numerical_inputs, batch_bcast_patterns
115+
)
116+
],
117+
):
118+
# If for any dimension where an entry has dim_length != 1,
119+
# and another a dim_length of 1 and broadcastable=False, we have runtime broadcasting.
120+
if (
121+
any(d != 1 for d, _ in dim_lengths_and_bcast)
122+
and (1, False) in dim_lengths_and_bcast
123+
):
124+
raise ValueError(
125+
"Runtime broadcasting not allowed. "
126+
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
127+
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
128+
)
129+
130+
31131
class Blockwise(Op):
32132
"""Generalizes a core `Op` to work with batched dimensions.
33133
@@ -308,91 +408,74 @@ def L_op(self, inputs, outs, ograds):
308408

309409
return rval
310410

311-
def _create_node_gufunc(self, node) -> None:
411+
def _create_node_gufunc(self, node: Apply, impl) -> Callable:
312412
"""Define (or retrieve) the node gufunc used in `perform`.
313413
314414
If the Blockwise or core_op have a `gufunc_spec`, the relevant numpy or scipy gufunc is used directly.
315415
Otherwise, we default to `np.vectorize` of the core_op `perform` method for a dummy node.
316416
317417
The gufunc is stored in the tag of the node.
318418
"""
319-
gufunc_spec = self.gufunc_spec or getattr(self.core_op, "gufunc_spec", None)
320-
321-
if gufunc_spec is not None:
322-
gufunc = import_func_from_string(gufunc_spec[0])
323-
if gufunc is None:
419+
batch_ndim = self.batch_ndim(node)
420+
batch_bcast_patterns = [
421+
inp.type.broadcastable[:batch_ndim] for inp in node.inputs
422+
]
423+
if (
424+
gufunc_spec := self.gufunc_spec
425+
or getattr(self.core_op, "gufunc_spec", None)
426+
) is not None:
427+
core_func = import_func_from_string(gufunc_spec[0])
428+
if core_func is None:
324429
raise ValueError(f"Could not import gufunc {gufunc_spec[0]} for {self}")
325430

326-
else:
327-
# Wrap core_op perform method in numpy vectorize
328-
n_outs = len(self.outputs_sig)
329-
core_node = self._create_dummy_core_node(node.inputs)
330-
inner_outputs_storage = [[None] for _ in range(n_outs)]
331-
332-
def core_func(
333-
*inner_inputs,
334-
core_node=core_node,
335-
inner_outputs_storage=inner_outputs_storage,
336-
):
337-
self.core_op.perform(
338-
core_node,
339-
[np.asarray(inp) for inp in inner_inputs],
340-
inner_outputs_storage,
341-
)
342-
343-
if n_outs == 1:
344-
return inner_outputs_storage[0][0]
345-
else:
346-
return tuple(r[0] for r in inner_outputs_storage)
431+
if len(node.outputs) == 1:
432+
433+
def gufunc(
434+
*inputs,
435+
batch_bcast_patterns=batch_bcast_patterns,
436+
batch_ndim=batch_ndim,
437+
):
438+
_check_runtime_broadcast_core(
439+
inputs, batch_bcast_patterns, batch_ndim
440+
)
441+
return (core_func(*inputs),)
442+
else:
347443

348-
gufunc = np.vectorize(core_func, signature=self.signature)
444+
def gufunc(
445+
*inputs,
446+
batch_bcast_patterns=batch_bcast_patterns,
447+
batch_ndim=batch_ndim,
448+
):
449+
_check_runtime_broadcast_core(
450+
inputs, batch_bcast_patterns, batch_ndim
451+
)
452+
return core_func(*inputs)
453+
else:
454+
core_node = self._create_dummy_core_node(node.inputs) # type: ignore
455+
gufunc = _vectorize_node_perform(
456+
core_node,
457+
batch_bcast_patterns=batch_bcast_patterns,
458+
batch_ndim=self.batch_ndim(node),
459+
impl=impl,
460+
)
349461

350-
node.tag.gufunc = gufunc
462+
return gufunc
351463

352464
def _check_runtime_broadcast(self, node, inputs):
353465
batch_ndim = self.batch_ndim(node)
466+
batch_bcast = [pt_inp.type.broadcastable[:batch_ndim] for pt_inp in node.inputs]
467+
_check_runtime_broadcast_core(inputs, batch_bcast, batch_ndim)
354468

355-
# strict=False because we are in a hot loop
356-
for dims_and_bcast in zip(
357-
*[
358-
zip(
359-
input.shape[:batch_ndim],
360-
sinput.type.broadcastable[:batch_ndim],
361-
strict=False,
362-
)
363-
for input, sinput in zip(inputs, node.inputs, strict=False)
364-
],
365-
strict=False,
366-
):
367-
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
368-
raise ValueError(
369-
"Runtime broadcasting not allowed. "
370-
"At least one input has a distinct batch dimension length of 1, but was not marked as broadcastable.\n"
371-
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
372-
)
469+
def prepare_node(self, node, storage_map, compute_map, impl=None):
470+
node.tag.gufunc = self._create_node_gufunc(node, impl=impl)
373471

374472
def perform(self, node, inputs, output_storage):
375-
gufunc = getattr(node.tag, "gufunc", None)
376-
377-
if gufunc is None:
378-
# Cache it once per node
379-
self._create_node_gufunc(node)
473+
try:
380474
gufunc = node.tag.gufunc
381-
382-
self._check_runtime_broadcast(node, inputs)
383-
384-
res = gufunc(*inputs)
385-
if not isinstance(res, tuple):
386-
res = (res,)
387-
388-
# strict=False because we are in a hot loop
389-
for node_out, out_storage, r in zip(
390-
node.outputs, output_storage, res, strict=False
391-
):
392-
out_dtype = getattr(node_out, "dtype", None)
393-
if out_dtype and out_dtype != r.dtype:
394-
r = np.asarray(r, dtype=out_dtype)
395-
out_storage[0] = r
475+
except AttributeError:
476+
gufunc = node.tag.gufunc = self._create_node_gufunc(node, impl=None)
477+
for out_storage, result in zip(output_storage, gufunc(*inputs)):
478+
out_storage[0] = result
396479

397480
def __str__(self):
398481
if self.name is None:

tests/tensor/test_blockwise.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
from pytensor.graph import Apply, Op
1313
from pytensor.graph.replace import vectorize_node
1414
from pytensor.raise_op import assert_op
15-
from pytensor.tensor import diagonal, log, ones_like, scalar, tensor, vector
15+
from pytensor.tensor import diagonal, dmatrix, log, ones_like, scalar, tensor, vector
1616
from pytensor.tensor.blockwise import Blockwise, vectorize_node_fallback
1717
from pytensor.tensor.nlinalg import MatrixInverse
1818
from pytensor.tensor.rewriting.blas import specialize_matmul_to_batched_dot
19+
from pytensor.tensor.signal import convolve1d
1920
from pytensor.tensor.slinalg import (
2021
Cholesky,
2122
Solve,
@@ -484,6 +485,26 @@ def test_batched_mvnormal_logp_and_dlogp(mu_batch_shape, cov_batch_shape, benchm
484485
benchmark(fn, *test_values)
485486

486487

488+
def test_small_blockwise_performance(benchmark):
489+
a = dmatrix(shape=(7, 128))
490+
b = dmatrix(shape=(7, 20))
491+
out = convolve1d(a, b, mode="valid")
492+
fn = pytensor.function([a, b], out, trust_input=True)
493+
assert isinstance(fn.maker.fgraph.outputs[0].owner.op, Blockwise)
494+
495+
rng = np.random.default_rng(495)
496+
a_test = rng.normal(size=a.type.shape)
497+
b_test = rng.normal(size=b.type.shape)
498+
np.testing.assert_allclose(
499+
fn(a_test, b_test),
500+
[
501+
np.convolve(a_test[i], b_test[i], mode="valid")
502+
for i in range(a_test.shape[0])
503+
],
504+
)
505+
benchmark(fn, a_test, b_test)
506+
507+
487508
def test_cop_with_params():
488509
matrix_assert = Blockwise(core_op=assert_op, signature="(x1,x2),()->(x1,x2)")
489510

0 commit comments

Comments
 (0)