Skip to content

Commit 5a53add

Browse files
authored
[mlir] Add optimization attrs for gpu-to-llvmspv function declarations and calls (#99301)
Adds the attributes nounwind and willreturn to all function declarations. Adds `memory(none)` equivalent to the id/dimension function declarations. The function declaration attributes are copied to the function calls. `nounwind` is legal because there are no exception in SPIR-V. I also do not see any reason why any of these functions would not return when used correctly. I'm confident that the get id/dim functions will have no externally observable memory effects, but think the convergent functions will have effects.
1 parent 5898a7f commit 5a53add

File tree

2 files changed

+255
-67
lines changed

2 files changed

+255
-67
lines changed

mlir/lib/Conversion/GPUToLLVMSPV/GPUToLLVMSPV.cpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ namespace mlir {
4343
static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
4444
StringRef name,
4545
ArrayRef<Type> paramTypes,
46-
Type resultType,
47-
bool isConvergent = false) {
46+
Type resultType, bool isMemNone,
47+
bool isConvergent) {
4848
auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
4949
SymbolTable::lookupSymbolIn(symbolTable, name));
5050
if (!func) {
@@ -53,6 +53,18 @@ static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
5353
symbolTable->getLoc(), name,
5454
LLVM::LLVMFunctionType::get(resultType, paramTypes));
5555
func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
56+
func.setNoUnwind(true);
57+
func.setWillReturn(true);
58+
59+
if (isMemNone) {
60+
// no externally observable effects
61+
constexpr auto noModRef = mlir::LLVM::ModRefInfo::NoModRef;
62+
auto memAttr = b.getAttr<LLVM::MemoryEffectsAttr>(
63+
/*other=*/noModRef,
64+
/*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
65+
func.setMemoryEffectsAttr(memAttr);
66+
}
67+
5668
func.setConvergent(isConvergent);
5769
}
5870
return func;
@@ -64,6 +76,10 @@ static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
6476
ValueRange args) {
6577
auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
6678
call.setCConv(func.getCConv());
79+
call.setConvergentAttr(func.getConvergentAttr());
80+
call.setNoUnwindAttr(func.getNoUnwindAttr());
81+
call.setWillReturnAttr(func.getWillReturnAttr());
82+
call.setMemoryEffectsAttr(func.getMemoryEffectsAttr());
6783
return call;
6884
}
6985

@@ -91,8 +107,9 @@ struct GPUBarrierConversion final : ConvertOpToLLVMPattern<gpu::BarrierOp> {
91107
assert(moduleOp && "Expecting module");
92108
Type flagTy = rewriter.getI32Type();
93109
Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
94-
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
95-
moduleOp, funcName, flagTy, voidTy, /*isConvergent=*/true);
110+
LLVM::LLVMFuncOp func =
111+
lookupOrCreateSPIRVFn(moduleOp, funcName, flagTy, voidTy,
112+
/*isMemNone=*/false, /*isConvergent=*/true);
96113

97114
// Value used by SPIR-V backend to represent `CLK_LOCAL_MEM_FENCE`.
98115
// See `llvm/lib/Target/SPIRV/SPIRVBuiltins.td`.
@@ -134,8 +151,9 @@ struct LaunchConfigConversion : ConvertToLLVMPattern {
134151
assert(moduleOp && "Expecting module");
135152
Type dimTy = rewriter.getI32Type();
136153
Type indexTy = getTypeConverter()->getIndexType();
137-
LLVM::LLVMFuncOp func =
138-
lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy, indexTy);
154+
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(moduleOp, funcName, dimTy,
155+
indexTy, /*isMemNone=*/true,
156+
/*isConvergent=*/false);
139157

140158
Location loc = op->getLoc();
141159
gpu::Dimension dim = getDimension(op);
@@ -268,9 +286,9 @@ struct GPUShuffleConversion final : ConvertOpToLLVMPattern<gpu::ShuffleOp> {
268286
Type valueType = adaptor.getValue().getType();
269287
Type offsetType = adaptor.getOffset().getType();
270288
Type resultType = valueType;
271-
LLVM::LLVMFuncOp func =
272-
lookupOrCreateSPIRVFn(moduleOp, funcName, {valueType, offsetType},
273-
resultType, /*isConvergent=*/true);
289+
LLVM::LLVMFuncOp func = lookupOrCreateSPIRVFn(
290+
moduleOp, funcName, {valueType, offsetType}, resultType,
291+
/*isMemNone=*/false, /*isConvergent=*/true);
274292

275293
Location loc = op->getLoc();
276294
std::array<Value, 2> args{adaptor.getValue(), adaptor.getOffset()};

0 commit comments

Comments
 (0)