Skip to content

Commit 02566b6

Browse files
Use subclassed OpFromGraph to represent pad Op
1 parent a98b8ae commit 02566b6

File tree

2 files changed

+85
-27
lines changed

2 files changed

+85
-27
lines changed

pytensor/tensor/pad.py

Lines changed: 74 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Callable
2-
from typing import Literal
2+
from typing import Literal, cast
33

4+
from pytensor.compile.builders import OpFromGraph
45
from pytensor.scan import scan
56
from pytensor.tensor import TensorLike
67
from pytensor.tensor.basic import (
@@ -34,6 +35,19 @@
3435
]
3536
stat_funcs = {"maximum": pt_max, "minimum": pt_min, "mean": mean}
3637

38+
allowed_kwargs = {
39+
"edge": [],
40+
"wrap": [],
41+
"constant": ["constant_values"],
42+
"linear_ramp": ["end_values"],
43+
"maximum": ["stat_length"],
44+
"mean": ["stat_length"],
45+
"median": ["stat_length"],
46+
"minimum": ["stat_length"],
47+
"reflect": ["reflect_type"],
48+
"symmetric": ["reflect_type"],
49+
}
50+
3751

3852
def _slice_at_axis(sl: slice, axis: int) -> tuple[slice, ...]:
3953
"""
@@ -225,17 +239,20 @@ def _get_stats(
225239

226240

227241
def _stat_pad(
228-
x: TensorVariable, pad_width: TensorVariable, stat_func, stat_length=None
242+
x: TensorVariable,
243+
pad_width: TensorVariable,
244+
stat_func: Callable,
245+
stat_length: TensorVariable | None,
229246
):
230247
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
231248
if stat_length is None:
232-
stat_length = [[None, None]] * padded.ndim
249+
stat_length = [[None, None]] * padded.ndim # type: ignore
233250
else:
234251
stat_length = broadcast_to(stat_length, as_tensor((padded.ndim, 2)))
235252

236253
for axis in range(padded.ndim):
237254
width_pair = pad_width[axis]
238-
length_pair = stat_length[axis]
255+
length_pair = stat_length[axis] # type: ignore
239256
dim_shape = padded.shape[axis]
240257

241258
left_stat, right_stat = _get_stats(
@@ -311,6 +328,10 @@ def inner_func(i, x):
311328
# Delay creation of this function to here because we want to use the axis global inside the scan
312329
def inner_func(i, x):
313330
return switch(eq(i % 2, 0), flip(x, axis=axis), x)
331+
else:
332+
raise ValueError(
333+
"You should not have gotten here. Open an issue on github!"
334+
) # pragma no cover
314335

315336
size = x.shape[axis]
316337
repeats, (left_remainder, right_remainder) = pt_divmod(pad_width[axis], size)
@@ -330,55 +351,81 @@ def inner_func(i, x):
330351
return x
331352

332353

333-
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
334-
allowed_kwargs = {
335-
"edge": [],
336-
"wrap": [],
337-
"constant": ["constant_values"],
338-
"linear_ramp": ["end_values"],
339-
"maximum": ["stat_length"],
340-
"mean": ["stat_length"],
341-
"median": ["stat_length"],
342-
"minimum": ["stat_length"],
343-
"reflect": ["reflect_type"],
344-
"symmetric": ["reflect_type"],
345-
}
354+
class Pad(OpFromGraph):
355+
"""
356+
Wrapper Op for Pad graphs
357+
"""
346358

359+
360+
def pad(x: TensorLike, pad_width: TensorLike, mode: PadMode = "constant", **kwargs):
347361
if any(value not in allowed_kwargs[mode] for value in kwargs.keys()):
348362
raise ValueError(
349363
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
350364
)
351-
x = as_tensor(x)
352-
pad_width = as_tensor(pad_width)
365+
x = as_tensor(x, name="x")
366+
pad_width = as_tensor(pad_width, name="pad_width")
367+
inputs = [x, pad_width]
368+
attrs = {}
353369

354370
if mode == "constant":
355-
constant_values = as_tensor(kwargs.pop("constant_values", 0))
356-
return _constant_pad(x, pad_width, constant_values)
371+
constant_values = as_tensor(
372+
kwargs.pop("constant_values", 0), name="constant_values"
373+
)
374+
inputs += [constant_values]
375+
outputs = _constant_pad(x, pad_width, constant_values)
376+
357377
elif mode == "edge":
358-
return _edge_pad(x, pad_width)
378+
outputs = _edge_pad(x, pad_width)
379+
359380
elif mode in ["maximum", "minimum", "mean", "median"]:
360381
if mode == "median":
361382
# TODO: pt.quantile? pt.median?
362383
raise NotImplementedError("Median padding not implemented")
363-
stat_func = stat_funcs[mode]
364-
return _stat_pad(x, pad_width, stat_func, **kwargs)
384+
stat_func = cast(Callable, stat_funcs[mode])
385+
stat_length = kwargs.get("stat_length")
386+
if stat_length is not None:
387+
stat_length = as_tensor(stat_length, name="stat_length")
388+
inputs += [stat_length]
389+
390+
attrs.update(
391+
{"stat_func": stat_func, "stat_length_input": stat_length is not None}
392+
)
393+
outputs = _stat_pad(x, pad_width, stat_func, stat_length)
394+
365395
elif mode == "linear_ramp":
366396
end_values = kwargs.pop("end_values", 0)
367-
return _linear_ramp_pad(x, pad_width, end_values)
397+
end_values = as_tensor(end_values)
398+
399+
inputs += [end_values]
400+
outputs = _linear_ramp_pad(x, pad_width, end_values)
401+
368402
elif mode == "wrap":
369-
return _looping_pad(x, pad_width, kind="wrap")
403+
attrs.update({"kind": "wrap"})
404+
outputs = _looping_pad(x, pad_width, kind="wrap")
405+
370406
elif mode == "symmetric":
371407
reflect_type = kwargs.pop("reflect_type", "even")
372408
if reflect_type == "odd":
373409
raise NotImplementedError("Odd reflection not implemented")
374-
return _looping_pad(x, pad_width, kind="symmetric")
410+
411+
attrs.update({"kind": reflect_type})
412+
outputs = _looping_pad(x, pad_width, kind="symmetric")
413+
375414
elif mode == "reflect":
376415
reflect_type = kwargs.pop("reflect_type", "even")
377416
if reflect_type == "odd":
378417
raise NotImplementedError("Odd reflection not implemented")
418+
attrs.update({"reflect_type": reflect_type})
379419
raise NotImplementedError("Reflect padding not implemented")
380420
else:
381421
raise ValueError(f"Invalid mode: {mode}")
382422

423+
op = Pad(inputs=inputs, outputs=[outputs])(*inputs) # type: ignore
424+
425+
setattr(op, "pad_mode", mode)
426+
for pad_arg, value in attrs.items():
427+
setattr(op, pad_arg, value)
428+
return op
429+
383430

384431
__all__ = ["pad"]

tests/tensor/test_pad.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def test_constant_pad(
2626
x = np.random.normal(size=size).astype(floatX)
2727
expected = np.pad(x, pad_width, mode="constant", constant_values=constant)
2828
z = pad(x, pad_width, mode="constant", constant_values=constant)
29+
assert z.pad_mode == "constant"
30+
2931
f = pytensor.function([], z, mode="FAST_COMPILE")
3032

3133
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -41,6 +43,8 @@ def test_edge_pad(size: tuple, pad_width: int | tuple[int, ...]):
4143
x = np.random.normal(size=size).astype(floatX)
4244
expected = np.pad(x, pad_width, mode="edge")
4345
z = pad(x, pad_width, mode="edge")
46+
assert z.pad_mode == "edge"
47+
4448
f = pytensor.function([], z, mode="FAST_COMPILE")
4549

4650
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -61,6 +65,8 @@ def test_linear_ramp_pad(
6165
x = np.random.normal(size=size).astype(floatX)
6266
expected = np.pad(x, pad_width, mode="linear_ramp", end_values=end_values)
6367
z = pad(x, pad_width, mode="linear_ramp", end_values=end_values)
68+
assert z.pad_mode == "linear_ramp"
69+
6470
f = pytensor.function([], z, mode="FAST_COMPILE")
6571

6672
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -83,6 +89,9 @@ def test_stat_pad(
8389
x = np.random.normal(size=size).astype(floatX)
8490
expected = np.pad(x, pad_width, mode=stat, stat_length=stat_length)
8591
z = pad(x, pad_width, mode=stat, stat_length=stat_length)
92+
assert z.pad_mode == stat
93+
assert z.stat_length_input == (stat_length is not None)
94+
8695
f = pytensor.function([], z, mode="FAST_COMPILE")
8796

8897
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -98,6 +107,7 @@ def test_wrap_pad(size: tuple, pad_width: int | tuple[int, ...]):
98107
x = np.random.normal(size=size).astype(floatX)
99108
expected = np.pad(x, pad_width, mode="wrap")
100109
z = pad(x, pad_width, mode="wrap")
110+
assert z.pad_mode == "wrap"
101111
f = pytensor.function([], z, mode="FAST_COMPILE")
102112

103113
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)
@@ -118,6 +128,7 @@ def test_symmetric_pad(size, pad_width, reflect_type):
118128
x = np.random.normal(size=size).astype(floatX)
119129
expected = np.pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
120130
z = pad(x, pad_width, mode="symmetric", reflect_type=reflect_type)
131+
assert z.pad_mode == "symmetric"
121132
f = pytensor.function([], z, mode="FAST_COMPILE")
122133

123134
np.testing.assert_allclose(expected, f(), atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)