|
17 | 17 |
|
18 | 18 | import numpy as np
|
19 | 19 | import pytest
|
20 |
| -from helper import get_queue_or_skip |
21 | 20 | from numpy.testing import assert_, assert_array_equal, assert_raises_regex
|
22 | 21 |
|
23 | 22 | import dpctl
|
24 | 23 | import dpctl.tensor as dpt
|
| 24 | +from dpctl.tests.helper import get_queue_or_skip |
25 | 25 |
|
26 | 26 |
|
27 | 27 | def test_permute_dims_incorrect_type():
|
@@ -1103,3 +1103,144 @@ def test_finfo_object():
|
1103 | 1103 | assert isinstance(fi.dtype, dpt.dtype)
|
1104 | 1104 | assert isinstance(str(fi), str)
|
1105 | 1105 | assert isinstance(repr(fi), str)
|
| 1106 | + |
| 1107 | + |
| 1108 | +def test_repeat_scalar_sequence_agreement(): |
| 1109 | + get_queue_or_skip() |
| 1110 | + |
| 1111 | + x = dpt.arange(5, dtype="i4") |
| 1112 | + expected_res = dpt.empty(10, dtype="i4") |
| 1113 | + expected_res[1::2], expected_res[::2] = x, x |
| 1114 | + |
| 1115 | + # scalar case |
| 1116 | + reps = 2 |
| 1117 | + res = dpt.repeat(x, reps) |
| 1118 | + assert dpt.all(res == expected_res) |
| 1119 | + |
| 1120 | + # tuple |
| 1121 | + reps = (2, 2, 2, 2, 2) |
| 1122 | + res = dpt.repeat(x, reps) |
| 1123 | + assert dpt.all(res == expected_res) |
| 1124 | + |
| 1125 | + |
| 1126 | +def test_repeat_as_broadcasting(): |
| 1127 | + get_queue_or_skip() |
| 1128 | + |
| 1129 | + reps = 5 |
| 1130 | + x = dpt.arange(reps, dtype="i4") |
| 1131 | + x1 = x[:, dpt.newaxis] |
| 1132 | + expected_res = dpt.broadcast_to(x1, (reps, reps)) |
| 1133 | + |
| 1134 | + res = dpt.repeat(x1, reps, axis=1) |
| 1135 | + assert dpt.all(res == expected_res) |
| 1136 | + |
| 1137 | + x2 = x[dpt.newaxis, :] |
| 1138 | + expected_res = dpt.broadcast_to(x2, (reps, reps)) |
| 1139 | + |
| 1140 | + res = dpt.repeat(x2, reps, axis=0) |
| 1141 | + assert dpt.all(res == expected_res) |
| 1142 | + |
| 1143 | + |
| 1144 | +def test_repeat_axes(): |
| 1145 | + get_queue_or_skip() |
| 1146 | + |
| 1147 | + reps = 2 |
| 1148 | + x = dpt.reshape(dpt.arange(5 * 10, dtype="i4"), (5, 10)) |
| 1149 | + expected_res = dpt.empty((x.shape[0] * 2, x.shape[1]), x.dtype) |
| 1150 | + expected_res[::2, :], expected_res[1::2] = x, x |
| 1151 | + res = dpt.repeat(x, reps, axis=0) |
| 1152 | + assert dpt.all(res == expected_res) |
| 1153 | + |
| 1154 | + expected_res = dpt.empty((x.shape[0], x.shape[1] * 2), x.dtype) |
| 1155 | + expected_res[:, ::2], expected_res[:, 1::2] = x, x |
| 1156 | + res = dpt.repeat(x, reps, axis=1) |
| 1157 | + assert dpt.all(res == expected_res) |
| 1158 | + |
| 1159 | + |
| 1160 | +def test_repeat_size_0_outputs(): |
| 1161 | + get_queue_or_skip() |
| 1162 | + |
| 1163 | + x = dpt.ones((3, 0, 5), dtype="i4") |
| 1164 | + reps = 10 |
| 1165 | + res = dpt.repeat(x, reps, axis=0) |
| 1166 | + assert res.size == 0 |
| 1167 | + assert res.shape == (30, 0, 5) |
| 1168 | + |
| 1169 | + res = dpt.repeat(x, reps, axis=1) |
| 1170 | + assert res.size == 0 |
| 1171 | + assert res.shape == (3, 0, 5) |
| 1172 | + |
| 1173 | + res = dpt.repeat(x, (2, 2, 2), axis=0) |
| 1174 | + assert res.size == 0 |
| 1175 | + assert res.shape == (6, 0, 5) |
| 1176 | + |
| 1177 | + x = dpt.ones((3, 2, 5)) |
| 1178 | + res = dpt.repeat(x, 0, axis=1) |
| 1179 | + assert res.size == 0 |
| 1180 | + assert res.shape == (3, 0, 5) |
| 1181 | + |
| 1182 | + x = dpt.ones((3, 2, 5)) |
| 1183 | + res = dpt.repeat(x, (0, 0), axis=1) |
| 1184 | + assert res.size == 0 |
| 1185 | + assert res.shape == (3, 0, 5) |
| 1186 | + |
| 1187 | + |
| 1188 | +def test_repeat_strides(): |
| 1189 | + get_queue_or_skip() |
| 1190 | + |
| 1191 | + reps = 2 |
| 1192 | + x = dpt.reshape(dpt.arange(10 * 10, dtype="i4"), (10, 10)) |
| 1193 | + x1 = x[:, ::-2] |
| 1194 | + expected_res = dpt.empty((10, 10), dtype="i4") |
| 1195 | + expected_res[:, ::2], expected_res[:, 1::2] = x1, x1 |
| 1196 | + res = dpt.repeat(x1, reps, axis=1) |
| 1197 | + assert dpt.all(res == expected_res) |
| 1198 | + res = dpt.repeat(x1, (reps,) * x1.shape[1], axis=1) |
| 1199 | + assert dpt.all(res == expected_res) |
| 1200 | + |
| 1201 | + x1 = x[::-2, :] |
| 1202 | + expected_res = dpt.empty((10, 10), dtype="i4") |
| 1203 | + expected_res[::2, :], expected_res[1::2, :] = x1, x1 |
| 1204 | + res = dpt.repeat(x1, reps, axis=0) |
| 1205 | + assert dpt.all(res == expected_res) |
| 1206 | + res = dpt.repeat(x1, (reps,) * x1.shape[0], axis=0) |
| 1207 | + assert dpt.all(res == expected_res) |
| 1208 | + |
| 1209 | + |
| 1210 | +def test_repeat_arg_validation(): |
| 1211 | + get_queue_or_skip() |
| 1212 | + |
| 1213 | + x = dict() |
| 1214 | + with pytest.raises(TypeError): |
| 1215 | + dpt.repeat(x, 2) |
| 1216 | + |
| 1217 | + # axis must be 0 for scalar |
| 1218 | + x = dpt.empty(()) |
| 1219 | + with pytest.raises(ValueError): |
| 1220 | + dpt.repeat(x, 2, axis=1) |
| 1221 | + |
| 1222 | + # x.ndim must be 1 for axis=None |
| 1223 | + x = dpt.empty((5, 10)) |
| 1224 | + with pytest.raises(ValueError): |
| 1225 | + dpt.repeat(x, 2, axis=None) |
| 1226 | + |
| 1227 | + # repeats must be positive |
| 1228 | + x = dpt.empty(1) |
| 1229 | + with pytest.raises(ValueError): |
| 1230 | + dpt.repeat(x, -2) |
| 1231 | + |
| 1232 | + # repeats must be integers |
| 1233 | + with pytest.raises(TypeError): |
| 1234 | + dpt.repeat(x, 2.0) |
| 1235 | + |
| 1236 | + # repeats tuple must be the same length as axis |
| 1237 | + with pytest.raises(ValueError): |
| 1238 | + dpt.repeat(x, (1, 2)) |
| 1239 | + |
| 1240 | + # repeats tuple elements must be positive |
| 1241 | + with pytest.raises(ValueError): |
| 1242 | + dpt.repeat(x, (-1,)) |
| 1243 | + |
| 1244 | + # repeats must be int or tuple |
| 1245 | + with pytest.raises(TypeError): |
| 1246 | + dpt.repeat(x, dict()) |
0 commit comments