Skip to content

Commit 1094ffc

Browse files
authored
[flang][fir] Add MLIR op for do concurrent (#130893)
Adds new MLIR ops to model `do concurrent`. In order to make `do concurrent` representation self-contained, a loop is modeled using 2 ops, one wrapper and one that contains the actual body of the loop. For example, a 2D `do concurrent` loop is modeled as follows: ```mlir fir.do_concurrent { %i = fir.alloca i32 %j = fir.alloca i32 fir.do_concurrent.loop (%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) { %0 = fir.convert %i_iv : (index) -> i32 fir.store %0 to %i : !fir.ref<i32> %1 = fir.convert %j_iv : (index) -> i32 fir.store %1 to %j : !fir.ref<i32> } } ``` The `fir.do_concurrent` wrapper op encapsulates both the actual loop and the allocations required for the iteration variables. The `fir.do_concurrent.loop` op is a multi-dimensional op that contains the loop control and body. See the ops' docs for more info.
1 parent 036c6cb commit 1094ffc

File tree

4 files changed

+453
-0
lines changed

4 files changed

+453
-0
lines changed

flang/include/flang/Optimizer/Dialect/FIROps.td

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3446,4 +3446,109 @@ def fir_BoxTotalElementsOp
34463446
let hasCanonicalizer = 1;
34473447
}
34483448

3449+
def fir_DoConcurrentOp : fir_Op<"do_concurrent",
3450+
[SingleBlock, AutomaticAllocationScope]> {
3451+
let summary = "do concurrent loop wrapper";
3452+
3453+
let description = [{
3454+
A wrapper operation for the actual op modeling `do concurrent` loops:
3455+
`fir.do_concurrent.loop` (see op declaration below for more info about it).
3456+
3457+
The `fir.do_concurrent` wrapper op consists of one single-block region with
3458+
the following properties:
3459+
- The first ops in the region are responsible for allocating storage for the
3460+
loop's iteration variables. This is property is **not** enforced by the op
3461+
verifier, but expected to be respected when building the op.
3462+
- The terminator of the region is an instance of `fir.do_concurrent.loop`.
3463+
3464+
For example, a 2D loop nest would be represented as follows:
3465+
```
3466+
fir.do_concurrent {
3467+
%i = fir.alloca i32
3468+
%j = fir.alloca i32
3469+
fir.do_concurrent.loop ...
3470+
}
3471+
```
3472+
}];
3473+
3474+
let regions = (region SizedRegion<1>:$region);
3475+
3476+
let assemblyFormat = "$region attr-dict";
3477+
let hasVerifier = 1;
3478+
}
3479+
3480+
def fir_DoConcurrentLoopOp : fir_Op<"do_concurrent.loop",
3481+
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
3482+
Terminator, NoTerminator, SingleBlock, ParentOneOf<["DoConcurrentOp"]>]> {
3483+
let summary = "do concurrent loop";
3484+
3485+
let description = [{
3486+
An operation that models a Fortran `do concurrent` loop's header and block.
3487+
This is a single-region single-block terminator op that is expected to
3488+
terminate the region of a `omp.do_concurrent` wrapper op.
3489+
3490+
This op borrows from both `scf.parallel` and `fir.do_loop` ops. Similar to
3491+
`scf.parallel`, a loop nest takes 3 groups of SSA values as operands that
3492+
represent the lower bounds, upper bounds, and steps. Similar to `fir.do_loop`
3493+
the op takes one additional group of SSA values to represent reductions.
3494+
3495+
The body region **does not** have a terminator.
3496+
3497+
For example, a 2D loop nest with 2 reductions (sum and max) would be
3498+
represented as follows:
3499+
```
3500+
// The wrapper of the loop
3501+
fir.do_concurrent {
3502+
%i = fir.alloca i32
3503+
%j = fir.alloca i32
3504+
3505+
// The actual `do concurrent` loop
3506+
fir.do_concurrent.loop
3507+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
3508+
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>,
3509+
#fir.reduce_attr<max> -> %max : !fir.ref<f32>) {
3510+
3511+
%0 = fir.convert %i_iv : (index) -> i32
3512+
fir.store %0 to %i : !fir.ref<i32>
3513+
3514+
%1 = fir.convert %j_iv : (index) -> i32
3515+
fir.store %1 to %j : !fir.ref<i32>
3516+
3517+
// ... loop body goes here ...
3518+
}
3519+
}
3520+
```
3521+
3522+
Description of arguments:
3523+
- `lowerBound`: The group of SSA values for the nest's lower bounds.
3524+
- `upperBound`: The group of SSA values for the nest's upper bounds.
3525+
- `step`: The group of SSA values for the nest's steps.
3526+
- `reduceOperands`: The reduction SSA values, if any.
3527+
- `reduceAttrs`: Attributes to store reduction operations, if any.
3528+
- `loopAnnotation`: Loop metadata to be passed down the compiler pipeline to
3529+
LLVM.
3530+
}];
3531+
3532+
let arguments = (ins
3533+
Variadic<Index>:$lowerBound,
3534+
Variadic<Index>:$upperBound,
3535+
Variadic<Index>:$step,
3536+
Variadic<AnyType>:$reduceOperands,
3537+
OptionalAttr<ArrayAttr>:$reduceAttrs,
3538+
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
3539+
);
3540+
3541+
let regions = (region SizedRegion<1>:$region);
3542+
3543+
let hasCustomAssemblyFormat = 1;
3544+
let hasVerifier = 1;
3545+
3546+
let extraClassDeclaration = [{
3547+
// Get Number of reduction operands
3548+
unsigned getNumReduceOperands() {
3549+
return getReduceOperands().size();
3550+
}
3551+
}];
3552+
}
3553+
34493554
#endif

flang/lib/Optimizer/Dialect/FIROps.cpp

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4748,6 +4748,167 @@ void fir::BoxTotalElementsOp::getCanonicalizationPatterns(
47484748
patterns.add<SimplifyBoxTotalElementsOp>(context);
47494749
}
47504750

4751+
//===----------------------------------------------------------------------===//
4752+
// DoConcurrentOp
4753+
//===----------------------------------------------------------------------===//
4754+
4755+
llvm::LogicalResult fir::DoConcurrentOp::verify() {
4756+
mlir::Block *body = getBody();
4757+
4758+
if (body->empty())
4759+
return emitOpError("body cannot be empty");
4760+
4761+
if (!body->mightHaveTerminator() ||
4762+
!mlir::isa<fir::DoConcurrentLoopOp>(body->getTerminator()))
4763+
return emitOpError("must be terminated by 'fir.do_concurrent.loop'");
4764+
4765+
return mlir::success();
4766+
}
4767+
4768+
//===----------------------------------------------------------------------===//
4769+
// DoConcurrentLoopOp
4770+
//===----------------------------------------------------------------------===//
4771+
4772+
mlir::ParseResult fir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
4773+
mlir::OperationState &result) {
4774+
auto &builder = parser.getBuilder();
4775+
// Parse an opening `(` followed by induction variables followed by `)`
4776+
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
4777+
if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
4778+
return mlir::failure();
4779+
4780+
// Parse loop bounds.
4781+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
4782+
if (parser.parseEqual() ||
4783+
parser.parseOperandList(lower, ivs.size(),
4784+
mlir::OpAsmParser::Delimiter::Paren) ||
4785+
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
4786+
return mlir::failure();
4787+
4788+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
4789+
if (parser.parseKeyword("to") ||
4790+
parser.parseOperandList(upper, ivs.size(),
4791+
mlir::OpAsmParser::Delimiter::Paren) ||
4792+
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
4793+
return mlir::failure();
4794+
4795+
// Parse step values.
4796+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
4797+
if (parser.parseKeyword("step") ||
4798+
parser.parseOperandList(steps, ivs.size(),
4799+
mlir::OpAsmParser::Delimiter::Paren) ||
4800+
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
4801+
return mlir::failure();
4802+
4803+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
4804+
llvm::SmallVector<mlir::Type> reduceArgTypes;
4805+
if (succeeded(parser.parseOptionalKeyword("reduce"))) {
4806+
// Parse reduction attributes and variables.
4807+
llvm::SmallVector<fir::ReduceAttr> attributes;
4808+
if (failed(parser.parseCommaSeparatedList(
4809+
mlir::AsmParser::Delimiter::Paren, [&]() {
4810+
if (parser.parseAttribute(attributes.emplace_back()) ||
4811+
parser.parseArrow() ||
4812+
parser.parseOperand(reduceOperands.emplace_back()) ||
4813+
parser.parseColonType(reduceArgTypes.emplace_back()))
4814+
return mlir::failure();
4815+
return mlir::success();
4816+
})))
4817+
return mlir::failure();
4818+
// Resolve input operands.
4819+
for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
4820+
if (parser.resolveOperand(std::get<0>(operand_type),
4821+
std::get<1>(operand_type), result.operands))
4822+
return mlir::failure();
4823+
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
4824+
attributes.end());
4825+
result.addAttribute(getReduceAttrsAttrName(result.name),
4826+
builder.getArrayAttr(arrayAttr));
4827+
}
4828+
4829+
// Now parse the body.
4830+
mlir::Region *body = result.addRegion();
4831+
for (auto &iv : ivs)
4832+
iv.type = builder.getIndexType();
4833+
if (parser.parseRegion(*body, ivs))
4834+
return mlir::failure();
4835+
4836+
// Set `operandSegmentSizes` attribute.
4837+
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
4838+
builder.getDenseI32ArrayAttr(
4839+
{static_cast<int32_t>(lower.size()),
4840+
static_cast<int32_t>(upper.size()),
4841+
static_cast<int32_t>(steps.size()),
4842+
static_cast<int32_t>(reduceOperands.size())}));
4843+
4844+
// Parse attributes.
4845+
if (parser.parseOptionalAttrDict(result.attributes))
4846+
return mlir::failure();
4847+
4848+
return mlir::success();
4849+
}
4850+
4851+
void fir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
4852+
p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
4853+
<< ") to (" << getUpperBound() << ") step (" << getStep() << ")";
4854+
4855+
if (!getReduceOperands().empty()) {
4856+
p << " reduce(";
4857+
auto attrs = getReduceAttrsAttr();
4858+
auto operands = getReduceOperands();
4859+
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
4860+
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
4861+
<< std::get<1>(it).getType();
4862+
});
4863+
p << ')';
4864+
}
4865+
4866+
p << ' ';
4867+
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
4868+
p.printOptionalAttrDict(
4869+
(*this)->getAttrs(),
4870+
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
4871+
DoConcurrentLoopOp::getReduceAttrsAttrName()});
4872+
}
4873+
4874+
llvm::SmallVector<mlir::Region *> fir::DoConcurrentLoopOp::getLoopRegions() {
4875+
return {&getRegion()};
4876+
}
4877+
4878+
llvm::LogicalResult fir::DoConcurrentLoopOp::verify() {
4879+
mlir::Operation::operand_range lbValues = getLowerBound();
4880+
mlir::Operation::operand_range ubValues = getUpperBound();
4881+
mlir::Operation::operand_range stepValues = getStep();
4882+
4883+
if (lbValues.empty())
4884+
return emitOpError(
4885+
"needs at least one tuple element for lowerBound, upperBound and step");
4886+
4887+
if (lbValues.size() != ubValues.size() ||
4888+
ubValues.size() != stepValues.size())
4889+
return emitOpError("different number of tuple elements for lowerBound, "
4890+
"upperBound or step");
4891+
4892+
// Check that the body defines the same number of block arguments as the
4893+
// number of tuple elements in step.
4894+
mlir::Block *body = getBody();
4895+
if (body->getNumArguments() != stepValues.size())
4896+
return emitOpError() << "expects the same number of induction variables: "
4897+
<< body->getNumArguments()
4898+
<< " as bound and step values: " << stepValues.size();
4899+
for (auto arg : body->getArguments())
4900+
if (!arg.getType().isIndex())
4901+
return emitOpError(
4902+
"expects arguments for the induction variable to be of index type");
4903+
4904+
auto reduceAttrs = getReduceAttrsAttr();
4905+
if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
4906+
return emitOpError(
4907+
"mismatch in number of reduction variables and reduction attributes");
4908+
4909+
return mlir::success();
4910+
}
4911+
47514912
//===----------------------------------------------------------------------===//
47524913
// FIROpsDialect
47534914
//===----------------------------------------------------------------------===//

flang/test/Fir/do_concurrent.fir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Test fir.do_concurrent operation parse, verify (no errors), and unparse
2+
3+
// RUN: fir-opt %s | fir-opt | FileCheck %s
4+
5+
func.func @dc_1d(%i_lb: index, %i_ub: index, %i_st: index) {
6+
fir.do_concurrent {
7+
%i = fir.alloca i32
8+
fir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st) {
9+
%0 = fir.convert %i_iv : (index) -> i32
10+
fir.store %0 to %i : !fir.ref<i32>
11+
}
12+
}
13+
return
14+
}
15+
16+
// CHECK-LABEL: func.func @dc_1d
17+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index)
18+
// CHECK: fir.do_concurrent {
19+
// CHECK: %[[I:.*]] = fir.alloca i32
20+
// CHECK: fir.do_concurrent.loop (%[[I_IV:.*]]) = (%[[I_LB]]) to (%[[I_UB]]) step (%[[I_ST]]) {
21+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
22+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
23+
// CHECK: }
24+
// CHECK: }
25+
26+
func.func @dc_2d(%i_lb: index, %i_ub: index, %i_st: index,
27+
%j_lb: index, %j_ub: index, %j_st: index) {
28+
fir.do_concurrent {
29+
%i = fir.alloca i32
30+
%j = fir.alloca i32
31+
fir.do_concurrent.loop
32+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) {
33+
%0 = fir.convert %i_iv : (index) -> i32
34+
fir.store %0 to %i : !fir.ref<i32>
35+
36+
%1 = fir.convert %j_iv : (index) -> i32
37+
fir.store %1 to %j : !fir.ref<i32>
38+
}
39+
}
40+
return
41+
}
42+
43+
// CHECK-LABEL: func.func @dc_2d
44+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)
45+
// CHECK: fir.do_concurrent {
46+
// CHECK: %[[I:.*]] = fir.alloca i32
47+
// CHECK: %[[J:.*]] = fir.alloca i32
48+
// CHECK: fir.do_concurrent.loop
49+
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) {
50+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
51+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
52+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
53+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
54+
// CHECK: }
55+
// CHECK: }
56+
57+
func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
58+
%j_lb: index, %j_ub: index, %j_st: index) {
59+
%sum = fir.alloca i32
60+
61+
fir.do_concurrent {
62+
%i = fir.alloca i32
63+
%j = fir.alloca i32
64+
fir.do_concurrent.loop
65+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
66+
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
67+
%0 = fir.convert %i_iv : (index) -> i32
68+
fir.store %0 to %i : !fir.ref<i32>
69+
70+
%1 = fir.convert %j_iv : (index) -> i32
71+
fir.store %1 to %j : !fir.ref<i32>
72+
}
73+
}
74+
return
75+
}
76+
77+
// CHECK-LABEL: func.func @dc_2d_reduction
78+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)
79+
80+
// CHECK: %[[SUM:.*]] = fir.alloca i32
81+
82+
// CHECK: fir.do_concurrent {
83+
// CHECK: %[[I:.*]] = fir.alloca i32
84+
// CHECK: %[[J:.*]] = fir.alloca i32
85+
// CHECK: fir.do_concurrent.loop
86+
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) reduce(#fir.reduce_attr<add> -> %[[SUM]] : !fir.ref<i32>) {
87+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
88+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
89+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
90+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
91+
// CHECK: }
92+
// CHECK: }

0 commit comments

Comments
 (0)