-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add spirv-to-llvm conversion for OpControlBarrier #111864
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
The translation is based on the expected llvm function from the LLVM/SPIRV translation tool
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Finlay (FMarno) ChangesThe translation is based on the expected llvm function from the LLVM/SPIRV translation tool Full diff: https://github.com/llvm/llvm-project/pull/111864.diff 4 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
index 1ebea94fced0a3..14593305490661 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td
@@ -54,7 +54,7 @@ def SPIRV_ControlBarrierOp : SPIRV_Op<"ControlBarrier", []> {
#### Example:
```mlir
- spirv.ControlBarrier "Workgroup", "Device", "Acquire|UniformMemory"
+ spirv.ControlBarrier <Workgroup>, <Device>, <Acquire|UniformMemory>
```
}];
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
index 71ecabfb444bd0..022cbbbb6720fb 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td
@@ -1,4 +1,4 @@
-//===-- SPIRVBarrierOps.td - MLIR SPIR-V Barrier Ops -------*- tablegen -*-===//
+//===-- SPIRVMiscOps.td - MLIR SPIR-V Misc Ops -------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
diff --git a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
index 74c169c9a7e76a..50d090ddad901f 100644
--- a/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
+++ b/mlir/lib/Conversion/SPIRVToLLVM/SPIRVToLLVM.cpp
@@ -1024,6 +1024,71 @@ class ReturnValuePattern : public SPIRVToLLVMConversion<spirv::ReturnValueOp> {
}
};
+static LLVM::LLVMFuncOp lookupOrCreateSPIRVFn(Operation *symbolTable,
+ StringRef name,
+ ArrayRef<Type> paramTypes,
+ Type resultType) {
+ auto func = dyn_cast_or_null<LLVM::LLVMFuncOp>(
+ SymbolTable::lookupSymbolIn(symbolTable, name));
+ if (!func) {
+ OpBuilder b(symbolTable->getRegion(0));
+ func = b.create<LLVM::LLVMFuncOp>(
+ symbolTable->getLoc(), name,
+ LLVM::LLVMFunctionType::get(resultType, paramTypes));
+ func.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
+ func.setConvergent(true);
+ func.setNoUnwind(true);
+ func.setWillReturn(true);
+ }
+ return func;
+}
+
+static LLVM::CallOp createSPIRVBuiltinCall(Location loc,
+ ConversionPatternRewriter &rewriter,
+ LLVM::LLVMFuncOp func,
+ ValueRange args) {
+ auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+ call.setCConv(func.getCConv());
+ call.setConvergentAttr(func.getConvergentAttr());
+ call.setNoUnwindAttr(func.getNoUnwindAttr());
+ call.setWillReturnAttr(func.getWillReturnAttr());
+ return call;
+}
+
+class ControlBarrierPattern
+ : public SPIRVToLLVMConversion<spirv::ControlBarrierOp> {
+public:
+ using SPIRVToLLVMConversion<spirv::ControlBarrierOp>::SPIRVToLLVMConversion;
+
+ LogicalResult
+ matchAndRewrite(spirv::ControlBarrierOp controlBarrierOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ constexpr StringRef funcName = "_Z22__spirv_ControlBarrieriii";
+ Operation *symbolTable =
+ controlBarrierOp->getParentWithTrait<OpTrait::SymbolTable>();
+
+ Type i32 = rewriter.getI32Type();
+
+ Type voidTy = rewriter.getType<LLVM::LLVMVoidType>();
+ LLVM::LLVMFuncOp func =
+ lookupOrCreateSPIRVFn(symbolTable, funcName, {i32, i32, i32}, voidTy);
+
+ auto loc = controlBarrierOp->getLoc();
+ Value execution = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getExecutionScope()));
+ Value memory = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getMemoryScope()));
+ Value semantics = rewriter.create<LLVM::ConstantOp>(
+ loc, i32, static_cast<int32_t>(adaptor.getMemorySemantics()));
+
+ auto call = createSPIRVBuiltinCall(loc, rewriter, func,
+ {execution, memory, semantics});
+
+ rewriter.replaceOp(controlBarrierOp, call);
+ return success();
+ }
+};
+
/// Converts `spirv.mlir.loop` to LLVM dialect. All blocks within selection
/// should be reachable for conversion to succeed. The structure of the loop in
/// LLVM dialect will be the following:
@@ -1648,7 +1713,10 @@ void mlir::populateSPIRVToLLVMConversionPatterns(
ShiftPattern<spirv::ShiftLeftLogicalOp, LLVM::ShlOp>,
// Return ops
- ReturnPattern, ReturnValuePattern>(patterns.getContext(), typeConverter);
+ ReturnPattern, ReturnValuePattern,
+
+ // Barrier ops
+ ControlBarrierPattern>(patterns.getContext(), typeConverter);
patterns.add<GlobalVariablePattern>(clientAPI, patterns.getContext(),
typeConverter);
diff --git a/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
new file mode 100644
index 00000000000000..d53afeeea15d10
--- /dev/null
+++ b/mlir/test/Conversion/SPIRVToLLVM/barrier-ops-to-llvm.mlir
@@ -0,0 +1,23 @@
+// RUN: mlir-opt -convert-spirv-to-llvm %s | FileCheck %s
+
+//===----------------------------------------------------------------------===//
+// spirv.ControlBarrierOp
+//===----------------------------------------------------------------------===//
+
+// CHECK: llvm.func spir_funccc @_Z22__spirv_ControlBarrieriii(i32, i32, i32) attributes {convergent, no_unwind, will_return}
+
+// CHECK-LABEL: @control_barrier
+spirv.func @control_barrier() "None" {
+ // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(768 : i32) : i32
+ // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+ spirv.ControlBarrier <Workgroup>, <Workgroup>, <CrossWorkgroupMemory|WorkgroupMemory>
+
+ // CHECK: [[EXECUTION:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[MEMORY:%.*]] = llvm.mlir.constant(2 : i32) : i32
+ // CHECK: [[SEMANTICS:%.*]] = llvm.mlir.constant(256 : i32) : i32
+ // CHECK: llvm.call spir_funccc @_Z22__spirv_ControlBarrieriii([[EXECUTION]], [[MEMORY]], [[SEMANTICS]]) {convergent, no_unwind, will_return} : (i32, i32, i32) -> ()
+ spirv.ControlBarrier <Workgroup>, <Workgroup>, <WorkgroupMemory>
+ spirv.Return
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the changes, I only have one more comment. I'm not clicking the 'approve' button because I don't have stakes in the spirv-llvm lowering.
Thanks very much for the review |
I'm planning to merge this at 1600 UTC today. Let me know if that is an issue. |
) Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to #111864, lower the operations to builtin functions understood by SPIR-V tools. --------- Signed-off-by: Lukas Sommer <[email protected]>
…#115501) Lowering for some of the uniform and non-uniform group operations defined in section 3.52.21 of the SPIR-V specification from SPIR-V dialect to LLVM dialect. Similar to llvm#111864, lower the operations to builtin functions understood by SPIR-V tools. --------- Signed-off-by: Lukas Sommer <[email protected]>
The conversion is based on the expected llvm function from the LLVM/SPIRV translation tool.
I see there is no existing translation for things that can't be directly translated as an LLVM operation, so I'd be open to discussion about how this should work.