Skip to content

[mlir][spirv] Add gpu printf op lowering to spirv.CL.printf op #78510

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Sep 30, 2024

Conversation

drprajap
Copy link
Contributor

This change contains following:
- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
- Fixes Constant decoration parsing for spirv GlobalVariable.
- minor modification to spirv.CL.printf op assembly format.

Copy link

github-actions bot commented Jan 17, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@drprajap drprajap marked this pull request as ready for review January 17, 2024 23:16
@llvmbot
Copy link
Member

llvmbot commented Jan 17, 2024

@llvm/pr-subscribers-mlir-spirv
@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Dimple Prajapati (drprajap)

Changes

This change contains following:
- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
- Fixes Constant decoration parsing for spirv GlobalVariable.
- minor modification to spirv.CL.printf op assembly format.


Full diff: https://github.com/llvm/llvm-project/pull/78510.diff

6 Files Affected:

  • (modified) mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td (+2-2)
  • (modified) mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp (+113-1)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+1)
  • (modified) mlir/lib/Target/SPIRV/Serialization/Serializer.cpp (+1)
  • (added) mlir/test/Conversion/GPUToSPIRV/printf.mlir (+71)
  • (modified) mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir (+2-2)
diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
index c7c2fe8bc742c1..b5ca27d7d75316 100644
--- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
+++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCLOps.td
@@ -875,7 +875,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
     #### Example:
 
     ```mlir
-    %0 = spirv.CL.printf %0 %1 %2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+    %0 = spirv.CL.printf %0 : !spirv.ptr<i8, UniformConstant>(%1, %2 : i32, i32) -> i32
     ```
   }];
 
@@ -889,7 +889,7 @@ def SPIRV_CLPrintfOp : SPIRV_CLOp<"printf", 184, []> {
   );
 
   let assemblyFormat = [{
-  $format `,` $arguments  attr-dict `:`  `(` type($format) `,` `(` type($arguments) `)` `)` `->` type($result)
+    $format `:` type($format) ( `(` $arguments^ `:` type($arguments) `)`)? attr-dict `->` type($result)
   }];
 
   let hasVerifier = 0;
diff --git a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
index d7885e0359592d..8d9f4554d8d799 100644
--- a/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
+++ b/mlir/lib/Conversion/GPUToSPIRV/GPUToSPIRV.cpp
@@ -135,6 +135,15 @@ class GPUShuffleConversion final : public OpConversionPattern<gpu::ShuffleOp> {
                   ConversionPatternRewriter &rewriter) const override;
 };
 
+class GPUPrintfConversion final : public OpConversionPattern<gpu::PrintfOp> {
+public:
+  using OpConversionPattern<gpu::PrintfOp>::OpConversionPattern;
+
+  LogicalResult
+  matchAndRewrite(gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override;
+};
+
 } // namespace
 
 //===----------------------------------------------------------------------===//
@@ -607,6 +616,108 @@ class GPUSubgroupReduceConversion final
   }
 };
 
+LogicalResult GPUPrintfConversion::matchAndRewrite(
+    gpu::PrintfOp gpuPrintfOp, OpAdaptor adaptor,
+    ConversionPatternRewriter &rewriter) const {
+
+  auto loc = gpuPrintfOp.getLoc();
+
+  auto funcOp =
+      rewriter.getBlock()->getParent()->getParentOfType<mlir::spirv::FuncOp>();
+
+  auto moduleOp = funcOp->getParentOfType<mlir::spirv::ModuleOp>();
+
+  const char formatStringPrefix[] = "printfMsg";
+  unsigned stringNumber = 0;
+  mlir::SmallString<16> globalVarName;
+  mlir::spirv::GlobalVariableOp globalVar;
+
+  // formulate spirv global variable name
+  do {
+    globalVarName.clear();
+    (formatStringPrefix + llvm::Twine(stringNumber++))
+        .toStringRef(globalVarName);
+  } while (moduleOp.lookupSymbol(globalVarName));
+
+  auto i8Type = rewriter.getI8Type();
+  auto i32Type = rewriter.getI32Type();
+
+  unsigned scNum = 0;
+  auto createSpecConstant = [&](unsigned value) {
+    auto attr = rewriter.getI8IntegerAttr(value);
+    mlir::SmallString<16> specCstName;
+    (llvm::Twine(globalVarName) + "_sc" + llvm::Twine(scNum++))
+        .toStringRef(specCstName);
+
+    return rewriter.create<mlir::spirv::SpecConstantOp>(
+        loc, rewriter.getStringAttr(specCstName), attr);
+  };
+
+  // define GlobalVarOp with printf format string using SpecConstants
+  // and make composite of SpecConstants
+  {
+    mlir::Operation *parent =
+        mlir::SymbolTable::getNearestSymbolTable(gpuPrintfOp->getParentOp());
+
+    mlir::ConversionPatternRewriter::InsertionGuard guard(rewriter);
+
+    mlir::Block &entryBlock = *parent->getRegion(0).begin();
+    rewriter.setInsertionPointToStart(
+        &entryBlock); // insertion point at module level
+
+    // Create Constituents with SpecConstant to construct
+    // SpecConstantCompositeOp
+    llvm::SmallString<20> formatString(gpuPrintfOp.getFormat());
+    formatString.push_back('\0'); // Null terminate for C
+    mlir::SmallVector<mlir::Attribute, 4> constituents;
+    for (auto c : formatString) {
+      auto cSpecConstantOp = createSpecConstant(c);
+      constituents.push_back(mlir::SymbolRefAttr::get(cSpecConstantOp));
+    }
+
+    // Create specialization constant composite defined via spirv.SpecConstant
+    size_t contentSize = constituents.size();
+    auto globalType = mlir::spirv::ArrayType::get(i8Type, contentSize);
+    mlir::spirv::SpecConstantCompositeOp specCstComposite;
+    mlir::SmallString<16> specCstCompositeName;
+    (llvm::Twine(globalVarName) + "_scc").toStringRef(specCstCompositeName);
+    specCstComposite = rewriter.create<mlir::spirv::SpecConstantCompositeOp>(
+        loc, mlir::TypeAttr::get(globalType),
+        rewriter.getStringAttr(specCstCompositeName),
+        rewriter.getArrayAttr(constituents));
+
+    // Define GlobalVariable initialized from Constant Composite
+    globalVar = rewriter.create<mlir::spirv::GlobalVariableOp>(
+        loc,
+        mlir::spirv::PointerType::get(
+            globalType, mlir::spirv::StorageClass::UniformConstant),
+        globalVarName, mlir::FlatSymbolRefAttr::get(specCstComposite));
+
+    globalVar->setAttr("Constant", rewriter.getUnitAttr());
+  }
+
+  // Get SSA value of Global variable
+  mlir::Value globalPtr =
+      rewriter.create<mlir::spirv::AddressOfOp>(loc, globalVar);
+  mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
+      loc,
+      mlir::spirv::PointerType::get(i8Type,
+                                    mlir::spirv::StorageClass::UniformConstant),
+      globalPtr);
+
+  // Get printf arguments
+  auto argsRange = adaptor.getArgs();
+  mlir::SmallVector<mlir::Value, 4> printfArgs;
+  printfArgs.reserve(argsRange.size() + 1);
+  printfArgs.append(argsRange.begin(), argsRange.end());
+
+  rewriter.create<mlir::spirv::CLPrintfOp>(loc, i32Type, fmtStr, printfArgs);
+
+  rewriter.eraseOp(gpuPrintfOp);
+
+  return mlir::success();
+}
+
 //===----------------------------------------------------------------------===//
 // GPU To SPIRV Patterns.
 //===----------------------------------------------------------------------===//
@@ -630,5 +741,6 @@ void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
       SingleDimLaunchConfigConversion<gpu::SubgroupSizeOp,
                                       spirv::BuiltIn::SubgroupSize>,
       WorkGroupSizeConversion, GPUAllReduceConversion,
-      GPUSubgroupReduceConversion>(typeConverter, patterns.getContext());
+      GPUSubgroupReduceConversion, GPUPrintfConversion>(typeConverter,
+                                                        patterns.getContext());
 }
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 02d03b3a0faeee..89a72260290e22 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -309,6 +309,7 @@ LogicalResult spirv::Deserializer::processDecoration(ArrayRef<uint32_t> words) {
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
   case spirv::Decoration::RestrictPointer:
+  case spirv::Decoration::Constant:
     if (words.size() != 2) {
       return emitError(unknownLoc, "OpDecoration with ")
              << decorationName << "needs a single target <id>";
diff --git a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
index 40337e007bbf74..2252c339af0a75 100644
--- a/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/Serializer.cpp
@@ -272,6 +272,7 @@ LogicalResult Serializer::processDecorationAttr(Location loc, uint32_t resultID,
   case spirv::Decoration::RelaxedPrecision:
   case spirv::Decoration::Restrict:
   case spirv::Decoration::RestrictPointer:
+  case spirv::Decoration::Constant:
     // For unit attributes and decoration attributes, the args list
     // has no values so we do nothing.
     if (isa<UnitAttr, DecorationAttr>(attr))
diff --git a/mlir/test/Conversion/GPUToSPIRV/printf.mlir b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
new file mode 100644
index 00000000000000..4c77195f916014
--- /dev/null
+++ b/mlir/test/Conversion/GPUToSPIRV/printf.mlir
@@ -0,0 +1,71 @@
+// RUN: mlir-opt -allow-unregistered-dialect -split-input-file  -convert-gpu-to-spirv -verify-diagnostics %s -o - | FileCheck %s
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+  func.func @main() {
+    %c1 = arith.constant 1 : index
+
+    gpu.launch_func @kernels::@printf
+        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+        args()
+    return
+  }
+  
+  gpu.module @kernels {
+    // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+    // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+    // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+    // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant}  : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+    gpu.func @printf() kernel
+      attributes 
+        {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+          // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+          // CHECK-NEXT: [[FMTSTR_PTR:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+          // CHECK-NEXT {{%.*}} = spirv.CL.printf [[FMTSTR_PTR]] : (!spirv.ptr<i8, UniformConstant>) -> i32
+          gpu.printf "\nHello\n"
+          // CHECK: spirv.Return
+          gpu.return
+    }
+  }
+}
+
+// -----
+
+module attributes {
+  gpu.container_module,
+  spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Addresses, Int8, Kernel], []>, #spirv.resource_limits<>>
+} {
+  func.func @main() {
+    %c1   = arith.constant 1 : index
+    %c100 = arith.constant 100: i32
+    %cst_f32 = arith.constant 314.4: f32
+
+    gpu.launch_func @kernels1::@printf_args
+        blocks in (%c1, %c1, %c1) threads in (%c1, %c1, %c1)
+        args(%c100: i32, %cst_f32: f32)
+    return
+  }
+
+   gpu.module @kernels1 {
+    // CHECK: spirv.module @{{.*}} Physical32 OpenCL {
+    // CHECK-DAG: spirv.SpecConstant [[SPECCST:@.*]] = {{.*}} : i8
+    // CHECK-DAG: spirv.SpecConstantComposite [[SPECCSTCOMPOSITE:@.*]] ([[SPECCST]], {{.*}}) : !spirv.array<[[ARRAYSIZE:.*]] x i8>
+    // CHECK-DAG: spirv.GlobalVariable [[PRINTMSG:@.*]] initializer([[SPECCSTCOMPOSITE]]) {Constant}  : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+    gpu.func @printf_args(%arg0: i32, %arg1: f32) kernel
+      attributes {spirv.entry_point_abi = #spirv.entry_point_abi<>} {
+        %0 = gpu.block_id x
+        %1 = gpu.block_id y
+        %2 = gpu.thread_id x
+
+        // CHECK: [[FMTSTR_ADDR:%.*]] = spirv.mlir.addressof [[PRINTMSG]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant>
+        // CHECK-NEXT: [[FMTSTR_PTR1:%.*]] = spirv.Bitcast [[FMTSTR_ADDR]] : !spirv.ptr<!spirv.array<[[ARRAYSIZE]] x i8>, UniformConstant> to !spirv.ptr<i8, UniformConstant>
+        // CHECK-NEXT: {{%.*}} = spirv.CL.printf [[FMTSTR_PTR1]] : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}}, {{%.*}} : i32, f32, i32) -> i32
+        gpu.printf "\nHello, world : %d %f \n Thread id: %d\n" %arg0, %arg1, %2: i32, f32, index
+        
+        // CHECK: spirv.Return
+        gpu.return
+    }
+  }
+}
\ No newline at end of file
diff --git a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
index 81ba471d3f51e3..171087a167850f 100644
--- a/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/ocl-ops.mlir
@@ -275,8 +275,8 @@ func.func @rintvec(%arg0 : vector<3xf16>) -> () {
 //===----------------------------------------------------------------------===//
 // CHECK-LABEL: func.func @printf(
 func.func @printf(%arg0 : !spirv.ptr<i8, UniformConstant>, %arg1 : i32, %arg2 : i32) -> i32 {
-  // CHECK: spirv.CL.printf {{%.*}}, {{%.*}}, {{%.*}} : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
-  %0 = spirv.CL.printf %arg0, %arg1, %arg2 : (!spirv.ptr<i8, UniformConstant>, (i32, i32)) -> i32
+  // CHECK: spirv.CL.printf {{%.*}} : !spirv.ptr<i8, UniformConstant>({{%.*}}, {{%.*}} : i32, i32) -> i32
+  %0 = spirv.CL.printf %arg0 : !spirv.ptr<i8, UniformConstant>(%arg1, %arg2 : i32, i32) -> i32
   return %0 : i32
 }
 

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the contribution, look like a useful debugging feature to have! I left some comments.

I find it hard to follow the matchAndRewrite code -- could we brake it down a bit and introduce some helper functions?

mlir::Value fmtStr = rewriter.create<mlir::spirv::BitcastOp>(
loc,
mlir::spirv::PointerType::get(i8Type,
mlir::spirv::StorageClass::UniformConstant),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we instead check GV's storage class and use it instead of constant? Both SPIR-V and OpenCL have extensions relaxing printf string arg requirement, see KhronosGroup/SPIRV-Registry#148

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I got your suggestion, do you mean we can get the storage class from already defined globalVar op?

thanks for pointing out to the requirements docs, while implementing this, when I checked OpenCL specs for printf , they mention fmt string needs to be in constant address space, hence I used UniformConstant here- https://registry.khronos.org/OpenCL/specs/3.0-unified/html/OpenCL_C.html#printf

This change contains following:
	- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
	- Fixes Constant decoration parsing for spirv GlobalVariable.
	- minor modification to spirv.CL.printf op assembly format.
@drprajap
Copy link
Contributor Author

Thanks for the contribution, look like a useful debugging feature to have! I left some comments.

I find it hard to follow the matchAndRewrite code -- could we brake it down a bit and introduce some helper functions?

Thank you @kuhar for your feedback and sorry for the delay in addressing them.
I added more comments and tried to clarify the functionality, let me know if it makes it a bit more clear.

@kuhar
Copy link
Member

kuhar commented Sep 18, 2024

@kuhar
Copy link
Member

kuhar commented Sep 18, 2024

@drprajap could you click the 'Resolve' button on past comments that you believe have been addressed? This will make it easier for me to tell what are the remaining open discussion threads.

@drprajap
Copy link
Contributor Author

drprajap commented Sep 27, 2024

@drprajap there are some test failures: https://buildkite.com/llvm-project/github-pull-requests/builds/101724#019202b6-5c05-43e1-aa01-48aba3a6a4de/6-202

Yes, sorry I missed to update test cases with new printf format changes, updated and verified locally now.

@drprajap
Copy link
Contributor Author

@drprajap could you click the 'Resolve' button on past comments that you believe have been addressed? This will make it easier for me to tell what are the remaining open discussion threads.

@kuhar I addressed all the feedback and resolved comments, please check it out and let me know if looks okay. Thanks so much for your valuable feedback.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall looks good, just some remaining issues

@drprajap
Copy link
Contributor Author

Overall looks good, just some remaining issues

Thank you. Addressed remaining issues.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

minor formatting changes

Co-authored-by: Jakub Kuderski <[email protected]>
@drprajap
Copy link
Contributor Author

LGTM

Thanks, I do not have write permission, can you please help merge it if possible? else I can ask other team members.

@kuhar
Copy link
Member

kuhar commented Sep 28, 2024

LGTM

Thanks, I do not have write permission, can you please help merge it if possible? else I can ask other team members.

Sure, I will merge this on Monday. Please ping me if I forget to do so.

@drprajap
Copy link
Contributor Author

LGTM

Thanks, I do not have write permission, can you please help merge it if possible? else I can ask other team members.

Sure, I will merge this on Monday. Please ping me if I forget to do so.

Gentle ping to merge this. Thank you :)

@kuhar kuhar merged commit f8ba021 into llvm:main Sep 30, 2024
8 checks passed
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…78510)

This change contains following:
	- adds lowering of printf op to spirv.CL.printf op in GPUToSPIRV pass.
	- Fixes Constant decoration parsing for spirv GlobalVariable.
	- minor modification to spirv.CL.printf op assembly format.

---------

Co-authored-by: Jakub Kuderski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants