@@ -4748,6 +4748,167 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
4748
4748
patterns.add <SimplifyBoxTotalElementsOp>(context);
4749
4749
}
4750
4750
4751
+ // ===----------------------------------------------------------------------===//
4752
+ // DoConcurrentOp
4753
+ // ===----------------------------------------------------------------------===//
4754
+
4755
+ llvm::LogicalResult fir::DoConcurrentOp::verify () {
4756
+ mlir::Block *body = getBody ();
4757
+
4758
+ if (body->empty ())
4759
+ return emitOpError (" body cannot be empty" );
4760
+
4761
+ if (!body->mightHaveTerminator () ||
4762
+ !mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator ()))
4763
+ return emitOpError (" must be terminated by 'fir.do_concurrent.loop'" );
4764
+
4765
+ return mlir::success ();
4766
+ }
4767
+
4768
+ // ===----------------------------------------------------------------------===//
4769
+ // DoConcurrentLoopOp
4770
+ // ===----------------------------------------------------------------------===//
4771
+
4772
+ mlir::ParseResult fir::DoConcurrentLoopOp::parse (mlir::OpAsmParser &parser,
4773
+ mlir::OperationState &result) {
4774
+ auto &builder = parser.getBuilder ();
4775
+ // Parse an opening `(` followed by induction variables followed by `)`
4776
+ llvm::SmallVector<mlir::OpAsmParser::Argument, 4 > ivs;
4777
+ if (parser.parseArgumentList (ivs, mlir::OpAsmParser::Delimiter::Paren))
4778
+ return mlir::failure ();
4779
+
4780
+ // Parse loop bounds.
4781
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > lower;
4782
+ if (parser.parseEqual () ||
4783
+ parser.parseOperandList (lower, ivs.size (),
4784
+ mlir::OpAsmParser::Delimiter::Paren) ||
4785
+ parser.resolveOperands (lower, builder.getIndexType (), result.operands ))
4786
+ return mlir::failure ();
4787
+
4788
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > upper;
4789
+ if (parser.parseKeyword (" to" ) ||
4790
+ parser.parseOperandList (upper, ivs.size (),
4791
+ mlir::OpAsmParser::Delimiter::Paren) ||
4792
+ parser.resolveOperands (upper, builder.getIndexType (), result.operands ))
4793
+ return mlir::failure ();
4794
+
4795
+ // Parse step values.
4796
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4 > steps;
4797
+ if (parser.parseKeyword (" step" ) ||
4798
+ parser.parseOperandList (steps, ivs.size (),
4799
+ mlir::OpAsmParser::Delimiter::Paren) ||
4800
+ parser.resolveOperands (steps, builder.getIndexType (), result.operands ))
4801
+ return mlir::failure ();
4802
+
4803
+ llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
4804
+ llvm::SmallVector<mlir::Type> reduceArgTypes;
4805
+ if (succeeded (parser.parseOptionalKeyword (" reduce" ))) {
4806
+ // Parse reduction attributes and variables.
4807
+ llvm::SmallVector<fir::ReduceAttr> attributes;
4808
+ if (failed (parser.parseCommaSeparatedList (
4809
+ mlir::AsmParser::Delimiter::Paren, [&]() {
4810
+ if (parser.parseAttribute (attributes.emplace_back ()) ||
4811
+ parser.parseArrow () ||
4812
+ parser.parseOperand (reduceOperands.emplace_back ()) ||
4813
+ parser.parseColonType (reduceArgTypes.emplace_back ()))
4814
+ return mlir::failure ();
4815
+ return mlir::success ();
4816
+ })))
4817
+ return mlir::failure ();
4818
+ // Resolve input operands.
4819
+ for (auto operand_type : llvm::zip (reduceOperands, reduceArgTypes))
4820
+ if (parser.resolveOperand (std::get<0 >(operand_type),
4821
+ std::get<1 >(operand_type), result.operands ))
4822
+ return mlir::failure ();
4823
+ llvm::SmallVector<mlir::Attribute> arrayAttr (attributes.begin (),
4824
+ attributes.end ());
4825
+ result.addAttribute (getReduceAttrsAttrName (result.name ),
4826
+ builder.getArrayAttr (arrayAttr));
4827
+ }
4828
+
4829
+ // Now parse the body.
4830
+ mlir::Region *body = result.addRegion ();
4831
+ for (auto &iv : ivs)
4832
+ iv.type = builder.getIndexType ();
4833
+ if (parser.parseRegion (*body, ivs))
4834
+ return mlir::failure ();
4835
+
4836
+ // Set `operandSegmentSizes` attribute.
4837
+ result.addAttribute (DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
4838
+ builder.getDenseI32ArrayAttr (
4839
+ {static_cast <int32_t >(lower.size ()),
4840
+ static_cast <int32_t >(upper.size ()),
4841
+ static_cast <int32_t >(steps.size ()),
4842
+ static_cast <int32_t >(reduceOperands.size ())}));
4843
+
4844
+ // Parse attributes.
4845
+ if (parser.parseOptionalAttrDict (result.attributes ))
4846
+ return mlir::failure ();
4847
+
4848
+ return mlir::success ();
4849
+ }
4850
+
4851
+ void fir::DoConcurrentLoopOp::print (mlir::OpAsmPrinter &p) {
4852
+ p << " (" << getBody ()->getArguments () << " ) = (" << getLowerBound ()
4853
+ << " ) to (" << getUpperBound () << " ) step (" << getStep () << " )" ;
4854
+
4855
+ if (!getReduceOperands ().empty ()) {
4856
+ p << " reduce(" ;
4857
+ auto attrs = getReduceAttrsAttr ();
4858
+ auto operands = getReduceOperands ();
4859
+ llvm::interleaveComma (llvm::zip (attrs, operands), p, [&](auto it) {
4860
+ p << std::get<0 >(it) << " -> " << std::get<1 >(it) << " : "
4861
+ << std::get<1 >(it).getType ();
4862
+ });
4863
+ p << ' )' ;
4864
+ }
4865
+
4866
+ p << ' ' ;
4867
+ p.printRegion (getRegion (), /* printEntryBlockArgs=*/ false );
4868
+ p.printOptionalAttrDict (
4869
+ (*this )->getAttrs (),
4870
+ /* elidedAttrs=*/ {DoConcurrentLoopOp::getOperandSegmentSizeAttr (),
4871
+ DoConcurrentLoopOp::getReduceAttrsAttrName ()});
4872
+ }
4873
+
4874
+ llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions () {
4875
+ return {&getRegion ()};
4876
+ }
4877
+
4878
+ llvm::LogicalResult fir::DoConcurrentLoopOp::verify () {
4879
+ mlir::Operation::operand_range lbValues = getLowerBound ();
4880
+ mlir::Operation::operand_range ubValues = getUpperBound ();
4881
+ mlir::Operation::operand_range stepValues = getStep ();
4882
+
4883
+ if (lbValues.empty ())
4884
+ return emitOpError (
4885
+ " needs at least one tuple element for lowerBound, upperBound and step" );
4886
+
4887
+ if (lbValues.size () != ubValues.size () ||
4888
+ ubValues.size () != stepValues.size ())
4889
+ return emitOpError (" different number of tuple elements for lowerBound, "
4890
+ " upperBound or step" );
4891
+
4892
+ // Check that the body defines the same number of block arguments as the
4893
+ // number of tuple elements in step.
4894
+ mlir::Block *body = getBody ();
4895
+ if (body->getNumArguments () != stepValues.size ())
4896
+ return emitOpError () << " expects the same number of induction variables: "
4897
+ << body->getNumArguments ()
4898
+ << " as bound and step values: " << stepValues.size ();
4899
+ for (auto arg : body->getArguments ())
4900
+ if (!arg.getType ().isIndex ())
4901
+ return emitOpError (
4902
+ " expects arguments for the induction variable to be of index type" );
4903
+
4904
+ auto reduceAttrs = getReduceAttrsAttr ();
4905
+ if (getNumReduceOperands () != (reduceAttrs ? reduceAttrs.size () : 0 ))
4906
+ return emitOpError (
4907
+ " mismatch in number of reduction variables and reduction attributes" );
4908
+
4909
+ return mlir::success ();
4910
+ }
4911
+
4751
4912
// ===----------------------------------------------------------------------===//
4752
4913
// FIROpsDialect
4753
4914
// ===----------------------------------------------------------------------===//
0 commit comments