22
22
#include " mlir/IR/BuiltinOps.h"
23
23
#include " mlir/IR/PatternMatch.h"
24
24
#include " mlir/Transforms/DialectConversion.h"
25
+ #include " llvm/ADT/TypeSwitch.h"
25
26
#include " llvm/Support/Debug.h"
26
27
#include " llvm/Support/FormatVariadic.h"
27
28
@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
1027
1028
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn (Operation *symbolTable,
1028
1029
StringRef name,
1029
1030
ArrayRef<Type> paramTypes,
1030
- Type resultType) {
1031
+ Type resultType,
1032
+ bool convergent = true ) {
1031
1033
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
1032
1034
SymbolTable::lookupSymbolIn (symbolTable, name));
1033
1035
if (func)
@@ -1038,7 +1040,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
1038
1040
symbolTable->getLoc (), name,
1039
1041
LLVM::LLVMFunctionType::get (resultType, paramTypes));
1040
1042
func.setCConv (LLVM::cconv::CConv::SPIR_FUNC);
1041
- func.setConvergent (true );
1043
+ func.setConvergent (convergent );
1042
1044
func.setNoUnwind (true );
1043
1045
func.setWillReturn (true );
1044
1046
return func;
@@ -1089,6 +1091,181 @@ class ControlBarrierPattern
1089
1091
}
1090
1092
};
1091
1093
1094
+ namespace {
1095
+
1096
+ StringRef getTypeMangling (Type type, bool isSigned) {
1097
+ return llvm::TypeSwitch<Type, StringRef>(type)
1098
+ .Case <Float16Type>([](auto ) { return " Dh" ; })
1099
+ .Case <Float32Type>([](auto ) { return " f" ; })
1100
+ .Case <Float64Type>([](auto ) { return " d" ; })
1101
+ .Case <IntegerType>([isSigned](IntegerType intTy) {
1102
+ switch (intTy.getWidth ()) {
1103
+ case 1 :
1104
+ return " b" ;
1105
+ case 8 :
1106
+ return (isSigned) ? " a" : " c" ;
1107
+ case 16 :
1108
+ return (isSigned) ? " s" : " t" ;
1109
+ case 32 :
1110
+ return (isSigned) ? " i" : " j" ;
1111
+ case 64 :
1112
+ return (isSigned) ? " l" : " m" ;
1113
+ default :
1114
+ llvm_unreachable (" Unsupported integer width" );
1115
+ }
1116
+ })
1117
+ .Default ([](auto ) {
1118
+ llvm_unreachable (" No mangling defined" );
1119
+ return " " ;
1120
+ });
1121
+ }
1122
+
1123
+ template <typename ReduceOp>
1124
+ constexpr StringLiteral getGroupFuncName ();
1125
+
1126
+ template <>
1127
+ constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1128
+ return " _Z17__spirv_GroupIAddii" ;
1129
+ }
1130
+ template <>
1131
+ constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1132
+ return " _Z17__spirv_GroupFAddii" ;
1133
+ }
1134
+ template <>
1135
+ constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1136
+ return " _Z17__spirv_GroupSMinii" ;
1137
+ }
1138
+ template <>
1139
+ constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1140
+ return " _Z17__spirv_GroupUMinii" ;
1141
+ }
1142
+ template <>
1143
+ constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1144
+ return " _Z17__spirv_GroupFMinii" ;
1145
+ }
1146
+ template <>
1147
+ constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1148
+ return " _Z17__spirv_GroupSMaxii" ;
1149
+ }
1150
+ template <>
1151
+ constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1152
+ return " _Z17__spirv_GroupUMaxii" ;
1153
+ }
1154
+ template <>
1155
+ constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1156
+ return " _Z17__spirv_GroupFMaxii" ;
1157
+ }
1158
+ template <>
1159
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1160
+ return " _Z27__spirv_GroupNonUniformIAddii" ;
1161
+ }
1162
+ template <>
1163
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1164
+ return " _Z27__spirv_GroupNonUniformFAddii" ;
1165
+ }
1166
+ template <>
1167
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1168
+ return " _Z27__spirv_GroupNonUniformIMulii" ;
1169
+ }
1170
+ template <>
1171
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1172
+ return " _Z27__spirv_GroupNonUniformFMulii" ;
1173
+ }
1174
+ template <>
1175
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1176
+ return " _Z27__spirv_GroupNonUniformSMinii" ;
1177
+ }
1178
+ template <>
1179
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1180
+ return " _Z27__spirv_GroupNonUniformUMinii" ;
1181
+ }
1182
+ template <>
1183
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1184
+ return " _Z27__spirv_GroupNonUniformFMinii" ;
1185
+ }
1186
+ template <>
1187
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1188
+ return " _Z27__spirv_GroupNonUniformSMaxii" ;
1189
+ }
1190
+ template <>
1191
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1192
+ return " _Z27__spirv_GroupNonUniformUMaxii" ;
1193
+ }
1194
+ template <>
1195
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1196
+ return " _Z27__spirv_GroupNonUniformFMaxii" ;
1197
+ }
1198
+ template <>
1199
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1200
+ return " _Z33__spirv_GroupNonUniformBitwiseAndii" ;
1201
+ }
1202
+ template <>
1203
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1204
+ return " _Z32__spirv_GroupNonUniformBitwiseOrii" ;
1205
+ }
1206
+ template <>
1207
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1208
+ return " _Z33__spirv_GroupNonUniformBitwiseXorii" ;
1209
+ }
1210
+ template <>
1211
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1212
+ return " _Z33__spirv_GroupNonUniformLogicalAndii" ;
1213
+ }
1214
+ template <>
1215
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1216
+ return " _Z32__spirv_GroupNonUniformLogicalOrii" ;
1217
+ }
1218
+ template <>
1219
+ constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1220
+ return " _Z33__spirv_GroupNonUniformLogicalXorii" ;
1221
+ }
1222
+ } // namespace
1223
+
1224
+ template <typename ReduceOp, bool Signed = false , bool NonUniform = false >
1225
+ class GroupReducePattern : public SPIRVToLLVMConversion <ReduceOp> {
1226
+ public:
1227
+ using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1228
+
1229
+ LogicalResult
1230
+ matchAndRewrite (ReduceOp op, typename ReduceOp::Adaptor adaptor,
1231
+ ConversionPatternRewriter &rewriter) const override {
1232
+
1233
+ Type retTy = op.getResult ().getType ();
1234
+ if (!retTy.isIntOrFloat ()) {
1235
+ return failure ();
1236
+ }
1237
+ SmallString<36 > funcName = getGroupFuncName<ReduceOp>();
1238
+ funcName += getTypeMangling (retTy, false );
1239
+
1240
+ Type i32Ty = rewriter.getI32Type ();
1241
+ SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1242
+ if constexpr (NonUniform) {
1243
+ if (adaptor.getClusterSize ()) {
1244
+ funcName += " j" ;
1245
+ paramTypes.push_back (i32Ty);
1246
+ }
1247
+ }
1248
+
1249
+ Operation *symbolTable =
1250
+ op->template getParentWithTrait <OpTrait::SymbolTable>();
1251
+
1252
+ LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn (
1253
+ symbolTable, funcName, paramTypes, retTy, !NonUniform);
1254
+
1255
+ Location loc = op.getLoc ();
1256
+ Value scope = rewriter.create <LLVM::ConstantOp>(
1257
+ loc, i32Ty, static_cast <int32_t >(adaptor.getExecutionScope ()));
1258
+ Value groupOp = rewriter.create <LLVM::ConstantOp>(
1259
+ loc, i32Ty, static_cast <int32_t >(adaptor.getGroupOperation ()));
1260
+ SmallVector<Value> operands{scope, groupOp};
1261
+ operands.append (adaptor.getOperands ().begin (), adaptor.getOperands ().end ());
1262
+
1263
+ auto call = createSPIRVBuiltinCall (loc, rewriter, func, operands);
1264
+ rewriter.replaceOp (op, call);
1265
+ return success ();
1266
+ }
1267
+ };
1268
+
1092
1269
// / Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
1093
1270
// / should be reachable for conversion to succeed. The structure of the loop in
1094
1271
// / LLVM dialect will be the following:
@@ -1722,7 +1899,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
1722
1899
ReturnPattern, ReturnValuePattern,
1723
1900
1724
1901
// Barrier ops
1725
- ControlBarrierPattern>(patterns.getContext (), typeConverter);
1902
+ ControlBarrierPattern,
1903
+
1904
+ // Group reduction operations
1905
+ GroupReducePattern<spirv::GroupIAddOp>,
1906
+ GroupReducePattern<spirv::GroupFAddOp>,
1907
+ GroupReducePattern<spirv::GroupFMinOp>,
1908
+ GroupReducePattern<spirv::GroupUMinOp>,
1909
+ GroupReducePattern<spirv::GroupSMinOp, /* Signed=*/ true >,
1910
+ GroupReducePattern<spirv::GroupFMaxOp>,
1911
+ GroupReducePattern<spirv::GroupUMaxOp>,
1912
+ GroupReducePattern<spirv::GroupSMaxOp, /* Signed=*/ true >,
1913
+ GroupReducePattern<spirv::GroupNonUniformIAddOp, /* Signed=*/ false ,
1914
+ /* NonUniform=*/ true >,
1915
+ GroupReducePattern<spirv::GroupNonUniformFAddOp, /* Signed=*/ false ,
1916
+ /* NonUniform=*/ true >,
1917
+ GroupReducePattern<spirv::GroupNonUniformIMulOp, /* Signed=*/ false ,
1918
+ /* NonUniform=*/ true >,
1919
+ GroupReducePattern<spirv::GroupNonUniformFMulOp, /* Signed=*/ false ,
1920
+ /* NonUniform=*/ true >,
1921
+ GroupReducePattern<spirv::GroupNonUniformSMinOp, /* Signed=*/ true ,
1922
+ /* NonUniform=*/ true >,
1923
+ GroupReducePattern<spirv::GroupNonUniformUMinOp, /* Signed=*/ false ,
1924
+ /* NonUniform=*/ true >,
1925
+ GroupReducePattern<spirv::GroupNonUniformFMinOp, /* Signed=*/ false ,
1926
+ /* NonUniform=*/ true >,
1927
+ GroupReducePattern<spirv::GroupNonUniformSMaxOp, /* Signed=*/ true ,
1928
+ /* NonUniform=*/ true >,
1929
+ GroupReducePattern<spirv::GroupNonUniformUMaxOp, /* Signed=*/ false ,
1930
+ /* NonUniform=*/ true >,
1931
+ GroupReducePattern<spirv::GroupNonUniformFMaxOp, /* Signed=*/ false ,
1932
+ /* NonUniform=*/ true >,
1933
+ GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /* Signed=*/ false ,
1934
+ /* NonUniform=*/ true >,
1935
+ GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /* Signed=*/ false ,
1936
+ /* NonUniform=*/ true >,
1937
+ GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /* Signed=*/ false ,
1938
+ /* NonUniform=*/ true >,
1939
+ GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /* Signed=*/ false ,
1940
+ /* NonUniform=*/ true >,
1941
+ GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /* Signed=*/ false ,
1942
+ /* NonUniform=*/ true >,
1943
+ GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /* Signed=*/ false ,
1944
+ /* NonUniform=*/ true >>(patterns.getContext (),
1945
+ typeConverter);
1726
1946
1727
1947
patterns.add <GlobalVariablePattern>(clientAPI, patterns.getContext (),
1728
1948
typeConverter);
0 commit comments