-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op #78510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir-spirv @llvm/pr-subscribers-mlir Author: Dimple Prajapati (drprajap) ChangesThis change contains following: Full diff: https://github.com/llvm/llvm-project/pull/78510.diff 6 Files Affected:
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index c7c2fe8bc742c1..b5ca27d7d75316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
#### Example:
```mlir
- %0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ %0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
```
}];
@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
);
let assemblyFormat = [{
- $format `,` $arguments attr-dict `:` `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
+ $format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
}];
let hasVerifier = 0;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d7885e0359592d..8d9f4554d8d799 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -135,6 +135,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
ConversionPatternRewriter &rewriter) const override;
};
+class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
+public:
+ using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override;
+};
+
} // namespace
//===----------------------------------------------------------------------===//
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
}
};
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+ gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const {
+
+ auto loc = gpuPrintfOp.getLoc();
+
+ auto funcOp =
+ rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+
+ auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+
+ const char formatStringPrefix[] = "printfMsg";
+ unsigned stringNumber = 0;
+ mlir::SmallString<16> globalVarName;
+ mlir::spirv::GlobalVariableOp globalVar;
+
+ // formulate spirv global variable name
+ do {
+ globalVarName.clear();
+ (formatStringPrefix + llvm::Twine(stringNumber++))
+ .toStringRef(globalVarName);
+ } while (moduleOp.lookupSymbol(globalVarName));
+
+ auto i8Type = rewriter.getI8Type();
+ auto i32Type = rewriter.getI32Type();
+
+ unsigned scNum = 0;
+ auto createSpecConstant = [&](unsigned value) {
+ auto attr = rewriter.getI8IntegerAttr(value);
+ mlir::SmallString<16> specCstName;
+ (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+ .toStringRef(specCstName);
+
+ return rewriter.create<mlir::spirv::SpecConstantOp>(
+ loc, rewriter.getStringAttr(specCstName), attr);
+ };
+
+ // define GlobalVarOp with printf format string using SpecConstants
+ // and make composite of SpecConstants
+ {
+ mlir::Operation *parent =
+ mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+ mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+ mlir::Block &entryBlock = *parent->getRegion(0).begin();
+ rewriter.setInsertionPointToStart(
+ &entryBlock); // insertion point at module level
+
+ // Create Constituents with SpecConstant to construct
+ // SpecConstantCompositeOp
+ llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+ formatString.push_back('\0'); // Null terminate for C
+ mlir::SmallVector<mlir::Attribute, 4> constituents;
+ for (auto c : formatString) {
+ auto cSpecConstantOp = createSpecConstant(c);
+ constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+ }
+
+ // Create specialization constant composite defined via spirv.SpecConstant
+ size_t contentSize = constituents.size();
+ auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
+ mlir::spirv::SpecConstantCompositeOp specCstComposite;
+ mlir::SmallString<16> specCstCompositeName;
+ (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+ specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
+ loc, mlir::TypeAttr::get(globalType),
+ rewriter.getStringAttr(specCstCompositeName),
+ rewriter.getArrayAttr(constituents));
+
+ // Define GlobalVariable initialized from Constant Composite
+ globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
+ loc,
+ mlir::spirv::PointerType::get(
+ globalType, mlir::spirv::StorageClass::UniformConstant),
+ globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+
+ globalVar->setAttr("Constant", rewriter.getUnitAttr());
+ }
+
+ // Get SSA value of Global variable
+ mlir::Value globalPtr =
+ rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
+ mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+ loc,
+ mlir::spirv::PointerType::get(i8Type,
+ mlir::spirv::StorageClass::UniformConstant),
+ globalPtr);
+
+ // Get printf arguments
+ auto argsRange = adaptor.getArgs();
+ mlir::SmallVector<mlir::Value, 4> printfArgs;
+ printfArgs.reserve(argsRange.size() + 1);
+ printfArgs.append(argsRange.begin(), argsRange.end());
+
+ rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+ rewriter.eraseOp(gpuPrintfOp);
+
+ return mlir::success();
+}
+
//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
@@ -630,5 +741,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
spirv::BuiltIn::SubgroupSize>,
WorkGroupSizeConversion, GPUAllReduceConversion,
- GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
+ GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
+ patterns.getContext());
}
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0faeee..89a72260290e22 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -309,6 +309,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
+ case spirv::Decoration::Constant:
if (words.size() != 2) {
return emitError(unknownLoc, "OpDecoration with ")
<< decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf74..2252c339af0a75 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -272,6 +272,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
case spirv::Decoration::RelaxedPrecision:
case spirv::Decoration::Restrict:
case spirv::Decoration::RestrictPointer:
+ case spirv::Decoration::Constant:
// For unit attributes and decoration attributes, the args list
// has no values so we do nothing.
if (isa<UnitAttr, DecorationAttr>(attr))
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
new file mode 100644
index 00000000000000..4c77195f916014
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+
+ gpu.launch_func @kernels::@printf
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args()
+ return
+ }
+
+ gpu.module @kernels {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf() kernel
+ attributes
+ {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : (!spirv.ptr<i8, UniformConstant>) -> i32
+ gpu.printf "\nHello\n"
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
+
+// -----
+
+module attributes {
+ gpu.container_module,
+ spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+ func.func @main() {
+ %c1 = arith.constant 1 : index
+ %c100 = arith.constant 100: i32
+ %cst_f32 = arith.constant 314.4: f32
+
+ gpu.launch_func @kernels1::@printf_args
+ blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+ args(%c100: i32, %cst_f32: f32)
+ return
+ }
+
+ gpu.module @kernels1 {
+ // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+ // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+ // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+ // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant} : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
+ attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+ %0 = gpu.block_id x
+ %1 = gpu.block_id y
+ %2 = gpu.thread_id x
+
+ // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+ // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+ // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}}, {{%.*}} : i32, f32, i32) -> i32
+ gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
+
+ // CHECK: spirv.Return
+ gpu.return
+ }
+ }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 81ba471d3f51e3..171087a167850f 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -275,8 +275,8 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
//===----------------------------------------------------------------------===//
// CHECK-LABEL: func.func @printf(
func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
- // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
- %0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+ // CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
+ %0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
return %0 : i32
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the contribution, look like a useful debugging feature to have! I left some comments.
I find it hard to follow the matchAndRewrite
code -- could we brake it down a bit and introduce some helper functions?
mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>( | ||
loc, | ||
mlir::spirv::PointerType::get(i8Type, | ||
mlir::spirv::StorageClass::UniformConstant), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we instead check GV's storage class and use it instead of constant? Both SPIR-V and OpenCL have extensions relaxing printf string arg requirement, see KhronosGroup/SPIRV-Registry#148
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I got your suggestion, do you mean we can get the storage class from already defined globalVar op?
thanks for pointing out to the requirements docs, while implementing this, when I checked OpenCL specs for printf , they mention fmt string needs to be in constant address space, hence I used UniformConstant here- https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#printf
This change contains following: - adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass. - Fixes Constant decoration parsing for spirv GlobalVariable. - minor modification to spirv.CL.printf op assembly format.
21ae63b
to
b0e9c5f
Compare
Thank you @kuhar for your feedback and sorry for the delay in addressing them. |
@drprajap there are some test failures: https://buildkite.com/llvm-project/github-pull-requests/builds/101724#019202b6-5c05-43e1-aa01-48aba3a6a4de/6-202 |
@drprajap could you click the 'Resolve' button on past comments that you believe have been addressed? This will make it easier for me to tell what are the remaining open discussion threads. |
review feedback Co-authored-by: Jakub Kuderski <[email protected]>
Yes, sorry I missed to update test cases with new printf format changes, updated and verified locally now. |
@kuhar I addressed all the feedback and resolved comments, please check it out and let me know if looks okay. Thanks so much for your valuable feedback. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall looks good, just some remaining issues
formatting changes Co-authored-by: Jakub Kuderski <[email protected]>
Thank you. Addressed remaining issues. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
minor formatting changes Co-authored-by: Jakub Kuderski <[email protected]>
Thanks, I do not have write permission, can you please help merge it if possible? else I can ask other team members. |
Sure, I will merge this on Monday. Please ping me if I forget to do so. |
Gentle ping to merge this. Thank you :) |
…78510) This change contains following: - adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass. - Fixes Constant decoration parsing for spirv GlobalVariable. - minor modification to spirv.CL.printf op assembly format. --------- Co-authored-by: Jakub Kuderski <[email protected]>
This change contains following:
- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
- Fixes Constant decoration parsing for spirv GlobalVariable.
- minor modification to spirv.CL.printf op assembly format.