Skip to content

[libclc] Move mad_sat to CLC; optimize for vector types #125517

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions libclc/clc/include/clc/integer/clc_mad_sat.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#ifndef __CLC_INTEGER_CLC_MAD_SAT_H__
#define __CLC_INTEGER_CLC_MAD_SAT_H__

#define __CLC_FUNCTION __clc_mad_sat
#define __CLC_BODY <clc/shared/ternary_decl.inc>

#include <clc/integer/gentype.inc>

#undef __CLC_BODY
#undef __CLC_FUNCTION

#endif // __CLC_INTEGER_CLC_MAD_SAT_H__
4 changes: 4 additions & 0 deletions libclc/clc/include/clc/integer/definitions.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,12 @@
#define SHRT_MAX 32767
#define SHRT_MIN (-32767 - 1)
#define UCHAR_MAX 255
#define UCHAR_MIN 0
#define USHRT_MAX 65535
#define USHRT_MIN 0
#define UINT_MAX 0xffffffff
#define UINT_MIN 0
#define ULONG_MAX 0xffffffffffffffffUL
#define ULONG_MIN 0UL

#endif // __CLC_INTEGER_DEFINITIONS_H__
1 change: 1 addition & 0 deletions libclc/clc/lib/clspv/SOURCES
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
../generic/integer/clc_clz.cl
../generic/integer/clc_hadd.cl
../generic/integer/clc_mad24.cl
../generic/integer/clc_mad_sat.cl
../generic/integer/clc_mul24.cl
../generic/integer/clc_mul_hi.cl
../generic/integer/clc_popcount.cl
Expand Down
1 change: 1 addition & 0 deletions libclc/clc/lib/generic/SOURCES
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ integer/clc_add_sat.cl
integer/clc_clz.cl
integer/clc_hadd.cl
integer/clc_mad24.cl
integer/clc_mad_sat.cl
integer/clc_mul24.cl
integer/clc_mul_hi.cl
integer/clc_popcount.cl
Expand Down
119 changes: 119 additions & 0 deletions libclc/clc/lib/generic/integer/clc_mad_sat.cl
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
#include <clc/clcmacro.h>
#include <clc/integer/clc_add_sat.h>
#include <clc/integer/clc_mad24.h>
#include <clc/integer/clc_mul_hi.h>
#include <clc/integer/clc_upsample.h>
#include <clc/integer/definitions.h>
#include <clc/internal/clc.h>
#include <clc/relational/clc_select.h>
#include <clc/shared/clc_clamp.h>

#define __CLC_CONVERT_TY(X, TY) __builtin_convertvector(X, TY)

// Macro for defining mad_sat variants for char/uchar/short/ushort
// FIXME: Once using __clc_convert_ty, can easily unify scalar and vector defs
#define __CLC_DEFINE_SIMPLE_MAD_SAT(TYPE, UP_TYPE, LIT_PREFIX) \
_CLC_OVERLOAD _CLC_DEF TYPE __clc_mad_sat(TYPE x, TYPE y, TYPE z) { \
return __clc_clamp( \
(UP_TYPE)__clc_mad24((UP_TYPE)x, (UP_TYPE)y, (UP_TYPE)z), \
(UP_TYPE)LIT_PREFIX##_MIN, (UP_TYPE)LIT_PREFIX##_MAX); \
}

#define __CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE, UP_TYPE, LIT_PREFIX) \
_CLC_OVERLOAD _CLC_DEF TYPE __clc_mad_sat(TYPE x, TYPE y, TYPE z) { \
UP_TYPE upscaled_mad = __clc_mad24(__CLC_CONVERT_TY(x, UP_TYPE), \
__CLC_CONVERT_TY(y, UP_TYPE), \
__CLC_CONVERT_TY(z, UP_TYPE)); \
UP_TYPE clamped_mad = __clc_clamp(upscaled_mad, (UP_TYPE)LIT_PREFIX##_MIN, \
(UP_TYPE)LIT_PREFIX##_MAX); \
return __CLC_CONVERT_TY(clamped_mad, TYPE); \
}

#define __CLC_DEFINE_SIMPLE_MAD_SAT_ALL_TYS(TYPE, UP_TYPE, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT(TYPE, UP_TYPE, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE##2, UP_TYPE##2, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE##3, UP_TYPE##3, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE##4, UP_TYPE##4, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE##8, UP_TYPE##8, LIT_PREFIX) \
__CLC_DEFINE_SIMPLE_MAD_SAT_VEC(TYPE##16, UP_TYPE##16, LIT_PREFIX)

__CLC_DEFINE_SIMPLE_MAD_SAT_ALL_TYS(char, int, CHAR)
__CLC_DEFINE_SIMPLE_MAD_SAT_ALL_TYS(uchar, uint, UCHAR)
__CLC_DEFINE_SIMPLE_MAD_SAT_ALL_TYS(short, int, SHRT)
__CLC_DEFINE_SIMPLE_MAD_SAT_ALL_TYS(ushort, uint, USHRT)

// Macro for defining mad_sat variants for uint/ulong
#define __CLC_DEFINE_UINTLONG_MAD_SAT(UTYPE, STYPE, ULIT_PREFIX) \
_CLC_OVERLOAD _CLC_DEF UTYPE __clc_mad_sat(UTYPE x, UTYPE y, UTYPE z) { \
STYPE has_mul_hi = __clc_mul_hi(x, y) != (UTYPE)0; \
return __clc_select(__clc_add_sat(x * y, z), (UTYPE)ULIT_PREFIX##_MAX, \
has_mul_hi); \
}

#define __CLC_DEFINE_UINTLONG_MAD_SAT_ALL_TYS(UTY, STY, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY, STY, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY##2, STY##2, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY##3, STY##3, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY##4, STY##4, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY##8, STY##8, ULIT_PREFIX) \
__CLC_DEFINE_UINTLONG_MAD_SAT(UTY##16, STY##16, ULIT_PREFIX)

__CLC_DEFINE_UINTLONG_MAD_SAT_ALL_TYS(uint, int, UINT)
__CLC_DEFINE_UINTLONG_MAD_SAT_ALL_TYS(ulong, long, ULONG)

// Macro for defining mad_sat variants for int
#define __CLC_DEFINE_SINT_MAD_SAT(INTTY, UINTTY, SLONGTY) \
_CLC_OVERLOAD _CLC_DEF INTTY __clc_mad_sat(INTTY x, INTTY y, INTTY z) { \
INTTY mhi = __clc_mul_hi(x, y); \
UINTTY mlo = __clc_as_##UINTTY(x * y); \
SLONGTY m = __clc_upsample(mhi, mlo); \
m += __CLC_CONVERT_TY(z, SLONGTY); \
m = __clc_clamp(m, (SLONGTY)INT_MIN, (SLONGTY)INT_MAX); \
return __CLC_CONVERT_TY(m, INTTY); \
}

// FIXME: Once using __clc_convert_ty, can easily unify scalar and vector defs
#define __CLC_DEFINE_SINT_MAD_SAT_ALL_TYS(INTTY, UINTTY, SLONGTY) \
_CLC_OVERLOAD _CLC_DEF INTTY __clc_mad_sat(INTTY x, INTTY y, INTTY z) { \
INTTY mhi = __clc_mul_hi(x, y); \
UINTTY mlo = __clc_as_##UINTTY(x * y); \
SLONGTY m = __clc_upsample(mhi, mlo); \
m += z; \
return __clc_clamp(m, (SLONGTY)INT_MIN, (SLONGTY)INT_MAX); \
} \
__CLC_DEFINE_SINT_MAD_SAT(INTTY##2, UINTTY##2, SLONGTY##2) \
__CLC_DEFINE_SINT_MAD_SAT(INTTY##3, UINTTY##3, SLONGTY##3) \
__CLC_DEFINE_SINT_MAD_SAT(INTTY##4, UINTTY##4, SLONGTY##4) \
__CLC_DEFINE_SINT_MAD_SAT(INTTY##8, UINTTY##8, SLONGTY##8) \
__CLC_DEFINE_SINT_MAD_SAT(INTTY##16, UINTTY##16, SLONGTY##16)

__CLC_DEFINE_SINT_MAD_SAT_ALL_TYS(int, uint, long)

// Macro for defining mad_sat variants for long
#define __CLC_DEFINE_SLONG_MAD_SAT(SLONGTY, ULONGTY) \
_CLC_OVERLOAD _CLC_DEF SLONGTY __clc_mad_sat(SLONGTY x, SLONGTY y, \
SLONGTY z) { \
SLONGTY hi = __clc_mul_hi(x, y); \
ULONGTY ulo = __clc_as_##ULONGTY(x * y); \
SLONGTY max1 = (x < 0) == (y < 0) && hi != 0; \
SLONGTY max2 = hi == 0 && ulo >= LONG_MAX && \
(z > 0 || (ulo + __clc_as_##ULONGTY(z)) > LONG_MAX); \
SLONGTY min1 = (((x < 0) != (y < 0)) && hi != -1); \
SLONGTY min2 = \
hi == -1 && ulo <= ((ULONGTY)LONG_MAX + 1UL) && \
(z < 0 || __clc_as_##ULONGTY(z) < ((ULONGTY)LONG_MAX - ulo)); \
SLONGTY ret = __clc_as_##SLONGTY(ulo + __clc_as_##ULONGTY(z)); \
ret = __clc_select(ret, (SLONGTY)LONG_MAX, (SLONGTY)(max1 || max2)); \
ret = __clc_select(ret, (SLONGTY)LONG_MIN, (SLONGTY)(min1 || min2)); \
return ret; \
}

#define __CLC_DEFINE_SLONG_MAD_SAT_ALL_TYS(SLONGTY, ULONGTY) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY, ULONGTY) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY##2, ULONGTY##2) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY##3, ULONGTY##3) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY##4, ULONGTY##4) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY##8, ULONGTY##8) \
__CLC_DEFINE_SLONG_MAD_SAT(SLONGTY##16, ULONGTY##16)

__CLC_DEFINE_SLONG_MAD_SAT_ALL_TYS(long, ulong)
1 change: 1 addition & 0 deletions libclc/clc/lib/spirv/SOURCES
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
../generic/integer/clc_clz.cl
../generic/integer/clc_hadd.cl
../generic/integer/clc_mad24.cl
../generic/integer/clc_mad_sat.cl
../generic/integer/clc_mul24.cl
../generic/integer/clc_mul_hi.cl
../generic/integer/clc_popcount.cl
Expand Down
73 changes: 4 additions & 69 deletions libclc/generic/lib/integer/mad_sat.cl
Original file line number Diff line number Diff line change
@@ -1,72 +1,7 @@
#include <clc/clc.h>
#include <clc/clcmacro.h>
#include <clc/integer/clc_mad_sat.h>

_CLC_OVERLOAD _CLC_DEF char mad_sat(char x, char y, char z) {
return clamp((short)mad24((short)x, (short)y, (short)z), (short)CHAR_MIN, (short) CHAR_MAX);
}
#define FUNCTION mad_sat
#define __CLC_BODY <clc/shared/ternary_def.inc>

_CLC_OVERLOAD _CLC_DEF uchar mad_sat(uchar x, uchar y, uchar z) {
return clamp((ushort)mad24((ushort)x, (ushort)y, (ushort)z), (ushort)0, (ushort) UCHAR_MAX);
}

_CLC_OVERLOAD _CLC_DEF short mad_sat(short x, short y, short z) {
return clamp((int)mad24((int)x, (int)y, (int)z), (int)SHRT_MIN, (int) SHRT_MAX);
}

_CLC_OVERLOAD _CLC_DEF ushort mad_sat(ushort x, ushort y, ushort z) {
return clamp((uint)mad24((uint)x, (uint)y, (uint)z), (uint)0, (uint) USHRT_MAX);
}

_CLC_OVERLOAD _CLC_DEF int mad_sat(int x, int y, int z) {
int mhi = mul_hi(x, y);
uint mlo = x * y;
long m = upsample(mhi, mlo);
m += z;
if (m > INT_MAX)
return INT_MAX;
if (m < INT_MIN)
return INT_MIN;
return m;
}

_CLC_OVERLOAD _CLC_DEF uint mad_sat(uint x, uint y, uint z) {
if (mul_hi(x, y) != 0)
return UINT_MAX;
return add_sat(x * y, z);
}

_CLC_OVERLOAD _CLC_DEF long mad_sat(long x, long y, long z) {
long hi = mul_hi(x, y);
ulong ulo = x * y;
long slo = x * y;
/* Big overflow of more than 2 bits, add can't fix this */
if (((x < 0) == (y < 0)) && hi != 0)
return LONG_MAX;
/* Low overflow in mul and z not neg enough to correct it */
if (hi == 0 && ulo >= LONG_MAX && (z > 0 || (ulo + z) > LONG_MAX))
return LONG_MAX;
/* Big overflow of more than 2 bits, add can't fix this */
if (((x < 0) != (y < 0)) && hi != -1)
return LONG_MIN;
/* Low overflow in mul and z not pos enough to correct it */
if (hi == -1 && ulo <= ((ulong)LONG_MAX + 1UL) && (z < 0 || z < (LONG_MAX - ulo)))
return LONG_MIN;
/* We have checked all conditions, any overflow in addition returns
* the correct value */
return ulo + z;
}

_CLC_OVERLOAD _CLC_DEF ulong mad_sat(ulong x, ulong y, ulong z) {
if (mul_hi(x, y) != 0)
return ULONG_MAX;
return add_sat(x * y, z);
}

_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, char, mad_sat, char, char, char)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uchar, mad_sat, uchar, uchar, uchar)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, short, mad_sat, short, short, short)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ushort, mad_sat, ushort, ushort, ushort)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, int, mad_sat, int, int, int)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, uint, mad_sat, uint, uint, uint)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, long, mad_sat, long, long, long)
_CLC_TERNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, ulong, mad_sat, ulong, ulong, ulong)
#include <clc/integer/gentype.inc>