Skip to content

[MLIR][LLVM] Fix memory explosion when converting global variable bodies in ModuleTranslation #82708

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Feb 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 69 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <optional>

#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"

using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
Expand Down Expand Up @@ -1042,17 +1046,80 @@ LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
if (Block *initializer = op.getInitializerBlock()) {
llvm::IRBuilder<> builder(llvmModule->getContext());

int numConstantsHit = 0;
int numConstantsErased = 0;
DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;

for (auto &op : initializer->without_terminator()) {
if (failed(convertOperation(op, builder)) ||
!isa<llvm::Constant>(lookupValue(op.getResult(0))))
if (failed(convertOperation(op, builder)))
return emitError(op.getLoc(), "fail to convert global initializer");
auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
if (!cst)
return emitError(op.getLoc(), "unemittable constant value");

// When emitting an LLVM constant, a new constant is created and the old
// constant may become dangling and take space. We should remove the
// dangling constants to avoid memory explosion especially for constant
// arrays whose number of elements is large.
// Because multiple operations may refer to the same constant, we need
// to count the number of uses of each constant array and remove it only
// when the count becomes zero.
if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
numConstantsHit++;
Value result = op.getResult(0);
int numUsers = std::distance(result.use_begin(), result.use_end());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is O(N) on a linked list, hopefully small though...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most of the ops will be insertvalue which only used once for build constant.

auto [iterator, inserted] =
constantAggregateUseMap.try_emplace(agg, numUsers);
if (!inserted) {
// Key already exists, update the value
iterator->second += numUsers;
}
}
// Scan the operands of the operation to decrement the use count of
// constants. Erase the constant if the use count becomes zero.
for (Value v : op.getOperands()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth adding a comment for this loop I think

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated.

auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
if (!cst)
continue;
auto iter = constantAggregateUseMap.find(cst);
assert(iter != constantAggregateUseMap.end() && "constant not found");
iter->second--;
if (iter->second == 0) {
// NOTE: cannot call removeDeadConstantUsers() here because it
// may remove the constant which has uses not be converted yet.
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
constantAggregateUseMap.erase(iter);
}
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel I'm missing something, why can't we have a DenseSet<llvm::ConstantAggregate *> constantAggregates; that we keep around and every iteration of the outer loop we could just do:

for (llvm::ConstantAggregate *cst : constantAggregates) {
  cst->removeDeadConstantUsers();
  if (cst->user_empty()) {
    cst->destroyConstant(); // TODO: add removing from the DenseSet
  }
}

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually this is done also line 1104, but I don't get why we need the dance around constantAggregateUseMap here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel I'm missing something, why can't we have a DenseSet<llvm::ConstantAggregate *> constantAggregates; that we keep around and every iteration of the outer loop we could just do:

for (llvm::ConstantAggregate *cst : constantAggregates) {
  cst->removeDeadConstantUsers();
  if (cst->user_empty()) {
    cst->destroyConstant(); // TODO: add removing from the DenseSet
  }
}

My understanding is that we're converting all the ops in intializer to llvm::Constant.
When a later op is converted into ConstantAggregate, it will add use of the constant converted from early op.
So we cannot remove a Constant unless all the ops which use it are converted.
Only check user_empty is not enough because the user might not yet be generated.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I get it: and how critical is it to garbage collect during the interpretation of the initializer block instead of just at the end of the block?
(line 1104)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worst block is building a constant with type [75525 x [100 x f32]] for my test case.
I can test only do the removing at the end of the block, code will be much simplified if it works.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also: isn't it something that MLIR can constant fold?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant will fold, but the input constant is still in memory.
And too many temp constants are created which eats all the memory.

I tried to avoid creating these insertvalue which will be converted to temp constant.

Here's what I got with experiment


  llvm.mlir.global private @krnl_global_28(dense<["", "ab", "xyz", ...]> : tensor<75525x!krnl.string>) {addr_space = 0 : i32, alignment = 8 : i64} : !llvm.array<75525 x i64>

It crashes because !llvm.array<75525 x i64> is not correct type.

Any suggestion for how to express array of string for llvm dialect?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be just a concat of all the strings into a single dense attribute of the right type?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The worst block is building a constant with type [75525 x [100 x f32]] for my test case. I can test only do the removing at the end of the block, code will be much simplified if it works.

Removing at the end not work :( One block already used all of the memory.

Copy link
Contributor Author

@python3kgae python3kgae Feb 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't it be just a concat of all the strings into a single dense attribute of the right type?

Thanks for this suggestion. It saves a lot of time for create globalOp for each string!

But to create the array of string, I still have to create a lot of insertvalue.

Is it possible to express the following llvm ir with LLVM dialect?


@c = dso_local global [7 x i8] c"abcxby\00", align 1
@a = dso_local global [3 x ptr] [ptr @c, ptr getelementptr (i8, ptr @c, i64 2), ptr getelementptr (i8, ptr @c, i64 5)]

Then there will be no insertvalue needed.
This PR will only avoid being out of memory, conversion of those insertvalues still takes a lot of cycles.


ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
llvm::Constant *cst =
cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
global->setInitializer(cst);

// Try to remove the dangling constants again after all operations are
// converted.
for (auto it : constantAggregateUseMap) {
auto cst = it.first;
cst->removeDeadConstantUsers();
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
}

LLVM_DEBUG(llvm::dbgs()
<< "Convert initializer for " << op.getName() << "\n";
llvm::dbgs() << numConstantsHit << " new constants hit\n";
llvm::dbgs()
<< numConstantsErased << " dangling constants erased\n";);
}
}

Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s

// CHECK: Convert initializer for dup_const
// CHECK: 6 new constants hit
// CHECK: 3 dangling constants erased
// CHECK: Convert initializer for unique_const
// CHECK: 6 new constants hit
// CHECK: 5 dangling constants erased


// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }

llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
%a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>

%empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
%a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>

%empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
%a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>

// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
// should not delete it before all of the uses of the ConstantAggregate finished.

%a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
%a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
%a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
%empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}

// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }

llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
%c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64

%c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
%c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64

%2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%3 = llvm.mlir.undef : !llvm.array<2 x f64>

%4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
%5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>

%6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
%8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>

%9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
%11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>

%12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}