|
12 | 12 |
|
13 | 13 | #include "flang/Optimizer/HLFIR/HLFIROps.h"
|
14 | 14 |
|
| 15 | +#include "flang/Optimizer/Dialect/FIRAttr.h" |
15 | 16 | #include "flang/Optimizer/Dialect/FIROpsSupport.h"
|
16 | 17 | #include "flang/Optimizer/Dialect/FIRType.h"
|
17 | 18 | #include "flang/Optimizer/Dialect/Support/FIRContext.h"
|
@@ -2246,6 +2247,168 @@ llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
|
2246 | 2247 | return mlir::success();
|
2247 | 2248 | }
|
2248 | 2249 |
|
| 2250 | +//===----------------------------------------------------------------------===// |
| 2251 | +// DoConcurrentOp |
| 2252 | +//===----------------------------------------------------------------------===// |
| 2253 | + |
| 2254 | +llvm::LogicalResult hlfir::DoConcurrentOp::verify() { |
| 2255 | + mlir::Block *body = getBody(); |
| 2256 | + |
| 2257 | + if (body->empty()) |
| 2258 | + return emitOpError("body cannot be empty"); |
| 2259 | + |
| 2260 | + if (!body->mightHaveTerminator() || |
| 2261 | + !mlir::isa<hlfir::DoConcurrentLoopOp>(body->getTerminator())) |
| 2262 | + return emitOpError("must be terminated by 'hlfir.do_concurrent.loop'"); |
| 2263 | + |
| 2264 | + return mlir::success(); |
| 2265 | +} |
| 2266 | + |
| 2267 | +//===----------------------------------------------------------------------===// |
| 2268 | +// DoConcurrentLoopOp |
| 2269 | +//===----------------------------------------------------------------------===// |
| 2270 | + |
| 2271 | +mlir::ParseResult |
| 2272 | +hlfir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser, |
| 2273 | + mlir::OperationState &result) { |
| 2274 | + auto &builder = parser.getBuilder(); |
| 2275 | + // Parse an opening `(` followed by induction variables followed by `)` |
| 2276 | + llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs; |
| 2277 | + if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren)) |
| 2278 | + return mlir::failure(); |
| 2279 | + |
| 2280 | + // Parse loop bounds. |
| 2281 | + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower; |
| 2282 | + if (parser.parseEqual() || |
| 2283 | + parser.parseOperandList(lower, ivs.size(), |
| 2284 | + mlir::OpAsmParser::Delimiter::Paren) || |
| 2285 | + parser.resolveOperands(lower, builder.getIndexType(), result.operands)) |
| 2286 | + return mlir::failure(); |
| 2287 | + |
| 2288 | + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper; |
| 2289 | + if (parser.parseKeyword("to") || |
| 2290 | + parser.parseOperandList(upper, ivs.size(), |
| 2291 | + mlir::OpAsmParser::Delimiter::Paren) || |
| 2292 | + parser.resolveOperands(upper, builder.getIndexType(), result.operands)) |
| 2293 | + return mlir::failure(); |
| 2294 | + |
| 2295 | + // Parse step values. |
| 2296 | + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps; |
| 2297 | + if (parser.parseKeyword("step") || |
| 2298 | + parser.parseOperandList(steps, ivs.size(), |
| 2299 | + mlir::OpAsmParser::Delimiter::Paren) || |
| 2300 | + parser.resolveOperands(steps, builder.getIndexType(), result.operands)) |
| 2301 | + return mlir::failure(); |
| 2302 | + |
| 2303 | + llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands; |
| 2304 | + llvm::SmallVector<mlir::Type> reduceArgTypes; |
| 2305 | + if (succeeded(parser.parseOptionalKeyword("reduce"))) { |
| 2306 | + // Parse reduction attributes and variables. |
| 2307 | + llvm::SmallVector<fir::ReduceAttr> attributes; |
| 2308 | + if (failed(parser.parseCommaSeparatedList( |
| 2309 | + mlir::AsmParser::Delimiter::Paren, [&]() { |
| 2310 | + if (parser.parseAttribute(attributes.emplace_back()) || |
| 2311 | + parser.parseArrow() || |
| 2312 | + parser.parseOperand(reduceOperands.emplace_back()) || |
| 2313 | + parser.parseColonType(reduceArgTypes.emplace_back())) |
| 2314 | + return mlir::failure(); |
| 2315 | + return mlir::success(); |
| 2316 | + }))) |
| 2317 | + return mlir::failure(); |
| 2318 | + // Resolve input operands. |
| 2319 | + for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes)) |
| 2320 | + if (parser.resolveOperand(std::get<0>(operand_type), |
| 2321 | + std::get<1>(operand_type), result.operands)) |
| 2322 | + return mlir::failure(); |
| 2323 | + llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(), |
| 2324 | + attributes.end()); |
| 2325 | + result.addAttribute(getReduceAttrsAttrName(result.name), |
| 2326 | + builder.getArrayAttr(arrayAttr)); |
| 2327 | + } |
| 2328 | + |
| 2329 | + // Now parse the body. |
| 2330 | + mlir::Region *body = result.addRegion(); |
| 2331 | + for (auto &iv : ivs) |
| 2332 | + iv.type = builder.getIndexType(); |
| 2333 | + if (parser.parseRegion(*body, ivs)) |
| 2334 | + return mlir::failure(); |
| 2335 | + |
| 2336 | + // Set `operandSegmentSizes` attribute. |
| 2337 | + result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
| 2338 | + builder.getDenseI32ArrayAttr( |
| 2339 | + {static_cast<int32_t>(lower.size()), |
| 2340 | + static_cast<int32_t>(upper.size()), |
| 2341 | + static_cast<int32_t>(steps.size()), |
| 2342 | + static_cast<int32_t>(reduceOperands.size())})); |
| 2343 | + |
| 2344 | + // Parse attributes. |
| 2345 | + if (parser.parseOptionalAttrDict(result.attributes)) |
| 2346 | + return mlir::failure(); |
| 2347 | + |
| 2348 | + return mlir::success(); |
| 2349 | +} |
| 2350 | + |
| 2351 | +void hlfir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) { |
| 2352 | + p << " (" << getBody()->getArguments() << ") = (" << getLowerBound() |
| 2353 | + << ") to (" << getUpperBound() << ") step (" << getStep() << ")"; |
| 2354 | + |
| 2355 | + if (hasReduceOperands()) { |
| 2356 | + p << " reduce("; |
| 2357 | + auto attrs = getReduceAttrsAttr(); |
| 2358 | + auto operands = getReduceOperands(); |
| 2359 | + llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) { |
| 2360 | + p << std::get<0>(it) << " -> " << std::get<1>(it) << " : " |
| 2361 | + << std::get<1>(it).getType(); |
| 2362 | + }); |
| 2363 | + p << ')'; |
| 2364 | + } |
| 2365 | + |
| 2366 | + p << ' '; |
| 2367 | + p.printRegion(getRegion(), /*printEntryBlockArgs=*/false); |
| 2368 | + p.printOptionalAttrDict( |
| 2369 | + (*this)->getAttrs(), |
| 2370 | + /*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(), |
| 2371 | + DoConcurrentLoopOp::getReduceAttrsAttrName()}); |
| 2372 | +} |
| 2373 | + |
| 2374 | +llvm::SmallVector<mlir::Region *> hlfir::DoConcurrentLoopOp::getLoopRegions() { |
| 2375 | + return {&getRegion()}; |
| 2376 | +} |
| 2377 | + |
| 2378 | +llvm::LogicalResult hlfir::DoConcurrentLoopOp::verify() { |
| 2379 | + mlir::Operation::operand_range lbValues = getLowerBound(); |
| 2380 | + mlir::Operation::operand_range ubValues = getUpperBound(); |
| 2381 | + mlir::Operation::operand_range stepValues = getStep(); |
| 2382 | + |
| 2383 | + if (lbValues.empty()) |
| 2384 | + return emitOpError( |
| 2385 | + "needs at least one tuple element for lowerBound, upperBound and step"); |
| 2386 | + |
| 2387 | + if (lbValues.size() != ubValues.size() || |
| 2388 | + ubValues.size() != stepValues.size()) |
| 2389 | + return emitOpError( |
| 2390 | + "different number of tuple elements for lowerBound, upperBound or step"); |
| 2391 | + |
| 2392 | + // Check that the body defines the same number of block arguments as the |
| 2393 | + // number of tuple elements in step. |
| 2394 | + mlir::Block *body = getBody(); |
| 2395 | + if (body->getNumArguments() != stepValues.size()) |
| 2396 | + return emitOpError() << "expects the same number of induction variables: " |
| 2397 | + << body->getNumArguments() |
| 2398 | + << " as bound and step values: " << stepValues.size(); |
| 2399 | + for (auto arg : body->getArguments()) |
| 2400 | + if (!arg.getType().isIndex()) |
| 2401 | + return emitOpError( |
| 2402 | + "expects arguments for the induction variable to be of index type"); |
| 2403 | + |
| 2404 | + auto reduceAttrs = getReduceAttrsAttr(); |
| 2405 | + if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0)) |
| 2406 | + return emitOpError( |
| 2407 | + "mismatch in number of reduction variables and reduction attributes"); |
| 2408 | + |
| 2409 | + return mlir::success(); |
| 2410 | +} |
| 2411 | + |
2249 | 2412 | #include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
|
2250 | 2413 | #define GET_OP_CLASSES
|
2251 | 2414 | #include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"
|
|
0 commit comments