Skip to content

Commit 7f690c4

Browse files
committed
Move byref reduction length check to verifyReductionVarList
1 parent e8c0ba4 commit 7f690c4

File tree

1 file changed

+15
-19
lines changed

1 file changed

+15
-19
lines changed

mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -599,14 +599,22 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
599599
}
600600

601601
/// Verifies Reduction Clause
602-
static LogicalResult verifyReductionVarList(Operation *op,
603-
std::optional<ArrayAttr> reductions,
604-
OperandRange reductionVars) {
602+
static LogicalResult
603+
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
604+
OperandRange reductionVars,
605+
std::optional<ArrayRef<bool>> byRef = std::nullopt) {
605606
if (!reductionVars.empty()) {
606607
if (!reductions || reductions->size() != reductionVars.size())
607608
return op->emitOpError()
608609
<< "expected as many reduction symbol references "
609610
"as reduction variables";
611+
if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
612+
assert(byRef);
613+
else
614+
assert(!byRef); // TODO: support byref reductions on other operations
615+
if (byRef && byRef->size() != reductionVars.size())
616+
return op->emitError() << "expected as many reduction variable by "
617+
"reference attributes as reduction variables";
610618
} else {
611619
if (reductions)
612620
return op->emitOpError() << "unexpected reduction symbol references";
@@ -1520,14 +1528,8 @@ LogicalResult ParallelOp::verify() {
15201528
if (failed(verifyPrivateVarList(*this)))
15211529
return failure();
15221530

1523-
auto reductionVarsByRef = getReductionVarsByref();
1524-
if (reductionVarsByRef &&
1525-
reductionVarsByRef->size() != getReductionVars().size())
1526-
return emitOpError()
1527-
<< "expected as many reduction variable by reference attributes "
1528-
"as reduction variables";
1529-
1530-
return verifyReductionVarList(*this, getReductions(), getReductionVars());
1531+
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1532+
getReductionVarsByref());
15311533
}
15321534

15331535
//===----------------------------------------------------------------------===//
@@ -1709,14 +1711,8 @@ LogicalResult WsloopOp::verify() {
17091711
return emitError() << "only supported nested wrapper is 'omp.simd'";
17101712
}
17111713

1712-
auto reductionVarsByRef = getReductionVarsByref();
1713-
if (reductionVarsByRef &&
1714-
reductionVarsByRef->size() != getReductionVars().size())
1715-
return emitOpError()
1716-
<< "expected as many reduction variable by reference attributes "
1717-
"as reduction variables";
1718-
1719-
return verifyReductionVarList(*this, getReductions(), getReductionVars());
1714+
return verifyReductionVarList(*this, getReductions(), getReductionVars(),
1715+
getReductionVarsByref());
17201716
}
17211717

17221718
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)