Skip to content

Commit 51cda52

Browse files
committed
Avoid numpy broadcast_to and ndindex in hot loops
1 parent 10105be commit 51cda52

File tree

4 files changed

+39
-21
lines changed

4 files changed

+39
-21
lines changed

pytensor/tensor/random/basic.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
broadcast_params,
1919
normalize_size_param,
2020
)
21+
from pytensor.tensor.utils import faster_broadcast_to, faster_ndindex
2122

2223

2324
# Scipy.stats is considerably slow to import
@@ -976,19 +977,13 @@ def __call__(self, alphas, size=None, **kwargs):
976977
@classmethod
977978
def rng_fn(cls, rng, alphas, size):
978979
if alphas.ndim > 1:
979-
if size is None:
980-
size = ()
981-
982-
size = tuple(np.atleast_1d(size))
983-
984-
if size:
985-
alphas = np.broadcast_to(alphas, size + alphas.shape[-1:])
980+
if size is not None:
981+
alphas = faster_broadcast_to(alphas, size + alphas.shape[-1:])
986982

987983
samples_shape = alphas.shape
988984
samples = np.empty(samples_shape)
989-
for index in np.ndindex(*samples_shape[:-1]):
985+
for index in faster_ndindex(samples_shape[:-1]):
990986
samples[index] = rng.dirichlet(alphas[index])
991-
992987
return samples
993988
else:
994989
return rng.dirichlet(alphas, size=size)
@@ -1800,11 +1795,11 @@ def rng_fn(cls, rng, n, p, size):
18001795
if size is None:
18011796
n, p = broadcast_params([n, p], [0, 1])
18021797
else:
1803-
n = np.broadcast_to(n, size)
1804-
p = np.broadcast_to(p, size + p.shape[-1:])
1798+
n = faster_broadcast_to(n, size)
1799+
p = faster_broadcast_to(p, size + p.shape[-1:])
18051800

18061801
res = np.empty(p.shape, dtype=cls.dtype)
1807-
for idx in np.ndindex(p.shape[:-1]):
1802+
for idx in faster_ndindex(p.shape[:-1]):
18081803
res[idx] = rng.multinomial(n[idx], p[idx])
18091804
return res
18101805
else:
@@ -1978,13 +1973,13 @@ def rng_fn(self, *params):
19781973
p.shape[:batch_ndim],
19791974
)
19801975

1981-
a = np.broadcast_to(a, size + a.shape[batch_ndim:])
1976+
a = faster_broadcast_to(a, size + a.shape[batch_ndim:])
19821977
if p is not None:
1983-
p = np.broadcast_to(p, size + p.shape[batch_ndim:])
1978+
p = faster_broadcast_to(p, size + p.shape[batch_ndim:])
19841979

19851980
a_indexed_shape = a.shape[len(size) + 1 :]
19861981
out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype)
1987-
for idx in np.ndindex(size):
1982+
for idx in faster_ndindex(size):
19881983
out[idx] = rng.choice(
19891984
a[idx], p=None if p is None else p[idx], size=core_shape, replace=False
19901985
)
@@ -2097,10 +2092,10 @@ def rng_fn(self, rng, x, size):
20972092
if size is None:
20982093
size = x.shape[:batch_ndim]
20992094
else:
2100-
x = np.broadcast_to(x, size + x.shape[batch_ndim:])
2095+
x = faster_broadcast_to(x, size + x.shape[batch_ndim:])
21012096

21022097
out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype)
2103-
for idx in np.ndindex(size):
2098+
for idx in faster_ndindex(size):
21042099
out[idx] = rng.permutation(x[idx])
21052100
return out
21062101

pytensor/tensor/random/utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from pytensor.tensor.math import maximum
1616
from pytensor.tensor.shape import shape_padleft, specify_shape
1717
from pytensor.tensor.type import int_dtypes
18+
from pytensor.tensor.utils import faster_broadcast_to
1819
from pytensor.tensor.variable import TensorVariable
1920

2021

@@ -125,7 +126,7 @@ def broadcast_params(
125126
shapes = params_broadcast_shapes(
126127
param_shapes, ndims_params, use_pytensor=use_pytensor
127128
)
128-
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
129+
broadcast_to_fn = broadcast_to if use_pytensor else faster_broadcast_to
129130

130131
# zip strict not specified because we are in a hot loop
131132
bcast_params = [

pytensor/tensor/utils.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import re
22
from collections.abc import Sequence
3+
from itertools import product
34
from typing import cast
45

56
import numpy as np
7+
from numpy import nditer
68

79
import pytensor
810
from pytensor.graph import FunctionGraph, Variable
@@ -233,3 +235,24 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None:
233235

234236
# TODO: If axis tuple is equivalent to None, return None for more canonicalization?
235237
return cast(tuple, axis)
238+
239+
240+
def faster_broadcast_to(x, shape):
241+
# Stripped down core logic of `np.broadcast_to`
242+
return nditer(
243+
(x,),
244+
flags=["multi_index", "zerosize_ok"],
245+
op_flags=["readonly"],
246+
itershape=shape,
247+
order="C",
248+
).itviews[0]
249+
250+
251+
def faster_ndindex(shape: Sequence[int]):
252+
"""Equivalent to `np.ndindex` but usually 10x faster.
253+
254+
Unlike `np.ndindex`, this function expects a single sequence of integers
255+
256+
https://github.com/numpy/numpy/issues/28921
257+
"""
258+
return product(*(range(s) for s in shape))

tests/tensor/random/test_basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -746,9 +746,8 @@ def test_mvnormal_cov_decomposition_method(method, psd):
746746
],
747747
)
748748
def test_dirichlet_samples(alphas, size):
749-
def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None):
750-
if size is None:
751-
size = ()
749+
# FIXME: Is this just testing itself against itself?
750+
def dirichlet_test_fn(alphas, size, random_state):
752751
return dirichlet.rng_fn(random_state, alphas, size)
753752

754753
compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)

0 commit comments

Comments
 (0)