Skip to content

Commit 02448c4

Browse files
authored
dpctl.tensor.floor_divide fixed for signed 0 output (#1271)
* floor_divide fixed for signed 0 output * Reduced number of computations for floor_divide between integers - Rather than computing division and modulo for each element for sycl::vec, instead the vector is initialized and filled per-element
1 parent 1cc45e4 commit 02448c4

File tree

2 files changed

+76
-63
lines changed

2 files changed

+76
-63
lines changed

dpctl/tensor/libtensor/include/kernels/elementwise_functions/floor_divide.hpp

Lines changed: 40 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -52,62 +52,60 @@ namespace tu_ns = dpctl::tensor::type_utils;
5252
template <typename argT1, typename argT2, typename resT>
5353
struct FloorDivideFunctor
5454
{
55-
56-
using supports_sg_loadstore = std::negation<
57-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58-
using supports_vec = std::negation<
59-
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
55+
using supports_sg_loadstore = std::true_type;
56+
using supports_vec = std::true_type;
6057

6158
resT operator()(const argT1 &in1, const argT2 &in2)
6259
{
63-
auto tmp = in1 / in2;
64-
if constexpr (std::is_integral_v<decltype(tmp)>) {
65-
if constexpr (std::is_unsigned_v<decltype(tmp)>) {
66-
return (in2 == argT2(0)) ? resT(0) : tmp;
60+
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
61+
if (in2 == argT2(0)) {
62+
return resT(0);
63+
}
64+
if constexpr (std::is_signed_v<argT1> || std::is_signed_v<argT2>) {
65+
auto div = in1 / in2;
66+
auto mod = in1 % in2;
67+
auto corr = (mod != 0 && l_xor(mod < 0, in2 < 0));
68+
return (div - corr);
6769
}
6870
else {
69-
if (in2 == argT2(0)) {
70-
return resT(0);
71-
}
72-
else {
73-
auto rem = in1 % in2;
74-
auto corr = (rem != 0 && ((rem < 0) != (in2 < 0)));
75-
return (tmp - corr);
76-
}
71+
return (in1 / in2);
7772
}
7873
}
7974
else {
80-
return sycl::floor(tmp);
75+
auto div = in1 / in2;
76+
return (div == resT(0)) ? div : resT(std::floor(div));
8177
}
8278
}
8379

8480
template <int vec_sz>
8581
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
8682
const sycl::vec<argT2, vec_sz> &in2)
8783
{
88-
auto tmp = in1 / in2;
89-
using tmpT = typename decltype(tmp)::element_type;
90-
if constexpr (std::is_integral_v<tmpT>) {
91-
if constexpr (std::is_signed_v<tmpT>) {
92-
auto rem_tmp = in1 % in2;
84+
if constexpr (std::is_integral_v<resT>) {
85+
sycl::vec<resT, vec_sz> res;
9386
#pragma unroll
94-
for (int i = 0; i < vec_sz; ++i) {
95-
if (in2[i] == argT2(0)) {
96-
tmp[i] = tmpT(0);
97-
}
98-
else {
99-
tmpT corr = (rem_tmp[i] != 0 &&
100-
((rem_tmp[i] < 0) != (in2[i] < 0)));
101-
tmp[i] -= corr;
87+
for (int i = 0; i < vec_sz; ++i) {
88+
if (in2[i] == argT2(0)) {
89+
res[i] = resT(0);
90+
}
91+
else {
92+
res[i] = in1[i] / in2[i];
93+
if constexpr (std::is_signed_v<resT>) {
94+
auto mod = in1[i] % in2[i];
95+
auto corr = (mod != 0 && l_xor(mod < 0, in2[i] < 0));
96+
res[i] -= corr;
10297
}
10398
}
10499
}
105-
else {
100+
return res;
101+
}
102+
else {
103+
auto tmp = in1 / in2;
104+
using tmpT = typename decltype(tmp)::element_type;
106105
#pragma unroll
107-
for (int i = 0; i < vec_sz; ++i) {
108-
if (in2[i] == argT2(0)) {
109-
tmp[i] = tmpT(0);
110-
}
106+
for (int i = 0; i < vec_sz; ++i) {
107+
if (in2[i] != argT2(0)) {
108+
tmp[i] = std::floor(tmp[i]);
111109
}
112110
}
113111
if constexpr (std::is_same_v<resT, tmpT>) {
@@ -118,19 +116,12 @@ struct FloorDivideFunctor
118116
return vec_cast<resT, tmpT, vec_sz>(tmp);
119117
}
120118
}
121-
else {
122-
sycl::vec<resT, vec_sz> res = sycl::floor(tmp);
123-
if constexpr (std::is_same_v<resT,
124-
typename decltype(res)::element_type>)
125-
{
126-
return res;
127-
}
128-
else {
129-
using dpctl::tensor::type_utils::vec_cast;
130-
return vec_cast<resT, typename decltype(res)::element_type,
131-
vec_sz>(res);
132-
}
133-
}
119+
}
120+
121+
private:
122+
bool l_xor(bool b1, bool b2) const
123+
{
124+
return (b1 != b2);
134125
}
135126
};
136127

dpctl/tests/elementwise/test_floor_divide.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -203,16 +203,6 @@ def test_floor_divide_gh_1247():
203203
dpt.asnumpy(res), np.full(res.shape, -1, dtype=res.dtype)
204204
)
205205

206-
# attempt to invoke sycl::vec overload using a larger array
207-
x = dpt.arange(-64, 65, 1, dtype="i4")
208-
np.testing.assert_array_equal(
209-
dpt.asnumpy(dpt.floor_divide(x, 3)), np.floor_divide(dpt.asnumpy(x), 3)
210-
)
211-
np.testing.assert_array_equal(
212-
dpt.asnumpy(dpt.floor_divide(x, -3)),
213-
np.floor_divide(dpt.asnumpy(x), -3),
214-
)
215-
216206

217207
@pytest.mark.parametrize("dtype", _no_complex_dtypes[1:9])
218208
def test_floor_divide_integer_zero(dtype):
@@ -226,10 +216,42 @@ def test_floor_divide_integer_zero(dtype):
226216
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
227217
)
228218

229-
# attempt to invoke sycl::vec overload using a larger array
230-
x = dpt.arange(129, dtype=dtype, sycl_queue=q)
231-
y = dpt.zeros_like(x, sycl_queue=q)
219+
220+
def test_floor_divide_special_cases():
221+
q = get_queue_or_skip()
222+
223+
x = dpt.empty(1, dtype="f4", sycl_queue=q)
224+
y = dpt.empty_like(x)
225+
x[0], y[0] = dpt.inf, dpt.inf
226+
res = dpt.floor_divide(x, y)
227+
with np.errstate(all="ignore"):
228+
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
229+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
230+
231+
x[0], y[0] = 0.0, -1.0
232+
res = dpt.floor_divide(x, y)
233+
x_np = dpt.asnumpy(x)
234+
y_np = dpt.asnumpy(y)
235+
res_np = np.floor_divide(x_np, y_np)
236+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
237+
238+
res = dpt.floor_divide(y, x)
239+
with np.errstate(all="ignore"):
240+
res_np = np.floor_divide(y_np, x_np)
241+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)
242+
243+
x[0], y[0] = -1.0, dpt.inf
232244
res = dpt.floor_divide(x, y)
233245
np.testing.assert_array_equal(
234-
dpt.asnumpy(res), np.zeros(x.shape, dtype=res.dtype)
246+
dpt.asnumpy(res), np.asarray([-0.0], dtype="f4")
235247
)
248+
249+
res = dpt.floor_divide(y, x)
250+
np.testing.assert_array_equal(
251+
dpt.asnumpy(res), np.asarray([-dpt.inf], dtype="f4")
252+
)
253+
254+
x[0], y[0] = 1.0, dpt.nan
255+
res = dpt.floor_divide(x, y)
256+
res_np = np.floor_divide(dpt.asnumpy(x), dpt.asnumpy(y))
257+
np.testing.assert_array_equal(dpt.asnumpy(res), res_np)

0 commit comments

Comments
 (0)