Skip to content

Commit 0490357

Browse files
committed
[flang][hlfir] Add MLIR op for do concurrent
Adds new MLIR ops to model `do concurrent`. In order to make `do concurrent` representation self-contained, a loop is modeled using 2 ops, one wrapper and one that contains the actual body of the loop. For example, a 2D `do concurrent` loop is modeled as follows: ```mlir hlfir.do_concurrent { %i = fir.alloca i32 %j = fir.alloca i32 hlfir.do_concurrent.loop (%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) { %0 = fir.convert %i_iv : (index) -> i32 fir.store %0 to %i : !fir.ref<i32> %1 = fir.convert %j_iv : (index) -> i32 fir.store %1 to %j : !fir.ref<i32> } } ``` The `hlfir.do_concurrent` wrapper op encapsulates both the actual loop and the allocations required for the iteration variables. The `hlfir.do_concurrent.loop` op is a multi-dimensional op that contains the loop control and body. See the ops' docs for more info.
1 parent 606e9fa commit 0490357

File tree

4 files changed

+466
-0
lines changed

4 files changed

+466
-0
lines changed

flang/include/flang/Optimizer/HLFIR/HLFIROps.td

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ include "flang/Optimizer/Dialect/FIRAttr.td"
2121
include "flang/Optimizer/Dialect/FortranVariableInterface.td"
2222
include "mlir/Dialect/Arith/IR/ArithBase.td"
2323
include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td"
24+
include "mlir/Dialect/LLVMIR/LLVMAttrDefs.td"
2425
include "mlir/IR/BuiltinAttributes.td"
2526

2627
// Base class for FIR operations.
@@ -1863,5 +1864,120 @@ def hlfir_EvaluateInMemoryOp : hlfir_Op<"eval_in_mem", [AttrSizedOperandSegments
18631864
let hasVerifier = 1;
18641865
}
18651866

1867+
def hlfir_DoConcurrentOp : hlfir_Op<"do_concurrent", [SingleBlock]> {
1868+
let summary = "do concurrent loop wrapper";
1869+
1870+
let description = [{
1871+
A wrapper operation for the actual op modeling `do concurrent` loops:
1872+
`hlfir.do_concurrent.loop` (see op declaration below for more info about it).
1873+
1874+
The `hlfir.do_concurrent` wrapper op consists of one single-block region with
1875+
the following properties:
1876+
- The first ops in the region are responsible for allocating storage for the
1877+
loop's iteration variables. This is property is **not** enforced by the op
1878+
verifier, but expected to be respected when building the op.
1879+
- The terminator of the region is an instance of `hlfir.do_concurrent.loop`.
1880+
1881+
For example, a 2D loop nest would be represented as follows:
1882+
```
1883+
hlfir.do_concurrent {
1884+
%i = fir.alloca i32
1885+
%j = fir.alloca i32
1886+
hlfir.do_concurrent.loop ...
1887+
}
1888+
```
1889+
}];
1890+
1891+
let regions = (region SizedRegion<1>:$region);
1892+
1893+
let assemblyFormat = "$region attr-dict";
1894+
let hasVerifier = 1;
1895+
}
1896+
1897+
def hlfir_DoConcurrentLoopOp : hlfir_Op<"do_concurrent.loop",
1898+
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<LoopLikeOpInterface>,
1899+
Terminator, NoTerminator, SingleBlock, ParentOneOf<["DoConcurrentOp"]>]> {
1900+
let summary = "do concurrent loop";
1901+
1902+
let description = [{
1903+
An operation that models a Fortran `do concurrent` loop's header and block.
1904+
This is a single-region single-block terminator op that is expected to
1905+
terminate the region of a `omp.do_concurrent` wrapper op.
1906+
1907+
This op borrows from both `scf.parallel` and `fir.do_loop` ops. Similar to
1908+
`scf.parallel`, a loop nest takes 3 groups of SSA values as operands that
1909+
represent the lower bounds, upper bounds, and steps. Similar to `fir.do_loop`
1910+
the op takes one additional group of SSA values to represent reductions.
1911+
1912+
The body region **does not** have a terminator.
1913+
1914+
For example, a 2D loop nest with 2 reductions (sum and max) would be
1915+
represented as follows:
1916+
```
1917+
// The wrapper of the loop
1918+
hlfir.do_concurrent {
1919+
%i = fir.alloca i32
1920+
%j = fir.alloca i32
1921+
1922+
// The actual `do concurrent` loop
1923+
hlfir.do_concurrent.loop
1924+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
1925+
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>,
1926+
#fir.reduce_attr<max> -> %max : !fir.ref<f32>) {
1927+
1928+
%0 = fir.convert %i_iv : (index) -> i32
1929+
fir.store %0 to %i : !fir.ref<i32>
1930+
1931+
%1 = fir.convert %j_iv : (index) -> i32
1932+
fir.store %1 to %j : !fir.ref<i32>
1933+
1934+
// ... loop body goes here ...
1935+
}
1936+
}
1937+
```
1938+
1939+
Description of arguments:
1940+
- `lowerBound`: The group of SSA values for the nest's lower bounds.
1941+
- `upperBound`: The group of SSA values for the nest's upper bounds.
1942+
- `step`: The group of SSA values for the nest's steps.
1943+
- `reduceOperands`: The reduction SSA values, if any.
1944+
- `reduceAttrs`: Attributes to store reduction operations, if any.
1945+
- `loopAnnotation`: Loop metadata to be passed down the compiler pipeline to
1946+
LLVM.
1947+
}];
1948+
1949+
let arguments = (ins
1950+
Variadic<Index>:$lowerBound,
1951+
Variadic<Index>:$upperBound,
1952+
Variadic<Index>:$step,
1953+
Variadic<AnyType>:$reduceOperands,
1954+
OptionalAttr<ArrayAttr>:$reduceAttrs,
1955+
OptionalAttr<LoopAnnotationAttr>:$loopAnnotation
1956+
);
1957+
1958+
let regions = (region SizedRegion<1>:$region);
1959+
1960+
let hasCustomAssemblyFormat = 1;
1961+
let hasVerifier = 1;
1962+
1963+
let extraClassDeclaration = [{
1964+
/// Get Number of variadic operands
1965+
unsigned getNumOperands(unsigned segmentIdx) {
1966+
auto segments = (*this)->getAttrOfType<mlir::DenseI32ArrayAttr>(
1967+
getOperandSegmentSizeAttr());
1968+
return static_cast<unsigned>(segments[segmentIdx]);
1969+
}
1970+
1971+
// Get Number of reduction operands
1972+
unsigned getNumReduceOperands() {
1973+
return getNumOperands(3);
1974+
}
1975+
1976+
/// Does the operation hold operands for reduction variables
1977+
bool hasReduceOperands() {
1978+
return getNumReduceOperands() > 0;
1979+
}
1980+
}];
1981+
}
18661982

18671983
#endif // FORTRAN_DIALECT_HLFIR_OPS

flang/lib/Optimizer/HLFIR/IR/HLFIROps.cpp

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
#include "flang/Optimizer/HLFIR/HLFIROps.h"
1414

15+
#include "flang/Optimizer/Dialect/FIRAttr.h"
1516
#include "flang/Optimizer/Dialect/FIROpsSupport.h"
1617
#include "flang/Optimizer/Dialect/FIRType.h"
1718
#include "flang/Optimizer/Dialect/Support/FIRContext.h"
@@ -2246,6 +2247,168 @@ llvm::LogicalResult hlfir::EvaluateInMemoryOp::verify() {
22462247
return mlir::success();
22472248
}
22482249

2250+
//===----------------------------------------------------------------------===//
2251+
// DoConcurrentOp
2252+
//===----------------------------------------------------------------------===//
2253+
2254+
llvm::LogicalResult hlfir::DoConcurrentOp::verify() {
2255+
mlir::Block *body = getBody();
2256+
2257+
if (body->empty())
2258+
return emitOpError("body cannot be empty");
2259+
2260+
if (!body->mightHaveTerminator() ||
2261+
!mlir::isa<hlfir::DoConcurrentLoopOp>(body->getTerminator()))
2262+
return emitOpError("must be terminated by 'hlfir.do_concurrent.loop'");
2263+
2264+
return mlir::success();
2265+
}
2266+
2267+
//===----------------------------------------------------------------------===//
2268+
// DoConcurrentLoopOp
2269+
//===----------------------------------------------------------------------===//
2270+
2271+
mlir::ParseResult
2272+
hlfir::DoConcurrentLoopOp::parse(mlir::OpAsmParser &parser,
2273+
mlir::OperationState &result) {
2274+
auto &builder = parser.getBuilder();
2275+
// Parse an opening `(` followed by induction variables followed by `)`
2276+
llvm::SmallVector<mlir::OpAsmParser::Argument, 4> ivs;
2277+
if (parser.parseArgumentList(ivs, mlir::OpAsmParser::Delimiter::Paren))
2278+
return mlir::failure();
2279+
2280+
// Parse loop bounds.
2281+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> lower;
2282+
if (parser.parseEqual() ||
2283+
parser.parseOperandList(lower, ivs.size(),
2284+
mlir::OpAsmParser::Delimiter::Paren) ||
2285+
parser.resolveOperands(lower, builder.getIndexType(), result.operands))
2286+
return mlir::failure();
2287+
2288+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> upper;
2289+
if (parser.parseKeyword("to") ||
2290+
parser.parseOperandList(upper, ivs.size(),
2291+
mlir::OpAsmParser::Delimiter::Paren) ||
2292+
parser.resolveOperands(upper, builder.getIndexType(), result.operands))
2293+
return mlir::failure();
2294+
2295+
// Parse step values.
2296+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand, 4> steps;
2297+
if (parser.parseKeyword("step") ||
2298+
parser.parseOperandList(steps, ivs.size(),
2299+
mlir::OpAsmParser::Delimiter::Paren) ||
2300+
parser.resolveOperands(steps, builder.getIndexType(), result.operands))
2301+
return mlir::failure();
2302+
2303+
llvm::SmallVector<mlir::OpAsmParser::UnresolvedOperand> reduceOperands;
2304+
llvm::SmallVector<mlir::Type> reduceArgTypes;
2305+
if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2306+
// Parse reduction attributes and variables.
2307+
llvm::SmallVector<fir::ReduceAttr> attributes;
2308+
if (failed(parser.parseCommaSeparatedList(
2309+
mlir::AsmParser::Delimiter::Paren, [&]() {
2310+
if (parser.parseAttribute(attributes.emplace_back()) ||
2311+
parser.parseArrow() ||
2312+
parser.parseOperand(reduceOperands.emplace_back()) ||
2313+
parser.parseColonType(reduceArgTypes.emplace_back()))
2314+
return mlir::failure();
2315+
return mlir::success();
2316+
})))
2317+
return mlir::failure();
2318+
// Resolve input operands.
2319+
for (auto operand_type : llvm::zip(reduceOperands, reduceArgTypes))
2320+
if (parser.resolveOperand(std::get<0>(operand_type),
2321+
std::get<1>(operand_type), result.operands))
2322+
return mlir::failure();
2323+
llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2324+
attributes.end());
2325+
result.addAttribute(getReduceAttrsAttrName(result.name),
2326+
builder.getArrayAttr(arrayAttr));
2327+
}
2328+
2329+
// Now parse the body.
2330+
mlir::Region *body = result.addRegion();
2331+
for (auto &iv : ivs)
2332+
iv.type = builder.getIndexType();
2333+
if (parser.parseRegion(*body, ivs))
2334+
return mlir::failure();
2335+
2336+
// Set `operandSegmentSizes` attribute.
2337+
result.addAttribute(DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
2338+
builder.getDenseI32ArrayAttr(
2339+
{static_cast<int32_t>(lower.size()),
2340+
static_cast<int32_t>(upper.size()),
2341+
static_cast<int32_t>(steps.size()),
2342+
static_cast<int32_t>(reduceOperands.size())}));
2343+
2344+
// Parse attributes.
2345+
if (parser.parseOptionalAttrDict(result.attributes))
2346+
return mlir::failure();
2347+
2348+
return mlir::success();
2349+
}
2350+
2351+
void hlfir::DoConcurrentLoopOp::print(mlir::OpAsmPrinter &p) {
2352+
p << " (" << getBody()->getArguments() << ") = (" << getLowerBound()
2353+
<< ") to (" << getUpperBound() << ") step (" << getStep() << ")";
2354+
2355+
if (hasReduceOperands()) {
2356+
p << " reduce(";
2357+
auto attrs = getReduceAttrsAttr();
2358+
auto operands = getReduceOperands();
2359+
llvm::interleaveComma(llvm::zip(attrs, operands), p, [&](auto it) {
2360+
p << std::get<0>(it) << " -> " << std::get<1>(it) << " : "
2361+
<< std::get<1>(it).getType();
2362+
});
2363+
p << ')';
2364+
}
2365+
2366+
p << ' ';
2367+
p.printRegion(getRegion(), /*printEntryBlockArgs=*/false);
2368+
p.printOptionalAttrDict(
2369+
(*this)->getAttrs(),
2370+
/*elidedAttrs=*/{DoConcurrentLoopOp::getOperandSegmentSizeAttr(),
2371+
DoConcurrentLoopOp::getReduceAttrsAttrName()});
2372+
}
2373+
2374+
llvm::SmallVector<mlir::Region *> hlfir::DoConcurrentLoopOp::getLoopRegions() {
2375+
return {&getRegion()};
2376+
}
2377+
2378+
llvm::LogicalResult hlfir::DoConcurrentLoopOp::verify() {
2379+
mlir::Operation::operand_range lbValues = getLowerBound();
2380+
mlir::Operation::operand_range ubValues = getUpperBound();
2381+
mlir::Operation::operand_range stepValues = getStep();
2382+
2383+
if (lbValues.empty())
2384+
return emitOpError(
2385+
"needs at least one tuple element for lowerBound, upperBound and step");
2386+
2387+
if (lbValues.size() != ubValues.size() ||
2388+
ubValues.size() != stepValues.size())
2389+
return emitOpError(
2390+
"different number of tuple elements for lowerBound, upperBound or step");
2391+
2392+
// Check that the body defines the same number of block arguments as the
2393+
// number of tuple elements in step.
2394+
mlir::Block *body = getBody();
2395+
if (body->getNumArguments() != stepValues.size())
2396+
return emitOpError() << "expects the same number of induction variables: "
2397+
<< body->getNumArguments()
2398+
<< " as bound and step values: " << stepValues.size();
2399+
for (auto arg : body->getArguments())
2400+
if (!arg.getType().isIndex())
2401+
return emitOpError(
2402+
"expects arguments for the induction variable to be of index type");
2403+
2404+
auto reduceAttrs = getReduceAttrsAttr();
2405+
if (getNumReduceOperands() != (reduceAttrs ? reduceAttrs.size() : 0))
2406+
return emitOpError(
2407+
"mismatch in number of reduction variables and reduction attributes");
2408+
2409+
return mlir::success();
2410+
}
2411+
22492412
#include "flang/Optimizer/HLFIR/HLFIROpInterfaces.cpp.inc"
22502413
#define GET_OP_CLASSES
22512414
#include "flang/Optimizer/HLFIR/HLFIREnums.cpp.inc"

flang/test/HLFIR/do_concurrent.fir

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// Test hlfir.do_concurrent operation parse, verify (no errors), and unparse
2+
3+
// RUN: fir-opt %s | fir-opt | FileCheck %s
4+
5+
func.func @dc_1d(%i_lb: index, %i_ub: index, %i_st: index) {
6+
hlfir.do_concurrent {
7+
%i = fir.alloca i32
8+
hlfir.do_concurrent.loop (%i_iv) = (%i_lb) to (%i_ub) step (%i_st) {
9+
%0 = fir.convert %i_iv : (index) -> i32
10+
fir.store %0 to %i : !fir.ref<i32>
11+
}
12+
}
13+
return
14+
}
15+
16+
// CHECK-LABEL: func.func @dc_1d
17+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index)
18+
// CHECK: hlfir.do_concurrent {
19+
// CHECK: %[[I:.*]] = fir.alloca i32
20+
// CHECK: hlfir.do_concurrent.loop (%[[I_IV:.*]]) = (%[[I_LB]]) to (%[[I_UB]]) step (%[[I_ST]]) {
21+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
22+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
23+
// CHECK: }
24+
// CHECK: }
25+
26+
func.func @dc_2d(%i_lb: index, %i_ub: index, %i_st: index,
27+
%j_lb: index, %j_ub: index, %j_st: index) {
28+
hlfir.do_concurrent {
29+
%i = fir.alloca i32
30+
%j = fir.alloca i32
31+
hlfir.do_concurrent.loop
32+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st) {
33+
%0 = fir.convert %i_iv : (index) -> i32
34+
fir.store %0 to %i : !fir.ref<i32>
35+
36+
%1 = fir.convert %j_iv : (index) -> i32
37+
fir.store %1 to %j : !fir.ref<i32>
38+
}
39+
}
40+
return
41+
}
42+
43+
// CHECK-LABEL: func.func @dc_2d
44+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)
45+
// CHECK: hlfir.do_concurrent {
46+
// CHECK: %[[I:.*]] = fir.alloca i32
47+
// CHECK: %[[J:.*]] = fir.alloca i32
48+
// CHECK: hlfir.do_concurrent.loop
49+
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) {
50+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
51+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
52+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
53+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
54+
// CHECK: }
55+
// CHECK: }
56+
57+
func.func @dc_2d_reduction(%i_lb: index, %i_ub: index, %i_st: index,
58+
%j_lb: index, %j_ub: index, %j_st: index) {
59+
%sum = fir.alloca i32
60+
61+
hlfir.do_concurrent {
62+
%i = fir.alloca i32
63+
%j = fir.alloca i32
64+
hlfir.do_concurrent.loop
65+
(%i_iv, %j_iv) = (%i_lb, %j_lb) to (%i_ub, %j_ub) step (%i_st, %j_st)
66+
reduce(#fir.reduce_attr<add> -> %sum : !fir.ref<i32>) {
67+
%0 = fir.convert %i_iv : (index) -> i32
68+
fir.store %0 to %i : !fir.ref<i32>
69+
70+
%1 = fir.convert %j_iv : (index) -> i32
71+
fir.store %1 to %j : !fir.ref<i32>
72+
}
73+
}
74+
return
75+
}
76+
77+
// CHECK-LABEL: func.func @dc_2d_reduction
78+
// CHECK-SAME: (%[[I_LB:.*]]: index, %[[I_UB:.*]]: index, %[[I_ST:.*]]: index, %[[J_LB:.*]]: index, %[[J_UB:.*]]: index, %[[J_ST:.*]]: index)
79+
80+
// CHECK: %[[SUM:.*]] = fir.alloca i32
81+
82+
// CHECK: hlfir.do_concurrent {
83+
// CHECK: %[[I:.*]] = fir.alloca i32
84+
// CHECK: %[[J:.*]] = fir.alloca i32
85+
// CHECK: hlfir.do_concurrent.loop
86+
// CHECK-SAME: (%[[I_IV:.*]], %[[J_IV:.*]]) = (%[[I_LB]], %[[J_LB]]) to (%[[I_UB]], %[[J_UB]]) step (%[[I_ST]], %[[J_ST]]) reduce(#fir.reduce_attr<add> -> %[[SUM]] : !fir.ref<i32>) {
87+
// CHECK: %[[I_IV_CVT:.*]] = fir.convert %[[I_IV]] : (index) -> i32
88+
// CHECK: fir.store %[[I_IV_CVT]] to %[[I]] : !fir.ref<i32>
89+
// CHECK: %[[J_IV_CVT:.*]] = fir.convert %[[J_IV]] : (index) -> i32
90+
// CHECK: fir.store %[[J_IV_CVT]] to %[[J]] : !fir.ref<i32>
91+
// CHECK: }
92+
// CHECK: }

0 commit comments

Comments
 (0)