Skip to content

Commit b76112e

Browse files
clementvalschweitzpgi
authored andcommitted
[flang][openacc] Lower rest of clauses for the loop construct
1 parent 2ba09d4 commit b76112e

File tree

2 files changed

+371
-36
lines changed

2 files changed

+371
-36
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 153 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
//===----------------------------------------------------------------------===//
1212

1313
#include "flang/Lower/OpenACC.h"
14+
#include "SymbolMap.h"
15+
#include "flang/Common/idioms.h"
1416
#include "flang/Lower/Bridge.h"
1517
#include "flang/Lower/FIRBuilder.h"
1618
#include "flang/Lower/PFTBuilder.h"
@@ -21,7 +23,36 @@
2123

2224
#define TODO() llvm_unreachable("not yet implemented")
2325

24-
static void genACC(Fortran::lower::AbstractConverter &absConv,
26+
static const Fortran::parser::Name *
27+
GetDesignatorNameIfDataRef(const Fortran::parser::Designator &designator) {
28+
const auto *dataRef{std::get_if<Fortran::parser::DataRef>(&designator.u)};
29+
return dataRef ? std::get_if<Fortran::parser::Name>(&dataRef->u) : nullptr;
30+
}
31+
32+
static void genObjectList(const Fortran::parser::AccObjectList &objectList,
33+
Fortran::lower::AbstractConverter &converter,
34+
std::int32_t &objectsCount,
35+
SmallVector<Value, 8> &operands) {
36+
for (const auto &accObject : objectList.v) {
37+
std::visit(
38+
Fortran::common::visitors{
39+
[&](const Fortran::parser::Designator &designator) {
40+
if (const auto *name = GetDesignatorNameIfDataRef(designator)) {
41+
++objectsCount;
42+
const auto variable = converter.getSymbolAddress(*name->symbol);
43+
operands.push_back(variable);
44+
}
45+
},
46+
[&](const Fortran::parser::Name &name) {
47+
++objectsCount;
48+
const auto variable = converter.getSymbolAddress(*name.symbol);
49+
operands.push_back(variable);
50+
}},
51+
accObject.u);
52+
}
53+
}
54+
55+
static void genACC(Fortran::lower::AbstractConverter &converter,
2556
Fortran::lower::pft::Evaluation &eval,
2657
const Fortran::parser::OpenACCLoopConstruct &loopConstruct) {
2758

@@ -31,53 +62,142 @@ static void genACC(Fortran::lower::AbstractConverter &absConv,
3162
std::get<Fortran::parser::AccLoopDirective>(beginLoopDirective.t);
3263

3364
if (loopDirective.v == llvm::acc::ACCD_loop) {
34-
auto &firOpBuilder = absConv.getFirOpBuilder();
35-
auto currentLocation = absConv.getCurrentLocation();
65+
auto &firOpBuilder = converter.getFirOpBuilder();
66+
auto currentLocation = converter.getCurrentLocation();
3667
llvm::ArrayRef<mlir::Type> argTy;
37-
mlir::ValueRange range;
38-
// Temporarly set to default 0 as operands are not generated yet.
39-
llvm::SmallVector<int32_t, 2> operandSegmentSizes(/*Size=*/7,
40-
/*Value=*/0);
41-
auto loopOp =
42-
firOpBuilder.create<mlir::acc::LoopOp>(currentLocation, argTy, range);
43-
loopOp.setAttr(mlir::acc::LoopOp::getOperandSegmentSizeAttr(),
44-
firOpBuilder.getI32VectorAttr(operandSegmentSizes));
68+
69+
// Add attribute extracted from clauses.
70+
const auto &accClauseList =
71+
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
72+
73+
mlir::Value workerNum;
74+
mlir::Value vectorLength;
75+
mlir::Value gangNum;
76+
mlir::Value gangStatic;
77+
std::int32_t tileOperands = 0;
78+
std::int32_t privateOperands = 0;
79+
std::int32_t reductionOperands = 0;
80+
std::int64_t executionMapping = mlir::acc::OpenACCExecMapping::NONE;
81+
SmallVector<Value, 8> operands;
82+
83+
// Lower clauses values mapped to operands.
84+
for (const auto &clause : accClauseList.v) {
85+
if (const auto *gangClause =
86+
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
87+
if (gangClause->v) {
88+
const Fortran::parser::AccGangArgument &x = *gangClause->v;
89+
if (const auto &gangNumValue =
90+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
91+
x.t)) {
92+
gangNum = converter.genExprValue(
93+
*Fortran::semantics::GetExpr(gangNumValue.value()));
94+
operands.push_back(gangNum);
95+
}
96+
if (const auto &gangStaticValue =
97+
std::get<std::optional<Fortran::parser::AccSizeExpr>>(x.t)) {
98+
const auto &expr =
99+
std::get<std::optional<Fortran::parser::ScalarIntExpr>>(
100+
gangStaticValue.value().t);
101+
if (expr) {
102+
gangStatic = converter.genExprValue(
103+
*Fortran::semantics::GetExpr(*expr));
104+
} else {
105+
// * was passed as value and will be represented as a -1 constant
106+
// integer.
107+
gangStatic = firOpBuilder.createIntegerConstant(
108+
currentLocation, firOpBuilder.getIntegerType(32),
109+
/* STAR */ -1);
110+
}
111+
operands.push_back(gangStatic);
112+
}
113+
}
114+
executionMapping |= mlir::acc::OpenACCExecMapping::GANG;
115+
} else if (const auto *workerClause =
116+
std::get_if<Fortran::parser::AccClause::Worker>(
117+
&clause.u)) {
118+
if (workerClause->v) {
119+
workerNum = converter.genExprValue(
120+
*Fortran::semantics::GetExpr(*workerClause->v));
121+
operands.push_back(workerNum);
122+
}
123+
executionMapping |= mlir::acc::OpenACCExecMapping::WORKER;
124+
} else if (const auto *vectorClause =
125+
std::get_if<Fortran::parser::AccClause::Vector>(
126+
&clause.u)) {
127+
if (vectorClause->v) {
128+
vectorLength = converter.genExprValue(
129+
*Fortran::semantics::GetExpr(*vectorClause->v));
130+
operands.push_back(vectorLength);
131+
}
132+
executionMapping |= mlir::acc::OpenACCExecMapping::VECTOR;
133+
} else if (const auto *tileClause =
134+
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
135+
const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
136+
for (const auto &accTileExpr : accTileExprList.v) {
137+
const auto &expr =
138+
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
139+
accTileExpr.t);
140+
++tileOperands;
141+
if (expr) {
142+
operands.push_back(converter.genExprValue(
143+
*Fortran::semantics::GetExpr(*expr)));
144+
} else {
145+
// * was passed as value and will be represented as a -1 constant
146+
// integer.
147+
mlir::Value tileStar = firOpBuilder.createIntegerConstant(
148+
currentLocation, firOpBuilder.getIntegerType(32),
149+
/* STAR */ -1);
150+
operands.push_back(tileStar);
151+
}
152+
}
153+
} else if (const auto *privateClause =
154+
std::get_if<Fortran::parser::AccClause::Private>(
155+
&clause.u)) {
156+
const Fortran::parser::AccObjectList &accObjectList = privateClause->v;
157+
genObjectList(accObjectList, converter, privateOperands, operands);
158+
}
159+
// Reduction clause is left out for the moment as the clause will probably
160+
// end up having its own operation.
161+
}
162+
163+
auto loopOp = firOpBuilder.create<mlir::acc::LoopOp>(currentLocation, argTy,
164+
operands);
165+
45166
firOpBuilder.createBlock(&loopOp.getRegion());
46167
auto &block = loopOp.getRegion().back();
47168
firOpBuilder.setInsertionPointToStart(&block);
48169
// ensure the block is well-formed.
49170
firOpBuilder.create<mlir::acc::YieldOp>(currentLocation);
50171

51-
// Add attribute extracted from clauses.
52-
const auto &accClauseList =
53-
std::get<Fortran::parser::AccClauseList>(beginLoopDirective.t);
172+
loopOp.setAttr(mlir::acc::LoopOp::getOperandSegmentSizeAttr(),
173+
firOpBuilder.getI32VectorAttr(
174+
{gangNum ? 1 : 0, gangStatic ? 1 : 0, workerNum ? 1 : 0,
175+
vectorLength ? 1 : 0, tileOperands, privateOperands,
176+
reductionOperands}));
177+
178+
loopOp.setAttr(mlir::acc::LoopOp::getExecutionMappingAttrName(),
179+
firOpBuilder.getI64IntegerAttr(executionMapping));
54180

181+
// Lower clauses mapped to attributes
55182
for (const auto &clause : accClauseList.v) {
56183
if (const auto *collapseClause =
57184
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
58-
59185
const auto *expr = Fortran::semantics::GetExpr(collapseClause->v);
60186
const auto collapseValue = Fortran::evaluate::ToInt64(*expr);
61-
if (collapseValue.has_value()) {
187+
if (collapseValue) {
62188
loopOp.setAttr(mlir::acc::LoopOp::getCollapseAttrName(),
63-
firOpBuilder.getI64IntegerAttr(collapseValue.value()));
189+
firOpBuilder.getI64IntegerAttr(*collapseValue));
64190
}
65-
} else if (const auto *seqClause =
66-
std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
67-
(void)seqClause;
68-
} else if (const auto *gangClause =
69-
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
70-
(void)gangClause;
71-
} else if (const auto *vectorClause =
72-
std::get_if<Fortran::parser::AccClause::Vector>(
73-
&clause.u)) {
74-
(void)vectorClause;
75-
} else if (const auto *workerClause =
76-
std::get_if<Fortran::parser::AccClause::Worker>(
77-
&clause.u)) {
78-
(void)workerClause;
79-
} else {
80-
TODO();
191+
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
192+
loopOp.setAttr(mlir::acc::LoopOp::getSeqAttrName(),
193+
firOpBuilder.getUnitAttr());
194+
} else if (std::get_if<Fortran::parser::AccClause::Independent>(
195+
&clause.u)) {
196+
loopOp.setAttr(mlir::acc::LoopOp::getIndependentAttrName(),
197+
firOpBuilder.getUnitAttr());
198+
} else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
199+
loopOp.setAttr(mlir::acc::LoopOp::getAutoAttrName(),
200+
firOpBuilder.getUnitAttr());
81201
}
82202
}
83203

0 commit comments

Comments
 (0)