-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR][LLVM] Fix memory explosion when converting global variable bodies in ModuleTranslation #82708
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ies in ModuleTranslation There is memory explosion when converting the body or initializer region of a large global variable, e.g. a constant array. For example, when translating a constant array of 100000 strings: llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<100000 x ptr<i8>> { %0 = llvm.mlir.undef : !llvm.array<100000 x ptr<i8>> %1 = llvm.mlir.addressof @om_1 : !llvm.ptr<array<1 x i8>> %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %3 = llvm.insertvalue %2, %0[0] : !llvm.array<100000 x ptr<i8>> %4 = llvm.mlir.addressof @om_2 : !llvm.ptr<array<1 x i8>> %5 = llvm.getelementptr %4[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %6 = llvm.insertvalue %5, %3[1] : !llvm.array<100000 x ptr<i8>> %7 = llvm.mlir.addressof @om_3 : !llvm.ptr<array<1 x i8>> %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %9 = llvm.insertvalue %8, %6[2] : !llvm.array<100000 x ptr<i8>> %10 = llvm.mlir.addressof @om_4 : !llvm.ptr<array<1 x i8>> %11 = llvm.getelementptr %10[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %12 = llvm.insertvalue %11, %9[3] : !llvm.array<100000 x ptr<i8>> ... (ignore the remaining part) } where @om_1, @om_2, ... are string global constants. Each time an operation is converted to LLVM, a new constant is created. When it comes to llvm.insertvalue, a new constant array of 100000 elements is created and the old constant array (input) is not destroyed. This causes memory explosion. We observed that, on a system with 128 GB memory, the translation of 100000 elements got killed due to using up all the memory. On a system with 64 GB, 65536 elements was enough to cause the translation killed. There is a previous patch (https://reviews.llvm.org/D148487) which fix this issue but was reverted for llvm#62802 The old patch checks generated constants and destroyed them if there is no use. But the check of use for the constant is too early, which cause the constant be removed before use. This new patch added a map was added a map to save expected use count for a constant. Then decrease when reach each use. And only erase the constant when the use count reach to zero With new patch, the repro in llvm#62802 finished correctly.
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-llvm Author: Xiang Li (python3kgae) ChangesThere is memory explosion when converting the body or initializer region of a large global variable, e.g. a constant array. For example, when translating a constant array of 100000 strings: llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<100000 x ptr<i8>> {
} where @om_1, @om_2, ... are string global constants. Each time an operation is converted to LLVM, a new constant is created. When it comes to llvm.insertvalue, a new constant array of 100000 elements is created and the old constant array (input) is not destroyed. This causes memory explosion. We observed that, on a system with 128 GB memory, the translation of 100000 elements got killed due to using up all the memory. On a system with 64 GB, 65536 elements was enough to cause the translation killed. There is a previous patch (https://reviews.llvm.org/D148487) which fix this issue but was reverted for #62802 The old patch checks generated constants and destroyed them if there is no use. But the check of use for the constant is too early, which cause the constant be removed before use. This new patch added a map was added a map to save expected use count for a constant. Then decrease when reach each use. With new patch, the repro in #62802 finished correctly. Full diff: https://github.com/llvm/llvm-project/pull/82708.diff 2 Files Affected:
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
index ee8fffd959c883..64c37b1d5fa961 100644
--- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
@@ -51,11 +51,15 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <optional>
+#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
+
using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
@@ -1042,17 +1046,77 @@ LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
if (Block *initializer = op.getInitializerBlock()) {
llvm::IRBuilder<> builder(llvmModule->getContext());
+
+ int numConstantsHit = 0;
+ int numConstantsErased = 0;
+ DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
+
for (auto &op : initializer->without_terminator()) {
if (failed(convertOperation(op, builder)) ||
!isa<llvm::Constant>(lookupValue(op.getResult(0))))
return emitError(op.getLoc(), "unemittable constant value");
+
+ // When emitting an LLVM constant, a new constant is created and the old
+ // constant may become dangling and take space. We should remove the
+ // dangling constants to avoid memory explosion especially for constant
+ // arrays whose number of elements is large.
+ // Because multiple operations may refer to the same constant, we need
+ // to count the number of uses of each constant array and remove it only
+ // when the count becomes zero.
+ if (op.getNumResults() == 1) {
+ Value result = op.getResult(0);
+ auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(result));
+ if (!cst)
+ continue;
+ numConstantsHit++;
+ auto iter = constantAggregateUseMap.find(cst);
+ int numUsers = std::distance(result.use_begin(), result.use_end());
+ if (iter == constantAggregateUseMap.end())
+ constantAggregateUseMap.try_emplace(cst, numUsers);
+ else
+ iter->second += numUsers;
+ }
+ for (Value v : op.getOperands()) {
+ auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
+ if (!cst)
+ continue;
+ auto iter = constantAggregateUseMap.find(cst);
+ assert(iter != constantAggregateUseMap.end() && "constant not found");
+ iter->second--;
+ if (iter->second == 0) {
+ cst->removeDeadConstantUsers();
+ if (cst->user_empty()) {
+ cst->destroyConstant();
+ numConstantsErased++;
+ }
+ constantAggregateUseMap.erase(iter);
+ }
+ }
}
+
ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
llvm::Constant *cst =
cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
global->setInitializer(cst);
+
+ // Try to remove the dangling constants again after all operations are
+ // converted.
+ for (auto it : constantAggregateUseMap) {
+ auto cst = it.first;
+ cst->removeDeadConstantUsers();
+ if (cst->user_empty()) {
+ cst->destroyConstant();
+ numConstantsErased++;
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "Convert initializer for " << op.getName() << "\n";
+ llvm::dbgs() << numConstantsHit << " new constants hit\n";
+ llvm::dbgs()
+ << numConstantsErased << " dangling constants erased\n";);
}
}
diff --git a/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir b/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
new file mode 100644
index 00000000000000..b3b5d540ae88fc
--- /dev/null
+++ b/mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
@@ -0,0 +1,72 @@
+// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s
+
+// CHECK: Convert initializer for dup_const
+// CHECK: 6 new constants hit
+// CHECK: 3 dangling constants erased
+// CHECK: Convert initializer for unique_const
+// CHECK: 6 new constants hit
+// CHECK: 5 dangling constants erased
+
+
+// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }
+
+llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
+ %c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
+ %c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
+
+ %empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
+ %a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>
+
+ %empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
+ %a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>
+
+ %empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
+ %a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>
+
+// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
+// should not delete it before all of the uses of the ConstantAggregate finished.
+
+ %a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
+ %a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
+ %a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
+ %empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+ %r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+ %r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+ %r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+ llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+ }
+
+// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }
+
+llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
+ %c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
+ %c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
+
+ %c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
+ %c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64
+
+ %c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
+ %c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64
+
+ %2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+ %3 = llvm.mlir.undef : !llvm.array<2 x f64>
+
+ %4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
+ %5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>
+
+ %6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+ %7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
+ %8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>
+
+ %9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+ %10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
+ %11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>
+
+ %12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+
+ llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
+}
|
// when the count becomes zero. | ||
if (op.getNumResults() == 1) { | ||
Value result = op.getResult(0); | ||
auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(result)); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: You already did a lookupValue line 1056: can we restructure to code to perform a single lookup?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
continue; | ||
numConstantsHit++; | ||
auto iter = constantAggregateUseMap.find(cst); | ||
int numUsers = std::distance(result.use_begin(), result.use_end()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is O(N) on a linked list, hopefully small though...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Most of the ops will be insertvalue which only used once for build constant.
} | ||
constantAggregateUseMap.erase(iter); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel I'm missing something, why can't we have a DenseSet<llvm::ConstantAggregate *> constantAggregates;
that we keep around and every iteration of the outer loop we could just do:
for (llvm::ConstantAggregate *cst : constantAggregates) {
cst->removeDeadConstantUsers();
if (cst->user_empty()) {
cst->destroyConstant(); // TODO: add removing from the DenseSet
}
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually this is done also line 1104, but I don't get why we need the dance around constantAggregateUseMap
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I feel I'm missing something, why can't we have a
DenseSet<llvm::ConstantAggregate *> constantAggregates;
that we keep around and every iteration of the outer loop we could just do:for (llvm::ConstantAggregate *cst : constantAggregates) { cst->removeDeadConstantUsers(); if (cst->user_empty()) { cst->destroyConstant(); // TODO: add removing from the DenseSet } }
My understanding is that we're converting all the ops in intializer to llvm::Constant.
When a later op is converted into ConstantAggregate, it will add use of the constant converted from early op.
So we cannot remove a Constant unless all the ops which use it are converted.
Only check user_empty is not enough because the user might not yet be generated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I get it: and how critical is it to garbage collect during the interpretation of the initializer block instead of just at the end of the block?
(line 1104)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The worst block is building a constant with type [75525 x [100 x f32]] for my test case.
I can test only do the removing at the end of the block, code will be much simplified if it works.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also: isn't it something that MLIR can constant fold?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The constant will fold, but the input constant is still in memory.
And too many temp constants are created which eats all the memory.
I tried to avoid creating these insertvalue which will be converted to temp constant.
Here's what I got with experiment
llvm.mlir.global private @krnl_global_28(dense<["", "ab", "xyz", ...]> : tensor<75525x!krnl.string>) {addr_space = 0 : i32, alignment = 8 : i64} : !llvm.array<75525 x i64>
It crashes because !llvm.array<75525 x i64> is not correct type.
Any suggestion for how to express array of string for llvm dialect?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be just a concat of all the strings into a single dense attribute of the right type?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The worst block is building a constant with type [75525 x [100 x f32]] for my test case. I can test only do the removing at the end of the block, code will be much simplified if it works.
Removing at the end not work :( One block already used all of the memory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't it be just a concat of all the strings into a single dense attribute of the right type?
Thanks for this suggestion. It saves a lot of time for create globalOp for each string!
But to create the array of string, I still have to create a lot of insertvalue.
Is it possible to express the following llvm ir with LLVM dialect?
@c = dso_local global [7 x i8] c"abcxby\00", align 1
@a = dso_local global [3 x ptr] [ptr @c, ptr getelementptr (i8, ptr @c, i64 2), ptr getelementptr (i8, ptr @c, i64 5)]
Then there will be no insertvalue needed.
This PR will only avoid being out of memory, conversion of those insertvalues still takes a lot of cycles.
auto iter = constantAggregateUseMap.find(cst); | ||
int numUsers = std::distance(result.use_begin(), result.use_end()); | ||
if (iter == constantAggregateUseMap.end()) | ||
constantAggregateUseMap.try_emplace(cst, numUsers); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of find
followed by an insertion, you can just try to insert and update the existing entry if the insertion "failed": that means a single lookup.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated to try_emplace.
Use try_emplace for map.
✅ With the latest revision this PR passed the C/C++ code formatter. |
…ated in later conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems reasonable.
iterator->second += numUsers; | ||
} | ||
} | ||
for (Value v : op.getOperands()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Worth adding a comment for this loop I think
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated.
Add more comment.
Hi! This new test only pass if tool was build with debug on. Without it it fails with Maybe it's possible to modify the test to check if tool was compiled with debug first? |
Sure. I'll modify the test. |
I saw another test which uses -debug-only added // REQUIRES: asserts, is it good enough to limit the test to debug on? |
PR created for limit the test #83145 |
There is memory explosion when converting the body or initializer region of a large global variable, e.g. a constant array.
For example, when translating a constant array of 100000 strings:
llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<100000 x ptr> {
%0 = llvm.mlir.undef : !llvm.array<100000 x ptr>
%1 = llvm.mlir.addressof @om_1 : !llvm.ptr<array<1 x i8>>
%2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr
%3 = llvm.insertvalue %2, %0[0] : !llvm.array<100000 x ptr>
%4 = llvm.mlir.addressof @om_2 : !llvm.ptr<array<1 x i8>>
%5 = llvm.getelementptr %4[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr
%6 = llvm.insertvalue %5, %3[1] : !llvm.array<100000 x ptr>
%7 = llvm.mlir.addressof @om_3 : !llvm.ptr<array<1 x i8>>
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr
%9 = llvm.insertvalue %8, %6[2] : !llvm.array<100000 x ptr>
%10 = llvm.mlir.addressof @om_4 : !llvm.ptr<array<1 x i8>>
%11 = llvm.getelementptr %10[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr
%12 = llvm.insertvalue %11, %9[3] : !llvm.array<100000 x ptr>
}
where @om_1, @om_2, ... are string global constants.
Each time an operation is converted to LLVM, a new constant is created. When it comes to llvm.insertvalue, a new constant array of 100000 elements is created and the old constant array (input) is not destroyed. This causes memory explosion. We observed that, on a system with 128 GB memory, the translation of 100000 elements got killed due to using up all the memory. On a system with 64 GB, 65536 elements was enough to cause the translation killed.
There is a previous patch (https://reviews.llvm.org/D148487) which fix this issue but was reverted for #62802
The old patch checks generated constants and destroyed them if there is no use. But the check of use for the constant is too early, which cause the constant be removed before use.
This new patch added a map was added a map to save expected use count for a constant. Then decrease when reach each use.
And only erase the constant when the use count reach to zero
With new patch, the repro in #62802 finished correctly.