Skip to content

Commit b989d36

Browse files
committed
impl_minimum_maximum_elementwise_funcs
1 parent acd1a60 commit b989d36

File tree

14 files changed

+1131
-7
lines changed

14 files changed

+1131
-7
lines changed

dpctl/tensor/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,8 @@
135135
logical_not,
136136
logical_or,
137137
logical_xor,
138+
maximum,
139+
minimum,
138140
multiply,
139141
negative,
140142
not_equal,
@@ -274,6 +276,8 @@
274276
"log1p",
275277
"log2",
276278
"log10",
279+
"maximum",
280+
"minimum",
277281
"multiply",
278282
"negative",
279283
"not_equal",

dpctl/tensor/_elementwise_funcs.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,6 +1176,66 @@
11761176
_logical_xor_docstring_,
11771177
)
11781178

1179+
# B??: ==== MAXIMUM (x1, x2)
1180+
_maximum_docstring_ = """
1181+
maximum(x1, x2, out=None, order='K')
1182+
1183+
Compares two input arrays `x1` and `x2` and returns
1184+
a new array containing the element-wise maxima.
1185+
1186+
Args:
1187+
x1 (usm_ndarray):
1188+
First input array, expected to have numeric data type.
1189+
x2 (usm_ndarray):
1190+
Second input array, also expected to have numeric data type.
1191+
out ({None, usm_ndarray}, optional):
1192+
Output array to populate.
1193+
Array have the correct shape and the expected data type.
1194+
order ("C","F","A","K", optional):
1195+
Memory layout of the newly output array, if parameter `out` is `None`.
1196+
Default: "K".
1197+
Returns:
1198+
usm_narray:
1199+
An array containing the element-wise products. The data type of
1200+
the returned array is determined by the Type Promotion Rules.
1201+
"""
1202+
maximum = BinaryElementwiseFunc(
1203+
"maximum",
1204+
ti._maximum_result_type,
1205+
ti._maximum,
1206+
_maximum_docstring_,
1207+
)
1208+
1209+
# B??: ==== MINIMUM (x1, x2)
1210+
_minimum_docstring_ = """
1211+
minimum(x1, x2, out=None, order='K')
1212+
1213+
Compares two input arrays `x1` and `x2` and returns
1214+
a new array containing the element-wise minima.
1215+
1216+
Args:
1217+
x1 (usm_ndarray):
1218+
First input array, expected to have numeric data type.
1219+
x2 (usm_ndarray):
1220+
Second input array, also expected to have numeric data type.
1221+
out ({None, usm_ndarray}, optional):
1222+
Output array to populate.
1223+
Array have the correct shape and the expected data type.
1224+
order ("C","F","A","K", optional):
1225+
Memory layout of the newly output array, if parameter `out` is `None`.
1226+
Default: "K".
1227+
Returns:
1228+
usm_narray:
1229+
An array containing the element-wise minima. The data type of
1230+
the returned array is determined by the Type Promotion Rules.
1231+
"""
1232+
minimum = BinaryElementwiseFunc(
1233+
"minimum",
1234+
ti._minimum_result_type,
1235+
ti._minimum,
1236+
_minimum_docstring_,
1237+
)
1238+
11791239
# B19: ==== MULTIPLY (x1, x2)
11801240
_multiply_docstring_ = """
11811241
multiply(x1, x2, out=None, order='K')
Lines changed: 284 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,284 @@
1+
//=== maximum.hpp - Binary function MAXIMUM ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of MAXIMUM(x1, x2)
23+
/// function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <CL/sycl.hpp>
28+
#include <cstddef>
29+
#include <cstdint>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/elementwise_functions/common.hpp"
37+
#include <pybind11/pybind11.h>
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace maximum
46+
{
47+
48+
namespace py = pybind11;
49+
namespace td_ns = dpctl::tensor::type_dispatch;
50+
namespace tu_ns = dpctl::tensor::type_utils;
51+
52+
template <typename argT1, typename argT2, typename resT> struct MaximumFunctor
53+
{
54+
55+
using supports_sg_loadstore = std::negation<
56+
std::disjunction<tu_ns::is_complex<argT1>, tu_ns::is_complex<argT2>>>;
57+
using supports_vec = std::conjunction<
58+
std::is_same<argT1, argT2>,
59+
std::negation<std::disjunction<tu_ns::is_complex<argT1>,
60+
tu_ns::is_complex<argT2>>>>;
61+
62+
resT operator()(const argT1 &in1, const argT2 &in2)
63+
{
64+
if constexpr (tu_ns::is_complex<argT1>::value ||
65+
tu_ns::is_complex<argT2>::value)
66+
{
67+
static_assert(std::is_same_v<argT1, argT2>);
68+
using realT = typename argT1::value_type;
69+
realT real1 = std::real(in1);
70+
realT real2 = std::real(in2);
71+
realT imag1 = std::imag(in1);
72+
realT imag2 = std::imag(in2);
73+
74+
if (std::isnan(real1) || std::isnan(imag1))
75+
return in1;
76+
else if (std::isnan(real2) || std::isnan(imag2))
77+
return in2;
78+
else if (real1 == real2)
79+
return imag1 > imag2 ? in1 : in2;
80+
else
81+
return real1 > real2 ? in1 : in2;
82+
}
83+
else {
84+
return (in1 != in1 || in1 > in2) ? in1 : in2;
85+
}
86+
}
87+
88+
template <int vec_sz>
89+
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT1, vec_sz> &in1,
90+
const sycl::vec<argT2, vec_sz> &in2)
91+
{
92+
sycl::vec<resT, vec_sz> res;
93+
#pragma unroll
94+
for (int i = 0; i < vec_sz; ++i) {
95+
res[i] = (in1[i] != in1[i] || in1[i] > in2[i]) ? in1[i] : in2[i];
96+
}
97+
return res;
98+
}
99+
};
100+
101+
template <typename argT1,
102+
typename argT2,
103+
typename resT,
104+
unsigned int vec_sz = 4,
105+
unsigned int n_vecs = 2>
106+
using MaximumContigFunctor =
107+
elementwise_common::BinaryContigFunctor<argT1,
108+
argT2,
109+
resT,
110+
MaximumFunctor<argT1, argT2, resT>,
111+
vec_sz,
112+
n_vecs>;
113+
114+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
115+
using MaximumStridedFunctor = elementwise_common::BinaryStridedFunctor<
116+
argT1,
117+
argT2,
118+
resT,
119+
IndexerT,
120+
MaximumFunctor<argT1, argT2, resT>>;
121+
122+
template <typename T1, typename T2> struct MaximumOutputType
123+
{
124+
using value_type = typename std::disjunction< // disjunction is C++17
125+
// feature, supported by DPC++
126+
td_ns::BinaryTypeMapResultEntry<T1, bool, T2, bool, bool>,
127+
td_ns::BinaryTypeMapResultEntry<T1,
128+
std::uint8_t,
129+
T2,
130+
std::uint8_t,
131+
std::uint8_t>,
132+
td_ns::BinaryTypeMapResultEntry<T1,
133+
std::int8_t,
134+
T2,
135+
std::int8_t,
136+
std::int8_t>,
137+
td_ns::BinaryTypeMapResultEntry<T1,
138+
std::uint16_t,
139+
T2,
140+
std::uint16_t,
141+
std::uint16_t>,
142+
td_ns::BinaryTypeMapResultEntry<T1,
143+
std::int16_t,
144+
T2,
145+
std::int16_t,
146+
std::int16_t>,
147+
td_ns::BinaryTypeMapResultEntry<T1,
148+
std::uint32_t,
149+
T2,
150+
std::uint32_t,
151+
std::uint32_t>,
152+
td_ns::BinaryTypeMapResultEntry<T1,
153+
std::int32_t,
154+
T2,
155+
std::int32_t,
156+
std::int32_t>,
157+
td_ns::BinaryTypeMapResultEntry<T1,
158+
std::uint64_t,
159+
T2,
160+
std::uint64_t,
161+
std::uint64_t>,
162+
td_ns::BinaryTypeMapResultEntry<T1,
163+
std::int64_t,
164+
T2,
165+
std::int64_t,
166+
std::int64_t>,
167+
td_ns::BinaryTypeMapResultEntry<T1,
168+
sycl::half,
169+
T2,
170+
sycl::half,
171+
sycl::half>,
172+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
173+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
174+
td_ns::BinaryTypeMapResultEntry<T1,
175+
std::complex<float>,
176+
T2,
177+
std::complex<float>,
178+
std::complex<float>>,
179+
td_ns::BinaryTypeMapResultEntry<T1,
180+
std::complex<double>,
181+
T2,
182+
std::complex<double>,
183+
std::complex<double>>,
184+
td_ns::DefaultResultEntry<void>>::result_type;
185+
};
186+
187+
template <typename argT1,
188+
typename argT2,
189+
typename resT,
190+
unsigned int vec_sz,
191+
unsigned int n_vecs>
192+
class maximum_contig_kernel;
193+
194+
template <typename argTy1, typename argTy2>
195+
sycl::event maximum_contig_impl(sycl::queue exec_q,
196+
size_t nelems,
197+
const char *arg1_p,
198+
py::ssize_t arg1_offset,
199+
const char *arg2_p,
200+
py::ssize_t arg2_offset,
201+
char *res_p,
202+
py::ssize_t res_offset,
203+
const std::vector<sycl::event> &depends = {})
204+
{
205+
return elementwise_common::binary_contig_impl<
206+
argTy1, argTy2, MaximumOutputType, MaximumContigFunctor,
207+
maximum_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
208+
arg2_offset, res_p, res_offset, depends);
209+
}
210+
211+
template <typename fnT, typename T1, typename T2> struct MaximumContigFactory
212+
{
213+
fnT get()
214+
{
215+
if constexpr (std::is_same_v<
216+
typename MaximumOutputType<T1, T2>::value_type, void>)
217+
{
218+
fnT fn = nullptr;
219+
return fn;
220+
}
221+
else {
222+
fnT fn = maximum_contig_impl<T1, T2>;
223+
return fn;
224+
}
225+
}
226+
};
227+
228+
template <typename fnT, typename T1, typename T2> struct MaximumTypeMapFactory
229+
{
230+
/*! @brief get typeid for output type of maximum(T1 x, T2 y) */
231+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
232+
{
233+
using rT = typename MaximumOutputType<T1, T2>::value_type;
234+
;
235+
return td_ns::GetTypeid<rT>{}.get();
236+
}
237+
};
238+
239+
template <typename T1, typename T2, typename resT, typename IndexerT>
240+
class maximum_strided_kernel;
241+
242+
template <typename argTy1, typename argTy2>
243+
sycl::event
244+
maximum_strided_impl(sycl::queue exec_q,
245+
size_t nelems,
246+
int nd,
247+
const py::ssize_t *shape_and_strides,
248+
const char *arg1_p,
249+
py::ssize_t arg1_offset,
250+
const char *arg2_p,
251+
py::ssize_t arg2_offset,
252+
char *res_p,
253+
py::ssize_t res_offset,
254+
const std::vector<sycl::event> &depends,
255+
const std::vector<sycl::event> &additional_depends)
256+
{
257+
return elementwise_common::binary_strided_impl<
258+
argTy1, argTy2, MaximumOutputType, MaximumStridedFunctor,
259+
maximum_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
260+
arg1_offset, arg2_p, arg2_offset, res_p,
261+
res_offset, depends, additional_depends);
262+
}
263+
264+
template <typename fnT, typename T1, typename T2> struct MaximumStridedFactory
265+
{
266+
fnT get()
267+
{
268+
if constexpr (std::is_same_v<
269+
typename MaximumOutputType<T1, T2>::value_type, void>)
270+
{
271+
fnT fn = nullptr;
272+
return fn;
273+
}
274+
else {
275+
fnT fn = maximum_strided_impl<T1, T2>;
276+
return fn;
277+
}
278+
}
279+
};
280+
281+
} // namespace maximum
282+
} // namespace kernels
283+
} // namespace tensor
284+
} // namespace dpctl

0 commit comments

Comments
 (0)