@@ -1143,6 +1143,144 @@ def test_complex(self, xp):
1143
1143
assert_raises ((ValueError , TypeError ), xp .i0 , a )
1144
1144
1145
1145
1146
+ class TestInterp :
1147
+ @pytest .mark .parametrize (
1148
+ "dtype_x" , get_all_dtypes (no_bool = True , no_complex = True )
1149
+ )
1150
+ @pytest .mark .parametrize ("dtype_y" , get_all_dtypes (no_bool = True ))
1151
+ def test_all_dtypes (self , dtype_x , dtype_y ):
1152
+ x = numpy .linspace (0.1 , 9.9 , 20 ).astype (dtype_x )
1153
+ xp = numpy .linspace (0.0 , 10.0 , 5 ).astype (dtype_x )
1154
+ fp = (xp * 1.5 + 1 ).astype (dtype_y )
1155
+
1156
+ ix = dpnp .array (x )
1157
+ ixp = dpnp .array (xp )
1158
+ ifp = dpnp .array (fp )
1159
+
1160
+ expected = numpy .interp (x , xp , fp )
1161
+ result = dpnp .interp (ix , ixp , ifp )
1162
+ assert_dtype_allclose (result , expected )
1163
+
1164
+ @pytest .mark .parametrize (
1165
+ "dtype_x" , get_all_dtypes (no_bool = True , no_complex = True )
1166
+ )
1167
+ @pytest .mark .parametrize ("dtype_y" , get_complex_dtypes ())
1168
+ def test_complex_fp (self , dtype_x , dtype_y ):
1169
+ x = numpy .array ([0.25 , 0.75 ], dtype = dtype_x )
1170
+ xp = numpy .array ([0.0 , 1.0 ], dtype = dtype_x )
1171
+ fp = numpy .array ([1 + 1j , 3 + 3j ], dtype = dtype_y )
1172
+
1173
+ ix = dpnp .array (x )
1174
+ ixp = dpnp .array (xp )
1175
+ ifp = dpnp .array (fp )
1176
+
1177
+ expected = numpy .interp (x , xp , fp )
1178
+ result = dpnp .interp (ix , ixp , ifp )
1179
+ assert_dtype_allclose (result , expected )
1180
+
1181
+ @pytest .mark .parametrize (
1182
+ "dtype" , get_all_dtypes (no_bool = True , no_complex = True )
1183
+ )
1184
+ def test_left_right_args (self , dtype ):
1185
+ x = numpy .array ([- 1 , 0 , 1 , 2 , 3 , 4 , 5 , 6 ], dtype = dtype )
1186
+ xp = numpy .array ([0 , 3 , 6 ], dtype = dtype )
1187
+ fp = numpy .array ([0 , 9 , 18 ], dtype = dtype )
1188
+
1189
+ ix = dpnp .array (x )
1190
+ ixp = dpnp .array (xp )
1191
+ ifp = dpnp .array (fp )
1192
+
1193
+ expected = numpy .interp (x , xp , fp , left = - 40 , right = 40 )
1194
+ result = dpnp .interp (ix , ixp , ifp , left = - 40 , right = 40 )
1195
+ assert_dtype_allclose (result , expected )
1196
+
1197
+ @pytest .mark .parametrize ("val" , [numpy .nan , numpy .inf , - numpy .inf ])
1198
+ def test_naninf (self , val ):
1199
+ x = numpy .array ([0 , 1 , 2 , val ])
1200
+ xp = numpy .array ([0 , 1 , 2 ])
1201
+ fp = numpy .array ([10 , 20 , 30 ])
1202
+
1203
+ ix = dpnp .array (x )
1204
+ ixp = dpnp .array (xp )
1205
+ ifp = dpnp .array (fp )
1206
+
1207
+ expected = numpy .interp (x , xp , fp )
1208
+ result = dpnp .interp (ix , ixp , ifp )
1209
+ assert_dtype_allclose (result , expected )
1210
+
1211
+ def test_empty_x (self ):
1212
+ x = numpy .array ([])
1213
+ xp = numpy .array ([0 , 1 ])
1214
+ fp = numpy .array ([10 , 20 ])
1215
+
1216
+ ix = dpnp .array (x )
1217
+ ixp = dpnp .array (xp )
1218
+ ifp = dpnp .array (fp )
1219
+
1220
+ expected = numpy .interp (x , xp , fp )
1221
+ result = dpnp .interp (ix , ixp , ifp )
1222
+ assert_dtype_allclose (result , expected )
1223
+
1224
+ @pytest .mark .parametrize ("dtype" , get_float_dtypes ())
1225
+ def test_period (self , dtype ):
1226
+ x = numpy .array ([- 180 , 0 , 180 ], dtype = dtype )
1227
+ xp = numpy .array ([- 90 , 0 , 90 ], dtype = dtype )
1228
+ fp = numpy .array ([0 , 1 , 0 ], dtype = dtype )
1229
+
1230
+ ix = dpnp .array (x )
1231
+ ixp = dpnp .array (xp )
1232
+ ifp = dpnp .array (fp )
1233
+
1234
+ expected = numpy .interp (x , xp , fp , period = 180 )
1235
+ result = dpnp .interp (ix , ixp , ifp , period = 180 )
1236
+ assert_dtype_allclose (result , expected )
1237
+
1238
+ def test_errors (self ):
1239
+ x = dpnp .array ([0.5 ])
1240
+
1241
+ # xp and fp have different lengths
1242
+ xp = dpnp .array ([0 ])
1243
+ fp = dpnp .array ([1 , 2 ])
1244
+ assert_raises (ValueError , dpnp .interp , x , xp , fp )
1245
+
1246
+ # xp is not 1D
1247
+ xp = dpnp .array ([[0 , 1 ]])
1248
+ fp = dpnp .array ([1 , 2 ])
1249
+ assert_raises (ValueError , dpnp .interp , x , xp , fp )
1250
+
1251
+ # fp is not 1D
1252
+ xp = dpnp .array ([0 , 1 ])
1253
+ fp = dpnp .array ([[1 , 2 ]])
1254
+ assert_raises (ValueError , dpnp .interp , x , xp , fp )
1255
+
1256
+ # xp and fp are empty
1257
+ xp = dpnp .array ([])
1258
+ fp = dpnp .array ([])
1259
+ assert_raises (ValueError , dpnp .interp , x , xp , fp )
1260
+
1261
+ # x complex
1262
+ x_complex = dpnp .array ([1 + 2j ])
1263
+ xp = dpnp .array ([0.0 , 2.0 ])
1264
+ fp = dpnp .array ([0.0 , 1.0 ])
1265
+ assert_raises (TypeError , dpnp .interp , x_complex , xp , fp )
1266
+
1267
+ # period is zero
1268
+ x = dpnp .array ([1.0 ])
1269
+ xp = dpnp .array ([0.0 , 2.0 ])
1270
+ fp = dpnp .array ([0.0 , 1.0 ])
1271
+ assert_raises (ValueError , dpnp .interp , x , xp , fp , period = 0 )
1272
+
1273
+ # period has a different SYCL queue
1274
+ q1 = dpctl .SyclQueue ()
1275
+ q2 = dpctl .SyclQueue ()
1276
+
1277
+ x = dpnp .array ([1.0 ], sycl_queue = q1 )
1278
+ xp = dpnp .array ([0.0 , 2.0 ], sycl_queue = q1 )
1279
+ fp = dpnp .array ([0.0 , 1.0 ], sycl_queue = q1 )
1280
+ period = dpnp .array ([180 ], sycl_queue = q2 )
1281
+ assert_raises (ValueError , dpnp .interp , x , xp , fp , period = period )
1282
+
1283
+
1146
1284
@pytest .mark .parametrize (
1147
1285
"rhs" , [[[1 , 2 , 3 ], [4 , 5 , 6 ]], [2.0 , 1.5 , 1.0 ], 3 , 0.3 ]
1148
1286
)
0 commit comments