Skip to content

[OpenMP][MLIR] Set omp.composite attr for composite loop wrappers and add verifier checks #102341

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Aug 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions flang/lib/Lower/OpenMP/OpenMP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2063,10 +2063,12 @@ static void genCompositeDistributeSimd(
// TODO: Populate entry block arguments with private variables.
auto distributeOp = genWrapperOp<mlir::omp::DistributeOp>(
converter, loc, distributeClauseOps, /*blockArgTypes=*/{});
distributeOp.setComposite(/*val=*/true);

// TODO: Populate entry block arguments with reduction and private variables.
auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps,
/*blockArgTypes=*/{});
simdOp.setComposite(/*val=*/true);

// Construct wrapper entry block list and associated symbols. It is important
// that the symbol order and the block argument order match, so that the
Expand Down Expand Up @@ -2111,10 +2113,12 @@ static void genCompositeDoSimd(lower::AbstractConverter &converter,
// TODO: Add private variables to entry block arguments.
auto wsloopOp = genWrapperOp<mlir::omp::WsloopOp>(
converter, loc, wsloopClauseOps, wsloopReductionTypes);
wsloopOp.setComposite(/*val=*/true);

// TODO: Populate entry block arguments with reduction and private variables.
auto simdOp = genWrapperOp<mlir::omp::SimdOp>(converter, loc, simdClauseOps,
/*blockArgTypes=*/{});
simdOp.setComposite(/*val=*/true);

// Construct wrapper entry block list and associated symbols. It is important
// that the symbol and block argument order match, so that the symbol-value
Expand Down
18 changes: 13 additions & 5 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def PrivateClauseOp : OpenMP_Op<"private", [IsolatedFromAbove, RecipeInterface]>

def ParallelOp : OpenMP_Op<"parallel", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
DeclareOpInterfaceMethods<OutlineableOpenMPOpInterface>,
RecursiveMemoryEffects
Expand Down Expand Up @@ -356,7 +357,9 @@ def LoopNestOp : OpenMP_Op<"loop_nest", traits = [
//===----------------------------------------------------------------------===//

def WsloopOp : OpenMP_Op<"wsloop", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, SingleBlock
], clauses = [
OpenMP_AllocateClauseSkip<assemblyFormat = true>,
Expand Down Expand Up @@ -432,7 +435,9 @@ def WsloopOp : OpenMP_Op<"wsloop", traits = [
//===----------------------------------------------------------------------===//

def SimdOp : OpenMP_Op<"simd", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, SingleBlock
], clauses = [
OpenMP_AlignedClause, OpenMP_IfClause, OpenMP_LinearClause,
Expand Down Expand Up @@ -499,7 +504,9 @@ def YieldOp : OpenMP_Op<"yield",
// Distribute construct [2.9.4.1]
//===----------------------------------------------------------------------===//
def DistributeOp : OpenMP_Op<"distribute", traits = [
AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopWrapperInterface>,
AttrSizedOperandSegments,
DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, SingleBlock
], clauses = [
OpenMP_AllocateClause, OpenMP_DistScheduleClause, OpenMP_OrderClause,
Expand Down Expand Up @@ -587,8 +594,9 @@ def TaskOp : OpenMP_Op<"task", traits = [

def TaskloopOp : OpenMP_Op<"taskloop", traits = [
AttrSizedOperandSegments, AutomaticAllocationScope,
DeclareOpInterfaceMethods<LoopWrapperInterface>, RecursiveMemoryEffects,
SingleBlock
DeclareOpInterfaceMethods<ComposableOpInterface>,
DeclareOpInterfaceMethods<LoopWrapperInterface>,
RecursiveMemoryEffects, SingleBlock
], clauses = [
OpenMP_AllocateClause, OpenMP_FinalClause, OpenMP_GrainsizeClause,
OpenMP_IfClause, OpenMP_InReductionClauseSkip<extraClassDeclaration = true>,
Expand Down
36 changes: 36 additions & 0 deletions mlir/include/mlir/Dialect/OpenMP/OpenMPOpsInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,42 @@ def LoopWrapperInterface : OpInterface<"LoopWrapperInterface"> {
];
}

def ComposableOpInterface : OpInterface<"ComposableOpInterface"> {
let description = [{
OpenMP operations that can represent a single leaf of a composite OpenMP
construct.
}];

let cppNamespace = "::mlir::omp";

let methods = [
InterfaceMethod<
/*description=*/[{
Check whether the operation is representing a leaf of a composite OpenMP
construct.
}],
/*retTy=*/"bool",
/*methodName=*/"isComposite",
(ins ), [{}], [{
return $_op->hasAttr("omp.composite");
}]
>,
InterfaceMethod<
/*description=*/[{
Mark the operation as part of an OpenMP composite construct.
}],
/*retTy=*/"void",
/*methodName=*/"setComposite",
(ins "bool":$val), [{}], [{
if (val)
$_op->setDiscardableAttr("omp.composite", mlir::UnitAttr::get($_op->getContext()));
else
$_op->removeDiscardableAttr("omp.composite");
}]
>
];
}

def DeclareTargetInterface : OpInterface<"DeclareTargetInterface"> {
let description = [{
OpenMP operations that support declare target have this interface.
Expand Down
52 changes: 52 additions & 0 deletions mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1546,6 +1546,9 @@ LogicalResult ParallelOp::verify() {
if (!isWrapper())
return emitOpError() << "must take a loop wrapper role if nested inside "
"of 'omp.distribute'";
if (!isComposite())
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
// Check for the allowed leaf constructs that may appear in a composite
Expand All @@ -1555,6 +1558,9 @@ LogicalResult ParallelOp::verify() {
} else {
return emitOpError() << "must not wrap an 'omp.loop_nest' directly";
}
} else if (isComposite()) {
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";
}

if (getAllocateVars().size() != getAllocatorVars().size())
Expand Down Expand Up @@ -1748,11 +1754,28 @@ LogicalResult WsloopOp::verify() {
if (!isWrapper())
return emitOpError() << "must be a loop wrapper";

auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
bool isCompositeChildLeaf =
wrapper && wrapper.isWrapper() &&
(!llvm::isa<ParallelOp>(wrapper) ||
llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));
if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";

// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DO/FOR.
if (!isa<SimdOp>(nested))
return emitError() << "only supported nested wrapper is 'omp.simd'";

} else if (isComposite() && !isCompositeChildLeaf) {
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";
} else if (!isComposite() && isCompositeChildLeaf) {
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";
}

return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
Expand Down Expand Up @@ -1796,6 +1819,21 @@ LogicalResult SimdOp::verify() {
if (getNestedWrapper())
return emitOpError() << "must wrap an 'omp.loop_nest' directly";

auto wrapper =
llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
bool isCompositeChildLeaf =
wrapper && wrapper.isWrapper() &&
(!llvm::isa<ParallelOp>(wrapper) ||
llvm::isa_and_present<DistributeOp>(wrapper->getParentOp()));

if (!isComposite() && isCompositeChildLeaf)
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";

if (isComposite() && !isCompositeChildLeaf)
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";

return success();
}

Expand Down Expand Up @@ -1825,11 +1863,17 @@ LogicalResult DistributeOp::verify() {
return emitOpError() << "must be a loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";
// Check for the allowed leaf constructs that may appear in a composite
// construct directly after DISTRIBUTE.
if (!isa<ParallelOp, SimdOp>(nested))
return emitError() << "only supported nested wrappers are 'omp.parallel' "
"and 'omp.simd'";
} else if (isComposite()) {
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";
}

return success();
Expand Down Expand Up @@ -2031,11 +2075,19 @@ LogicalResult TaskloopOp::verify() {
return emitOpError() << "must be a loop wrapper";

if (LoopWrapperInterface nested = getNestedWrapper()) {
if (!isComposite())
return emitError()
<< "'omp.composite' attribute missing from composite wrapper";

// Check for the allowed leaf constructs that may appear in a composite
// construct directly after TASKLOOP.
if (!isa<SimdOp>(nested))
return emitError() << "only supported nested wrapper is 'omp.simd'";
} else if (isComposite()) {
return emitError()
<< "'omp.composite' attribute present in non-composite wrapper";
}

return success();
}

Expand Down
Loading
Loading