Skip to content

[SYCL-PTX] Add warp-reduce path in sub-group reduce #3949

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions libclc/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,9 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
if( ${d} STREQUAL "none" OR ${ARCH} STREQUAL "spirv" OR ${ARCH} STREQUAL "spirv64" )
# FIXME: Ideally we would not be tied to a specific PTX ISA version
if( ${ARCH} STREQUAL nvptx OR ${ARCH} STREQUAL nvptx64 )
set( flags "SHELL:-Xclang -target-feature" "SHELL:-Xclang +ptx64")
# Disables NVVM reflection to defer to after linking
set( flags "SHELL:-Xclang -target-feature" "SHELL:-Xclang +ptx72"
"SHELL:-march=sm_86" "SHELL:-mllvm --nvvm-reflect-enable=false")
endif()
set( arch_suffix "${t}" )
else()
Expand All @@ -327,12 +329,15 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
set( t "spir64--" )
endif()
set( build_flags -O0 -finline-hint-functions )
set( opt_flags )
set( opt_flags -O3 )
set( spvflags --spirv-max-version=1.1 )
elseif( ${ARCH} STREQUAL "clspv" )
set( t "spir--" )
set( build_flags )
set( opt_flags -O3 )
elseif( ${ARCH} STREQUAL "nvptx" OR ${ARCH} STREQUAL "nvptx64" )
set( build_flags )
set( opt_flags -O3 "--nvvm-reflect-enable=false" )
else()
set( build_flags )
set( opt_flags -O3 )
Expand All @@ -342,6 +347,7 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
TRIPLE ${t}
TARGET_ENV libspirv
COMPILE_OPT ${flags}
OPT_FLAGS ${opt_flags}
FILES ${libspirv_files}
ALIASES ${${d}_aliases}
GENERATE_TARGET "generate_convert_spirv.cl" "generate_convert_core.cl"
Expand All @@ -351,6 +357,7 @@ foreach( t ${LIBCLC_TARGETS_TO_BUILD} )
TRIPLE ${t}
TARGET_ENV clc
COMPILE_OPT ${flags}
OPT_FLAGS ${opt_flags}
FILES ${lib_files}
LIB_DEP libspirv-${arch_suffix}
ALIASES ${${d}_aliases}
Expand Down
4 changes: 2 additions & 2 deletions libclc/cmake/modules/AddLibclc.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ macro(add_libclc_builtin_set arch_suffix)
cmake_parse_arguments(ARG
""
"TRIPLE;TARGET_ENV;LIB_DEP;PARENT_TARGET"
"FILES;ALIASES;GENERATE_TARGET;COMPILE_OPT"
"FILES;ALIASES;GENERATE_TARGET;COMPILE_OPT;OPT_FLAGS"
${ARGN})

if (DEFINED ${ARG_LIB_DEP})
Expand Down Expand Up @@ -76,7 +76,7 @@ macro(add_libclc_builtin_set arch_suffix)
# Add opt target
set( builtins_opt_path "${LIBCLC_LIBRARY_OUTPUT_INTDIR}/builtins.opt.${obj_suffix}" )
add_custom_command( OUTPUT "${builtins_opt_path}"
COMMAND ${LLVM_OPT} -O3 -o
COMMAND ${LLVM_OPT} ${ARG_OPT_FLAGS} -o
"${builtins_opt_path}"
"${LIBCLC_LIBRARY_OUTPUT_INTDIR}/builtins.link.${obj_suffix}"
DEPENDS opt "builtins.link.${arch_suffix}" )
Expand Down
79 changes: 49 additions & 30 deletions libclc/ptx-nvidiacl/libspirv/group/collectives.cl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#pragma OPENCL EXTENSION cl_khr_fp64 : enable

int __nvvm_reflect(const char __constant *);

// CLC helpers
__local bool *
__clc__get_group_scratch_bool() __asm("__clc__get_group_scratch_bool");
Expand Down Expand Up @@ -150,43 +152,58 @@ __clc__SubgroupBitwiseAny(uint op, bool predicate, bool *carry) {
#define __CLC_OR(x, y) (x | y)
#define __CLC_AND(x, y) (x & y)

#define __CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
/* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
for (int o = 1; o < __spirv_SubgroupMaxSize(); o *= 2) { \
TYPE contribution = __clc__SubgroupShuffleUp(x, o); \
bool inactive = (sg_lid < o); \
contribution = (inactive) ? IDENTITY : contribution; \
x = OP(x, contribution); \
} \
/* For Reduce, broadcast result from highest active lane */ \
TYPE result; \
if (op == Reduce) { \
result = __clc__SubgroupShuffle(x, __spirv_SubgroupSize() - 1); \
*carry = result; \
} /* For InclusiveScan, use results as computed */ \
else if (op == InclusiveScan) { \
result = x; \
*carry = result; \
} /* For ExclusiveScan, shift and prepend identity */ \
else if (op == ExclusiveScan) { \
*carry = x; \
result = __clc__SubgroupShuffleUp(x, 1); \
if (sg_lid == 0) { \
result = IDENTITY; \
} \
} \
return result;

#define __CLC_SUBGROUP_COLLECTIVE(NAME, OP, TYPE, IDENTITY) \
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
__clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
uint sg_lid = __spirv_SubgroupLocalInvocationId(); \
/* Can't use XOR/butterfly shuffles; some lanes may be inactive */ \
for (int o = 1; o < __spirv_SubgroupMaxSize(); o *= 2) { \
TYPE contribution = __clc__SubgroupShuffleUp(x, o); \
bool inactive = (sg_lid < o); \
contribution = (inactive) ? IDENTITY : contribution; \
x = OP(x, contribution); \
} \
/* For Reduce, broadcast result from highest active lane */ \
TYPE result; \
if (op == Reduce) { \
result = __clc__SubgroupShuffle(x, __spirv_SubgroupSize() - 1); \
*carry = result; \
} /* For InclusiveScan, use results as computed */ \
else if (op == InclusiveScan) { \
result = x; \
__CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
}

#define __CLC_SUBGROUP_COLLECTIVE_REDUX(NAME, OP, REDUX_OP, TYPE, IDENTITY) \
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
__clc__Subgroup, NAME)(uint op, TYPE x, TYPE * carry) { \
/* Fast path for warp reductions for sm_80+ */ \
if (__nvvm_reflect("__CUDA_ARCH") >= 800 && op == Reduce) { \
TYPE result = __nvvm_redux_sync_##REDUX_OP(x, __clc__membermask()); \
*carry = result; \
} /* For ExclusiveScan, shift and prepend identity */ \
else if (op == ExclusiveScan) { \
*carry = x; \
result = __clc__SubgroupShuffleUp(x, 1); \
if (sg_lid == 0) { \
result = IDENTITY; \
} \
return result; \
} \
return result; \
__CLC_SUBGROUP_COLLECTIVE_BODY(OP, TYPE, IDENTITY) \
}

__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, char, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, uchar, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, short, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, ushort, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, int, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, uint, 0)
__CLC_SUBGROUP_COLLECTIVE_REDUX(IAdd, __CLC_ADD, add, int, 0)
__CLC_SUBGROUP_COLLECTIVE_REDUX(IAdd, __CLC_ADD, add, uint, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, long, 0)
__CLC_SUBGROUP_COLLECTIVE(IAdd, __CLC_ADD, ulong, 0)
__CLC_SUBGROUP_COLLECTIVE(FAdd, __CLC_ADD, half, 0)
Expand All @@ -197,8 +214,8 @@ __CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, char, CHAR_MAX)
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, uchar, UCHAR_MAX)
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, short, SHRT_MAX)
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, ushort, USHRT_MAX)
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, int, INT_MAX)
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, uint, UINT_MAX)
__CLC_SUBGROUP_COLLECTIVE_REDUX(SMin, __CLC_MIN, min, int, INT_MAX)
__CLC_SUBGROUP_COLLECTIVE_REDUX(UMin, __CLC_MIN, umin, uint, UINT_MAX)
__CLC_SUBGROUP_COLLECTIVE(SMin, __CLC_MIN, long, LONG_MAX)
__CLC_SUBGROUP_COLLECTIVE(UMin, __CLC_MIN, ulong, ULONG_MAX)
__CLC_SUBGROUP_COLLECTIVE(FMin, __CLC_MIN, half, HALF_MAX)
Expand All @@ -209,15 +226,17 @@ __CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, char, CHAR_MIN)
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, uchar, 0)
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, short, SHRT_MIN)
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, ushort, 0)
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, int, INT_MIN)
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, uint, 0)
__CLC_SUBGROUP_COLLECTIVE_REDUX(SMax, __CLC_MAX, max, int, INT_MIN)
__CLC_SUBGROUP_COLLECTIVE_REDUX(UMax, __CLC_MAX, umax, uint, 0)
__CLC_SUBGROUP_COLLECTIVE(SMax, __CLC_MAX, long, LONG_MIN)
__CLC_SUBGROUP_COLLECTIVE(UMax, __CLC_MAX, ulong, 0)
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, half, -HALF_MAX)
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, float, -FLT_MAX)
__CLC_SUBGROUP_COLLECTIVE(FMax, __CLC_MAX, double, -DBL_MAX)

#undef __CLC_SUBGROUP_COLLECTIVE_BODY
#undef __CLC_SUBGROUP_COLLECTIVE
#undef __CLC_SUBGROUP_COLLECTIVE_REDUX

#define __CLC_GROUP_COLLECTIVE(NAME, OP, TYPE, IDENTITY) \
_CLC_DEF _CLC_OVERLOAD _CLC_CONVERGENT TYPE __CLC_APPEND( \
Expand Down