Skip to content

Commit 0e31948

Browse files
committed
address reviewer's comments
1 parent bac7662 commit 0e31948

File tree

5 files changed

+64
-54
lines changed

5 files changed

+64
-54
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/ceil.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ template <typename argT, typename resT> struct CeilFunctor
7171
else {
7272
return sycl::ceil(in);
7373
}
74-
// return sycl::ceil(in);
7574
}
7675
};
7776

@@ -90,7 +89,14 @@ template <typename T> struct CeilOutputType
9089
{
9190
using value_type = typename std::disjunction< // disjunction is C++17
9291
// feature, supported by DPC++
93-
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
92+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
93+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
94+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
95+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
96+
td_ns::TypeMapResultEntry<T, std::int8_t>,
97+
td_ns::TypeMapResultEntry<T, std::int16_t>,
98+
td_ns::TypeMapResultEntry<T, std::int32_t>,
99+
td_ns::TypeMapResultEntry<T, std::int64_t>,
94100
td_ns::TypeMapResultEntry<T, sycl::half>,
95101
td_ns::TypeMapResultEntry<T, float>,
96102
td_ns::TypeMapResultEntry<T, double>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ template <typename argT, typename resT> struct FloorFunctor
7171
else {
7272
return sycl::floor(in);
7373
}
74-
// return sycl::floor(in);
7574
}
7675
};
7776

@@ -94,7 +93,14 @@ template <typename T> struct FloorOutputType
9493
{
9594
using value_type = typename std::disjunction< // disjunction is C++17
9695
// feature, supported by DPC++
97-
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
96+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
97+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
98+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
99+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
100+
td_ns::TypeMapResultEntry<T, std::int8_t>,
101+
td_ns::TypeMapResultEntry<T, std::int16_t>,
102+
td_ns::TypeMapResultEntry<T, std::int32_t>,
103+
td_ns::TypeMapResultEntry<T, std::int64_t>,
98104
td_ns::TypeMapResultEntry<T, sycl::half>,
99105
td_ns::TypeMapResultEntry<T, float>,
100106
td_ns::TypeMapResultEntry<T, double>,

dpctl/tensor/libtensor/include/kernels/elementwise_functions/trunc.hpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ template <typename argT, typename resT> struct TruncFunctor
7171
else {
7272
return sycl::trunc(in);
7373
}
74-
// return sycl::trunc(in);
7574
}
7675
};
7776

@@ -94,7 +93,14 @@ template <typename T> struct TruncOutputType
9493
{
9594
using value_type = typename std::disjunction< // disjunction is C++17
9695
// feature, supported by DPC++
97-
td_ns::TypeMapResultEntry<T, bool, sycl::half>,
96+
td_ns::TypeMapResultEntry<T, std::uint8_t>,
97+
td_ns::TypeMapResultEntry<T, std::uint16_t>,
98+
td_ns::TypeMapResultEntry<T, std::uint32_t>,
99+
td_ns::TypeMapResultEntry<T, std::uint64_t>,
100+
td_ns::TypeMapResultEntry<T, std::int8_t>,
101+
td_ns::TypeMapResultEntry<T, std::int16_t>,
102+
td_ns::TypeMapResultEntry<T, std::int32_t>,
103+
td_ns::TypeMapResultEntry<T, std::int64_t>,
98104
td_ns::TypeMapResultEntry<T, sycl::half>,
99105
td_ns::TypeMapResultEntry<T, float>,
100106
td_ns::TypeMapResultEntry<T, double>,

dpctl/tests/elementwise/test_floor_ceil_trunc.py

Lines changed: 27 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -24,27 +24,24 @@
2424
import dpctl.tensor as dpt
2525
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2626

27-
from .utils import _map_to_device_dtype, _no_complex_dtypes
27+
from .utils import _map_to_device_dtype, _real_value_dtypes
2828

2929
_all_funcs = [(np.floor, dpt.floor), (np.ceil, dpt.ceil), (np.trunc, dpt.trunc)]
3030

3131

32-
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
33-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
34-
def test_floor_ceil_trunc_out_type(np_call, dpt_call, dtype):
32+
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
33+
@pytest.mark.parametrize("dtype", _real_value_dtypes)
34+
def test_floor_ceil_trunc_out_type(dpt_call, dtype):
3535
q = get_queue_or_skip()
3636
skip_if_dtype_not_supported(dtype, q)
37-
if dtype == "b1":
38-
skip_if_dtype_not_supported("f2", q)
3937

40-
X = dpt.asarray(0.1, dtype=dtype, sycl_queue=q)
41-
expected_dtype = np_call(np.array(0.1, dtype=dtype)).dtype
42-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
38+
arg_dt = np.dtype(dtype)
39+
X = dpt.asarray(0.1, dtype=arg_dt, sycl_queue=q)
40+
expected_dtype = _map_to_device_dtype(arg_dt, q.sycl_device)
4341
assert dpt_call(X).dtype == expected_dtype
4442

4543
X = dpt.asarray(0.1, dtype=dtype, sycl_queue=q)
46-
expected_dtype = np_call(np.array(0.1, dtype=dtype)).dtype
47-
expected_dtype = _map_to_device_dtype(expected_dtype, q.sycl_device)
44+
expected_dtype = _map_to_device_dtype(arg_dt, q.sycl_device)
4845
Y = dpt.empty_like(X, dtype=expected_dtype)
4946
dpt_call(X, out=Y)
5047
assert_allclose(dpt.asnumpy(dpt_call(X)), dpt.asnumpy(Y))
@@ -73,12 +70,10 @@ def test_floor_ceil_trunc_usm_type(np_call, dpt_call, usm_type):
7370

7471

7572
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
76-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
73+
@pytest.mark.parametrize("dtype", _real_value_dtypes)
7774
def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
7875
q = get_queue_or_skip()
7976
skip_if_dtype_not_supported(dtype, q)
80-
if dtype == "b1":
81-
skip_if_dtype_not_supported("f2", q)
8277

8378
arg_dt = np.dtype(dtype)
8479
input_shape = (10, 10, 10, 10)
@@ -90,17 +85,12 @@ def test_floor_ceil_trunc_order(np_call, dpt_call, dtype):
9085
for perms in itertools.permutations(range(4)):
9186
U = dpt.permute_dims(X[:, ::-1, ::-1, :], perms)
9287
Y = dpt_call(U, order=ord)
93-
with np.errstate(all="ignore"):
94-
expected_Y = np_call(dpt.asnumpy(U))
95-
tol = 8 * max(
96-
dpt.finfo(Y.dtype).resolution,
97-
np.finfo(expected_Y.dtype).resolution,
98-
)
99-
assert_allclose(dpt.asnumpy(Y), expected_Y, atol=tol, rtol=tol)
88+
expected_Y = np_call(dpt.asnumpy(U))
89+
assert_allclose(dpt.asnumpy(Y), expected_Y)
10090

10191

102-
@pytest.mark.parametrize("callable", [dpt.floor, dpt.ceil, dpt.trunc])
103-
def test_floor_ceil_trunc_errors(callable):
92+
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
93+
def test_floor_ceil_trunc_errors(dpt_call):
10494
get_queue_or_skip()
10595
try:
10696
gpu_queue = dpctl.SyclQueue("gpu")
@@ -116,7 +106,7 @@ def test_floor_ceil_trunc_errors(callable):
116106
assert_raises_regex(
117107
TypeError,
118108
"Input and output allocation queues are not compatible",
119-
callable,
109+
dpt_call,
120110
x,
121111
y,
122112
)
@@ -126,41 +116,39 @@ def test_floor_ceil_trunc_errors(callable):
126116
assert_raises_regex(
127117
TypeError,
128118
"The shape of input and output arrays are inconsistent",
129-
callable,
119+
dpt_call,
130120
x,
131121
y,
132122
)
133123

134124
x = dpt.zeros(2)
135125
y = x
136126
assert_raises_regex(
137-
TypeError, "Input and output arrays have memory overlap", callable, x, y
127+
TypeError, "Input and output arrays have memory overlap", dpt_call, x, y
138128
)
139129

140130
x = dpt.zeros(2, dtype="float32")
141131
y = np.empty_like(x)
142132
assert_raises_regex(
143-
TypeError, "output array must be of usm_ndarray type", callable, x, y
133+
TypeError, "output array must be of usm_ndarray type", dpt_call, x, y
144134
)
145135

146136

147-
@pytest.mark.parametrize("callable", [dpt.floor, dpt.ceil, dpt.trunc])
148-
@pytest.mark.parametrize("dtype", _no_complex_dtypes)
149-
def test_floor_ceil_trunc_error_dtype(callable, dtype):
137+
@pytest.mark.parametrize("dpt_call", [dpt.floor, dpt.ceil, dpt.trunc])
138+
@pytest.mark.parametrize("dtype", _real_value_dtypes)
139+
def test_floor_ceil_trunc_error_dtype(dpt_call, dtype):
150140
q = get_queue_or_skip()
151141
skip_if_dtype_not_supported(dtype, q)
152-
if dtype == "b1":
153-
skip_if_dtype_not_supported("f2", q)
154142

155143
x = dpt.zeros(5, dtype=dtype)
156-
y = dpt.empty_like(x, dtype="int16")
144+
y = dpt.empty_like(x, dtype="b1")
157145
assert_raises_regex(
158-
TypeError, "Output array of type.*is needed", callable, x, y
146+
TypeError, "Output array of type.*is needed", dpt_call, x, y
159147
)
160148

161149

162150
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
163-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
151+
@pytest.mark.parametrize("dtype", _real_value_dtypes)
164152
def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
165153
q = get_queue_or_skip()
166154
skip_if_dtype_not_supported(dtype, q)
@@ -172,29 +160,23 @@ def test_floor_ceil_trunc_contig(np_call, dpt_call, dtype):
172160
X = dpt.asarray(np.repeat(Xnp, n_rep), dtype=dtype, sycl_queue=q)
173161
Y = dpt_call(X)
174162

175-
tol = 8 * dpt.finfo(Y.dtype).resolution
176-
assert_allclose(
177-
dpt.asnumpy(Y), np.repeat(np_call(Xnp), n_rep), atol=tol, rtol=tol
178-
)
163+
assert_allclose(dpt.asnumpy(Y), np.repeat(np_call(Xnp), n_rep))
179164

180165
Z = dpt.empty_like(X, dtype=dtype)
181166
dpt_call(X, out=Z)
182167

183-
assert_allclose(
184-
dpt.asnumpy(Z), np.repeat(np_call(Xnp), n_rep), atol=tol, rtol=tol
185-
)
168+
assert_allclose(dpt.asnumpy(Z), np.repeat(np_call(Xnp), n_rep))
186169

187170

188171
@pytest.mark.parametrize("np_call, dpt_call", _all_funcs)
189-
@pytest.mark.parametrize("dtype", ["f2", "f4", "f8"])
172+
@pytest.mark.parametrize("dtype", _real_value_dtypes)
190173
def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
191174
q = get_queue_or_skip()
192175
skip_if_dtype_not_supported(dtype, q)
193176

194177
np.random.seed(42)
195178
strides = np.array([-4, -3, -2, -1, 1, 2, 3, 4])
196179
sizes = np.arange(2, 100)
197-
tol = 8 * dpt.finfo(dtype).resolution
198180

199181
for ii in sizes:
200182
Xnp = np.random.uniform(low=-99.9, high=99.9, size=ii)
@@ -205,8 +187,6 @@ def test_floor_ceil_trunc_strided(np_call, dpt_call, dtype):
205187
assert_allclose(
206188
dpt.asnumpy(dpt_call(X[::jj])),
207189
Ynp[::jj],
208-
atol=tol,
209-
rtol=tol,
210190
)
211191

212192

@@ -221,8 +201,7 @@ def test_floor_ceil_trunc_special_cases(np_call, dpt_call, dtype):
221201
xf = np.array(x, dtype=dtype)
222202
yf = dpt.asarray(xf, dtype=dtype, sycl_queue=q)
223203

224-
with np.errstate(all="ignore"):
225-
Y_np = np_call(xf)
204+
Y_np = np_call(xf)
226205

227206
tol = 8 * dpt.finfo(dtype).resolution
228207
assert_allclose(dpt.asnumpy(dpt_call(yf)), Y_np, atol=tol, rtol=tol)

dpctl/tests/elementwise/utils.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,19 @@
3131
"f4",
3232
"f8",
3333
]
34+
_real_value_dtypes = [
35+
"i1",
36+
"u1",
37+
"i2",
38+
"u2",
39+
"i4",
40+
"u4",
41+
"i8",
42+
"u8",
43+
"f2",
44+
"f4",
45+
"f8",
46+
]
3447
_all_dtypes = _no_complex_dtypes + [
3548
"c8",
3649
"c16",

0 commit comments

Comments
 (0)