-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][GPU] gpu.printf: Do not emit duplicate format strings #110504
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
@llvm/pr-subscribers-mlir-gpu @llvm/pr-subscribers-mlir Author: Matthias Springer (matthias-springer) ChangesEven if the same format string is used multiple times, emit just one Full diff: https://github.com/llvm/llvm-project/pull/110504.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
index 5b590a457f7714..06d759b5f54175 100644
--- a/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
+++ b/mlir/lib/Conversion/GPUCommon/GPUOpsLowering.cpp
@@ -340,6 +340,34 @@ static SmallString<16> getUniqueFormatGlobalName(gpu::GPUModuleOp moduleOp) {
return stringConstName;
}
+/// Create an global that contains the given format string. If a global with
+/// the same format string exists already in the module, return that global.
+static LLVM::GlobalOp getOrCreateFormatStringConstant(
+ OpBuilder &b, Location loc, gpu::GPUModuleOp moduleOp, Type llvmI8,
+ StringRef str, uint64_t alignment = 0, unsigned addrSpace = 0) {
+ llvm::SmallString<20> formatString(str);
+ formatString.push_back('\0'); // Null terminate for C
+ auto globalType =
+ LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
+ StringAttr attr = b.getStringAttr(formatString);
+
+ // Try to find existing global.
+ for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>())
+ if (globalOp.getGlobalType() == globalType && globalOp.getConstant() &&
+ globalOp.getValueAttr() == attr &&
+ globalOp.getAlignment().value_or(0) == alignment &&
+ globalOp.getAddrSpace() == addrSpace)
+ return globalOp;
+
+ // Not found: create new global.
+ OpBuilder::InsertionGuard guard(b);
+ b.setInsertionPointToStart(moduleOp.getBody());
+ SmallString<16> name = getUniqueFormatGlobalName(moduleOp);
+ return b.create<LLVM::GlobalOp>(loc, globalType,
+ /*isConstant=*/true, LLVM::Linkage::Internal,
+ name, attr, alignment, addrSpace);
+}
+
template <typename T>
static LLVM::LLVMFuncOp getOrDefineFunction(T &moduleOp, const Location loc,
ConversionPatternRewriter &rewriter,
@@ -391,33 +419,20 @@ LogicalResult GPUPrintfOpToHIPLowering::matchAndRewrite(
auto printfBeginCall = rewriter.create<LLVM::CallOp>(loc, ocklBegin, zeroI64);
Value printfDesc = printfBeginCall.getResult();
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- size_t formatStringSize = formatString.size_in_bytes();
-
- auto globalType = LLVM::LLVMArrayType::get(llvmI8, formatStringSize);
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString));
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
// Get a pointer to the format string's first element and pass it to printf()
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
- Value stringLen =
- rewriter.create<LLVM::ConstantOp>(loc, llvmI64, formatStringSize);
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringLen = rewriter.create<LLVM::ConstantOp>(
+ loc, llvmI64, cast<StringAttr>(global.getValueAttr()).size());
Value oneI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 1);
Value zeroI32 = rewriter.create<LLVM::ConstantOp>(loc, llvmI32, 0);
@@ -486,30 +501,19 @@ LogicalResult GPUPrintfOpToLLVMCallLowering::matchAndRewrite(
LLVM::LLVMFuncOp printfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "printf", printfType);
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString), /*allignment=*/0, addressSpace);
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat(), /*alignment=*/0,
+ addressSpace);
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(
loc,
LLVM::LLVMPointerType::get(rewriter.getContext(), global.getAddrSpace()),
global.getSymNameAttr());
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
// Construct arguments and function call
auto argsRange = adaptor.getArgs();
@@ -541,27 +545,15 @@ LogicalResult GPUPrintfOpToVPrintfLowering::matchAndRewrite(
LLVM::LLVMFuncOp vprintfDecl =
getOrDefineFunction(moduleOp, loc, rewriter, "vprintf", vprintfType);
- // Get a unique global name for the format.
- SmallString<16> stringConstName = getUniqueFormatGlobalName(moduleOp);
-
- llvm::SmallString<20> formatString(adaptor.getFormat());
- formatString.push_back('\0'); // Null terminate for C
- auto globalType =
- LLVM::LLVMArrayType::get(llvmI8, formatString.size_in_bytes());
- LLVM::GlobalOp global;
- {
- ConversionPatternRewriter::InsertionGuard guard(rewriter);
- rewriter.setInsertionPointToStart(moduleOp.getBody());
- global = rewriter.create<LLVM::GlobalOp>(
- loc, globalType,
- /*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
- rewriter.getStringAttr(formatString), /*allignment=*/0);
- }
+ // Create the global op or find an existing one.
+ LLVM::GlobalOp global = getOrCreateFormatStringConstant(
+ rewriter, loc, moduleOp, llvmI8, adaptor.getFormat());
// Get a pointer to the format string's first element
Value globalPtr = rewriter.create<LLVM::AddressOfOp>(loc, global);
- Value stringStart = rewriter.create<LLVM::GEPOp>(
- loc, ptrType, globalType, globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
+ Value stringStart =
+ rewriter.create<LLVM::GEPOp>(loc, ptrType, global.getGlobalType(),
+ globalPtr, ArrayRef<LLVM::GEPArg>{0, 0});
SmallVector<Type> types;
SmallVector<Value> args;
// Promote and pack the arguments into a stack allocation.
diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
index ad4e9ec1791a77..748dfe8c68fc7e 100644
--- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
+++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir
@@ -610,6 +610,13 @@ gpu.module @test_module_29 {
// CHECK-NEXT: %[[ALLOC:.*]] = llvm.alloca %[[O]] x !llvm.struct<()> : (i64) -> !llvm.ptr
// CHECK-NEXT: llvm.call @vprintf(%[[FORMATSTART]], %[[ALLOC]]) : (!llvm.ptr, !llvm.ptr) -> i32
gpu.printf "Hello, world\n"
+
+ // Make sure that the same global is reused.
+ // CHECK: %[[FORMATSTR2:.*]] = llvm.mlir.addressof @[[$PRINT_GLOBAL0]] : !llvm.ptr
+ // CHECK: %[[FORMATSTART2:.*]] = llvm.getelementptr %[[FORMATSTR2]][0, 0] : (!llvm.ptr) -> !llvm.ptr, !llvm.array<14 x i8>
+ // CHECK: llvm.call @vprintf(%[[FORMATSTART2]], %{{.*}}) : (!llvm.ptr, !llvm.ptr) -> i32
+ gpu.printf "Hello, world\n"
+
gpu.return
}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The improvement is nice all in all.
StringAttr attr = b.getStringAttr(formatString); | ||
|
||
// Try to find existing global. | ||
for (auto globalOp : moduleOp.getOps<LLVM::GlobalOp>()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you look at the getDynamicSharedMemorySymbol
below, it does something very similar. Can we unify them?
I don't want to block the PR because of this, but I'd like to have this unification at some point.
Side comment: I kind of concerned that it iterates over all globals for each printf making it |
There are two options to do that:
|
Actually, a cache is dangerous because there could be a pattern that erases a cached op and now we have a dangling pointer. |
Yeah. Another possible solution is to leave pattern as is and instead have a separate pass, which merges globals. This will also be useful outside this specific lowering. |
Also, you can potentially track all erases as all IR modifications must be done through |
That's an interesting idea, but I'm not sure if it's generally safe to CSE equivalent symbols. E.g., what if the symbol is a |
I'm going to merge this now. The number of symbols should be fairly small, so I don't expect a performance problem. Let's discuss the symbol-CSE approach on Discord/Discourse and change the implementation here if we decide to go with it. |
…0504) Even if the same format string is used multiple times, emit just one `LLVM:GlobalOp`.
Even if the same format string is used multiple times, emit just one
LLVM:GlobalOp
.