Skip to content

Commit 53c3268

Browse files
authored
[SYCL][CUDA][libclc] Implement nextafter for sycl::half in generic/. (#4939)
sycl::nextafter(half,half) was defaulting to sycl::nextafter(float,float) which does not return the next half. Software implementation written in libclc/generic and #included into ptx-nvidiacl.
1 parent 0c55d3a commit 53c3268

File tree

3 files changed

+70
-1
lines changed

3 files changed

+70
-1
lines changed
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef HALF_NEXTAFTER_INC
10+
#define HALF_NEXTAFTER_INC
11+
12+
#include <clcmacro.h>
13+
#include <math/math.h>
14+
#include <spirv/spirv.h>
15+
16+
#ifdef cl_khr_fp16
17+
18+
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
19+
20+
_CLC_OVERLOAD _CLC_DEF half __spirv_ocl_nextafter(half x, half y) {
21+
// NaNs
22+
if (x != x)
23+
return x;
24+
if (y != y)
25+
return y;
26+
// Parity
27+
if (x == y)
28+
return x;
29+
30+
short *a = (short *)&x;
31+
short *b = (short *)&y;
32+
// Checking for sign digit
33+
if (*a & 0x8000)
34+
*a = 0x8000 - *a;
35+
if (*b & 0x8000)
36+
*b = 0x8000 - *b;
37+
// Increment / decrement
38+
*a += (*a < *b) ? 1 : -1;
39+
// Undo the sign flip if necessary
40+
*a = (*a < 0) ? 0x8000 - *a : *a;
41+
return x;
42+
}
43+
44+
_CLC_BINARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_nextafter, half,
45+
half)
46+
47+
#endif
48+
49+
#endif

libclc/generic/libspirv/math/nextafter.cl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,5 @@ _CLC_DEFINE_BINARY_BUILTIN(double, __spirv_ocl_nextafter, __builtin_nextafter,
2020
double, double)
2121

2222
#endif
23+
24+
#include "half_nextafter.inc"

libclc/ptx-nvidiacl/libspirv/math/nextafter.cl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,28 @@
88

99
#include <spirv/spirv.h>
1010

11+
#include "utils.h"
1112
#include <../../include/libdevice.h>
1213
#include <clcmacro.h>
1314

1415
#define __CLC_FUNCTION __spirv_ocl_nextafter
1516
#define __CLC_BUILTIN __nv_nextafter
1617
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
17-
#include <math/binary_builtin.inc>
18+
#define __CLC_BUILTIN_D __CLC_BUILTIN
19+
20+
_CLC_DEFINE_BINARY_BUILTIN(float, __CLC_FUNCTION, __CLC_BUILTIN_F, float, float)
21+
22+
#ifndef __FLOAT_ONLY
23+
24+
#ifdef cl_khr_fp64
25+
26+
#pragma OPENCL EXTENSION cl_khr_fp64 : enable
27+
28+
_CLC_DEFINE_BINARY_BUILTIN(double, __CLC_FUNCTION, __CLC_BUILTIN_D, double,
29+
double)
30+
31+
#endif
32+
33+
#include "../../../generic/libspirv/math/half_nextafter.inc"
34+
35+
#endif

0 commit comments

Comments
 (0)