Skip to content

Commit 4e85676

Browse files
Pytensor-native interpolation functions (#1141)
* add interpolate.py * Add jax dispatch for `searchsorted` * Import user-facing functions in `tensor.__init__`
1 parent 83c6b44 commit 4e85676

File tree

5 files changed

+327
-0
lines changed

5 files changed

+327
-0
lines changed

pytensor/link/jax/dispatch/extra_ops.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
FillDiagonalOffset,
1111
RavelMultiIndex,
1212
Repeat,
13+
SearchsortedOp,
1314
Unique,
1415
UnravelIndex,
1516
)
@@ -130,3 +131,13 @@ def jax_funcify_FillDiagonalOffset(op, **kwargs):
130131
# return filldiagonaloffset
131132

132133
raise NotImplementedError("flatiter not implemented in JAX")
134+
135+
136+
@jax_funcify.register(SearchsortedOp)
137+
def jax_funcify_SearchsortedOp(op, **kwargs):
138+
side = op.side
139+
140+
def searchsorted(a, v, side=side, sorter=None):
141+
return jnp.searchsorted(a=a, v=v, side=side, sorter=sorter)
142+
143+
return searchsorted

pytensor/tensor/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ def _get_vector_length_Constant(op: Op | Variable, var: Constant) -> int:
128128
from pytensor.tensor.basic import *
129129
from pytensor.tensor.blas import batched_dot, batched_tensordot
130130
from pytensor.tensor.extra_ops import *
131+
from pytensor.tensor.interpolate import interp, interpolate1d
131132
from pytensor.tensor.io import *
132133
from pytensor.tensor.math import *
133134
from pytensor.tensor.pad import pad

pytensor/tensor/interpolate.py

Lines changed: 200 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,200 @@
1+
from collections.abc import Callable
2+
from difflib import get_close_matches
3+
from typing import Literal, get_args
4+
5+
from pytensor import Variable
6+
from pytensor.tensor.basic import as_tensor_variable, switch
7+
from pytensor.tensor.extra_ops import searchsorted
8+
from pytensor.tensor.functional import vectorize
9+
from pytensor.tensor.math import clip, eq, le
10+
from pytensor.tensor.sort import argsort
11+
12+
13+
InterpolationMethod = Literal["linear", "nearest", "first", "last", "mean"]
14+
valid_methods = get_args(InterpolationMethod)
15+
16+
17+
def pad_or_return(x, idx, output, left_pad, right_pad, extrapolate):
18+
if extrapolate:
19+
return output
20+
21+
n = x.shape[0]
22+
23+
return switch(eq(idx, 0), left_pad, switch(eq(idx, n), right_pad, output))
24+
25+
26+
def _linear_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
27+
clip_idx = clip(idx, 1, x.shape[0] - 1)
28+
29+
slope = (x_hat - x[clip_idx - 1]) / (x[clip_idx] - x[clip_idx - 1])
30+
y_hat = y[clip_idx - 1] + slope * (y[clip_idx] - y[clip_idx - 1])
31+
32+
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
33+
34+
35+
def _nearest_neighbor_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
36+
clip_idx = clip(idx, 1, x.shape[0] - 1)
37+
38+
left_distance = x_hat - x[clip_idx - 1]
39+
right_distance = x[clip_idx] - x_hat
40+
y_hat = switch(le(left_distance, right_distance), y[clip_idx - 1], y[clip_idx])
41+
42+
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
43+
44+
45+
def _stepwise_first_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
46+
clip_idx = clip(idx - 1, 0, x.shape[0] - 1)
47+
y_hat = y[clip_idx]
48+
49+
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
50+
51+
52+
def _stepwise_last_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
53+
clip_idx = clip(idx, 0, x.shape[0] - 1)
54+
y_hat = y[clip_idx]
55+
56+
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
57+
58+
59+
def _stepwise_mean_interp1d(x, y, x_hat, idx, left_pad, right_pad, extrapolate=True):
60+
clip_idx = clip(idx, 1, x.shape[0] - 1)
61+
y_hat = (y[clip_idx - 1] + y[clip_idx]) / 2
62+
63+
return pad_or_return(x, idx, y_hat, left_pad, right_pad, extrapolate)
64+
65+
66+
def interpolate1d(
67+
x: Variable,
68+
y: Variable,
69+
method: InterpolationMethod = "linear",
70+
left_pad: Variable | None = None,
71+
right_pad: Variable | None = None,
72+
extrapolate: bool = True,
73+
) -> Callable[[Variable], Variable]:
74+
"""
75+
Create a function to interpolate one-dimensional data.
76+
77+
Parameters
78+
----------
79+
x : TensorLike
80+
Input data used to create an interpolation function. Data will be sorted to be monotonically increasing.
81+
y: TensorLike
82+
Output data used to create an interpolation function. Must have the same shape as `x`.
83+
method : InterpolationMethod, optional
84+
Method for interpolation. The following methods are available:
85+
- 'linear': Linear interpolation
86+
- 'nearest': Nearest neighbor interpolation
87+
- 'first': Stepwise interpolation using the closest value to the left of the query point
88+
- 'last': Stepwise interpolation using the closest value to the right of the query point
89+
- 'mean': Stepwise interpolation using the mean of the two closest values to the query point
90+
left_pad: TensorLike, optional
91+
Value to return inputs `x_hat < x[0]`. Default is `y[0]`. Ignored if ``extrapolate == True``; in this
92+
case, values `x_hat < x[0]` will be extrapolated from the endpoints of `x` and `y`.
93+
right_pad: TensorLike, optional
94+
Value to return for inputs `x_hat > x[-1]`. Default is `y[-1]`. Ignored if ``extrapolate == True``; in this
95+
case, values `x_hat > x[-1]` will be extrapolated from the endpoints of `x` and `y`.
96+
extrapolate: bool
97+
Whether to extend the request interpolation function beyond the range of the input-output pairs specified in
98+
`x` and `y.` If False, constant values will be returned for such inputs.
99+
100+
Returns
101+
-------
102+
interpolation_func: OpFromGraph
103+
A function that can be used to interpolate new data. The function takes a single input `x_hat` and returns
104+
the interpolated value `y_hat`. The input `x_hat` must be a 1d array.
105+
106+
"""
107+
x = as_tensor_variable(x)
108+
y = as_tensor_variable(y)
109+
110+
sort_idx = argsort(x)
111+
x = x[sort_idx]
112+
y = y[sort_idx]
113+
114+
if left_pad is None:
115+
left_pad = y[0] # type: ignore
116+
else:
117+
left_pad = as_tensor_variable(left_pad)
118+
if right_pad is None:
119+
right_pad = y[-1] # type: ignore
120+
else:
121+
right_pad = as_tensor_variable(right_pad)
122+
123+
def _scalar_interpolate1d(x_hat):
124+
idx = searchsorted(x, x_hat)
125+
126+
if x.ndim != 1 or y.ndim != 1:
127+
raise ValueError("Inputs must be 1d")
128+
129+
if method == "linear":
130+
y_hat = _linear_interp1d(
131+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
132+
)
133+
elif method == "nearest":
134+
y_hat = _nearest_neighbor_interp1d(
135+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
136+
)
137+
elif method == "first":
138+
y_hat = _stepwise_first_interp1d(
139+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
140+
)
141+
elif method == "mean":
142+
y_hat = _stepwise_mean_interp1d(
143+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
144+
)
145+
elif method == "last":
146+
y_hat = _stepwise_last_interp1d(
147+
x, y, x_hat, idx, left_pad, right_pad, extrapolate=extrapolate
148+
)
149+
else:
150+
raise NotImplementedError(
151+
f"Unknown interpolation method: {method}. "
152+
f"Did you mean {get_close_matches(method, valid_methods)}?"
153+
)
154+
155+
return y_hat
156+
157+
return vectorize(_scalar_interpolate1d, signature="()->()")
158+
159+
160+
def interp(x, xp, fp, left=None, right=None, period=None):
161+
"""
162+
One-dimensional linear interpolation. Similar to ``pytensor.interpolate.interpolate1d``, but with a signature that
163+
matches ``np.interp``
164+
165+
Parameters
166+
----------
167+
x : TensorLike
168+
The x-coordinates at which to evaluate the interpolated values.
169+
170+
xp : TensorLike
171+
The x-coordinates of the data points, must be increasing if argument `period` is not specified. Otherwise,
172+
`xp` is internally sorted after normalizing the periodic boundaries with ``xp = xp % period``.
173+
174+
fp : TensorLike
175+
The y-coordinates of the data points, same length as `xp`.
176+
177+
left : float, optional
178+
Value to return for `x < xp[0]`. Default is `fp[0]`.
179+
180+
right : float, optional
181+
Value to return for `x > xp[-1]`. Default is `fp[-1]`.
182+
183+
period : None
184+
Not supported. Included to ensure the signature of this function matches ``numpy.interp``.
185+
186+
Returns
187+
-------
188+
y : Variable
189+
The interpolated values, same shape as `x`.
190+
"""
191+
192+
xp = as_tensor_variable(xp)
193+
fp = as_tensor_variable(fp)
194+
x = as_tensor_variable(x)
195+
196+
f = interpolate1d(
197+
xp, fp, method="linear", left_pad=left, right_pad=right, extrapolate=False
198+
)
199+
200+
return f(x)

tests/link/jax/test_extra_ops.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from pytensor.graph.fg import FunctionGraph
77
from pytensor.graph.op import get_test_value
88
from pytensor.tensor import extra_ops as pt_extra_ops
9+
from pytensor.tensor.sort import argsort
910
from pytensor.tensor.type import matrix, tensor
1011
from tests.link.jax.test_basic import compare_jax_and_py
1112

@@ -55,6 +56,13 @@ def test_extra_ops():
5556
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
5657
)
5758

59+
v = ptb.as_tensor_variable(6.0)
60+
sorted_idx = argsort(a.ravel())
61+
62+
out = pt_extra_ops.searchsorted(a.ravel()[sorted_idx], v)
63+
fgraph = FunctionGraph([a], [out])
64+
compare_jax_and_py(fgraph, [a_test])
65+
5866

5967
@pytest.mark.xfail(reason="Jitted JAX does not support dynamic shapes")
6068
def test_bartlett_dynamic_shape():

tests/tensor/test_interpolate.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
import numpy as np
2+
import pytest
3+
from numpy.testing import assert_allclose
4+
5+
import pytensor
6+
import pytensor.tensor as pt
7+
from pytensor.tensor.interpolate import (
8+
InterpolationMethod,
9+
interp,
10+
interpolate1d,
11+
valid_methods,
12+
)
13+
14+
15+
floatX = pytensor.config.floatX
16+
17+
18+
def test_interp():
19+
xp = [1.0, 2.0, 3.0]
20+
fp = [3.0, 2.0, 0.0]
21+
22+
x = [0, 1, 1.5, 2.72, 3.14]
23+
24+
out = interp(x, xp, fp).eval()
25+
np_out = np.interp(x, xp, fp)
26+
27+
assert_allclose(out, np_out)
28+
29+
30+
def test_interp_padded():
31+
xp = [1.0, 2.0, 3.0]
32+
fp = [3.0, 2.0, 0.0]
33+
34+
assert interp(3.14, xp, fp, right=-99.0).eval() == -99.0
35+
assert_allclose(
36+
interp([-1.0, -2.0, -3.0], xp, fp, left=1000.0).eval(), [1000.0, 1000.0, 1000.0]
37+
)
38+
assert_allclose(
39+
interp([-1.0, 10.0], xp, fp, left=-10, right=10).eval(), [-10, 10.0]
40+
)
41+
42+
43+
@pytest.mark.parametrize("method", valid_methods, ids=str)
44+
@pytest.mark.parametrize(
45+
"left_pad, right_pad", [(None, None), (None, 100), (-100, None), (-100, 100)]
46+
)
47+
def test_interpolate_scalar_no_extrapolate(
48+
method: InterpolationMethod, left_pad, right_pad
49+
):
50+
x = np.linspace(-2, 6, 10)
51+
y = np.sin(x)
52+
53+
f_op = interpolate1d(
54+
x, y, method, extrapolate=False, left_pad=left_pad, right_pad=right_pad
55+
)
56+
x_hat_pt = pt.dscalar("x_hat")
57+
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")
58+
59+
# Data points should be returned exactly, except when method == mean
60+
if method not in ["mean", "first"]:
61+
assert f(x[3]) == y[3]
62+
elif method == "first":
63+
assert f(x[3]) == y[2]
64+
else:
65+
# method == 'mean
66+
assert f(x[3]) == (y[2] + y[3]) / 2
67+
68+
# When extrapolate=False, points beyond the data envelope should be constant
69+
left_pad = y[0] if left_pad is None else left_pad
70+
right_pad = y[-1] if right_pad is None else right_pad
71+
72+
assert f(-10) == left_pad
73+
assert f(100) == right_pad
74+
75+
76+
@pytest.mark.parametrize("method", valid_methods, ids=str)
77+
def test_interpolate_scalar_extrapolate(method: InterpolationMethod):
78+
x = np.linspace(-2, 6, 10)
79+
y = np.sin(x)
80+
81+
f_op = interpolate1d(x, y, method)
82+
x_hat_pt = pt.dscalar("x_hat")
83+
f = pytensor.function([x_hat_pt], f_op(x_hat_pt), mode="FAST_RUN")
84+
85+
left_test_point = -5
86+
right_test_point = 100
87+
if method == "linear":
88+
# Linear will compute a slope from the endpoints and continue it
89+
left_slope = (left_test_point - x[0]) / (x[1] - x[0])
90+
right_slope = (right_test_point - x[-2]) / (x[-1] - x[-2])
91+
assert f(left_test_point) == y[0] + left_slope * (y[1] - y[0])
92+
assert f(right_test_point) == y[-2] + right_slope * (y[-1] - y[-2])
93+
94+
elif method == "mean":
95+
left_expected = (y[0] + y[1]) / 2
96+
right_expected = (y[-1] + y[-2]) / 2
97+
assert f(left_test_point) == left_expected
98+
assert f(right_test_point) == right_expected
99+
100+
else:
101+
assert f(left_test_point) == y[0]
102+
assert f(right_test_point) == y[-1]
103+
104+
# For interior points, "first" and "last" should disagree. First should take the left side of the interval,
105+
# and last should take the right.
106+
interior_point = x[3] + 0.1
107+
assert f(interior_point) == (y[4] if method == "last" else y[3])

0 commit comments

Comments
 (0)