Skip to content

Commit 0f8f453

Browse files
Remove support for non-scalar period
1 parent b97a92e commit 0f8f453

File tree

4 files changed

+8
-30
lines changed

4 files changed

+8
-30
lines changed

dpnp/dpnp_iface_mathematical.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2813,7 +2813,7 @@ def interp(x, xp, fp, left=None, right=None, period=None):
28132813
28142814
Default: ``fp[-1]``.
28152815
2816-
period : {None, scalar, dpnp.ndarray, usm_ndarray}, optional
2816+
period : {None, scalar}, optional
28172817
A period for the x-coordinates. This parameter allows the proper
28182818
interpolation of angular x-coordinates. Parameters `left` and `right`
28192819
are ignored if `period` is specified.
@@ -2902,9 +2902,12 @@ def interp(x, xp, fp, left=None, right=None, period=None):
29022902
fp = dpnp.asarray(fp, dtype=out_dtype, order="C")
29032903

29042904
if period is not None:
2905-
period = _validate_interp_param(period, "period", exec_q, usm_type)
2905+
if not dpnp.isscalar(period):
2906+
raise TypeError(f"period must be a scalar, but got {type(period)}")
29062907
if period == 0:
29072908
raise ValueError("period must be a non-zero value")
2909+
period = _validate_interp_param(period, "period", exec_q, usm_type)
2910+
29082911
period = dpnp.abs(period)
29092912

29102913
# left/right are ignored when period is specified

dpnp/tests/test_mathematical.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,16 +1273,6 @@ def test_errors(self):
12731273
# period is not scalar or 0-dim
12741274
assert_raises(TypeError, dpnp.interp, x, xp, fp, period=[180])
12751275

1276-
# period has a different SYCL queue
1277-
q1 = dpctl.SyclQueue()
1278-
q2 = dpctl.SyclQueue()
1279-
1280-
x = dpnp.array([1.0], sycl_queue=q1)
1281-
xp = dpnp.array([0.0, 2.0], sycl_queue=q1)
1282-
fp = dpnp.array([0.0, 1.0], sycl_queue=q1)
1283-
period = dpnp.array([180], sycl_queue=q2)
1284-
assert_raises(ValueError, dpnp.interp, x, xp, fp, period=period)
1285-
12861276
# left is not scalar or 0-dim
12871277
left = dpnp.array([1.0])
12881278
assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left)
@@ -1292,6 +1282,8 @@ def test_errors(self):
12921282
assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left)
12931283

12941284
# left has a different SYCL queue
1285+
q1 = dpctl.SyclQueue()
1286+
q2 = dpctl.SyclQueue()
12951287
left = dpnp.array(1.0, sycl_queue=q2)
12961288
if q1 != q2:
12971289
assert_raises(ValueError, dpnp.interp, x, xp, fp, left=left)

dpnp/tests/test_sycl_queue.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1464,9 +1464,7 @@ def test_interp(device, left, right, period):
14641464

14651465
l = None if left is None else dpnp.array(left, sycl_queue=x.sycl_queue)
14661466
r = None if right is None else dpnp.array(right, sycl_queue=x.sycl_queue)
1467-
p = None if period is None else dpnp.array(period, sycl_queue=x.sycl_queue)
1468-
1469-
result = dpnp.interp(x, xp, fp, left=l, right=r, period=p)
1467+
result = dpnp.interp(x, xp, fp, left=l, right=r, period=period)
14701468

14711469
assert_sycl_queue_equal(result.sycl_queue, x.sycl_queue)
14721470

dpnp/tests/test_usm_type.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1311,21 +1311,6 @@ def test_left_right(self, usm_type_x, usm_type_left, usm_type_right):
13111311
]
13121312
)
13131313

1314-
@pytest.mark.parametrize("usm_type_x", list_of_usm_types)
1315-
@pytest.mark.parametrize("usm_type_period", list_of_usm_types)
1316-
def test_period(self, usm_type_x, usm_type_period):
1317-
x = dpnp.linspace(0.1, 9.9, 20, usm_type=usm_type_x)
1318-
xp = dpnp.linspace(0.0, 10.0, 5, usm_type=usm_type_x)
1319-
fp = dpnp.array(xp * 2 + 1, usm_type=usm_type_x)
1320-
period = dpnp.array(10.0, usm_type=usm_type_period)
1321-
1322-
result = dpnp.interp(x, xp, fp, period=period)
1323-
1324-
assert period.usm_type == usm_type_period
1325-
assert result.usm_type == du.get_coerced_usm_type(
1326-
[x.usm_type, xp.usm_type, fp.usm_type, period.usm_type]
1327-
)
1328-
13291314

13301315
@pytest.mark.parametrize("usm_type", list_of_usm_types)
13311316
class TestLinAlgebra:

0 commit comments

Comments
 (0)