@@ -52,77 +52,59 @@ namespace tu_ns = dpctl::tensor::type_utils;
52
52
template <typename argT1, typename argT2, typename resT>
53
53
struct FloorDivideFunctor
54
54
{
55
-
56
- using supports_sg_loadstore = std::negation<
57
- std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
58
- using supports_vec = std::negation<
59
- std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
55
+ using supports_sg_loadstore = std::true_type;
56
+ using supports_vec = std::true_type;
60
57
61
58
resT operator ()(const argT1 &in1, const argT2 &in2)
62
59
{
63
60
if constexpr (std::is_integral_v<argT1> || std::is_integral_v<argT2>) {
64
- static_assert (std::is_same_v<argT1, argT2>);
65
- if (in2 == 0 ) {
61
+ if (in2 == argT2 (0 )) {
66
62
return resT (0 );
67
63
}
68
- auto tmp = in1 / in2;
69
- if constexpr (std::is_unsigned_v<argT1> ||
70
- std::is_unsigned_v<argT2>) {
71
- return tmp;
64
+ if constexpr (std::is_signed_v<argT1> || std::is_signed_v<argT2>) {
65
+ auto div = in1 / in2;
66
+ auto mod = in1 % in2;
67
+ auto corr = (mod != 0 && l_xor (mod < 0 , in2 < 0 ));
68
+ return (div - corr);
72
69
}
73
70
else {
74
- auto rem = in1 % in2;
75
- auto corr = (rem != 0 && ((rem < 0 ) != (in2 < 0 )));
76
- return (tmp - corr);
71
+ return (in1 / in2);
77
72
}
78
73
}
79
74
else {
80
- auto tmp = in1 / in2;
81
- return (tmp == 0 ) ? resT (tmp) : resT (std::floor (tmp ));
75
+ auto div = in1 / in2;
76
+ return (div == resT ( 0 )) ? div : resT (std::floor (div ));
82
77
}
83
78
}
84
79
85
80
template <int vec_sz>
86
81
sycl::vec<resT, vec_sz> operator ()(const sycl::vec<argT1, vec_sz> &in1,
87
82
const sycl::vec<argT2, vec_sz> &in2)
88
83
{
89
- auto tmp = in1 / in2;
90
- using tmpT = typename decltype (tmp)::element_type;
91
- if constexpr (std::is_integral_v<tmpT>) {
92
- if constexpr (std::is_unsigned_v<tmpT>) {
84
+ if constexpr (std::is_integral_v<resT>) {
85
+ sycl::vec<resT, vec_sz> res;
93
86
#pragma unroll
94
- for (int i = 0 ; i < vec_sz; ++i) {
95
- if (in2[i] == argT2 (0 )) {
96
- tmp[i] = tmpT (0 );
97
- }
87
+ for (int i = 0 ; i < vec_sz; ++i) {
88
+ if (in2[i] == argT2 (0 )) {
89
+ res[i] = resT (0 );
98
90
}
99
- }
100
- else {
101
- auto rem = in1 % in2;
102
- #pragma unroll
103
- for (int i = 0 ; i < vec_sz; ++i) {
104
- if (in2[i] == 0 ) {
105
- tmp[i] = tmpT (0 );
106
- }
107
- else {
108
- tmpT corr =
109
- (rem[i] != 0 && ((rem[i] < 0 ) != (in2[i] < 0 )));
110
- tmp[i] -= corr;
91
+ else {
92
+ res[i] = in1[i] / in2[i];
93
+ if constexpr (std::is_signed_v<resT>) {
94
+ auto mod = in1[i] % in2[i];
95
+ auto corr = (mod != 0 && l_xor (mod < 0 , in2[i] < 0 ));
96
+ res[i] -= corr;
111
97
}
112
98
}
113
99
}
114
- if constexpr (std::is_same_v<resT, tmpT>) {
115
- return tmp;
116
- }
117
- else {
118
- using dpctl::tensor::type_utils::vec_cast;
119
- return vec_cast<resT, tmpT, vec_sz>(tmp);
120
- }
100
+ return res;
121
101
}
122
102
else {
103
+ auto tmp = in1 / in2;
104
+ using tmpT = typename decltype (tmp)::element_type;
123
105
#pragma unroll
124
106
for (int i = 0 ; i < vec_sz; ++i) {
125
- if (in2[i] != 0 ) {
107
+ if (in2[i] != argT2 ( 0 ) ) {
126
108
tmp[i] = std::floor (tmp[i]);
127
109
}
128
110
}
@@ -135,6 +117,12 @@ struct FloorDivideFunctor
135
117
}
136
118
}
137
119
}
120
+
121
+ private:
122
+ bool l_xor (bool b1, bool b2) const
123
+ {
124
+ return (b1 != b2);
125
+ }
138
126
};
139
127
140
128
template <typename argT1,
0 commit comments