Skip to content

Commit 5d0acd4

Browse files
move pad functions to pad.py
1 parent f3f027d commit 5d0acd4

File tree

3 files changed

+134
-122
lines changed

3 files changed

+134
-122
lines changed

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
130130
from pytensor.tensor.extra_ops import *
131131
from pytensor.tensor.io import *
132132
from pytensor.tensor.math import *
133+
from pytensor.tensor.pad import pad
133134
from pytensor.tensor.shape import (
134135
reshape,
135136
shape,

pytensor/tensor/basic.py

Lines changed: 0 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import numpy as np
1717
from numpy.core.multiarray import normalize_axis_index
1818
from numpy.core.numeric import normalize_axis_tuple
19-
from numpy.lib.arraypad import _get_edges, _slice_at_axis
2019

2120
import pytensor
2221
import pytensor.scalar.sharedvar
@@ -51,7 +50,6 @@
5150
scalar_elemwise,
5251
)
5352
from pytensor.tensor.exceptions import NotScalarConstantError
54-
from pytensor.tensor.extra_ops import broadcast_to, linspace
5553
from pytensor.tensor.shape import (
5654
Shape,
5755
Shape_i,
@@ -63,7 +61,6 @@
6361
shape_tuple,
6462
specify_broadcastable,
6563
)
66-
from pytensor.tensor.subtensor import set_subtensor
6764
from pytensor.tensor.type import (
6865
TensorType,
6966
discrete_dtypes,
@@ -4345,125 +4342,6 @@ def ix_(*args):
43454342
return tuple(out)
43464343

43474344

4348-
def _symbolic_pad(
4349-
x: TensorVariable, pad_width: TensorVariable
4350-
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
4351-
pad_width = broadcast_to(pad_width, (x.ndim, 2))
4352-
new_shape = as_tensor(
4353-
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
4354-
)
4355-
original_area_slice = tuple(
4356-
slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape)
4357-
)
4358-
padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x)
4359-
return padded, original_area_slice, pad_width
4360-
4361-
4362-
def _get_padding_slices(
4363-
dim_shape: TensorVariable, width_pair: tuple[TensorVariable], axis: int
4364-
):
4365-
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
4366-
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
4367-
4368-
return left_slice, right_slice
4369-
4370-
4371-
def _constant_pad(
4372-
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
4373-
):
4374-
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4375-
values = broadcast_to(constant_values, (padded.ndim, 2))
4376-
4377-
for axis in range(padded.ndim):
4378-
width_pair = pad_width[axis]
4379-
value_pair = values[axis]
4380-
dim_shape = padded.shape[axis]
4381-
4382-
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4383-
padded = set_subtensor(padded[left_slice], value_pair[0])
4384-
padded = set_subtensor(padded[right_slice], value_pair[1])
4385-
4386-
return padded
4387-
4388-
4389-
def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
4390-
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4391-
for axis in range(padded.ndim):
4392-
width_pair = pad_width[axis]
4393-
dim_shape = padded.shape[axis]
4394-
4395-
left_edge, right_edge = _get_edges(padded, axis, width_pair)
4396-
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4397-
4398-
padded = set_subtensor(padded[left_slice], left_edge)
4399-
padded = set_subtensor(padded[right_slice], right_edge)
4400-
4401-
return padded
4402-
4403-
4404-
def _linear_ramp_pad(
4405-
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable = 0
4406-
):
4407-
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
4408-
end_values = broadcast_to(end_values, (padded.ndim, 2))
4409-
for axis in range(padded.ndim):
4410-
width_pair = pad_width[axis]
4411-
end_value_pair = end_values[axis]
4412-
edge_pair = _get_edges(padded, axis, width_pair)
4413-
dim_shape = padded.shape[axis]
4414-
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
4415-
4416-
# pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
4417-
left_ramp = linspace(
4418-
start=end_value_pair[0],
4419-
end=specify_broadcastable(edge_pair[0], axis).squeeze(axis),
4420-
steps=width_pair[0] + 1,
4421-
)[:-1]
4422-
right_ramp = linspace(
4423-
start=end_value_pair[1],
4424-
end=specify_broadcastable(edge_pair[1], axis).squeeze(axis),
4425-
steps=width_pair[1] + 1,
4426-
)[:-1]
4427-
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)]
4428-
4429-
# FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
4430-
# the different dimensions. But this makes the non-active dimensions backwards in the padding.
4431-
padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis))
4432-
padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis))
4433-
4434-
return padded
4435-
4436-
4437-
def pad(x, pad_width, mode="constant", **kwargs):
4438-
allowed_kwargs = {
4439-
"empty": [],
4440-
"edge": [],
4441-
"wrap": [],
4442-
"constant": ["constant_values"],
4443-
"linear_ramp": ["end_values"],
4444-
"maximum": ["stat_length"],
4445-
"mean": ["stat_length"],
4446-
"median": ["stat_length"],
4447-
"minimum": ["stat_length"],
4448-
"reflect": ["reflect_type"],
4449-
"symmetric": ["reflect_type"],
4450-
}
4451-
4452-
if any(value not in allowed_kwargs[mode] for value in kwargs.values()):
4453-
raise ValueError(
4454-
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
4455-
)
4456-
4457-
if mode == "constant":
4458-
constant_values = kwargs.pop("constant_values", 0)
4459-
return _constant_pad(x, pad_width, constant_values)
4460-
elif mode == "edge":
4461-
return _edge_pad(x, pad_width)
4462-
elif mode == "linear_ramp":
4463-
end_values = kwargs.pop("end_values", 0)
4464-
return _linear_ramp_pad(x, pad_width, end_values)
4465-
4466-
44674345
__all__ = [
44684346
"take_along_axis",
44694347
"expand_dims",

pytensor/tensor/pad.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
from numpy.lib.arraypad import _get_edges, _slice_at_axis # noqa
2+
3+
from pytensor.tensor.basic import (
4+
TensorVariable,
5+
as_tensor,
6+
swapaxes,
7+
zeros,
8+
)
9+
from pytensor.tensor.extra_ops import linspace, broadcast_to
10+
from pytensor.tensor.shape import specify_broadcastable
11+
from pytensor.tensor.subtensor import set_subtensor
12+
13+
14+
def _symbolic_pad(
15+
x: TensorVariable, pad_width: TensorVariable
16+
) -> tuple[TensorVariable, tuple[slice, ...], TensorVariable]:
17+
pad_width = broadcast_to(pad_width, (x.ndim, 2))
18+
new_shape = as_tensor(
19+
[pad_width[i][0] + size + pad_width[i][1] for i, size in enumerate(x.shape)]
20+
)
21+
original_area_slice = tuple(
22+
slice(pad_width[i][0], pad_width[i][0] + size) for i, size in enumerate(x.shape)
23+
)
24+
padded: TensorVariable = set_subtensor(zeros(new_shape)[original_area_slice], x)
25+
return padded, original_area_slice, pad_width
26+
27+
28+
def _get_padding_slices(
29+
dim_shape: TensorVariable, width_pair: tuple[TensorVariable, TensorVariable], axis: int
30+
):
31+
left_slice = _slice_at_axis(slice(None, width_pair[0]), axis)
32+
right_slice = _slice_at_axis(slice(dim_shape - width_pair[1], None), axis)
33+
34+
return left_slice, right_slice
35+
36+
37+
def _constant_pad(
38+
x: TensorVariable, pad_width: TensorVariable, constant_values: TensorVariable
39+
):
40+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
41+
values = broadcast_to(constant_values, (padded.ndim, 2))
42+
43+
for axis in range(padded.ndim):
44+
width_pair = pad_width[axis]
45+
value_pair = values[axis]
46+
dim_shape = padded.shape[axis]
47+
48+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
49+
padded = set_subtensor(padded[left_slice], value_pair[0])
50+
padded = set_subtensor(padded[right_slice], value_pair[1])
51+
52+
return padded
53+
54+
55+
def _edge_pad(x: TensorVariable, pad_width: TensorVariable):
56+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
57+
for axis in range(padded.ndim):
58+
width_pair = pad_width[axis]
59+
dim_shape = padded.shape[axis]
60+
61+
left_edge, right_edge = _get_edges(padded, axis, width_pair)
62+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
63+
64+
padded = set_subtensor(padded[left_slice], left_edge)
65+
padded = set_subtensor(padded[right_slice], right_edge)
66+
67+
return padded
68+
69+
70+
def _linear_ramp_pad(
71+
x: TensorVariable, pad_width: TensorVariable, end_values: TensorVariable | int = 0
72+
):
73+
padded, area_slice, pad_width = _symbolic_pad(x, pad_width)
74+
end_values = broadcast_to(end_values, (padded.ndim, 2))
75+
for axis in range(padded.ndim):
76+
width_pair = pad_width[axis]
77+
end_value_pair = end_values[axis]
78+
edge_pair = _get_edges(padded, axis, width_pair)
79+
dim_shape = padded.shape[axis]
80+
left_slice, right_slice = _get_padding_slices(dim_shape, width_pair, axis)
81+
82+
# pt.linspace doesn't have the endpoint kwarg, so need to take one extra step then slice it away
83+
left_ramp = linspace(
84+
start=end_value_pair[0],
85+
end=specify_broadcastable(edge_pair[0], axis).squeeze(axis),
86+
steps=width_pair[0] + 1,
87+
)[:-1]
88+
right_ramp = linspace(
89+
start=end_value_pair[1],
90+
end=specify_broadcastable(edge_pair[1], axis).squeeze(axis),
91+
steps=width_pair[1] + 1,
92+
)[:-1]
93+
right_ramp = right_ramp[_slice_at_axis(slice(None, None, -1), axis)]
94+
95+
# FIXME: This swapaxes is needed because the shapes of the linspaces don't "rotate" with
96+
# the different dimensions. But this makes the non-active dimensions backwards in the padding.
97+
padded = set_subtensor(padded[left_slice], swapaxes(left_ramp, 0, axis))
98+
padded = set_subtensor(padded[right_slice], swapaxes(right_ramp, 0, axis))
99+
100+
return padded
101+
102+
103+
def pad(x, pad_width, mode="constant", **kwargs):
104+
allowed_kwargs = {
105+
"empty": [],
106+
"edge": [],
107+
"wrap": [],
108+
"constant": ["constant_values"],
109+
"linear_ramp": ["end_values"],
110+
"maximum": ["stat_length"],
111+
"mean": ["stat_length"],
112+
"median": ["stat_length"],
113+
"minimum": ["stat_length"],
114+
"reflect": ["reflect_type"],
115+
"symmetric": ["reflect_type"],
116+
}
117+
118+
if any(value not in allowed_kwargs[mode] for value in kwargs.values()):
119+
raise ValueError(
120+
f"Invalid keyword arguments for mode '{mode}': {kwargs.keys()}"
121+
)
122+
123+
if mode == "constant":
124+
constant_values = kwargs.pop("constant_values", 0)
125+
return _constant_pad(x, pad_width, constant_values)
126+
elif mode == "edge":
127+
return _edge_pad(x, pad_width)
128+
elif mode == "linear_ramp":
129+
end_values = kwargs.pop("end_values", 0)
130+
return _linear_ramp_pad(x, pad_width, end_values)
131+
132+
133+
__all__ = ["pad"]

0 commit comments

Comments
 (0)