Skip to content

Commit c11627c

Browse files
authored
[MLIR][LLVM] Fix memory explosion when converting global variable bodies in ModuleTranslation (#82708)
There is memory explosion when converting the body or initializer region of a large global variable, e.g. a constant array. For example, when translating a constant array of 100000 strings: llvm.mlir.global internal constant @cats_strings() {addr_space = 0 : i32, alignment = 16 : i64} : !llvm.array<100000 x ptr<i8>> { %0 = llvm.mlir.undef : !llvm.array<100000 x ptr<i8>> %1 = llvm.mlir.addressof @om_1 : !llvm.ptr<array<1 x i8>> %2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %3 = llvm.insertvalue %2, %0[0] : !llvm.array<100000 x ptr<i8>> %4 = llvm.mlir.addressof @om_2 : !llvm.ptr<array<1 x i8>> %5 = llvm.getelementptr %4[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %6 = llvm.insertvalue %5, %3[1] : !llvm.array<100000 x ptr<i8>> %7 = llvm.mlir.addressof @om_3 : !llvm.ptr<array<1 x i8>> %8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %9 = llvm.insertvalue %8, %6[2] : !llvm.array<100000 x ptr<i8>> %10 = llvm.mlir.addressof @om_4 : !llvm.ptr<array<1 x i8>> %11 = llvm.getelementptr %10[0, 0] : (!llvm.ptr<array<1 x i8>>) -> !llvm.ptr<i8> %12 = llvm.insertvalue %11, %9[3] : !llvm.array<100000 x ptr<i8>> ... (ignore the remaining part) } where @om_1, @om_2, ... are string global constants. Each time an operation is converted to LLVM, a new constant is created. When it comes to llvm.insertvalue, a new constant array of 100000 elements is created and the old constant array (input) is not destroyed. This causes memory explosion. We observed that, on a system with 128 GB memory, the translation of 100000 elements got killed due to using up all the memory. On a system with 64 GB, 65536 elements was enough to cause the translation killed. There is a previous patch (https://reviews.llvm.org/D148487) which fix this issue but was reverted for #62802 The old patch checks generated constants and destroyed them if there is no use. But the check of use for the constant is too early, which cause the constant be removed before use. This new patch added a map was added a map to save expected use count for a constant. Then decrease when reach each use. And only erase the constant when the use count reach to zero With new patch, the repro in #62802 finished correctly.
1 parent 86f6caa commit c11627c

File tree

2 files changed

+141
-2
lines changed

2 files changed

+141
-2
lines changed

mlir/lib/Target/LLVMIR/ModuleTranslation.cpp

Lines changed: 69 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,15 @@
5151
#include "llvm/IR/MDBuilder.h"
5252
#include "llvm/IR/Module.h"
5353
#include "llvm/IR/Verifier.h"
54+
#include "llvm/Support/Debug.h"
55+
#include "llvm/Support/raw_ostream.h"
5456
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
5557
#include "llvm/Transforms/Utils/Cloning.h"
5658
#include "llvm/Transforms/Utils/ModuleUtils.h"
5759
#include <optional>
5860

61+
#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"
62+
5963
using namespace mlir;
6064
using namespace mlir::LLVM;
6165
using namespace mlir::LLVM::detail;
@@ -1042,17 +1046,80 @@ LogicalResult ModuleTranslation::convertGlobals() {
10421046
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
10431047
if (Block *initializer = op.getInitializerBlock()) {
10441048
llvm::IRBuilder<> builder(llvmModule->getContext());
1049+
1050+
int numConstantsHit = 0;
1051+
int numConstantsErased = 0;
1052+
DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;
1053+
10451054
for (auto &op : initializer->without_terminator()) {
1046-
if (failed(convertOperation(op, builder)) ||
1047-
!isa<llvm::Constant>(lookupValue(op.getResult(0))))
1055+
if (failed(convertOperation(op, builder)))
1056+
return emitError(op.getLoc(), "fail to convert global initializer");
1057+
auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
1058+
if (!cst)
10481059
return emitError(op.getLoc(), "unemittable constant value");
1060+
1061+
// When emitting an LLVM constant, a new constant is created and the old
1062+
// constant may become dangling and take space. We should remove the
1063+
// dangling constants to avoid memory explosion especially for constant
1064+
// arrays whose number of elements is large.
1065+
// Because multiple operations may refer to the same constant, we need
1066+
// to count the number of uses of each constant array and remove it only
1067+
// when the count becomes zero.
1068+
if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
1069+
numConstantsHit++;
1070+
Value result = op.getResult(0);
1071+
int numUsers = std::distance(result.use_begin(), result.use_end());
1072+
auto [iterator, inserted] =
1073+
constantAggregateUseMap.try_emplace(agg, numUsers);
1074+
if (!inserted) {
1075+
// Key already exists, update the value
1076+
iterator->second += numUsers;
1077+
}
1078+
}
1079+
// Scan the operands of the operation to decrement the use count of
1080+
// constants. Erase the constant if the use count becomes zero.
1081+
for (Value v : op.getOperands()) {
1082+
auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
1083+
if (!cst)
1084+
continue;
1085+
auto iter = constantAggregateUseMap.find(cst);
1086+
assert(iter != constantAggregateUseMap.end() && "constant not found");
1087+
iter->second--;
1088+
if (iter->second == 0) {
1089+
// NOTE: cannot call removeDeadConstantUsers() here because it
1090+
// may remove the constant which has uses not be converted yet.
1091+
if (cst->user_empty()) {
1092+
cst->destroyConstant();
1093+
numConstantsErased++;
1094+
}
1095+
constantAggregateUseMap.erase(iter);
1096+
}
1097+
}
10491098
}
1099+
10501100
ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
10511101
llvm::Constant *cst =
10521102
cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
10531103
auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
10541104
if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
10551105
global->setInitializer(cst);
1106+
1107+
// Try to remove the dangling constants again after all operations are
1108+
// converted.
1109+
for (auto it : constantAggregateUseMap) {
1110+
auto cst = it.first;
1111+
cst->removeDeadConstantUsers();
1112+
if (cst->user_empty()) {
1113+
cst->destroyConstant();
1114+
numConstantsErased++;
1115+
}
1116+
}
1117+
1118+
LLVM_DEBUG(llvm::dbgs()
1119+
<< "Convert initializer for " << op.getName() << "\n";
1120+
llvm::dbgs() << numConstantsHit << " new constants hit\n";
1121+
llvm::dbgs()
1122+
<< numConstantsErased << " dangling constants erased\n";);
10561123
}
10571124
}
10581125

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s
2+
3+
// CHECK: Convert initializer for dup_const
4+
// CHECK: 6 new constants hit
5+
// CHECK: 3 dangling constants erased
6+
// CHECK: Convert initializer for unique_const
7+
// CHECK: 6 new constants hit
8+
// CHECK: 5 dangling constants erased
9+
10+
11+
// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }
12+
13+
llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
14+
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
15+
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
16+
17+
%empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
18+
%a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>
19+
20+
%empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
21+
%a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>
22+
23+
%empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
24+
%a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>
25+
26+
// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
27+
// should not delete it before all of the uses of the ConstantAggregate finished.
28+
29+
%a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
30+
%a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
31+
%a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
32+
%empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
33+
%r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
34+
%r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
35+
%r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
36+
37+
llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
38+
}
39+
40+
// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }
41+
42+
llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
43+
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
44+
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64
45+
46+
%c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
47+
%c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64
48+
49+
%c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
50+
%c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64
51+
52+
%2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
53+
54+
%3 = llvm.mlir.undef : !llvm.array<2 x f64>
55+
56+
%4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
57+
%5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>
58+
59+
%6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
60+
61+
%7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
62+
%8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>
63+
64+
%9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
65+
66+
%10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
67+
%11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>
68+
69+
%12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
70+
71+
llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
72+
}

0 commit comments

Comments
 (0)