Skip to content

Commit 6ade03d

Browse files
authored
[mlir][spirv] Add spirv-to-llvm conversion for group operations (#115501)
Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools. --------- Signed-off-by: Lukas Sommer <[email protected]>
1 parent 58ca707 commit 6ade03d

File tree

3 files changed

+782
-3
lines changed

3 files changed

+782
-3
lines changed

mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp

Lines changed: 223 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/BuiltinOps.h"
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/Transforms/DialectConversion.h"
25+
#include "llvm/ADT/TypeSwitch.h"
2526
#include "llvm/Support/Debug.h"
2627
#include "llvm/Support/FormatVariadic.h"
2728

@@ -1027,7 +1028,8 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
10271028
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
10281029
StringRef name,
10291030
ArrayRef<Type> paramTypes,
1030-
Type resultType) {
1031+
Type resultType,
1032+
bool convergent = true) {
10311033
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
10321034
SymbolTable::lookupSymbolIn(symbolTable, name));
10331035
if (func)
@@ -1038,7 +1040,7 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
10381040
symbolTable->getLoc(), name,
10391041
LLVM::LLVMFunctionType::get(resultType, paramTypes));
10401042
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
1041-
func.setConvergent(true);
1043+
func.setConvergent(convergent);
10421044
func.setNoUnwind(true);
10431045
func.setWillReturn(true);
10441046
return func;
@@ -1089,6 +1091,181 @@ class ControlBarrierPattern
10891091
}
10901092
};
10911093

1094+
namespace {
1095+
1096+
StringRef getTypeMangling(Type type, bool isSigned) {
1097+
return llvm::TypeSwitch<Type, StringRef>(type)
1098+
.Case<Float16Type>([](auto) { return "Dh"; })
1099+
.Case<Float32Type>([](auto) { return "f"; })
1100+
.Case<Float64Type>([](auto) { return "d"; })
1101+
.Case<IntegerType>([isSigned](IntegerType intTy) {
1102+
switch (intTy.getWidth()) {
1103+
case 1:
1104+
return "b";
1105+
case 8:
1106+
return (isSigned) ? "a" : "c";
1107+
case 16:
1108+
return (isSigned) ? "s" : "t";
1109+
case 32:
1110+
return (isSigned) ? "i" : "j";
1111+
case 64:
1112+
return (isSigned) ? "l" : "m";
1113+
default:
1114+
llvm_unreachable("Unsupported integer width");
1115+
}
1116+
})
1117+
.Default([](auto) {
1118+
llvm_unreachable("No mangling defined");
1119+
return "";
1120+
});
1121+
}
1122+
1123+
template <typename ReduceOp>
1124+
constexpr StringLiteral getGroupFuncName();
1125+
1126+
template <>
1127+
constexpr StringLiteral getGroupFuncName<spirv::GroupIAddOp>() {
1128+
return "_Z17__spirv_GroupIAddii";
1129+
}
1130+
template <>
1131+
constexpr StringLiteral getGroupFuncName<spirv::GroupFAddOp>() {
1132+
return "_Z17__spirv_GroupFAddii";
1133+
}
1134+
template <>
1135+
constexpr StringLiteral getGroupFuncName<spirv::GroupSMinOp>() {
1136+
return "_Z17__spirv_GroupSMinii";
1137+
}
1138+
template <>
1139+
constexpr StringLiteral getGroupFuncName<spirv::GroupUMinOp>() {
1140+
return "_Z17__spirv_GroupUMinii";
1141+
}
1142+
template <>
1143+
constexpr StringLiteral getGroupFuncName<spirv::GroupFMinOp>() {
1144+
return "_Z17__spirv_GroupFMinii";
1145+
}
1146+
template <>
1147+
constexpr StringLiteral getGroupFuncName<spirv::GroupSMaxOp>() {
1148+
return "_Z17__spirv_GroupSMaxii";
1149+
}
1150+
template <>
1151+
constexpr StringLiteral getGroupFuncName<spirv::GroupUMaxOp>() {
1152+
return "_Z17__spirv_GroupUMaxii";
1153+
}
1154+
template <>
1155+
constexpr StringLiteral getGroupFuncName<spirv::GroupFMaxOp>() {
1156+
return "_Z17__spirv_GroupFMaxii";
1157+
}
1158+
template <>
1159+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIAddOp>() {
1160+
return "_Z27__spirv_GroupNonUniformIAddii";
1161+
}
1162+
template <>
1163+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFAddOp>() {
1164+
return "_Z27__spirv_GroupNonUniformFAddii";
1165+
}
1166+
template <>
1167+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformIMulOp>() {
1168+
return "_Z27__spirv_GroupNonUniformIMulii";
1169+
}
1170+
template <>
1171+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMulOp>() {
1172+
return "_Z27__spirv_GroupNonUniformFMulii";
1173+
}
1174+
template <>
1175+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMinOp>() {
1176+
return "_Z27__spirv_GroupNonUniformSMinii";
1177+
}
1178+
template <>
1179+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMinOp>() {
1180+
return "_Z27__spirv_GroupNonUniformUMinii";
1181+
}
1182+
template <>
1183+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMinOp>() {
1184+
return "_Z27__spirv_GroupNonUniformFMinii";
1185+
}
1186+
template <>
1187+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformSMaxOp>() {
1188+
return "_Z27__spirv_GroupNonUniformSMaxii";
1189+
}
1190+
template <>
1191+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformUMaxOp>() {
1192+
return "_Z27__spirv_GroupNonUniformUMaxii";
1193+
}
1194+
template <>
1195+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformFMaxOp>() {
1196+
return "_Z27__spirv_GroupNonUniformFMaxii";
1197+
}
1198+
template <>
1199+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseAndOp>() {
1200+
return "_Z33__spirv_GroupNonUniformBitwiseAndii";
1201+
}
1202+
template <>
1203+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseOrOp>() {
1204+
return "_Z32__spirv_GroupNonUniformBitwiseOrii";
1205+
}
1206+
template <>
1207+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformBitwiseXorOp>() {
1208+
return "_Z33__spirv_GroupNonUniformBitwiseXorii";
1209+
}
1210+
template <>
1211+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalAndOp>() {
1212+
return "_Z33__spirv_GroupNonUniformLogicalAndii";
1213+
}
1214+
template <>
1215+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalOrOp>() {
1216+
return "_Z32__spirv_GroupNonUniformLogicalOrii";
1217+
}
1218+
template <>
1219+
constexpr StringLiteral getGroupFuncName<spirv::GroupNonUniformLogicalXorOp>() {
1220+
return "_Z33__spirv_GroupNonUniformLogicalXorii";
1221+
}
1222+
} // namespace
1223+
1224+
template <typename ReduceOp, bool Signed = false, bool NonUniform = false>
1225+
class GroupReducePattern : public SPIRVToLLVMConversion<ReduceOp> {
1226+
public:
1227+
using SPIRVToLLVMConversion<ReduceOp>::SPIRVToLLVMConversion;
1228+
1229+
LogicalResult
1230+
matchAndRewrite(ReduceOp op, typename ReduceOp::Adaptor adaptor,
1231+
ConversionPatternRewriter &rewriter) const override {
1232+
1233+
Type retTy = op.getResult().getType();
1234+
if (!retTy.isIntOrFloat()) {
1235+
return failure();
1236+
}
1237+
SmallString<36> funcName = getGroupFuncName<ReduceOp>();
1238+
funcName += getTypeMangling(retTy, false);
1239+
1240+
Type i32Ty = rewriter.getI32Type();
1241+
SmallVector<Type> paramTypes{i32Ty, i32Ty, retTy};
1242+
if constexpr (NonUniform) {
1243+
if (adaptor.getClusterSize()) {
1244+
funcName += "j";
1245+
paramTypes.push_back(i32Ty);
1246+
}
1247+
}
1248+
1249+
Operation *symbolTable =
1250+
op->template getParentWithTrait<OpTrait::SymbolTable>();
1251+
1252+
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
1253+
symbolTable, funcName, paramTypes, retTy, !NonUniform);
1254+
1255+
Location loc = op.getLoc();
1256+
Value scope = rewriter.create<LLVM::ConstantOp>(
1257+
loc, i32Ty, static_cast<int32_t>(adaptor.getExecutionScope()));
1258+
Value groupOp = rewriter.create<LLVM::ConstantOp>(
1259+
loc, i32Ty, static_cast<int32_t>(adaptor.getGroupOperation()));
1260+
SmallVector<Value> operands{scope, groupOp};
1261+
operands.append(adaptor.getOperands().begin(), adaptor.getOperands().end());
1262+
1263+
auto call = createSPIRVBuiltinCall(loc, rewriter, func, operands);
1264+
rewriter.replaceOp(op, call);
1265+
return success();
1266+
}
1267+
};
1268+
10921269
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
10931270
/// should be reachable for conversion to succeed. The structure of the loop in
10941271
/// LLVM dialect will be the following:
@@ -1722,7 +1899,50 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
17221899
ReturnPattern, ReturnValuePattern,
17231900

17241901
// Barrier ops
1725-
ControlBarrierPattern>(patterns.getContext(), typeConverter);
1902+
ControlBarrierPattern,
1903+
1904+
// Group reduction operations
1905+
GroupReducePattern<spirv::GroupIAddOp>,
1906+
GroupReducePattern<spirv::GroupFAddOp>,
1907+
GroupReducePattern<spirv::GroupFMinOp>,
1908+
GroupReducePattern<spirv::GroupUMinOp>,
1909+
GroupReducePattern<spirv::GroupSMinOp, /*Signed=*/true>,
1910+
GroupReducePattern<spirv::GroupFMaxOp>,
1911+
GroupReducePattern<spirv::GroupUMaxOp>,
1912+
GroupReducePattern<spirv::GroupSMaxOp, /*Signed=*/true>,
1913+
GroupReducePattern<spirv::GroupNonUniformIAddOp, /*Signed=*/false,
1914+
/*NonUniform=*/true>,
1915+
GroupReducePattern<spirv::GroupNonUniformFAddOp, /*Signed=*/false,
1916+
/*NonUniform=*/true>,
1917+
GroupReducePattern<spirv::GroupNonUniformIMulOp, /*Signed=*/false,
1918+
/*NonUniform=*/true>,
1919+
GroupReducePattern<spirv::GroupNonUniformFMulOp, /*Signed=*/false,
1920+
/*NonUniform=*/true>,
1921+
GroupReducePattern<spirv::GroupNonUniformSMinOp, /*Signed=*/true,
1922+
/*NonUniform=*/true>,
1923+
GroupReducePattern<spirv::GroupNonUniformUMinOp, /*Signed=*/false,
1924+
/*NonUniform=*/true>,
1925+
GroupReducePattern<spirv::GroupNonUniformFMinOp, /*Signed=*/false,
1926+
/*NonUniform=*/true>,
1927+
GroupReducePattern<spirv::GroupNonUniformSMaxOp, /*Signed=*/true,
1928+
/*NonUniform=*/true>,
1929+
GroupReducePattern<spirv::GroupNonUniformUMaxOp, /*Signed=*/false,
1930+
/*NonUniform=*/true>,
1931+
GroupReducePattern<spirv::GroupNonUniformFMaxOp, /*Signed=*/false,
1932+
/*NonUniform=*/true>,
1933+
GroupReducePattern<spirv::GroupNonUniformBitwiseAndOp, /*Signed=*/false,
1934+
/*NonUniform=*/true>,
1935+
GroupReducePattern<spirv::GroupNonUniformBitwiseOrOp, /*Signed=*/false,
1936+
/*NonUniform=*/true>,
1937+
GroupReducePattern<spirv::GroupNonUniformBitwiseXorOp, /*Signed=*/false,
1938+
/*NonUniform=*/true>,
1939+
GroupReducePattern<spirv::GroupNonUniformLogicalAndOp, /*Signed=*/false,
1940+
/*NonUniform=*/true>,
1941+
GroupReducePattern<spirv::GroupNonUniformLogicalOrOp, /*Signed=*/false,
1942+
/*NonUniform=*/true>,
1943+
GroupReducePattern<spirv::GroupNonUniformLogicalXorOp, /*Signed=*/false,
1944+
/*NonUniform=*/true>>(patterns.getContext(),
1945+
typeConverter);
17261946

17271947
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
17281948
typeConverter);

0 commit comments

Comments
 (0)