Skip to content

Commit 1fb78eb

Browse files
aratajewsys_zuul
authored and
sys_zuul
committed
Emulation of 64-bit subgroup arithmetic and clustered non-uniform functions for
platforms not supporting 64-bit type natively. Change-Id: If138189b5e343620b4b78c716e234e33f773828e
1 parent 79ea4a6 commit 1fb78eb

File tree

1 file changed

+135
-18
lines changed

1 file changed

+135
-18
lines changed

IGC/BiFModule/Implementation/group.cl

Lines changed: 135 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,7 +1725,7 @@ uint __builtin_spirv_OpGroupNonUniformBallotFindMSB_i32_v4i32(uint Execution, ui
17251725
{
17261726
if (Execution == Subgroup)
17271727
{
1728-
return __builtin_spirv_OpenCL_clz_i32(Value.x);
1728+
return (sizeof(uint) * 8) -__builtin_spirv_OpenCL_clz_i32(Value.x);
17291729
}
17301730
return 0;
17311731
}
@@ -2083,6 +2083,117 @@ DEFN_UNIFORM_GROUP_FUNC(SMax, int, i32, __builtin_spirv_OpenCL_s_max_i32_i32,
20832083
DEFN_UNIFORM_GROUP_FUNC(SMax, long, i64, __builtin_spirv_OpenCL_s_max_i64_i64, LONG_MIN)
20842084

20852085
#if defined(cl_khr_subgroup_non_uniform_arithmetic) || defined(cl_khr_subgroup_clustered_reduce)
2086+
#define DEFN_SUB_GROUP_REDUCE_NON_UNIFORM(type, type_abbr, op, identity, X) \
2087+
{ \
2088+
uint activeChannels = __builtin_IB_WaveBallot(true); \
2089+
uint firstActive = __builtin_spirv_OpenCL_ctz_i32(activeChannels); \
2090+
\
2091+
type result = identity; \
2092+
while (activeChannels) \
2093+
{ \
2094+
uint activeId = __builtin_spirv_OpenCL_ctz_i32(activeChannels); \
2095+
\
2096+
type value = intel_sub_group_shuffle(X, activeId); \
2097+
result = op(value, result); \
2098+
\
2099+
uint disable = 1 << activeId; \
2100+
activeChannels ^= disable; \
2101+
} \
2102+
\
2103+
uint3 vec3; \
2104+
vec3.s0 = firstActive; \
2105+
X = __builtin_spirv_OpGroupBroadcast_i32_##type_abbr##_v3i32(Subgroup, result, vec3); \
2106+
}
2107+
2108+
#define DEFN_SUB_GROUP_SCAN_INCL_NON_UNIFORM(type, type_abbr, op, identity, X) \
2109+
{ \
2110+
uint sglid = __builtin_spirv_BuiltInSubgroupLocalInvocationId(); \
2111+
uint activeChannels = __builtin_IB_WaveBallot(true); \
2112+
\
2113+
while (activeChannels) \
2114+
{ \
2115+
uint activeId = __builtin_spirv_OpenCL_ctz_i32(activeChannels); \
2116+
\
2117+
type value = intel_sub_group_shuffle(X, activeId); \
2118+
if (sglid > activeId) \
2119+
X = op(value, X); \
2120+
\
2121+
uint disable = 1 << activeId; \
2122+
activeChannels ^= disable; \
2123+
} \
2124+
}
2125+
2126+
#define DEFN_SUB_GROUP_SCAN_EXCL_NON_UNIFORM(type, type_abbr, op, identity, X) \
2127+
{ \
2128+
uint sglid = __builtin_spirv_BuiltInSubgroupLocalInvocationId(); \
2129+
uint activeChannels = __builtin_IB_WaveBallot(true); \
2130+
\
2131+
uint mask = (1 << sglid) - 1; \
2132+
uint sglidPrev = (sizeof(uint) * 8 - __builtin_spirv_OpenCL_clz_i32(activeChannels & mask)) - 1; \
2133+
uint offsetToPrevActive = sglid - sglidPrev; \
2134+
X = intel_sub_group_shuffle_up((type)identity, X, offsetToPrevActive); \
2135+
\
2136+
while (activeChannels) \
2137+
{ \
2138+
uint activeId = __builtin_spirv_OpenCL_ctz_i32(activeChannels); \
2139+
\
2140+
type value = intel_sub_group_shuffle(X, activeId); \
2141+
if (sglid > activeId) \
2142+
X = op(value, X); \
2143+
\
2144+
uint disable = 1 << activeId; \
2145+
activeChannels ^= disable; \
2146+
} \
2147+
}
2148+
2149+
#define DEFN_SUB_GROUP_CLUSTERED_REDUCE(type, type_abbr, op, identity, X, ClusterSize) \
2150+
{ \
2151+
uint clusterIndex = 0; \
2152+
uint activeChannels = __builtin_IB_WaveBallot(true); \
2153+
uint numActive = __builtin_spirv_OpenCL_popcount_i32(activeChannels); \
2154+
uint numClusters = numActive / ClusterSize; \
2155+
\
2156+
for (uint clusterIndex = 0; clusterIndex < numClusters; clusterIndex++) \
2157+
{ \
2158+
uint Counter = ClusterSize; \
2159+
uint Ballot = activeChannels; \
2160+
uint clusterBallot = 0; \
2161+
while (Counter--) \
2162+
{ \
2163+
uint trailingOne = 1 << __builtin_spirv_OpenCL_ctz_i32(Ballot); \
2164+
clusterBallot |= trailingOne; \
2165+
Ballot ^= trailingOne; \
2166+
} \
2167+
uint active = __builtin_spirv_OpGroupNonUniformInverseBallot_i32_v4i32(Subgroup, clusterBallot); \
2168+
if (active) \
2169+
{ \
2170+
DEFN_SUB_GROUP_REDUCE_NON_UNIFORM(type, type_abbr, op, identity, X) \
2171+
} \
2172+
activeChannels ^= clusterBallot; \
2173+
} \
2174+
}
2175+
2176+
#define SUB_GROUP_SWITCH_NON_UNIFORM(type, type_abbr, op, identity, X, Operation, ClusterSize) \
2177+
{ \
2178+
switch (Operation){ \
2179+
case GroupOperationReduce: \
2180+
DEFN_SUB_GROUP_REDUCE_NON_UNIFORM(type, type_abbr, op, identity, X) \
2181+
break; \
2182+
case GroupOperationInclusiveScan: \
2183+
DEFN_SUB_GROUP_SCAN_INCL_NON_UNIFORM(type, type_abbr, op, identity, X) \
2184+
break; \
2185+
case GroupOperationExclusiveScan: \
2186+
DEFN_SUB_GROUP_SCAN_EXCL_NON_UNIFORM(type, type_abbr, op, identity, X) \
2187+
break; \
2188+
case GroupOperationClusteredReduce: \
2189+
DEFN_SUB_GROUP_CLUSTERED_REDUCE(type, type_abbr, op, identity, X, ClusterSize) \
2190+
break; \
2191+
default: \
2192+
return 0; \
2193+
break; \
2194+
} \
2195+
}
2196+
20862197
// ClusterSize is an optional parameter
20872198
#define DEFN_NON_UNIFORM_GROUP_FUNC(func, type, type_abbr, op, identity) \
20882199
type __builtin_spirv_OpGroupNonUniform##func##_i32_i32_##type_abbr##_i32(uint Execution, uint Operation, type X, uint ClusterSize) \
@@ -2109,7 +2220,8 @@ type __builtin_spirv_OpGroupNonUniform##func##_i32_i32_##type_abbr##_i32(uint E
21092220
} \
21102221
} \
21112222
else { \
2112-
SUB_GROUP_SWITCH(type, type_abbr, op, identity, X, Operation) \
2223+
SUB_GROUP_SWITCH_NON_UNIFORM(type, type_abbr, op, identity, X, Operation, ClusterSize) \
2224+
return X; \
21132225
} \
21142226
return 0; \
21152227
} \
@@ -2171,31 +2283,36 @@ DEFN_NON_UNIFORM_GROUP_FUNC(FMax, half, f16, __builtin_spirv_OpenCL_fmax_f16_f
21712283
#endif // defined(cl_khr_fp16)
21722284

21732285
// OpGroupNonUniformIMul, OpGroupNonUniformFMul
2174-
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, uchar, i8, __intel_mul, 0)
2175-
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, ushort, i16, __intel_mul, 0)
2176-
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, uint, i32, __intel_mul, 0)
2177-
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, ulong, i64, __intel_mul, 0)
2178-
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, float, f32, __intel_mul, 0)
2286+
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, uchar, i8, __intel_mul, 1)
2287+
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, ushort, i16, __intel_mul, 1)
2288+
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, uint, i32, __intel_mul, 1)
2289+
DEFN_NON_UNIFORM_GROUP_FUNC(IMul, ulong, i64, __intel_mul, 1)
2290+
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, float, f32, __intel_mul, 1)
21792291
#if defined(cl_khr_fp64)
2180-
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, double, f64, __intel_mul, 0)
2292+
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, double, f64, __intel_mul, 1)
21812293
#endif // defined(cl_khr_fp64)
21822294
#if defined(cl_khr_fp16)
2183-
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, half, f16, __intel_mul, 0)
2295+
DEFN_NON_UNIFORM_GROUP_FUNC(FMul, half, f16, __intel_mul, 1)
21842296
#endif // defined(cl_khr_fp16)
21852297

21862298
// OpGroupNonUniformBitwiseAnd, OpGroupNonUniformBitwiseOr, OpGroupNonUniformBitwiseXor
2187-
#define DEFN_NON_UNIFORM_BITWISE_OPERATION(func, op) \
2188-
DEFN_NON_UNIFORM_GROUP_FUNC(func, uchar, i8, __intel_##op, 0) \
2189-
DEFN_NON_UNIFORM_GROUP_FUNC(func, ushort, i16, __intel_##op, 0) \
2190-
DEFN_NON_UNIFORM_GROUP_FUNC(func, uint, i32, __intel_##op, 0) \
2191-
DEFN_NON_UNIFORM_GROUP_FUNC(func, ulong, i64, __intel_##op, 0)
2299+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseAnd, uchar, i8, __intel_and, 0xFF)
2300+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseAnd, ushort, i16, __intel_and, 0xFFFF)
2301+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseAnd, uint, i32, __intel_and, 0xFFFFFFFF)
2302+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseAnd, ulong, i64, __intel_and, 0xFFFFFFFFFFFFFFFF)
2303+
2304+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseOr, uchar, i8, __intel_or, 0)
2305+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseOr, ushort, i16, __intel_or, 0)
2306+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseOr, uint, i32, __intel_or, 0)
2307+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseOr, ulong, i64, __intel_or, 0)
21922308

2193-
DEFN_NON_UNIFORM_BITWISE_OPERATION(BitwiseAnd, and)
2194-
DEFN_NON_UNIFORM_BITWISE_OPERATION(BitwiseOr, or)
2195-
DEFN_NON_UNIFORM_BITWISE_OPERATION(BitwiseXor, xor)
2309+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseXor, uchar, i8, __intel_xor, 0)
2310+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseXor, ushort, i16, __intel_xor, 0)
2311+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseXor, uint, i32, __intel_xor, 0)
2312+
DEFN_NON_UNIFORM_GROUP_FUNC(BitwiseXor, ulong, i64, __intel_xor, 0)
21962313

21972314
// OpGroupNonUniformLogicalAnd, OpGroupNonUniformLogicalOr, OpGroupNonUniformLogicalXor
2198-
DEFN_NON_UNIFORM_GROUP_FUNC(LogicalAnd, bool, i1, __intel_and, 0)
2315+
DEFN_NON_UNIFORM_GROUP_FUNC(LogicalAnd, bool, i1, __intel_and, 1)
21992316
DEFN_NON_UNIFORM_GROUP_FUNC(LogicalOr, bool, i1, __intel_or, 0)
22002317
DEFN_NON_UNIFORM_GROUP_FUNC(LogicalXor, bool, i1, __intel_xor, 0)
22012318
#endif // defined(cl_khr_subgroup_non_uniform_arithmetic) || defined(cl_khr_subgroup_clustered_reduce)

0 commit comments

Comments
 (0)