Skip to content

Commit

Permalink
[MLIR][LLVM] Fix memory explosion when converting global variable bod…
Browse files Browse the repository at this point in the history
…ies in ModuleTranslation (llvm#82708)

There is memory explosion when converting the body or initializer region
of a large global variable, e.g. a constant array.

For example, when translating a constant array of 100000 strings:

llvm.mlir.global internal constant @cats_strings() {addr_space = 0 :
i32, alignment = 16 : i64} : !llvm.array<100000 x ptr<i8>> {
    %0 = llvm.mlir.undef : !llvm.array<100000 x ptr<i8>>
    %1 = llvm.mlir.addressof @om_1 : !llvm.ptr<array<1 x i8>>
%2 = llvm.getelementptr %1[0, 0] : (!llvm.ptr<array<1 x i8>>) ->
!llvm.ptr<i8>
    %3 = llvm.insertvalue %2, %0[0] : !llvm.array<100000 x ptr<i8>>
    %4 = llvm.mlir.addressof @om_2 : !llvm.ptr<array<1 x i8>>
%5 = llvm.getelementptr %4[0, 0] : (!llvm.ptr<array<1 x i8>>) ->
!llvm.ptr<i8>
    %6 = llvm.insertvalue %5, %3[1] : !llvm.array<100000 x ptr<i8>>
    %7 = llvm.mlir.addressof @om_3 : !llvm.ptr<array<1 x i8>>
%8 = llvm.getelementptr %7[0, 0] : (!llvm.ptr<array<1 x i8>>) ->
!llvm.ptr<i8>
    %9 = llvm.insertvalue %8, %6[2] : !llvm.array<100000 x ptr<i8>>
    %10 = llvm.mlir.addressof @om_4 : !llvm.ptr<array<1 x i8>>
%11 = llvm.getelementptr %10[0, 0] : (!llvm.ptr<array<1 x i8>>) ->
!llvm.ptr<i8>
    %12 = llvm.insertvalue %11, %9[3] : !llvm.array<100000 x ptr<i8>>

    ... (ignore the remaining part)
}

where @om_1, @om_2, ... are string global constants.

Each time an operation is converted to LLVM, a new constant is created.
When it comes to llvm.insertvalue, a new constant array of 100000
elements is created and the old constant array (input) is not destroyed.
This causes memory explosion. We observed that, on a system with 128 GB
memory, the translation of 100000 elements got killed due to using up
all the memory. On a system with 64 GB, 65536 elements was enough to
cause the translation killed.

There is a previous patch (https://reviews.llvm.org/D148487) which fix
this issue but was reverted for
llvm#62802

The old patch checks generated constants and destroyed them if there is
no use. But the check of use for the constant is too early, which cause
the constant be removed before use.

This new patch added a map was added a map to save expected use count
for a constant. Then decrease when reach each use.
And only erase the constant when the use count reach to zero

With new patch, the repro in
llvm#62802 finished correctly.

commit-id:3cd23a98
  • Loading branch information
python3kgae authored and vzakhari committed Mar 14, 2024
1 parent 8f0da38 commit 4268fec
Show file tree
Hide file tree
Showing 2 changed files with 141 additions and 2 deletions.
71 changes: 69 additions & 2 deletions mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,15 @@
#include "llvm/IR/MDBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/IR/Verifier.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
#include <optional>

#define DEBUG_TYPE "llvm-dialect-to-llvm-ir"

using namespace mlir;
using namespace mlir::LLVM;
using namespace mlir::LLVM::detail;
Expand Down Expand Up @@ -1042,17 +1046,80 @@ LogicalResult ModuleTranslation::convertGlobals() {
for (auto op : getModuleBody(mlirModule).getOps<LLVM::GlobalOp>()) {
if (Block *initializer = op.getInitializerBlock()) {
llvm::IRBuilder<> builder(llvmModule->getContext());

int numConstantsHit = 0;
int numConstantsErased = 0;
DenseMap<llvm::ConstantAggregate *, int> constantAggregateUseMap;

for (auto &op : initializer->without_terminator()) {
if (failed(convertOperation(op, builder)) ||
!isa<llvm::Constant>(lookupValue(op.getResult(0))))
if (failed(convertOperation(op, builder)))
return emitError(op.getLoc(), "fail to convert global initializer");
auto *cst = dyn_cast<llvm::Constant>(lookupValue(op.getResult(0)));
if (!cst)
return emitError(op.getLoc(), "unemittable constant value");

// When emitting an LLVM constant, a new constant is created and the old
// constant may become dangling and take space. We should remove the
// dangling constants to avoid memory explosion especially for constant
// arrays whose number of elements is large.
// Because multiple operations may refer to the same constant, we need
// to count the number of uses of each constant array and remove it only
// when the count becomes zero.
if (auto *agg = dyn_cast<llvm::ConstantAggregate>(cst)) {
numConstantsHit++;
Value result = op.getResult(0);
int numUsers = std::distance(result.use_begin(), result.use_end());
auto [iterator, inserted] =
constantAggregateUseMap.try_emplace(agg, numUsers);
if (!inserted) {
// Key already exists, update the value
iterator->second += numUsers;
}
}
// Scan the operands of the operation to decrement the use count of
// constants. Erase the constant if the use count becomes zero.
for (Value v : op.getOperands()) {
auto cst = dyn_cast<llvm::ConstantAggregate>(lookupValue(v));
if (!cst)
continue;
auto iter = constantAggregateUseMap.find(cst);
assert(iter != constantAggregateUseMap.end() && "constant not found");
iter->second--;
if (iter->second == 0) {
// NOTE: cannot call removeDeadConstantUsers() here because it
// may remove the constant which has uses not be converted yet.
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
constantAggregateUseMap.erase(iter);
}
}
}

ReturnOp ret = cast<ReturnOp>(initializer->getTerminator());
llvm::Constant *cst =
cast<llvm::Constant>(lookupValue(ret.getOperand(0)));
auto *global = cast<llvm::GlobalVariable>(lookupGlobal(op));
if (!shouldDropGlobalInitializer(global->getLinkage(), cst))
global->setInitializer(cst);

// Try to remove the dangling constants again after all operations are
// converted.
for (auto it : constantAggregateUseMap) {
auto cst = it.first;
cst->removeDeadConstantUsers();
if (cst->user_empty()) {
cst->destroyConstant();
numConstantsErased++;
}
}

LLVM_DEBUG(llvm::dbgs()
<< "Convert initializer for " << op.getName() << "\n";
llvm::dbgs() << numConstantsHit << " new constants hit\n";
llvm::dbgs()
<< numConstantsErased << " dangling constants erased\n";);
}
}

Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Target/LLVMIR/erase-dangling-constants.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
// RUN: mlir-translate -mlir-to-llvmir %s -debug-only=llvm-dialect-to-llvm-ir 2>&1 | FileCheck %s

// CHECK: Convert initializer for dup_const
// CHECK: 6 new constants hit
// CHECK: 3 dangling constants erased
// CHECK: Convert initializer for unique_const
// CHECK: 6 new constants hit
// CHECK: 5 dangling constants erased


// CHECK:@dup_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.612250e-02, double 5.119230e-02] }

llvm.mlir.global @dup_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%empty0 = llvm.mlir.undef : !llvm.array<2 x f64>
%a00 = llvm.insertvalue %c0, %empty0[0] : !llvm.array<2 x f64>

%empty1 = llvm.mlir.undef : !llvm.array<2 x f64>
%a10 = llvm.insertvalue %c0, %empty1[0] : !llvm.array<2 x f64>

%empty2 = llvm.mlir.undef : !llvm.array<2 x f64>
%a20 = llvm.insertvalue %c0, %empty2[0] : !llvm.array<2 x f64>

// NOTE: a00, a10, a20 are all same ConstantAggregate which not used at this point.
// should not delete it before all of the uses of the ConstantAggregate finished.

%a01 = llvm.insertvalue %c1, %a00[1] : !llvm.array<2 x f64>
%a11 = llvm.insertvalue %c1, %a10[1] : !llvm.array<2 x f64>
%a21 = llvm.insertvalue %c1, %a20[1] : !llvm.array<2 x f64>
%empty_r = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r0 = llvm.insertvalue %a01, %empty_r[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r1 = llvm.insertvalue %a11, %r0[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
%r2 = llvm.insertvalue %a21, %r1[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %r2 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}

// CHECK:@unique_const = global { [2 x double], [2 x double], [2 x double] } { [2 x double] [double 3.612250e-02, double 5.119230e-02], [2 x double] [double 3.312250e-02, double 5.219230e-02], [2 x double] [double 3.412250e-02, double 5.419230e-02] }

llvm.mlir.global @unique_const() : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)> {
%c0 = llvm.mlir.constant(3.612250e-02 : f64) : f64
%c1 = llvm.mlir.constant(5.119230e-02 : f64) : f64

%c2 = llvm.mlir.constant(3.312250e-02 : f64) : f64
%c3 = llvm.mlir.constant(5.219230e-02 : f64) : f64

%c4 = llvm.mlir.constant(3.412250e-02 : f64) : f64
%c5 = llvm.mlir.constant(5.419230e-02 : f64) : f64

%2 = llvm.mlir.undef : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%3 = llvm.mlir.undef : !llvm.array<2 x f64>

%4 = llvm.insertvalue %c0, %3[0] : !llvm.array<2 x f64>
%5 = llvm.insertvalue %c1, %4[1] : !llvm.array<2 x f64>

%6 = llvm.insertvalue %5, %2[0] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%7 = llvm.insertvalue %c2, %3[0] : !llvm.array<2 x f64>
%8 = llvm.insertvalue %c3, %7[1] : !llvm.array<2 x f64>

%9 = llvm.insertvalue %8, %6[1] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

%10 = llvm.insertvalue %c4, %3[0] : !llvm.array<2 x f64>
%11 = llvm.insertvalue %c5, %10[1] : !llvm.array<2 x f64>

%12 = llvm.insertvalue %11, %9[2] : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>

llvm.return %12 : !llvm.struct<(array<2 x f64>, array<2 x f64>, array<2 x f64>)>
}

0 comments on commit 4268fec

Please sign in to comment.