@@ -599,14 +599,22 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
599
599
}
600
600
601
601
// / Verifies Reduction Clause
602
- static LogicalResult verifyReductionVarList (Operation *op,
603
- std::optional<ArrayAttr> reductions,
604
- OperandRange reductionVars) {
602
+ static LogicalResult
603
+ verifyReductionVarList (Operation *op, std::optional<ArrayAttr> reductions,
604
+ OperandRange reductionVars,
605
+ std::optional<ArrayRef<bool >> byRef = std::nullopt) {
605
606
if (!reductionVars.empty ()) {
606
607
if (!reductions || reductions->size () != reductionVars.size ())
607
608
return op->emitOpError ()
608
609
<< " expected as many reduction symbol references "
609
610
" as reduction variables" ;
611
+ if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
612
+ assert (byRef);
613
+ else
614
+ assert (!byRef); // TODO: support byref reductions on other operations
615
+ if (byRef && byRef->size () != reductionVars.size ())
616
+ return op->emitError () << " expected as many reduction variable by "
617
+ " reference attributes as reduction variables" ;
610
618
} else {
611
619
if (reductions)
612
620
return op->emitOpError () << " unexpected reduction symbol references" ;
@@ -1520,14 +1528,8 @@ LogicalResult ParallelOp::verify() {
1520
1528
if (failed (verifyPrivateVarList (*this )))
1521
1529
return failure ();
1522
1530
1523
- auto reductionVarsByRef = getReductionVarsByref ();
1524
- if (reductionVarsByRef &&
1525
- reductionVarsByRef->size () != getReductionVars ().size ())
1526
- return emitOpError ()
1527
- << " expected as many reduction variable by reference attributes "
1528
- " as reduction variables" ;
1529
-
1530
- return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1531
+ return verifyReductionVarList (*this , getReductions (), getReductionVars (),
1532
+ getReductionVarsByref ());
1531
1533
}
1532
1534
1533
1535
// ===----------------------------------------------------------------------===//
@@ -1709,14 +1711,8 @@ LogicalResult WsloopOp::verify() {
1709
1711
return emitError () << " only supported nested wrapper is 'omp.simd'" ;
1710
1712
}
1711
1713
1712
- auto reductionVarsByRef = getReductionVarsByref ();
1713
- if (reductionVarsByRef &&
1714
- reductionVarsByRef->size () != getReductionVars ().size ())
1715
- return emitOpError ()
1716
- << " expected as many reduction variable by reference attributes "
1717
- " as reduction variables" ;
1718
-
1719
- return verifyReductionVarList (*this , getReductions (), getReductionVars ());
1714
+ return verifyReductionVarList (*this , getReductions (), getReductionVars (),
1715
+ getReductionVarsByref ());
1720
1716
}
1721
1717
1722
1718
// ===----------------------------------------------------------------------===//
0 commit comments