Skip to content

Commit c31d020

Browse files
committed
floor_divide fixed for signed 0 output
1 parent 92aa81d commit c31d020

File tree

2 files changed

+67
-42
lines changed

2 files changed

+67
-42
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 31 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -60,24 +60,25 @@ struct FloorDivideFunctor
6060

6161
resT operator()(const argT1 &in1, const argT2 &in2)
6262
{
63-
auto tmp = in1 / in2;
64-
if constexpr (std::is_integral_v<decltype(tmp)>) {
65-
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
66-
return (in2 == argT2(0)) ? resT(0) : tmp;
63+
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
64+
static_assert(std::is_same_v<argT1, argT2>);
65+
if (in2 == 0) {
66+
return resT(0);
67+
}
68+
auto tmp = in1 / in2;
69+
if constexpr (std::is_unsigned_v<argT1> ||
70+
std::is_unsigned_v<argT2>) {
71+
return tmp;
6772
}
6873
else {
69-
if (in2 == argT2(0)) {
70-
return resT(0);
71-
}
72-
else {
73-
auto rem = in1 % in2;
74-
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
75-
return (tmp - corr);
76-
}
74+
auto rem = in1 % in2;
75+
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
76+
return (tmp - corr);
7777
}
7878
}
7979
else {
80-
return sycl::floor(tmp);
80+
auto tmp = in1 / in2;
81+
return (tmp == 0) ? resT(tmp) : resT(std::floor(tmp));
8182
}
8283
}
8384

@@ -88,26 +89,26 @@ struct FloorDivideFunctor
8889
auto tmp = in1 / in2;
8990
using tmpT = typename decltype(tmp)::element_type;
9091
if constexpr (std::is_integral_v<tmpT>) {
91-
if constexpr (std::is_signed_v<tmpT>) {
92-
auto rem_tmp = in1 % in2;
92+
if constexpr (std::is_unsigned_v<tmpT>) {
9393
#pragma unroll
9494
for (int i = 0; i < vec_sz; ++i) {
9595
if (in2[i] == argT2(0)) {
9696
tmp[i] = tmpT(0);
9797
}
98-
else {
99-
tmpT corr = (rem_tmp[i] != 0 &&
100-
((rem_tmp[i] < 0) != (in2[i] < 0)));
101-
tmp[i] -= corr;
102-
}
10398
}
10499
}
105100
else {
101+
auto rem = in1 % in2;
106102
#pragma unroll
107103
for (int i = 0; i < vec_sz; ++i) {
108-
if (in2[i] == argT2(0)) {
104+
if (in2[i] == 0) {
109105
tmp[i] = tmpT(0);
110106
}
107+
else {
108+
tmpT corr =
109+
(rem[i] != 0 && ((rem[i] < 0) != (in2[i] < 0)));
110+
tmp[i] -= corr;
111+
}
111112
}
112113
}
113114
if constexpr (std::is_same_v<resT, tmpT>) {
@@ -119,16 +120,18 @@ struct FloorDivideFunctor
119120
}
120121
}
121122
else {
122-
sycl::vec<resT, vec_sz> res = sycl::floor(tmp);
123-
if constexpr (std::is_same_v<resT,
124-
typename decltype(res)::element_type>)
125-
{
126-
return res;
123+
#pragma unroll
124+
for (int i = 0; i < vec_sz; ++i) {
125+
if (in2[i] != 0) {
126+
tmp[i] = std::floor(tmp[i]);
127+
}
128+
}
129+
if constexpr (std::is_same_v<resT, tmpT>) {
130+
return tmp;
127131
}
128132
else {
129133
using dpctl::tensor::type_utils::vec_cast;
130-
return vec_cast<resT, typename decltype(res)::element_type,
131-
vec_sz>(res);
134+
return vec_cast<resT, tmpT, vec_sz>(tmp);
132135
}
133136
}
134137
}

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,6 @@ def test_floor_divide_gh_1247():
203203
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
204204
)
205205

206-
# attempt to invoke sycl::vec overload using a larger array
207-
x = dpt.arange(-64, 65, 1, dtype="i4")
208-
np.testing.assert_array_equal(
209-
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
210-
)
211-
np.testing.assert_array_equal(
212-
dpt.asnumpy(dpt.floor_divide(x, -3)),
213-
np.floor_divide(dpt.asnumpy(x), -3),
214-
)
215-
216206

217207
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
218208
def test_floor_divide_integer_zero(dtype):
@@ -226,10 +216,42 @@ def test_floor_divide_integer_zero(dtype):
226216
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
227217
)
228218

229-
# attempt to invoke sycl::vec overload using a larger array
230-
x = dpt.arange(129, dtype=dtype, sycl_queue=q)
231-
y = dpt.zeros_like(x, sycl_queue=q)
219+
220+
def test_floor_divide_special_cases():
221+
q = get_queue_or_skip()
222+
223+
x = dpt.empty(1, dtype="f4", sycl_queue=q)
224+
y = dpt.empty_like(x)
225+
x[0], y[0] = dpt.inf, dpt.inf
226+
res = dpt.floor_divide(x, y)
227+
with np.errstate(all="ignore"):
228+
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
229+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
230+
231+
x[0], y[0] = 0.0, -1.0
232+
res = dpt.floor_divide(x, y)
233+
x_np = dpt.asnumpy(x)
234+
y_np = dpt.asnumpy(y)
235+
res_np = np.floor_divide(x_np, y_np)
236+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
237+
238+
res = dpt.floor_divide(y, x)
239+
with np.errstate(all="ignore"):
240+
res_np = np.floor_divide(y_np, x_np)
241+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
242+
243+
x[0], y[0] = -1.0, dpt.inf
232244
res = dpt.floor_divide(x, y)
233245
np.testing.assert_array_equal(
234-
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
246+
dpt.asnumpy(res), np.asarray([-0.0], dtype="f4")
235247
)
248+
249+
res = dpt.floor_divide(y, x)
250+
np.testing.assert_array_equal(
251+
dpt.asnumpy(res), np.asarray([-dpt.inf], dtype="f4")
252+
)
253+
254+
x[0], y[0] = 1.0, dpt.nan
255+
res = dpt.floor_divide(x, y)
256+
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
257+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)

0 commit comments

Comments
 (0)