Skip to content

Commit 5cda3d2

Browse files
Add dpnp tests for interp
1 parent 771d3eb commit 5cda3d2

File tree

1 file changed

+138
-0
lines changed

1 file changed

+138
-0
lines changed

dpnp/tests/test_mathematical.py

Lines changed: 138 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,144 @@ def test_complex(self, xp):
11431143
assert_raises((ValueError, TypeError), xp.i0, a)
11441144

11451145

1146+
class TestInterp:
1147+
@pytest.mark.parametrize(
1148+
"dtype_x", get_all_dtypes(no_bool=True, no_complex=True)
1149+
)
1150+
@pytest.mark.parametrize("dtype_y", get_all_dtypes(no_bool=True))
1151+
def test_all_dtypes(self, dtype_x, dtype_y):
1152+
x = numpy.linspace(0.1, 9.9, 20).astype(dtype_x)
1153+
xp = numpy.linspace(0.0, 10.0, 5).astype(dtype_x)
1154+
fp = (xp * 1.5 + 1).astype(dtype_y)
1155+
1156+
ix = dpnp.array(x)
1157+
ixp = dpnp.array(xp)
1158+
ifp = dpnp.array(fp)
1159+
1160+
expected = numpy.interp(x, xp, fp)
1161+
result = dpnp.interp(ix, ixp, ifp)
1162+
assert_dtype_allclose(result, expected)
1163+
1164+
@pytest.mark.parametrize(
1165+
"dtype_x", get_all_dtypes(no_bool=True, no_complex=True)
1166+
)
1167+
@pytest.mark.parametrize("dtype_y", get_complex_dtypes())
1168+
def test_complex_fp(self, dtype_x, dtype_y):
1169+
x = numpy.array([0.25, 0.75], dtype=dtype_x)
1170+
xp = numpy.array([0.0, 1.0], dtype=dtype_x)
1171+
fp = numpy.array([1 + 1j, 3 + 3j], dtype=dtype_y)
1172+
1173+
ix = dpnp.array(x)
1174+
ixp = dpnp.array(xp)
1175+
ifp = dpnp.array(fp)
1176+
1177+
expected = numpy.interp(x, xp, fp)
1178+
result = dpnp.interp(ix, ixp, ifp)
1179+
assert_dtype_allclose(result, expected)
1180+
1181+
@pytest.mark.parametrize(
1182+
"dtype", get_all_dtypes(no_bool=True, no_complex=True)
1183+
)
1184+
def test_left_right_args(self, dtype):
1185+
x = numpy.array([-1, 0, 1, 2, 3, 4, 5, 6], dtype=dtype)
1186+
xp = numpy.array([0, 3, 6], dtype=dtype)
1187+
fp = numpy.array([0, 9, 18], dtype=dtype)
1188+
1189+
ix = dpnp.array(x)
1190+
ixp = dpnp.array(xp)
1191+
ifp = dpnp.array(fp)
1192+
1193+
expected = numpy.interp(x, xp, fp, left=-40, right=40)
1194+
result = dpnp.interp(ix, ixp, ifp, left=-40, right=40)
1195+
assert_dtype_allclose(result, expected)
1196+
1197+
@pytest.mark.parametrize("val", [numpy.nan, numpy.inf, -numpy.inf])
1198+
def test_naninf(self, val):
1199+
x = numpy.array([0, 1, 2, val])
1200+
xp = numpy.array([0, 1, 2])
1201+
fp = numpy.array([10, 20, 30])
1202+
1203+
ix = dpnp.array(x)
1204+
ixp = dpnp.array(xp)
1205+
ifp = dpnp.array(fp)
1206+
1207+
expected = numpy.interp(x, xp, fp)
1208+
result = dpnp.interp(ix, ixp, ifp)
1209+
assert_dtype_allclose(result, expected)
1210+
1211+
def test_empty_x(self):
1212+
x = numpy.array([])
1213+
xp = numpy.array([0, 1])
1214+
fp = numpy.array([10, 20])
1215+
1216+
ix = dpnp.array(x)
1217+
ixp = dpnp.array(xp)
1218+
ifp = dpnp.array(fp)
1219+
1220+
expected = numpy.interp(x, xp, fp)
1221+
result = dpnp.interp(ix, ixp, ifp)
1222+
assert_dtype_allclose(result, expected)
1223+
1224+
@pytest.mark.parametrize("dtype", get_float_dtypes())
1225+
def test_period(self, dtype):
1226+
x = numpy.array([-180, 0, 180], dtype=dtype)
1227+
xp = numpy.array([-90, 0, 90], dtype=dtype)
1228+
fp = numpy.array([0, 1, 0], dtype=dtype)
1229+
1230+
ix = dpnp.array(x)
1231+
ixp = dpnp.array(xp)
1232+
ifp = dpnp.array(fp)
1233+
1234+
expected = numpy.interp(x, xp, fp, period=180)
1235+
result = dpnp.interp(ix, ixp, ifp, period=180)
1236+
assert_dtype_allclose(result, expected)
1237+
1238+
def test_errors(self):
1239+
x = dpnp.array([0.5])
1240+
1241+
# xp and fp have different lengths
1242+
xp = dpnp.array([0])
1243+
fp = dpnp.array([1, 2])
1244+
assert_raises(ValueError, dpnp.interp, x, xp, fp)
1245+
1246+
# xp is not 1D
1247+
xp = dpnp.array([[0, 1]])
1248+
fp = dpnp.array([1, 2])
1249+
assert_raises(ValueError, dpnp.interp, x, xp, fp)
1250+
1251+
# fp is not 1D
1252+
xp = dpnp.array([0, 1])
1253+
fp = dpnp.array([[1, 2]])
1254+
assert_raises(ValueError, dpnp.interp, x, xp, fp)
1255+
1256+
# xp and fp are empty
1257+
xp = dpnp.array([])
1258+
fp = dpnp.array([])
1259+
assert_raises(ValueError, dpnp.interp, x, xp, fp)
1260+
1261+
# x complex
1262+
x_complex = dpnp.array([1 + 2j])
1263+
xp = dpnp.array([0.0, 2.0])
1264+
fp = dpnp.array([0.0, 1.0])
1265+
assert_raises(TypeError, dpnp.interp, x_complex, xp, fp)
1266+
1267+
# period is zero
1268+
x = dpnp.array([1.0])
1269+
xp = dpnp.array([0.0, 2.0])
1270+
fp = dpnp.array([0.0, 1.0])
1271+
assert_raises(ValueError, dpnp.interp, x, xp, fp, period=0)
1272+
1273+
# period has a different SYCL queue
1274+
q1 = dpctl.SyclQueue()
1275+
q2 = dpctl.SyclQueue()
1276+
1277+
x = dpnp.array([1.0], sycl_queue=q1)
1278+
xp = dpnp.array([0.0, 2.0], sycl_queue=q1)
1279+
fp = dpnp.array([0.0, 1.0], sycl_queue=q1)
1280+
period = dpnp.array([180], sycl_queue=q2)
1281+
assert_raises(ValueError, dpnp.interp, x, xp, fp, period=period)
1282+
1283+
11461284
@pytest.mark.parametrize(
11471285
"rhs", [[[1, 2, 3], [4, 5, 6]], [2.0, 1.5, 1.0], 3, 0.3]
11481286
)

0 commit comments

Comments
 (0)