Skip to content

[MLIR][OpenMP] Prevent composite omp.simd related crashes #113680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,62 @@ static llvm::omp::ProcBindKind getProcBindKind(omp::ClauseProcBindKind kind) {
llvm_unreachable("Unknown ClauseProcBindKind kind");
}

/// Helper function to map block arguments defined by ignored loop wrappers to
/// LLVM values and prevent any uses of those from triggering null pointer
/// dereferences.
///
/// This must be called after block arguments of parent wrappers have already
/// been mapped to LLVM IR values.
static LogicalResult
convertIgnoredWrapper(omp::LoopWrapperInterface &opInst,
LLVM::ModuleTranslation &moduleTranslation) {
// Map block arguments directly to the LLVM value associated to the
// corresponding operand. This is semantically equivalent to this wrapper not
// being present.
auto forwardArgs =
[&moduleTranslation](llvm::ArrayRef<BlockArgument> blockArgs,
OperandRange operands) {
for (auto [arg, var] : llvm::zip_equal(blockArgs, operands))
moduleTranslation.mapValue(arg, moduleTranslation.lookupValue(var));
};

return llvm::TypeSwitch<Operation *, LogicalResult>(opInst)
.Case([&](omp::SimdOp op) {
auto blockArgIface = cast<omp::BlockArgOpenMPOpInterface>(*op);
forwardArgs(blockArgIface.getPrivateBlockArgs(), op.getPrivateVars());
forwardArgs(blockArgIface.getReductionBlockArgs(),
op.getReductionVars());
return success();
})
.Default([&](Operation *op) {
return op->emitError() << "cannot ignore nested wrapper";
});
}

/// Helper function to call \c convertIgnoredWrapper() for all wrappers of the
/// given \c loopOp nested inside of \c parentOp. This has the effect of mapping
/// entry block arguments defined by these operations to outside values.
///
/// It must be called after block arguments of \c parentOp have already been
/// mapped themselves.
static LogicalResult
convertIgnoredWrappers(omp::LoopNestOp loopOp,
omp::LoopWrapperInterface parentOp,
LLVM::ModuleTranslation &moduleTranslation) {
SmallVector<omp::LoopWrapperInterface> wrappers;
loopOp.gatherWrappers(wrappers);

// Process wrappers nested inside of `parentOp` from outermost to innermost.
for (auto it =
std::next(std::find(wrappers.rbegin(), wrappers.rend(), parentOp));
it != wrappers.rend(); ++it) {
if (failed(convertIgnoredWrapper(*it, moduleTranslation)))
return failure();
}

return success();
}

/// Converts an OpenMP 'masked' operation into LLVM IR using OpenMPIRBuilder.
static LogicalResult
convertOmpMasked(Operation &opInst, llvm::IRBuilderBase &builder,
Expand Down Expand Up @@ -1262,9 +1318,6 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
!wsloopOp.getPrivateVars().empty() || wsloopOp.getPrivateSyms())
return opInst.emitError("unhandled clauses for translation to LLVM IR");

// FIXME: Here any other nested wrappers (e.g. omp.simd) are skipped, so
// codegen for composite constructs like 'DO/FOR SIMD' will be the same as for
// 'DO/FOR'.
auto loopOp = cast<omp::LoopNestOp>(wsloopOp.getWrappedLoop());

llvm::ArrayRef<bool> isByRef = getIsByRef(wsloopOp.getReductionByref());
Expand Down Expand Up @@ -1302,6 +1355,13 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
isByRef)))
return failure();

// TODO: Replace this with proper composite translation support.
// Currently, all nested wrappers are ignored, so 'do/for simd' will be
// treated the same as a standalone 'do/for'. This is allowed by the spec,
// since it's equivalent to always using a SIMD length of 1.
if (failed(convertIgnoredWrappers(loopOp, wsloopOp, moduleTranslation)))
return failure();

// Store the mapping between reduction variables and their private copies on
// ModuleTranslation stack. It can be then recovered when translating
// omp.reduce operations in a separate call.
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Target/LLVMIR/openmp-reduction.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -586,3 +586,83 @@ llvm.func @parallel_nested_workshare_reduction(%ub : i64) {
// Reduction function.
// CHECK: define internal void @[[REDFUNC]]
// CHECK: add i32

// -----

omp.declare_reduction @add_f32 : f32
init {
^bb0(%arg: f32):
%0 = llvm.mlir.constant(0.0 : f32) : f32
omp.yield (%0 : f32)
}
combiner {
^bb1(%arg0: f32, %arg1: f32):
%1 = llvm.fadd %arg0, %arg1 : f32
omp.yield (%1 : f32)
}
atomic {
^bb2(%arg2: !llvm.ptr, %arg3: !llvm.ptr):
%2 = llvm.load %arg3 : !llvm.ptr -> f32
llvm.atomicrmw fadd %arg2, %2 monotonic : !llvm.ptr, f32
omp.yield
}

// CHECK-LABEL: @wsloop_simd_reduction
llvm.func @wsloop_simd_reduction(%lb : i64, %ub : i64, %step : i64) {
%c1 = llvm.mlir.constant(1 : i32) : i32
%0 = llvm.alloca %c1 x i32 : (i32) -> !llvm.ptr
omp.parallel {
omp.wsloop reduction(@add_f32 %0 -> %prv1 : !llvm.ptr) {
omp.simd reduction(@add_f32 %prv1 -> %prv2 : !llvm.ptr) {
omp.loop_nest (%iv) : i64 = (%lb) to (%ub) step (%step) {
%1 = llvm.mlir.constant(2.0 : f32) : f32
%2 = llvm.load %prv2 : !llvm.ptr -> f32
%3 = llvm.fadd %1, %2 : f32
llvm.store %3, %prv2 : f32, !llvm.ptr
omp.yield
}
} {omp.composite}
} {omp.composite}
omp.terminator
}
llvm.return
}

// Same checks as for wsloop reduction, because currently omp.simd is ignored in
// a composite 'do/for simd' construct.
// Call to the outlined function.
// CHECK: call void {{.*}} @__kmpc_fork_call
// CHECK-SAME: @[[OUTLINED:[A-Za-z_.][A-Za-z0-9_.]*]]

// Outlined function.
// CHECK: define internal void @[[OUTLINED]]

// Private reduction variable and its initialization.
// CHECK: %[[PRIVATE:.+]] = alloca float
// CHECK: store float 0.000000e+00, ptr %[[PRIVATE]]

// Call to the reduction function.
// CHECK: call i32 @__kmpc_reduce
// CHECK-SAME: @[[REDFUNC:[A-Za-z_.][A-Za-z0-9_.]*]]

// Atomic reduction.
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE]]
// CHECK: atomicrmw fadd ptr %{{.*}}, float %[[PARTIAL]]

// Non-atomic reduction:
// CHECK: fadd float
// CHECK: call void @__kmpc_end_reduce
// CHECK: br label %[[FINALIZE:.+]]

// CHECK: [[FINALIZE]]:
// CHECK: call void @__kmpc_barrier

// Update of the private variable using the reduction region
// (the body block currently comes after all the other blocks).
// CHECK: %[[PARTIAL:.+]] = load float, ptr %[[PRIVATE]]
// CHECK: %[[UPDATED:.+]] = fadd float 2.000000e+00, %[[PARTIAL]]
// CHECK: store float %[[UPDATED]], ptr %[[PRIVATE]]

// Reduction function.
// CHECK: define internal void @[[REDFUNC]]
// CHECK: fadd float
Loading