Skip to content

Commit 477946e

Browse files
Merge pull request #1730 from IntelPython/dpctl-tensor-nextafter
Implements `dpctl.tensor.nextafter` per array API
2 parents 28a231e + e5f9810 commit 477946e

File tree

9 files changed

+605
-0
lines changed

9 files changed

+605
-0
lines changed

dpctl/tensor/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ set(_elementwise_sources
7676
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/minimum.cpp
7777
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/multiply.cpp
7878
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/negative.cpp
79+
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/nextafter.cpp
7980
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/not_equal.cpp
8081
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/positive.cpp
8182
${CMAKE_CURRENT_SOURCE_DIR}/libtensor/source/elementwise_functions/pow.cpp

dpctl/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,7 @@
152152
minimum,
153153
multiply,
154154
negative,
155+
nextafter,
155156
not_equal,
156157
positive,
157158
pow,
@@ -371,4 +372,5 @@
371372
"cumulative_logsumexp",
372373
"cumulative_prod",
373374
"cumulative_sum",
375+
"nextafter",
374376
]

dpctl/tensor/_elementwise_funcs.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1487,6 +1487,40 @@
14871487
)
14881488
del _negative_docstring_
14891489

1490+
# B28: ==== NEXTAFTER (x1, x2)
1491+
_nextafter_docstring_ = r"""
1492+
nextafter(x1, x2, /, \*, out=None, order='K')
1493+
1494+
Calculates the next floating-point value after element `x1_i` of the input
1495+
array `x1` toward the respective element `x2_i` of the input array `x2`.
1496+
1497+
Args:
1498+
x1 (usm_ndarray):
1499+
First input array.
1500+
x2 (usm_ndarray):
1501+
Second input array.
1502+
out (Union[usm_ndarray, None], optional):
1503+
Output array to populate.
1504+
Array must have the correct shape and the expected data type.
1505+
order ("C","F","A","K", optional):
1506+
Memory layout of the new output array, if parameter
1507+
`out` is ``None``.
1508+
Default: "K".
1509+
1510+
Returns:
1511+
usm_ndarray:
1512+
An array containing the element-wise next representable values of `x1`
1513+
in the direction of `x2`. The data type of the returned array is
1514+
determined by the Type Promotion Rules.
1515+
"""
1516+
nextafter = BinaryElementwiseFunc(
1517+
"nextafter",
1518+
ti._nextafter_result_type,
1519+
ti._nextafter,
1520+
_nextafter_docstring_,
1521+
)
1522+
del _nextafter_docstring_
1523+
14901524
# B20: ==== NOT_EQUAL (x1, x2)
14911525
_not_equal_docstring_ = r"""
14921526
not_equal(x1, x2, /, \*, out=None, order='K')
Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
//=== NEXTAFTER.hpp - Binary function NEXTAFTER ------ *-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2024 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===---------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for elementwise evaluation of NEXTAFTER(x1, x2)
23+
/// function.
24+
//===---------------------------------------------------------------------===//
25+
26+
#pragma once
27+
#include <cstddef>
28+
#include <cstdint>
29+
#include <sycl/sycl.hpp>
30+
#include <type_traits>
31+
32+
#include "utils/offset_utils.hpp"
33+
#include "utils/type_dispatch_building.hpp"
34+
#include "utils/type_utils.hpp"
35+
36+
#include "kernels/dpctl_tensor_types.hpp"
37+
#include "kernels/elementwise_functions/common.hpp"
38+
39+
namespace dpctl
40+
{
41+
namespace tensor
42+
{
43+
namespace kernels
44+
{
45+
namespace nextafter
46+
{
47+
48+
namespace td_ns = dpctl::tensor::type_dispatch;
49+
namespace tu_ns = dpctl::tensor::type_utils;
50+
51+
template <typename argT1, typename argT2, typename resT> struct NextafterFunctor
52+
{
53+
54+
using supports_sg_loadstore = std::true_type;
55+
using supports_vec = std::true_type;
56+
57+
resT operator()(const argT1 &in1, const argT2 &in2) const
58+
{
59+
return sycl::nextafter(in1, in2);
60+
}
61+
62+
template <int vec_sz>
63+
sycl::vec<resT, vec_sz>
64+
operator()(const sycl::vec<argT1, vec_sz> &in1,
65+
const sycl::vec<argT2, vec_sz> &in2) const
66+
{
67+
auto res = sycl::nextafter(in1, in2);
68+
if constexpr (std::is_same_v<resT,
69+
typename decltype(res)::element_type>) {
70+
return res;
71+
}
72+
else {
73+
using dpctl::tensor::type_utils::vec_cast;
74+
75+
return vec_cast<resT, typename decltype(res)::element_type, vec_sz>(
76+
res);
77+
}
78+
}
79+
};
80+
81+
template <typename argT1,
82+
typename argT2,
83+
typename resT,
84+
unsigned int vec_sz = 4,
85+
unsigned int n_vecs = 2,
86+
bool enable_sg_loadstore = true>
87+
using NextafterContigFunctor = elementwise_common::BinaryContigFunctor<
88+
argT1,
89+
argT2,
90+
resT,
91+
NextafterFunctor<argT1, argT2, resT>,
92+
vec_sz,
93+
n_vecs,
94+
enable_sg_loadstore>;
95+
96+
template <typename argT1, typename argT2, typename resT, typename IndexerT>
97+
using NextafterStridedFunctor = elementwise_common::BinaryStridedFunctor<
98+
argT1,
99+
argT2,
100+
resT,
101+
IndexerT,
102+
NextafterFunctor<argT1, argT2, resT>>;
103+
104+
template <typename T1, typename T2> struct NextafterOutputType
105+
{
106+
using value_type = typename std::disjunction< // disjunction is C++17
107+
// feature, supported by DPC++
108+
td_ns::BinaryTypeMapResultEntry<T1,
109+
sycl::half,
110+
T2,
111+
sycl::half,
112+
sycl::half>,
113+
td_ns::BinaryTypeMapResultEntry<T1, float, T2, float, float>,
114+
td_ns::BinaryTypeMapResultEntry<T1, double, T2, double, double>,
115+
td_ns::DefaultResultEntry<void>>::result_type;
116+
};
117+
118+
template <typename argT1,
119+
typename argT2,
120+
typename resT,
121+
unsigned int vec_sz,
122+
unsigned int n_vecs>
123+
class nextafter_contig_kernel;
124+
125+
template <typename argTy1, typename argTy2>
126+
sycl::event nextafter_contig_impl(sycl::queue &exec_q,
127+
size_t nelems,
128+
const char *arg1_p,
129+
ssize_t arg1_offset,
130+
const char *arg2_p,
131+
ssize_t arg2_offset,
132+
char *res_p,
133+
ssize_t res_offset,
134+
const std::vector<sycl::event> &depends = {})
135+
{
136+
return elementwise_common::binary_contig_impl<
137+
argTy1, argTy2, NextafterOutputType, NextafterContigFunctor,
138+
nextafter_contig_kernel>(exec_q, nelems, arg1_p, arg1_offset, arg2_p,
139+
arg2_offset, res_p, res_offset, depends);
140+
}
141+
142+
template <typename fnT, typename T1, typename T2> struct NextafterContigFactory
143+
{
144+
fnT get()
145+
{
146+
if constexpr (std::is_same_v<
147+
typename NextafterOutputType<T1, T2>::value_type,
148+
void>)
149+
{
150+
fnT fn = nullptr;
151+
return fn;
152+
}
153+
else {
154+
fnT fn = nextafter_contig_impl<T1, T2>;
155+
return fn;
156+
}
157+
}
158+
};
159+
160+
template <typename fnT, typename T1, typename T2> struct NextafterTypeMapFactory
161+
{
162+
/*! @brief get typeid for output type of std::nextafter(T1 x, T2 y) */
163+
std::enable_if_t<std::is_same<fnT, int>::value, int> get()
164+
{
165+
using rT = typename NextafterOutputType<T1, T2>::value_type;
166+
;
167+
return td_ns::GetTypeid<rT>{}.get();
168+
}
169+
};
170+
171+
template <typename T1, typename T2, typename resT, typename IndexerT>
172+
class nextafter_strided_kernel;
173+
174+
template <typename argTy1, typename argTy2>
175+
sycl::event
176+
nextafter_strided_impl(sycl::queue &exec_q,
177+
size_t nelems,
178+
int nd,
179+
const ssize_t *shape_and_strides,
180+
const char *arg1_p,
181+
ssize_t arg1_offset,
182+
const char *arg2_p,
183+
ssize_t arg2_offset,
184+
char *res_p,
185+
ssize_t res_offset,
186+
const std::vector<sycl::event> &depends,
187+
const std::vector<sycl::event> &additional_depends)
188+
{
189+
return elementwise_common::binary_strided_impl<
190+
argTy1, argTy2, NextafterOutputType, NextafterStridedFunctor,
191+
nextafter_strided_kernel>(exec_q, nelems, nd, shape_and_strides, arg1_p,
192+
arg1_offset, arg2_p, arg2_offset, res_p,
193+
res_offset, depends, additional_depends);
194+
}
195+
196+
template <typename fnT, typename T1, typename T2> struct NextafterStridedFactory
197+
{
198+
fnT get()
199+
{
200+
if constexpr (std::is_same_v<
201+
typename NextafterOutputType<T1, T2>::value_type,
202+
void>)
203+
{
204+
fnT fn = nullptr;
205+
return fn;
206+
}
207+
else {
208+
fnT fn = nextafter_strided_impl<T1, T2>;
209+
return fn;
210+
}
211+
}
212+
};
213+
214+
} // namespace nextafter
215+
} // namespace kernels
216+
} // namespace tensor
217+
} // namespace dpctl

dpctl/tensor/libtensor/source/elementwise_functions/elementwise_common.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#include "minimum.hpp"
7676
#include "multiply.hpp"
7777
#include "negative.hpp"
78+
#include "nextafter.hpp"
7879
#include "not_equal.hpp"
7980
#include "positive.hpp"
8081
#include "pow.hpp"
@@ -158,6 +159,7 @@ void init_elementwise_functions(py::module_ m)
158159
init_maximum(m);
159160
init_minimum(m);
160161
init_multiply(m);
162+
init_nextafter(m);
161163
init_negative(m);
162164
init_not_equal(m);
163165
init_positive(m);

0 commit comments

Comments
 (0)