Skip to content

[flang][OpenMP] Translate OpenMP scopes when compiling for target device #130078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Frontend/OpenMP/OMPConstants.h"
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
Expand Down Expand Up @@ -537,6 +538,20 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
llvm_unreachable("Unknown ClauseProcBindKind kind");
}

/// Maps block arguments from \p blockArgIface (which are MLIR values) to the
/// corresponding LLVM values of \p the interface's operands. This is useful
/// when an OpenMP region with entry block arguments is converted to LLVM. In
/// this case the block arguments are (part of) of the OpenMP region's entry
/// arguments and the operands are (part of) of the operands to the OpenMP op
/// containing the region.
static void forwardArgs(LLVM::ModuleTranslation &moduleTranslation,
omp::BlockArgOpenMPOpInterface blockArgIface) {
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
blockArgIface.getBlockArgsPairs(blockArgsPairs);
for (auto [var, arg] : blockArgsPairs)
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
}

/// Helper function to map block arguments defined by ignored loop wrappers to
/// LLVM values and prevent any uses of those from triggering null pointer
/// dereferences.
Expand All @@ -549,17 +564,10 @@ convertIgnoredWrapper(omp::LoopWrapperInterface opInst,
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
// being present.
auto forwardArgs =
[&moduleTranslation](omp::BlockArgOpenMPOpInterface blockArgIface) {
llvm::SmallVector<std::pair<Value, BlockArgument>> blockArgsPairs;
blockArgIface.getBlockArgsPairs(blockArgsPairs);
for (auto [var, arg] : blockArgsPairs)
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
};

return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](omp::SimdOp op) {
forwardArgs(cast<omp::BlockArgOpenMPOpInterface>(*op));
forwardArgs(moduleTranslation,
cast<omp::BlockArgOpenMPOpInterface>(*op));
op.emitWarning() << "simd information on composite construct discarded";
return success();
})
Expand Down Expand Up @@ -5313,6 +5321,46 @@ convertTargetOpsInNest(Operation *op, llvm::IRBuilderBase &builder,
return WalkResult::interrupt();
return WalkResult::skip();
}

// Non-target ops might nest target-related ops, therefore, we
// translate them as non-OpenMP scopes. Translating them is needed by
// nested target-related ops since they might need LLVM values defined
// in their parent non-target ops.
if (isa<omp::OpenMPDialect>(oper->getDialect()) &&
oper->getParentOfType<LLVM::LLVMFuncOp>() &&
!oper->getRegions().empty()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering what should happen to OpenMP ops that don't have regions. If they return a value, it seems like that value could end up impacting what's passed into an omp.map.info as an argument. Maybe we should map their results to something as well.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you provide an example where this might happen?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was thinking of something like this (since omp.threadprivate seems to be the only non map-related OpenMP op that returns some value at the moment):

module test
  implicit none
  integer :: n
  
  !$omp threadprivate(n)
  
  contains
  subroutine foo(x)
    integer, intent(inout) :: x(10)

    !$omp target map(tofrom: x(1:n))
    call bar(x)
    !$omp end target
  end subroutine
end module

Lowering it to MLIR for the device results in the following sequence of operations:

%3 = omp.threadprivate ...
%4 = fir.declare %3
%13 = omp.map.info var_ptr(%4 : !fir.ref<i32>, i32) ...
omp.target map_entries(%13 -> %arg2 ...) {
  ...
}

Copy link
Member Author

@ergawy ergawy Mar 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the example. However, this is not an issue since when translating the surrounding region, the LLVM value corresponding to the original n value will be resolved:

omp.parallel.fake.region:
  %2 = load i32, ptr @_QFtestEn, align 4
  %3 = sext i32 %2 to i64
  %4 = sub i64 %3, 1
  %5 = sub i64 %4, 0
  %6 = add i64 %5, 1
  %7 = mul i64 1, %6
  %8 = mul i64 %7, 4
  br label %omp.region.cont

(Just to clarify, I surrounded the target op by parallel to make sure the issue could be reproduced.)
So the translation happens without issues.

I think we should keep things simple and not add the extra translation for non-region ops since there does not seem to be need for that, at least at the moment. I think it is not much more effort, the main reason is to simplify things even by a little bit. Let me know if you disagree.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised that the example above isn't causing crashes, since %3 is not mapped to any LLVM values and it contributes to the initialization of %4, which is passed to an omp.map.info.

Actually, if that had caused a crash it looks like any non-OpenMP dialect operation found during the op->walk() in convertTargetOpsInNest would have triggered it in the same way, since their results are not mapped to anything there either. I'm not clear how this is working, but apparently it is.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will dig deeper in this and try to provide a more detailed explanation.

if (auto blockArgsIface =
dyn_cast<omp::BlockArgOpenMPOpInterface>(oper))
forwardArgs(moduleTranslation, blockArgsIface);

if (auto loopNest = dyn_cast<omp::LoopNestOp>(oper)) {
assert(builder.GetInsertBlock() &&
"No insert block is set for the builder");
for (auto iv : loopNest.getIVs()) {
// Map iv to an undefined value just to keep the IR validity.
moduleTranslation.mapValue(
iv, llvm::PoisonValue::get(
moduleTranslation.convertType(iv.getType())));
}
}

for (Region &region : oper->getRegions()) {
// Regions are fake in the sense that they are not a truthful
// translation of the OpenMP construct being converted (e.g. no
// OpenMP runtime calls will be generated). We just need this to
// prepare the kernel invocation args.
auto result = convertOmpOpRegions(
region, oper->getName().getStringRef().str() + ".fake.region",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this named .fake.region? As far as my understanding of the problem and the solution you have implemented goes, this is a region that has been created by this pass.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is fake in the sense that it is not a truthful translation of the OpenMP construct being converted. We just need this to prepare the kernel invocation args. Added a comment to clarify.

builder, moduleTranslation);
if (failed(handleError(result, *oper)))
return WalkResult::interrupt();

builder.SetInsertPoint(result.get(), result.get()->end());
}

return WalkResult::skip();
}

return WalkResult::advance();
}).wasInterrupted();
return failure(interrupted);
Expand Down
135 changes: 135 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-target-nesting-in-host-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
// RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s

module attributes {llvm.target_triple = "amdgcn-amd-amdhsa", omp.is_gpu = true, omp.is_target_device = true} {

omp.private {type = private} @i32_privatizer : i32

llvm.func @test_nested_target_in_parallel(%arg0: !llvm.ptr) {
omp.parallel {
%0 = llvm.mlir.constant(4 : index) : i64
%1 = llvm.mlir.constant(1 : index) : i64
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
omp.terminator
}
omp.terminator
}
llvm.return
}

// CHECK-LABEL: define void @test_nested_target_in_parallel({{.*}}) {
// CHECK-NEXT: br label %omp.parallel.fake.region
// CHECK: omp.parallel.fake.region:
// CHECK-NEXT: br label %omp.region.cont
// CHECK: omp.region.cont:
// CHECK-NEXT: ret void
// CHECK-NEXT: }

llvm.func @test_nested_target_in_wsloop(%arg0: !llvm.ptr) {
%8 = llvm.mlir.constant(1 : i64) : i64
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
%16 = llvm.mlir.constant(10 : i32) : i32
%17 = llvm.mlir.constant(1 : i32) : i32
omp.wsloop private(@i32_privatizer %9 -> %loop_arg : !llvm.ptr) {
omp.loop_nest (%arg1) : i32 = (%17) to (%16) inclusive step (%17) {
llvm.store %arg1, %loop_arg : i32, !llvm.ptr
%0 = llvm.mlir.constant(4 : index) : i64
%1 = llvm.mlir.constant(1 : index) : i64
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%0 : i64) stride(%1 : i64) start_idx(%1 : i64)
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
omp.terminator
}
omp.yield
}
}
llvm.return
}

// CHECK-LABEL: define void @test_nested_target_in_wsloop(ptr %0) {
// CHECK-NEXT: %{{.*}} = alloca i32, i64 1, align 4
// CHECK-NEXT: br label %omp.wsloop.fake.region
// CHECK: omp.wsloop.fake.region:
// CHECK-NEXT: br label %omp.loop_nest.fake.region
// CHECK: omp.loop_nest.fake.region:
// CHECK-NEXT: store i32 poison, ptr %{{.*}}
// CHECK-NEXT: br label %omp.region.cont1
// CHECK: omp.region.cont1:
// CHECK-NEXT: br label %omp.region.cont
// CHECK: omp.region.cont:
// CHECK-NEXT: ret void
// CHECK-NEXT: }

llvm.func @test_nested_target_in_parallel_with_private(%arg0: !llvm.ptr) {
%8 = llvm.mlir.constant(1 : i64) : i64
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
omp.parallel private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
%1 = llvm.mlir.constant(1 : index) : i64
// Use the private clause from omp.parallel to make sure block arguments
// are handled.
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
omp.terminator
}
omp.terminator
}
llvm.return
}

llvm.func @test_nested_target_in_task_with_private(%arg0: !llvm.ptr) {
%8 = llvm.mlir.constant(1 : i64) : i64
%9 = llvm.alloca %8 x i32 {bindc_name = "i"} : (i64) -> !llvm.ptr
omp.task private(@i32_privatizer %9 -> %i_priv_arg : !llvm.ptr) {
%1 = llvm.mlir.constant(1 : index) : i64
// Use the private clause from omp.task to make sure block arguments
// are handled.
%i_val = llvm.load %i_priv_arg : !llvm.ptr -> i64
%4 = omp.map.bounds lower_bound(%1 : i64) upper_bound(%i_val : i64) stride(%1 : i64) start_idx(%1 : i64)
%mapv1 = omp.map.info var_ptr(%arg0 : !llvm.ptr, !llvm.array<10 x i32>) map_clauses(tofrom) capture(ByRef) bounds(%4) -> !llvm.ptr {name = ""}
omp.target map_entries(%mapv1 -> %map_arg : !llvm.ptr) {
omp.terminator
}
omp.terminator
}
llvm.return
}

// CHECK-LABEL: define void @test_nested_target_in_parallel_with_private({{.*}}) {
// CHECK: br label %omp.parallel.fake.region
// CHECK: omp.parallel.fake.region:
// CHECK: br label %omp.region.cont
// CHECK: omp.region.cont:
// CHECK-NEXT: ret void
// CHECK-NEXT: }

// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_nested_target_in_parallel_{{.*}} {
// CHECK: call i32 @__kmpc_target_init
// CHECK: user_code.entry:
// CHECK: call void @__kmpc_target_deinit()
// CHECK: ret void
// CHECK: }

// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_wsloop_{{.*}} {
// CHECK: call i32 @__kmpc_target_init
// CHECK: user_code.entry:
// CHECK: call void @__kmpc_target_deinit()
// CHECK: ret void
// CHECK: }

// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_parallel_with_private_{{.*}} {
// CHECK: call i32 @__kmpc_target_init
// CHECK: user_code.entry:
// CHECK: call void @__kmpc_target_deinit()
// CHECK: ret void
// CHECK: }

// CHECK-LABEL: define {{.*}} amdgpu_kernel void @__omp_offloading_{{.*}}_test_nested_target_in_task_with_private_{{.*}} {
// CHECK: call i32 @__kmpc_target_init
// CHECK: user_code.entry:
// CHECK: call void @__kmpc_target_deinit()
// CHECK: ret void
// CHECK: }
}
Loading