Skip to content

Commit 7f977d6

Browse files
committed
Adds tests for dpctl.tensor.repeat
1 parent 6972c46 commit 7f977d6

File tree

1 file changed

+142
-1
lines changed

1 file changed

+142
-1
lines changed

dpctl/tests/test_usm_ndarray_manipulation.py

Lines changed: 142 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,11 @@
1717

1818
import numpy as np
1919
import pytest
20-
from helper import get_queue_or_skip
2120
from numpy.testing import assert_, assert_array_equal, assert_raises_regex
2221

2322
import dpctl
2423
import dpctl.tensor as dpt
24+
from dpctl.tests.helper import get_queue_or_skip
2525

2626

2727
def test_permute_dims_incorrect_type():
@@ -1103,3 +1103,144 @@ def test_finfo_object():
11031103
assert isinstance(fi.dtype, dpt.dtype)
11041104
assert isinstance(str(fi), str)
11051105
assert isinstance(repr(fi), str)
1106+
1107+
1108+
def test_repeat_scalar_sequence_agreement():
1109+
get_queue_or_skip()
1110+
1111+
x = dpt.arange(5, dtype="i4")
1112+
expected_res = dpt.empty(10, dtype="i4")
1113+
expected_res[1::2], expected_res[::2] = x, x
1114+
1115+
# scalar case
1116+
reps = 2
1117+
res = dpt.repeat(x, reps)
1118+
assert dpt.all(res == expected_res)
1119+
1120+
# tuple
1121+
reps = (2, 2, 2, 2, 2)
1122+
res = dpt.repeat(x, reps)
1123+
assert dpt.all(res == expected_res)
1124+
1125+
1126+
def test_repeat_as_broadcasting():
1127+
get_queue_or_skip()
1128+
1129+
reps = 5
1130+
x = dpt.arange(reps, dtype="i4")
1131+
x1 = x[:, dpt.newaxis]
1132+
expected_res = dpt.broadcast_to(x1, (reps, reps))
1133+
1134+
res = dpt.repeat(x1, reps, axis=1)
1135+
assert dpt.all(res == expected_res)
1136+
1137+
x2 = x[dpt.newaxis, :]
1138+
expected_res = dpt.broadcast_to(x2, (reps, reps))
1139+
1140+
res = dpt.repeat(x2, reps, axis=0)
1141+
assert dpt.all(res == expected_res)
1142+
1143+
1144+
def test_repeat_axes():
1145+
get_queue_or_skip()
1146+
1147+
reps = 2
1148+
x = dpt.reshape(dpt.arange(5 * 10, dtype="i4"), (5, 10))
1149+
expected_res = dpt.empty((x.shape[0] * 2, x.shape[1]), x.dtype)
1150+
expected_res[::2, :], expected_res[1::2] = x, x
1151+
res = dpt.repeat(x, reps, axis=0)
1152+
assert dpt.all(res == expected_res)
1153+
1154+
expected_res = dpt.empty((x.shape[0], x.shape[1] * 2), x.dtype)
1155+
expected_res[:, ::2], expected_res[:, 1::2] = x, x
1156+
res = dpt.repeat(x, reps, axis=1)
1157+
assert dpt.all(res == expected_res)
1158+
1159+
1160+
def test_repeat_size_0_outputs():
1161+
get_queue_or_skip()
1162+
1163+
x = dpt.ones((3, 0, 5), dtype="i4")
1164+
reps = 10
1165+
res = dpt.repeat(x, reps, axis=0)
1166+
assert res.size == 0
1167+
assert res.shape == (30, 0, 5)
1168+
1169+
res = dpt.repeat(x, reps, axis=1)
1170+
assert res.size == 0
1171+
assert res.shape == (3, 0, 5)
1172+
1173+
res = dpt.repeat(x, (2, 2, 2), axis=0)
1174+
assert res.size == 0
1175+
assert res.shape == (6, 0, 5)
1176+
1177+
x = dpt.ones((3, 2, 5))
1178+
res = dpt.repeat(x, 0, axis=1)
1179+
assert res.size == 0
1180+
assert res.shape == (3, 0, 5)
1181+
1182+
x = dpt.ones((3, 2, 5))
1183+
res = dpt.repeat(x, (0, 0), axis=1)
1184+
assert res.size == 0
1185+
assert res.shape == (3, 0, 5)
1186+
1187+
1188+
def test_repeat_strides():
1189+
get_queue_or_skip()
1190+
1191+
reps = 2
1192+
x = dpt.reshape(dpt.arange(10 * 10, dtype="i4"), (10, 10))
1193+
x1 = x[:, ::-2]
1194+
expected_res = dpt.empty((10, 10), dtype="i4")
1195+
expected_res[:, ::2], expected_res[:, 1::2] = x1, x1
1196+
res = dpt.repeat(x1, reps, axis=1)
1197+
assert dpt.all(res == expected_res)
1198+
res = dpt.repeat(x1, (reps,) * x1.shape[1], axis=1)
1199+
assert dpt.all(res == expected_res)
1200+
1201+
x1 = x[::-2, :]
1202+
expected_res = dpt.empty((10, 10), dtype="i4")
1203+
expected_res[::2, :], expected_res[1::2, :] = x1, x1
1204+
res = dpt.repeat(x1, reps, axis=0)
1205+
assert dpt.all(res == expected_res)
1206+
res = dpt.repeat(x1, (reps,) * x1.shape[0], axis=0)
1207+
assert dpt.all(res == expected_res)
1208+
1209+
1210+
def test_repeat_arg_validation():
1211+
get_queue_or_skip()
1212+
1213+
x = dict()
1214+
with pytest.raises(TypeError):
1215+
dpt.repeat(x, 2)
1216+
1217+
# axis must be 0 for scalar
1218+
x = dpt.empty(())
1219+
with pytest.raises(ValueError):
1220+
dpt.repeat(x, 2, axis=1)
1221+
1222+
# x.ndim must be 1 for axis=None
1223+
x = dpt.empty((5, 10))
1224+
with pytest.raises(ValueError):
1225+
dpt.repeat(x, 2, axis=None)
1226+
1227+
# repeats must be positive
1228+
x = dpt.empty(1)
1229+
with pytest.raises(ValueError):
1230+
dpt.repeat(x, -2)
1231+
1232+
# repeats must be integers
1233+
with pytest.raises(TypeError):
1234+
dpt.repeat(x, 2.0)
1235+
1236+
# repeats tuple must be the same length as axis
1237+
with pytest.raises(ValueError):
1238+
dpt.repeat(x, (1, 2))
1239+
1240+
# repeats tuple elements must be positive
1241+
with pytest.raises(ValueError):
1242+
dpt.repeat(x, (-1,))
1243+
1244+
# repeats must be int or tuple
1245+
with pytest.raises(TypeError):
1246+
dpt.repeat(x, dict())

0 commit comments

Comments
 (0)