Skip to content

Commit 3ad6d8b

Browse files
committed
Added tests for in-place remainder and pow
Fixed in-place remainder for devices that do not support 64-bit floating point data types
1 parent f2b335d commit 3ad6d8b

File tree

3 files changed

+126
-11
lines changed

3 files changed

+126
-11
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/remainder.hpp

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,11 @@ template <typename fnT, typename T1, typename T2> struct RemainderStridedFactory
316316
template <typename argT, typename resT> struct RemainderInplaceFunctor
317317
{
318318

319-
using supports_sg_loadstore = std::negation<
320-
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
321-
using supports_vec = std::negation<
322-
std::disjunction<tu_ns::is_complex<argT>, tu_ns::is_complex<resT>>>;
319+
using supports_sg_loadstore = std::true_type;
320+
using supports_vec = std::true_type;
321+
322+
// functor is only well-defined when argT and resT are the same
323+
static_assert(std::is_same_v<argT, resT>);
323324

324325
void operator()(resT &res, const argT &in)
325326
{
@@ -331,7 +332,7 @@ template <typename argT, typename resT> struct RemainderInplaceFunctor
331332
if constexpr (std::is_signed_v<argT> || std::is_signed_v<resT>) {
332333
auto tmp = res;
333334
res %= in;
334-
if (res != 0 && l_xor(tmp < 0, in < 0)) {
335+
if (res != resT(0) && l_xor(tmp < 0, in < 0)) {
335336
res += in;
336337
}
337338
}
@@ -347,7 +348,7 @@ template <typename argT, typename resT> struct RemainderInplaceFunctor
347348
}
348349
}
349350
else {
350-
res = std::copysign(0, in);
351+
res = sycl::copysign(resT(0), in);
351352
}
352353
}
353354
}
@@ -384,7 +385,7 @@ template <typename argT, typename resT> struct RemainderInplaceFunctor
384385
}
385386
}
386387
else {
387-
res[i] = std::copysign(0, in[i]);
388+
res[i] = sycl::copysign(resT(0), in[i]);
388389
}
389390
}
390391
}
@@ -444,8 +445,10 @@ struct RemainderInplaceContigFactory
444445
{
445446
fnT get()
446447
{
447-
if constexpr (std::is_same_v<typename RemainderOutputType<T1, T2>::value_type,
448-
void>) {
448+
if constexpr (std::is_same_v<
449+
typename RemainderOutputType<T1, T2>::value_type,
450+
void>)
451+
{
449452
fnT fn = nullptr;
450453
return fn;
451454
}
@@ -484,8 +487,10 @@ struct RemainderInplaceStridedFactory
484487
{
485488
fnT get()
486489
{
487-
if constexpr (std::is_same_v<typename RemainderOutputType<T1, T2>::value_type,
488-
void>) {
490+
if constexpr (std::is_same_v<
491+
typename RemainderOutputType<T1, T2>::value_type,
492+
void>)
493+
{
489494
fnT fn = nullptr;
490495
return fn;
491496
}

dpctl/tests/elementwise/test_pow.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

2627
from .utils import _all_dtypes, _compare_dtypes, _usm_types
@@ -152,3 +153,60 @@ def test_pow_python_scalar(arr_dt):
152153
assert isinstance(R, dpt.usm_ndarray)
153154
R = dpt.pow(sc, X)
154155
assert isinstance(R, dpt.usm_ndarray)
156+
157+
158+
@pytest.mark.parametrize("dtype", _all_dtypes[1:])
159+
def test_pow_inplace_python_scalar(dtype):
160+
q = get_queue_or_skip()
161+
skip_if_dtype_not_supported(dtype, q)
162+
X = dpt.ones((10, 10), dtype=dtype, sycl_queue=q)
163+
dt_kind = X.dtype.kind
164+
if dt_kind in "ui":
165+
X **= int(1)
166+
elif dt_kind == "f":
167+
X **= float(1)
168+
elif dt_kind == "c":
169+
X **= complex(1)
170+
171+
172+
@pytest.mark.parametrize("op1_dtype", _all_dtypes[1:])
173+
@pytest.mark.parametrize("op2_dtype", _all_dtypes[1:])
174+
def test_pow_inplace_dtype_matrix(op1_dtype, op2_dtype):
175+
q = get_queue_or_skip()
176+
skip_if_dtype_not_supported(op1_dtype, q)
177+
skip_if_dtype_not_supported(op2_dtype, q)
178+
179+
sz = 127
180+
ar1 = dpt.ones(sz, dtype=op1_dtype)
181+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
182+
183+
dev = q.sycl_device
184+
_fp16 = dev.has_aspect_fp16
185+
_fp64 = dev.has_aspect_fp64
186+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
187+
ar1 **= ar2
188+
assert (
189+
dpt.asnumpy(ar1) == np.full(ar1.shape, 1, dtype=ar1.dtype)
190+
).all()
191+
192+
ar3 = dpt.ones(sz, dtype=op1_dtype)
193+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
194+
195+
ar3[::-1] *= ar4[::2]
196+
assert (
197+
dpt.asnumpy(ar3) == np.full(ar3.shape, 1, dtype=ar3.dtype)
198+
).all()
199+
200+
else:
201+
with pytest.raises(TypeError):
202+
ar1 **= ar2
203+
204+
205+
def test_pow_inplace_basic():
206+
get_queue_or_skip()
207+
208+
x = dpt.arange(10, dtype="i4")
209+
expected = dpt.square(x)
210+
x **= 2
211+
212+
assert dpt.all(x == expected)

dpctl/tests/elementwise/test_remainder.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import dpctl
2323
import dpctl.tensor as dpt
24+
from dpctl.tensor._type_utils import _can_cast
2425
from dpctl.tests.helper import get_queue_or_skip, skip_if_dtype_not_supported
2526

2627
from .utils import _compare_dtypes, _no_complex_dtypes, _usm_types
@@ -206,3 +207,54 @@ def test_remainder_python_scalar(arr_dt):
206207
assert isinstance(R, dpt.usm_ndarray)
207208
R = dpt.remainder(sc, X)
208209
assert isinstance(R, dpt.usm_ndarray)
210+
211+
212+
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:])
213+
def test_remainder_inplace_python_scalar(dtype):
214+
q = get_queue_or_skip()
215+
skip_if_dtype_not_supported(dtype, q)
216+
X = dpt.ones((10, 10), dtype=dtype, sycl_queue=q)
217+
dt_kind = X.dtype.kind
218+
if dt_kind in "ui":
219+
X %= int(1)
220+
elif dt_kind == "f":
221+
X %= float(1)
222+
223+
224+
@pytest.mark.parametrize("op1_dtype", _no_complex_dtypes[1:])
225+
@pytest.mark.parametrize("op2_dtype", _no_complex_dtypes[1:])
226+
def test_remainder_inplace_dtype_matrix(op1_dtype, op2_dtype):
227+
q = get_queue_or_skip()
228+
skip_if_dtype_not_supported(op1_dtype, q)
229+
skip_if_dtype_not_supported(op2_dtype, q)
230+
231+
sz = 127
232+
ar1 = dpt.ones(sz, dtype=op1_dtype)
233+
ar2 = dpt.ones_like(ar1, dtype=op2_dtype)
234+
235+
dev = q.sycl_device
236+
_fp16 = dev.has_aspect_fp16
237+
_fp64 = dev.has_aspect_fp64
238+
if _can_cast(ar2.dtype, ar1.dtype, _fp16, _fp64):
239+
ar1 %= ar2
240+
assert dpt.all(ar1 == dpt.zeros(ar1.shape, dtype=ar1.dtype))
241+
242+
ar3 = dpt.ones(sz, dtype=op1_dtype)
243+
ar4 = dpt.ones(2 * sz, dtype=op2_dtype)
244+
245+
ar3[::-1] %= ar4[::2]
246+
assert dpt.all(ar3 == dpt.zeros(ar3.shape, dtype=ar3.dtype))
247+
248+
else:
249+
with pytest.raises(TypeError):
250+
ar1 %= ar2
251+
252+
253+
def test_remainder_inplace_basic():
254+
get_queue_or_skip()
255+
256+
x = dpt.arange(10, dtype="i4")
257+
expected = x & 1
258+
x %= 2
259+
260+
assert dpt.all(x == expected)

0 commit comments

Comments
 (0)