Skip to content

[mlir][GPU] gpu.printf: Do not emit duplicate format strings #110504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 1, 2024

Conversation

matthias-springer
Copy link
Member

Even if the same format string is used multiple times, emit just one LLVM:GlobalOp.

Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
@llvmbot
Copy link
Member

llvmbot commented Sep 30, 2024

@llvm/pr-subscribers-mlir-gpu

@llvm/pr-subscribers-mlir

Author: Matthias Springer (matthias-springer)

Changes

Even if the same format string is used multiple times, emit just one LLVM:GlobalOp.


Full diff: https://github.com/llvm/llvm-project/pull/110504.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp (+49-57)
  • (modified) mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir (+7)
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5b590a457f7714..06d759b5f54175 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
   return stringConstName;
 }
 
+/// Create an global that contains the given format string. If a global with
+/// the same format string exists already in the module, return that global.
+static LLVM::GlobalOp getOrCreateFormatStringConstant(
+    OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
+    StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
+  llvm::SmallString<20> formatString(str);
+  formatString.push_back('\0'); // Null terminate for C
+  auto globalType =
+      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+  StringAttr attr = b.getStringAttr(formatString);
+
+  // Try to find existing global.
+  for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+    if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+        globalOp.getValueAttr() == attr &&
+        globalOp.getAlignment().value_or(0) == alignment &&
+        globalOp.getAddrSpace() == addrSpace)
+      return globalOp;
+
+  // Not found: create new global.
+  OpBuilder::InsertionGuard guard(b);
+  b.setInsertionPointToStart(moduleOp.getBody());
+  SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
+  return b.create<LLVM::GlobalOp>(loc, globalType,
+                                  /*isConstant=*/true, LLVM::Linkage::Internal,
+                                  name, attr, alignment, addrSpace);
+}
+
 template <typename T>
 static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
                                             ConversionPatternRewriter &rewriter,
@@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
   auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
   Value printfDesc = printfBeginCall.getResult();
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  size_t formatStringSize = formatString.size_in_bytes();
-
-  auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString));
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
 
   // Get a pointer to the format string's first element and pass it to printf()
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
       loc,
       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
       global.getSymNameAttr());
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
-  Value stringLen =
-      rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringLen = rewriter.create<LLVM::ConstantOp>(
+      loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
 
   Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
   Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
@@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
   LLVM::LLVMFuncOp printfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
+      addressSpace);
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
       loc,
       LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
       global.getSymNameAttr());
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
 
   // Construct arguments and function call
   auto argsRange = adaptor.getArgs();
@@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
   LLVM::LLVMFuncOp vprintfDecl =
       getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
 
-  // Get a unique global name for the format.
-  SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
-  llvm::SmallString<20> formatString(adaptor.getFormat());
-  formatString.push_back('\0'); // Null terminate for C
-  auto globalType =
-      LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
-  LLVM::GlobalOp global;
-  {
-    ConversionPatternRewriter::InsertionGuard guard(rewriter);
-    rewriter.setInsertionPointToStart(moduleOp.getBody());
-    global = rewriter.create<LLVM::GlobalOp>(
-        loc, globalType,
-        /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
-        rewriter.getStringAttr(formatString), /*allignment=*/0);
-  }
+  // Create the global op or find an existing one.
+  LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+      rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
 
   // Get a pointer to the format string's first element
   Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
-  Value stringStart = rewriter.create<LLVM::GEPOp>(
-      loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+  Value stringStart =
+      rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+                                   globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
   SmallVector<Type> types;
   SmallVector<Value> args;
   // Promote and pack the arguments into a stack allocation.
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ad4e9ec1791a77..748dfe8c68fc7e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -610,6 +610,13 @@ gpu.module @test_module_29 {
     // CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
     // CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
     gpu.printf "Hello, world\n"
+
+    // Make sure that the same global is reused.
+    // CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+    // CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+    // CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+    gpu.printf "Hello, world\n"
+
     gpu.return
   }
 

Copy link
Member

@grypp grypp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The improvement is nice all in all.

StringAttr attr = b.getStringAttr(formatString);

// Try to find existing global.
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
Copy link
Member

@grypp grypp Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you look at the getDynamicSharedMemorySymbol below, it does something very similar. Can we unify them?

I don't want to block the PR because of this, but I'd like to have this unification at some point.

@Hardcode84
Copy link
Contributor

Hardcode84 commented Sep 30, 2024

Side comment: I kind of concerned that it iterates over all globals for each printf making it O(N^2) pass. Can we build a map once and reuse it? I do not want to block anyone either, so we can do it in separate PR.

@matthias-springer
Copy link
Member Author

Side comment: I kind of concerned that it iterates over all globals for each printf making it O(N^2) pass. Can we build a map once and reuse it? I do not want to block anyone either, so we can do it in separate PR.

There are two options to do that:

  1. Pass a DenseMap object to populateGpuToNVVMConversionPatterns. Not ideal.
  2. Store the cache directly in the lowering pattern. matchAndRewrite is a const function, so I would have to mark the cache as mutable. Not sure if that's ideal either.

@matthias-springer
Copy link
Member Author

Actually, a cache is dangerous because there could be a pattern that erases a cached op and now we have a dangling pointer.

@Hardcode84
Copy link
Contributor

Actually, a cache is dangerous because there could be a pattern that erases a cached op and now we have a dangling pointer.

Yeah. Another possible solution is to leave pattern as is and instead have a separate pass, which merges globals. This will also be useful outside this specific lowering.

@Hardcode84
Copy link
Contributor

Hardcode84 commented Sep 30, 2024

Actually, a cache is dangerous because there could be a pattern that erases a cached op and now we have a dangling pointer.

Also, you can potentially track all erases as all IR modifications must be done through rewriter anyway. But it will require a some changes and design discussions on core infra, I believe.

@matthias-springer
Copy link
Member Author

That's an interesting idea, but I'm not sure if it's generally safe to CSE equivalent symbols. E.g., what if the symbol is a memref.global. Merging two memref.global turns two different allocations into one. I asked on Discord...

@matthias-springer
Copy link
Member Author

I'm going to merge this now. The number of symbols should be fairly small, so I don't expect a performance problem. Let's discuss the symbol-CSE approach on Discord/Discourse and change the implementation here if we decide to go with it.

@matthias-springer matthias-springer merged commit 2da417e into main Oct 1, 2024
11 checks passed
@matthias-springer matthias-springer deleted the users/matthias-springer/gpu_format_str branch October 1, 2024 07:12
Sterling-Augustine pushed a commit to Sterling-Augustine/llvm-project that referenced this pull request Oct 3, 2024
…0504)

Even if the same format string is used multiple times, emit just one
`LLVM:GlobalOp`.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants