Skip to content

[OpenMP][MLIR] Create LLVM IR lifetime markers for OpenMP loop-related allocations #74843

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,92 @@ static llvm::BasicBlock *convertOmpOpRegions(
return continuationBlock;
}

/// Finds the set of \c llvm.alloca instructions associated to \c LLVM::AllocaOp
/// MLIR operations for primitive types that are defined outside of the given
/// \p region but only used inside of it.
static void
gatherSinkableAllocas(const LLVM::ModuleTranslation &moduleTranslation,
Region &region,
SetVector<llvm::AllocaInst *> &allocasToSink) {
Operation *op = region.getParentOp();

auto processLoadStore = [&](auto loadStoreOp) {
Value addr = loadStoreOp.getAddr();
Operation *addrOp = addr.getDefiningOp();

// The destination address is already defined in this region or it is not an
// llvm.alloca operation, so skip it.
if (!isa_and_present<LLVM::AllocaOp>(addrOp) || op->isAncestor(addrOp))
return;

// Get LLVM value to which the address is mapped. It has to be mapped to the
// allocation instruction of a scalar type to be marked as sinkable by this
// function.
llvm::Value *llvmAddr = moduleTranslation.lookupValue(addr);
if (!isa_and_present<llvm::AllocaInst>(llvmAddr))
return;

auto *llvmAlloca = cast<llvm::AllocaInst>(llvmAddr);
if (llvmAlloca->getAllocatedType()->getPrimitiveSizeInBits() == 0)
return;

// Check that the address is only used inside of the region.
bool addressUsedOnlyInternally = true;
for (auto &addrUse : addr.getUses()) {
if (!op->isAncestor(addrUse.getOwner())) {
addressUsedOnlyInternally = false;
break;
}
}

if (!addressUsedOnlyInternally)
return;

allocasToSink.insert(llvmAlloca);
};

region.walk([&processLoadStore](Operation *op) {
if (auto loadOp = dyn_cast<LLVM::LoadOp>(op))
processLoadStore(loadOp);
else if (auto storeOp = dyn_cast<LLVM::StoreOp>(op))
processLoadStore(storeOp);
});
}

/// Converts the given region that appears within an OpenMP dialect operation to
/// LLVM IR, according to the process described in \c convertOmpOpRegions(), and
/// marks the lifetime of allocas read/written exclusively inside of the region
/// but defined outside of it.
///
/// This information enables later compilation stages to sink these allocations
/// inside of the region, such as when outlining it into a separate function.
static llvm::BasicBlock *convertOmpOpRegionsWithAllocaLifetimes(
Region &region, StringRef blockName, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation, LogicalResult &bodyGenStatus) {
SetVector<llvm::AllocaInst *> allocasToSink;
gatherSinkableAllocas(moduleTranslation, region, allocasToSink);

for (auto *alloca : allocasToSink) {
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
builder.CreateLifetimeStart(alloca, builder.getInt64(size));
}

llvm::BasicBlock *continuationBlock = convertOmpOpRegions(
region, blockName, builder, moduleTranslation, bodyGenStatus);

if (!allocasToSink.empty()) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
builder.SetInsertPoint(continuationBlock, continuationBlock->begin());

for (auto *alloca : allocasToSink) {
unsigned size = alloca->getAllocatedType()->getPrimitiveSizeInBits() / 8;
builder.CreateLifetimeEnd(alloca, builder.getInt64(size));
}
}

return continuationBlock;
}

/// Convert ProcBindKind from MLIR-generated enum to LLVM enum.
static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
switch (kind) {
Expand Down Expand Up @@ -910,8 +996,9 @@ convertOmpWsLoop(Operation &opInst, llvm::IRBuilderBase &builder,

// Convert the body of the loop.
builder.restoreIP(ip);
convertOmpOpRegions(loop.getRegion(), "omp.wsloop.region", builder,
moduleTranslation, bodyGenStatus);
convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(),
"omp.wsloop.region", builder,
moduleTranslation, bodyGenStatus);
};

// Delegate actual loop construction to the OpenMP IRBuilder.
Expand Down Expand Up @@ -1151,8 +1238,9 @@ convertOmpSimdLoop(Operation &opInst, llvm::IRBuilderBase &builder,

// Convert the body of the loop.
builder.restoreIP(ip);
convertOmpOpRegions(loop.getRegion(), "omp.simdloop.region", builder,
moduleTranslation, bodyGenStatus);
convertOmpOpRegionsWithAllocaLifetimes(loop.getRegion(),
"omp.simdloop.region", builder,
moduleTranslation, bodyGenStatus);
};

// Delegate actual loop construction to the OpenMP IRBuilder.
Expand Down
124 changes: 124 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-alloca-lifetime.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
// This test checks the introduction of lifetime information for allocas defined
// outside of omp.wsloop and omp.simdloop regions but only used inside of them.

// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

llvm.func @foo(%arg0 : i32) {
llvm.return
}

llvm.func @bar(%arg0 : i64) {
llvm.return
}

// CHECK-LABEL: define void @wsloop_i32
llvm.func @wsloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) {
// CHECK-DAG: %[[LASTITER:.*]] = alloca i32
// CHECK-DAG: %[[LB:.*]] = alloca i32
// CHECK-DAG: %[[UB:.*]] = alloca i32
// CHECK-DAG: %[[STRIDE:.*]] = alloca i32
// CHECK-DAG: %[[I:.*]] = alloca i32
%1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr

// CHECK-NOT: %[[I]]
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]])
// CHECK-NEXT: br label %[[WSLOOP_BB:.*]]
// CHECK-NOT: %[[I]]
// CHECK: [[WSLOOP_BB]]:
// CHECK-NOT: {{^.*}}:
// CHECK: br label %[[CONT_BB:.*]]
// CHECK: [[CONT_BB]]:
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]])
// CHECK-NOT: %[[I]]
omp.wsloop for (%iv) : i32 = (%lb) to (%ub) step (%step) {
llvm.store %iv, %1 : i32, !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
omp.yield
}

// CHECK: ret void
llvm.return
}

// CHECK-LABEL: define void @wsloop_i64
llvm.func @wsloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) {
// CHECK-DAG: %[[LASTITER:.*]] = alloca i32
// CHECK-DAG: %[[LB:.*]] = alloca i64
// CHECK-DAG: %[[UB:.*]] = alloca i64
// CHECK-DAG: %[[STRIDE:.*]] = alloca i64
// CHECK-DAG: %[[I:.*]] = alloca i64
%1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr

// CHECK-NOT: %[[I]]
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]])
// CHECK-NEXT: br label %[[WSLOOP_BB:.*]]
// CHECK-NOT: %[[I]]
// CHECK: [[WSLOOP_BB]]:
// CHECK-NOT: {{^.*}}:
// CHECK: br label %[[CONT_BB:.*]]
// CHECK: [[CONT_BB]]:
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]])
// CHECK-NOT: %[[I]]
omp.wsloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
llvm.store %iv, %1 : i64, !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> i64
llvm.call @bar(%2) : (i64) -> ()
omp.yield
}

// CHECK: ret void
llvm.return
}

// CHECK-LABEL: define void @simdloop_i32
llvm.func @simdloop_i32(%size : i64, %lb : i32, %ub : i32, %step : i32) {
// CHECK: %[[I:.*]] = alloca i32
%1 = llvm.alloca %size x i32 : (i64) -> !llvm.ptr

// CHECK-NOT: %[[I]]
// CHECK: call void @llvm.lifetime.start.p0(i64 4, ptr %[[I]])
// CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]]
// CHECK-NOT: %[[I]]
// CHECK: [[SIMDLOOP_BB]]:
// CHECK-NOT: {{^.*}}:
// CHECK: br label %[[CONT_BB:.*]]
// CHECK: [[CONT_BB]]:
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 4, ptr %[[I]])
// CHECK-NOT: %[[I]]
omp.simdloop for (%iv) : i32 = (%lb) to (%ub) step (%step) {
llvm.store %iv, %1 : i32, !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> i32
llvm.call @foo(%2) : (i32) -> ()
omp.yield
}

// CHECK: ret void
llvm.return
}

// CHECK-LABEL: define void @simdloop_i64
llvm.func @simdloop_i64(%size : i64, %lb : i64, %ub : i64, %step : i64) {
// CHECK: %[[I:.*]] = alloca i64
%1 = llvm.alloca %size x i64 : (i64) -> !llvm.ptr

// CHECK-NOT: %[[I]]
// CHECK: call void @llvm.lifetime.start.p0(i64 8, ptr %[[I]])
// CHECK-NEXT: br label %[[SIMDLOOP_BB:.*]]
// CHECK-NOT: %[[I]]
// CHECK: [[SIMDLOOP_BB]]:
// CHECK-NOT: {{^.*}}:
// CHECK: br label %[[CONT_BB:.*]]
// CHECK: [[CONT_BB]]:
// CHECK-NEXT: call void @llvm.lifetime.end.p0(i64 8, ptr %[[I]])
// CHECK-NOT: %[[I]]
omp.simdloop for (%iv) : i64 = (%lb) to (%ub) step (%step) {
llvm.store %iv, %1 : i64, !llvm.ptr
%2 = llvm.load %1 : !llvm.ptr -> i64
llvm.call @bar(%2) : (i64) -> ()
omp.yield
}

// CHECK: ret void
llvm.return
}