Skip to content

Commit 10105be

Browse files
committed
Don't specify zip strict kwarg in hot loops
It seems to add a non-trivial 100ns
1 parent 5335a68 commit 10105be

File tree

20 files changed

+60
-72
lines changed

20 files changed

+60
-72
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ exclude = ["doc/", "pytensor/_version.py"]
130130
docstring-code-format = true
131131

132132
[tool.ruff.lint]
133-
select = ["B905", "C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
133+
select = ["C", "E", "F", "I", "UP", "W", "RUF", "PERF", "PTH", "ISC", "T20", "NPY201"]
134134
ignore = ["C408", "C901", "E501", "E741", "RUF012", "PERF203", "ISC001"]
135135
unfixable = [
136136
# zip-strict: the auto-fix adds `strict=False` but we might want `strict=True` instead

pytensor/compile/builders.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -873,7 +873,6 @@ def clone(self):
873873

874874
def perform(self, node, inputs, outputs):
875875
variables = self.fn(*inputs)
876-
assert len(variables) == len(outputs)
877-
# strict=False because asserted above
878-
for output, variable in zip(outputs, variables, strict=False):
876+
# zip strict not specified because we are in a hot loop
877+
for output, variable in zip(outputs, variables):
879878
output[0] = variable

pytensor/compile/function/types.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,7 +924,8 @@ def __call__(self, *args, output_subset=None, **kwargs):
924924

925925
# Reinitialize each container's 'provided' counter
926926
if trust_input:
927-
for arg_container, arg in zip(input_storage, args, strict=False):
927+
# zip strict not specified because we are in a hot loop
928+
for arg_container, arg in zip(input_storage, args):
928929
arg_container.storage[0] = arg
929930
else:
930931
for arg_container in input_storage:
@@ -934,7 +935,8 @@ def __call__(self, *args, output_subset=None, **kwargs):
934935
raise TypeError("Too many parameter passed to pytensor function")
935936

936937
# Set positional arguments
937-
for arg_container, arg in zip(input_storage, args, strict=False):
938+
# zip strict not specified because we are in a hot loop
939+
for arg_container, arg in zip(input_storage, args):
938940
# See discussion about None as input
939941
# https://groups.google.com/group/theano-dev/browse_thread/thread/920a5e904e8a8525/4f1b311a28fc27e5
940942
if arg is None:

pytensor/ifelse.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,8 +305,8 @@ def thunk():
305305
if len(ls) > 0:
306306
return ls
307307
else:
308-
# strict=False because we are in a hot loop
309-
for out, t in zip(outputs, input_true_branch, strict=False):
308+
# zip strict not specified because we are in a hot loop
309+
for out, t in zip(outputs, input_true_branch):
310310
compute_map[out][0] = 1
311311
val = storage_map[t][0]
312312
if self.as_view:
@@ -326,8 +326,8 @@ def thunk():
326326
if len(ls) > 0:
327327
return ls
328328
else:
329-
# strict=False because we are in a hot loop
330-
for out, f in zip(outputs, inputs_false_branch, strict=False):
329+
# zip strict not specified because we are in a hot loop
330+
for out, f in zip(outputs, inputs_false_branch):
331331
compute_map[out][0] = 1
332332
# can't view both outputs unless destroyhandler
333333
# improves

pytensor/link/basic.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -539,14 +539,14 @@ def make_thunk(self, **kwargs):
539539

540540
def f():
541541
for inputs in input_lists[1:]:
542-
# strict=False because we are in a hot loop
543-
for input1, input2 in zip(inputs0, inputs, strict=False):
542+
# zip strict not specified because we are in a hot loop
543+
for input1, input2 in zip(inputs0, inputs):
544544
input2.storage[0] = copy(input1.storage[0])
545545
for x in to_reset:
546546
x[0] = None
547547
pre(self, [input.data for input in input_lists[0]], order, thunk_groups)
548-
# strict=False because we are in a hot loop
549-
for i, (thunks, node) in enumerate(zip(thunk_groups, order, strict=False)):
548+
# zip strict not specified because we are in a hot loop
549+
for i, (thunks, node) in enumerate(zip(thunk_groups, order)):
550550
try:
551551
wrapper(self.fgraph, i, node, *thunks)
552552
except Exception:
@@ -668,8 +668,8 @@ def thunk(
668668
# since the error may come from any of them?
669669
raise_with_op(self.fgraph, output_nodes[0], thunk)
670670

671-
# strict=False because we are in a hot loop
672-
for o_storage, o_val in zip(thunk_outputs, outputs, strict=False):
671+
# zip strict not specified because we are in a hot loop
672+
for o_storage, o_val in zip(thunk_outputs, outputs):
673673
o_storage[0] = o_val
674674

675675
thunk.inputs = thunk_inputs

pytensor/link/c/basic.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1988,27 +1988,23 @@ def make_thunk(self, **kwargs):
19881988
)
19891989

19901990
def f():
1991-
# strict=False because we are in a hot loop
1992-
for input1, input2 in zip(i1, i2, strict=False):
1991+
# zip strict not specified because we are in a hot loop
1992+
for input1, input2 in zip(i1, i2):
19931993
# Set the inputs to be the same in both branches.
19941994
# The copy is necessary in order for inplace ops not to
19951995
# interfere.
19961996
input2.storage[0] = copy(input1.storage[0])
1997-
for thunk1, thunk2, node1, node2 in zip(
1998-
thunks1, thunks2, order1, order2, strict=False
1999-
):
2000-
for output, storage in zip(node1.outputs, thunk1.outputs, strict=False):
1997+
for thunk1, thunk2, node1, node2 in zip(thunks1, thunks2, order1, order2):
1998+
for output, storage in zip(node1.outputs, thunk1.outputs):
20011999
if output in no_recycling:
20022000
storage[0] = None
2003-
for output, storage in zip(node2.outputs, thunk2.outputs, strict=False):
2001+
for output, storage in zip(node2.outputs, thunk2.outputs):
20042002
if output in no_recycling:
20052003
storage[0] = None
20062004
try:
20072005
thunk1()
20082006
thunk2()
2009-
for output1, output2 in zip(
2010-
thunk1.outputs, thunk2.outputs, strict=False
2011-
):
2007+
for output1, output2 in zip(thunk1.outputs, thunk2.outputs):
20122008
self.checker(output1, output2)
20132009
except Exception:
20142010
raise_with_op(fgraph, node1)

pytensor/link/numba/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -312,10 +312,10 @@ def py_perform_return(inputs):
312312
else:
313313

314314
def py_perform_return(inputs):
315-
# strict=False because we are in a hot loop
315+
# zip strict not specified because we are in a hot loop
316316
return tuple(
317317
out_type.filter(out[0])
318-
for out_type, out in zip(output_types, py_perform(inputs), strict=False)
318+
for out_type, out in zip(output_types, py_perform(inputs))
319319
)
320320

321321
@numba_njit

pytensor/link/numba/dispatch/cython_support.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,7 @@ def __wrapper_address__(self):
166166
def __call__(self, *args, **kwargs):
167167
# no strict argument because of the JIT
168168
# TODO: check
169-
args = [
170-
dtype(arg)
171-
for arg, dtype in zip(args, self._signature.arg_dtypes) # noqa: B905
172-
]
169+
args = [dtype(arg) for arg, dtype in zip(args, self._signature.arg_dtypes)]
173170
if self.has_pyx_skip_dispatch():
174171
output = self._pyfunc(*args[:-1], **kwargs)
175172
else:

pytensor/link/numba/dispatch/extra_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def ravelmultiindex(*inp):
186186
new_arr = arr.T.astype(np.float64).copy()
187187
for i, b in enumerate(new_arr):
188188
# no strict argument to this zip because numba doesn't support it
189-
for j, (d, v) in enumerate(zip(shape, b)): # noqa: B905
189+
for j, (d, v) in enumerate(zip(shape, b)):
190190
if v < 0 or v >= d:
191191
mode_fn(new_arr, i, j, v, d)
192192

pytensor/link/numba/dispatch/slinalg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ def block_diag(*arrs):
183183

184184
r, c = 0, 0
185185
# no strict argument because it is incompatible with numba
186-
for arr, shape in zip(arrs, shapes): # noqa: B905
186+
for arr, shape in zip(arrs, shapes):
187187
rr, cc = shape
188188
out[r : r + rr, c : c + cc] = arr
189189
r += rr

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,7 @@ def advanced_subtensor_multiple_vector(x, *idxs):
219219
shape_aft = x_shape[after_last_axis:]
220220
out_shape = (*shape_bef, *idx_shape, *shape_aft)
221221
out_buffer = np.empty(out_shape, dtype=x.dtype)
222-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
222+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
223223
out_buffer[(*none_slices, i)] = x[(*none_slices, *scalar_idxs)]
224224
return out_buffer
225225

@@ -253,7 +253,7 @@ def advanced_set_subtensor_multiple_vector(x, y, *idxs):
253253
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
254254

255255
for outer in np.ndindex(x_shape[:first_axis]):
256-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
256+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
257257
out[(*outer, *scalar_idxs)] = y[(*outer, i)]
258258
return out
259259

@@ -275,7 +275,7 @@ def advanced_inc_subtensor_multiple_vector(x, y, *idxs):
275275
y = np.broadcast_to(y, x_shape[:first_axis] + x_shape[last_axis:])
276276

277277
for outer in np.ndindex(x_shape[:first_axis]):
278-
for i, scalar_idxs in enumerate(zip(*vec_idxs)): # noqa: B905
278+
for i, scalar_idxs in enumerate(zip(*vec_idxs)):
279279
out[(*outer, *scalar_idxs)] += y[(*outer, i)]
280280
return out
281281

@@ -314,7 +314,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
314314
if not len(idxs) == len(vals):
315315
raise ValueError("The number of indices and values must match.")
316316
# no strict argument because incompatible with numba
317-
for idx, val in zip(idxs, vals): # noqa: B905
317+
for idx, val in zip(idxs, vals):
318318
x[idx] = val
319319
return x
320320
else:
@@ -342,7 +342,7 @@ def advancedincsubtensor1_inplace(x, vals, idxs):
342342
raise ValueError("The number of indices and values must match.")
343343
# no strict argument because unsupported by numba
344344
# TODO: this doesn't come up in tests
345-
for idx, val in zip(idxs, vals): # noqa: B905
345+
for idx, val in zip(idxs, vals):
346346
x[idx] += val
347347
return x
348348

pytensor/link/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,8 @@ def streamline_nice_errors_f():
207207
for x in no_recycling:
208208
x[0] = None
209209
try:
210-
# strict=False because we are in a hot loop
211-
for thunk, node in zip(thunks, order, strict=False):
210+
# zip strict not specified because we are in a hot loop
211+
for thunk, node in zip(thunks, order):
212212
thunk()
213213
except Exception:
214214
raise_with_op(fgraph, node, thunk)

pytensor/scalar/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4416,8 +4416,8 @@ def make_node(self, *inputs):
44164416

44174417
def perform(self, node, inputs, output_storage):
44184418
outputs = self.py_perform_fn(*inputs)
4419-
# strict=False because we are in a hot loop
4420-
for storage, out_val in zip(output_storage, outputs, strict=False):
4419+
# zip strict not specified because we are in a hot loop
4420+
for storage, out_val in zip(output_storage, outputs):
44214421
storage[0] = out_val
44224422

44234423
def grad(self, inputs, output_grads):

pytensor/scalar/loop.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -196,8 +196,8 @@ def perform(self, node, inputs, output_storage):
196196
for i in range(n_steps):
197197
carry = inner_fn(*carry, *constant)
198198

199-
# strict=False because we are in a hot loop
200-
for storage, out_val in zip(output_storage, carry, strict=False):
199+
# zip strict not specified because we are in a hot loop
200+
for storage, out_val in zip(output_storage, carry):
201201
storage[0] = out_val
202202

203203
@property

pytensor/tensor/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3589,8 +3589,8 @@ def perform(self, node, inp, out):
35893589

35903590
# Make sure the output is big enough
35913591
out_s = []
3592-
# strict=False because we are in a hot loop
3593-
for xdim, ydim in zip(x_s, y_s, strict=False):
3592+
# zip strict not specified because we are in a hot loop
3593+
for xdim, ydim in zip(x_s, y_s):
35943594
if xdim == ydim:
35953595
outdim = xdim
35963596
elif xdim == 1:

pytensor/tensor/elemwise.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -712,9 +712,9 @@ def perform(self, node, inputs, output_storage):
712712
if nout == 1:
713713
variables = [variables]
714714

715-
# strict=False because we are in a hot loop
715+
# zip strict not specified because we are in a hot loop
716716
for i, (variable, storage, nout) in enumerate(
717-
zip(variables, output_storage, node.outputs, strict=False)
717+
zip(variables, output_storage, node.outputs)
718718
):
719719
storage[0] = variable = np.asarray(variable, dtype=nout.dtype)
720720

@@ -729,11 +729,11 @@ def perform(self, node, inputs, output_storage):
729729

730730
@staticmethod
731731
def _check_runtime_broadcast(node, inputs):
732-
# strict=False because we are in a hot loop
732+
# zip strict not specified because we are in a hot loop
733733
for dims_and_bcast in zip(
734734
*[
735-
zip(input.shape, sinput.type.broadcastable, strict=False)
736-
for input, sinput in zip(inputs, node.inputs, strict=False)
735+
zip(input.shape, sinput.type.broadcastable)
736+
for input, sinput in zip(inputs, node.inputs)
737737
],
738738
strict=False,
739739
):

pytensor/tensor/random/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1865,8 +1865,8 @@ def rng_fn(cls, rng, p, size):
18651865
# to `p.shape[:-1]` in the call to `vsearchsorted` below.
18661866
if len(size) < (p.ndim - 1):
18671867
raise ValueError("`size` is incompatible with the shape of `p`")
1868-
# strict=False because we are in a hot loop
1869-
for s, ps in zip(reversed(size), reversed(p.shape[:-1]), strict=False):
1868+
# zip strict not specified because we are in a hot loop
1869+
for s, ps in zip(reversed(size), reversed(p.shape[:-1])):
18701870
if s == 1 and ps != 1:
18711871
raise ValueError("`size` is incompatible with the shape of `p`")
18721872

pytensor/tensor/random/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,8 @@ def params_broadcast_shapes(
4444
max_fn = maximum if use_pytensor else max
4545

4646
rev_extra_dims: list[int] = []
47-
# strict=False because we are in a hot loop
48-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False):
47+
# zip strict not specified because we are in a hot loop
48+
for ndim_param, param_shape in zip(ndims_params, param_shapes):
4949
# We need this in order to use `len`
5050
param_shape = tuple(param_shape)
5151
extras = tuple(param_shape[: (len(param_shape) - ndim_param)])
@@ -64,12 +64,12 @@ def max_bcast(x, y):
6464

6565
extra_dims = tuple(reversed(rev_extra_dims))
6666

67-
# strict=False because we are in a hot loop
67+
# zip strict not specified because we are in a hot loop
6868
bcast_shapes = [
6969
(extra_dims + tuple(param_shape)[-ndim_param:])
7070
if ndim_param > 0
7171
else extra_dims
72-
for ndim_param, param_shape in zip(ndims_params, param_shapes, strict=False)
72+
for ndim_param, param_shape in zip(ndims_params, param_shapes)
7373
]
7474

7575
return bcast_shapes
@@ -127,10 +127,9 @@ def broadcast_params(
127127
)
128128
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
129129

130-
# strict=False because we are in a hot loop
130+
# zip strict not specified because we are in a hot loop
131131
bcast_params = [
132-
broadcast_to_fn(param, shape)
133-
for shape, param in zip(shapes, params, strict=False)
132+
broadcast_to_fn(param, shape) for shape, param in zip(shapes, params)
134133
]
135134

136135
return bcast_params

pytensor/tensor/shape.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -447,10 +447,8 @@ def perform(self, node, inp, out_):
447447
raise AssertionError(
448448
f"SpecifyShape: Got {x.ndim} dimensions (shape {x.shape}), expected {ndim} dimensions with shape {tuple(shape)}."
449449
)
450-
# strict=False because we are in a hot loop
451-
if not all(
452-
xs == s for xs, s in zip(x.shape, shape, strict=False) if s is not None
453-
):
450+
# zip strict not specified because we are in a hot loop
451+
if not all(xs == s for xs, s in zip(x.shape, shape) if s is not None):
454452
raise AssertionError(
455453
f"SpecifyShape: Got shape {x.shape}, expected {tuple(int(s) if s is not None else None for s in shape)}."
456454
)

pytensor/tensor/type.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
261261
" PyTensor C code does not support that.",
262262
)
263263

264-
# strict=False because we are in a hot loop
264+
# zip strict not specified because we are in a hot loop
265265
if not all(
266266
ds == ts if ts is not None else True
267-
for ds, ts in zip(data.shape, self.shape, strict=False)
267+
for ds, ts in zip(data.shape, self.shape)
268268
):
269269
raise TypeError(
270270
f"The type's shape ({self.shape}) is not compatible with the data's ({data.shape})"
@@ -333,17 +333,14 @@ def in_same_class(self, otype):
333333
return False
334334

335335
def is_super(self, otype):
336-
# strict=False because we are in a hot loop
336+
# zip strict not specified because we are in a hot loop
337337
if (
338338
isinstance(otype, type(self))
339339
and otype.dtype == self.dtype
340340
and otype.ndim == self.ndim
341341
# `otype` is allowed to be as or more shape-specific than `self`,
342342
# but not less
343-
and all(
344-
sb == ob or sb is None
345-
for sb, ob in zip(self.shape, otype.shape, strict=False)
346-
)
343+
and all(sb == ob or sb is None for sb, ob in zip(self.shape, otype.shape))
347344
):
348345
return True
349346

0 commit comments

Comments
 (0)