Skip to content

Commit e456689

Browse files
authored
[mlir][flang][openacc] Support device_type on loop construct (#76892)
This is adding support for `device_type` clause representation in the OpenACC MLIR dialect on the acc.loop operation and adjust flang to lower correctly to the new representation. Each "value" that can be impacted by a `device_type` clause is now associated with an array attribute that carry this information. This includes: - `worker` clause information - `gang` clause information - `vector` clause information - `collapse` clause information - `tile` clause information The representation of the `gang` clause information has been updated and all values are now carried in a single operand segment. This segment is then subdivided by `device_type`. Each value in a segment is also associated with a `GangArgType` so it can be differentiated (num/dim/static). This simplify the handling of gang values an limit the number of new attributes needed. When the clause can be associated with the operation without any value (`gang`, `vector`, `worker`). These are represented by a dedicated attributes with device_type information. Extra getter functions are provided to make it easier to retrieve a value based on a device_type.
1 parent 71ec301 commit e456689

File tree

10 files changed

+737
-306
lines changed

10 files changed

+737
-306
lines changed

flang/lib/Lower/OpenACC.cpp

Lines changed: 122 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1593,67 +1593,89 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
15931593
const Fortran::parser::AccClauseList &accClauseList,
15941594
bool needEarlyReturnHandling = false) {
15951595
fir::FirOpBuilder &builder = converter.getFirOpBuilder();
1596-
1597-
mlir::Value workerNum;
1598-
mlir::Value vectorNum;
1599-
mlir::Value gangNum;
1600-
mlir::Value gangDim;
1601-
mlir::Value gangStatic;
16021596
llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
1603-
reductionOperands, cacheOperands;
1597+
reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
1598+
gangOperands;
16041599
llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
1605-
bool hasGang = false, hasVector = false, hasWorker = false;
1600+
llvm::SmallVector<int32_t> tileOperandsSegments, gangOperandsSegments;
1601+
llvm::SmallVector<int64_t> collapseValues;
1602+
1603+
llvm::SmallVector<mlir::Attribute> gangArgTypes;
1604+
llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
1605+
autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
1606+
vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
1607+
collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
1608+
1609+
// device_type attribute is set to `none` until a device_type clause is
1610+
// encountered.
1611+
auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1612+
builder.getContext(), mlir::acc::DeviceType::None);
16061613

16071614
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
16081615
mlir::Location clauseLocation = converter.genLocation(clause.source);
16091616
if (const auto *gangClause =
16101617
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u)) {
16111618
if (gangClause->v) {
1619+
auto crtGangOperands = gangOperands.size();
16121620
const Fortran::parser::AccGangArgList &x = *gangClause->v;
16131621
for (const Fortran::parser::AccGangArg &gangArg : x.v) {
16141622
if (const auto *num =
16151623
std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u)) {
1616-
gangNum = fir::getBase(converter.genExprValue(
1617-
*Fortran::semantics::GetExpr(num->v), stmtCtx));
1624+
gangOperands.push_back(fir::getBase(converter.genExprValue(
1625+
*Fortran::semantics::GetExpr(num->v), stmtCtx)));
1626+
gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1627+
builder.getContext(), mlir::acc::GangArgType::Num));
16181628
} else if (const auto *staticArg =
16191629
std::get_if<Fortran::parser::AccGangArg::Static>(
16201630
&gangArg.u)) {
16211631
const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v;
16221632
if (sizeExpr.v) {
1623-
gangStatic = fir::getBase(converter.genExprValue(
1624-
*Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx));
1633+
gangOperands.push_back(fir::getBase(converter.genExprValue(
1634+
*Fortran::semantics::GetExpr(*sizeExpr.v), stmtCtx)));
16251635
} else {
16261636
// * was passed as value and will be represented as a special
16271637
// constant.
1628-
gangStatic = builder.createIntegerConstant(
1629-
clauseLocation, builder.getIndexType(), starCst);
1638+
gangOperands.push_back(builder.createIntegerConstant(
1639+
clauseLocation, builder.getIndexType(), starCst));
16301640
}
1641+
gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1642+
builder.getContext(), mlir::acc::GangArgType::Static));
16311643
} else if (const auto *dim =
16321644
std::get_if<Fortran::parser::AccGangArg::Dim>(
16331645
&gangArg.u)) {
1634-
gangDim = fir::getBase(converter.genExprValue(
1635-
*Fortran::semantics::GetExpr(dim->v), stmtCtx));
1646+
gangOperands.push_back(fir::getBase(converter.genExprValue(
1647+
*Fortran::semantics::GetExpr(dim->v), stmtCtx)));
1648+
gangArgTypes.push_back(mlir::acc::GangArgTypeAttr::get(
1649+
builder.getContext(), mlir::acc::GangArgType::Dim));
16361650
}
16371651
}
1652+
gangOperandsSegments.push_back(gangOperands.size() - crtGangOperands);
1653+
gangOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1654+
} else {
1655+
gangDeviceTypes.push_back(crtDeviceTypeAttr);
16381656
}
1639-
hasGang = true;
16401657
} else if (const auto *workerClause =
16411658
std::get_if<Fortran::parser::AccClause::Worker>(&clause.u)) {
16421659
if (workerClause->v) {
1643-
workerNum = fir::getBase(converter.genExprValue(
1644-
*Fortran::semantics::GetExpr(*workerClause->v), stmtCtx));
1660+
workerNumOperands.push_back(fir::getBase(converter.genExprValue(
1661+
*Fortran::semantics::GetExpr(*workerClause->v), stmtCtx)));
1662+
workerNumOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1663+
} else {
1664+
workerNumDeviceTypes.push_back(crtDeviceTypeAttr);
16451665
}
1646-
hasWorker = true;
16471666
} else if (const auto *vectorClause =
16481667
std::get_if<Fortran::parser::AccClause::Vector>(&clause.u)) {
16491668
if (vectorClause->v) {
1650-
vectorNum = fir::getBase(converter.genExprValue(
1651-
*Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx));
1669+
vectorOperands.push_back(fir::getBase(converter.genExprValue(
1670+
*Fortran::semantics::GetExpr(*vectorClause->v), stmtCtx)));
1671+
vectorOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1672+
} else {
1673+
vectorDeviceTypes.push_back(crtDeviceTypeAttr);
16521674
}
1653-
hasVector = true;
16541675
} else if (const auto *tileClause =
16551676
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u)) {
16561677
const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v;
1678+
auto crtTileOperands = tileOperands.size();
16571679
for (const auto &accTileExpr : accTileExprList.v) {
16581680
const auto &expr =
16591681
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
@@ -1669,6 +1691,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
16691691
tileOperands.push_back(tileStar);
16701692
}
16711693
}
1694+
tileOperandsDeviceTypes.push_back(crtDeviceTypeAttr);
1695+
tileOperandsSegments.push_back(tileOperands.size() - crtTileOperands);
16721696
} else if (const auto *privateClause =
16731697
std::get_if<Fortran::parser::AccClause::Private>(
16741698
&clause.u)) {
@@ -1680,17 +1704,46 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
16801704
&clause.u)) {
16811705
genReductions(reductionClause->v, converter, semanticsContext, stmtCtx,
16821706
reductionOperands, reductionRecipes);
1707+
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
1708+
seqDeviceTypes.push_back(crtDeviceTypeAttr);
1709+
} else if (std::get_if<Fortran::parser::AccClause::Independent>(
1710+
&clause.u)) {
1711+
independentDeviceTypes.push_back(crtDeviceTypeAttr);
1712+
} else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
1713+
autoDeviceTypes.push_back(crtDeviceTypeAttr);
1714+
} else if (const auto *deviceTypeClause =
1715+
std::get_if<Fortran::parser::AccClause::DeviceType>(
1716+
&clause.u)) {
1717+
const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1718+
deviceTypeClause->v;
1719+
assert(deviceTypeExprList.v.size() == 1 &&
1720+
"expect only one device_type expr");
1721+
crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get(
1722+
builder.getContext(), getDeviceType(deviceTypeExprList.v.front().v));
1723+
} else if (const auto *collapseClause =
1724+
std::get_if<Fortran::parser::AccClause::Collapse>(
1725+
&clause.u)) {
1726+
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
1727+
const auto &force = std::get<bool>(arg.t);
1728+
if (force)
1729+
TODO(clauseLocation, "OpenACC collapse force modifier");
1730+
const auto &intExpr =
1731+
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
1732+
const auto *expr = Fortran::semantics::GetExpr(intExpr);
1733+
const std::optional<int64_t> collapseValue =
1734+
Fortran::evaluate::ToInt64(*expr);
1735+
assert(collapseValue && "expect integer value for the collapse clause");
1736+
collapseValues.push_back(*collapseValue);
1737+
collapseDeviceTypes.push_back(crtDeviceTypeAttr);
16831738
}
16841739
}
16851740

16861741
// Prepare the operand segment size attribute and the operands value range.
16871742
llvm::SmallVector<mlir::Value> operands;
16881743
llvm::SmallVector<int32_t> operandSegments;
1689-
addOperand(operands, operandSegments, gangNum);
1690-
addOperand(operands, operandSegments, gangDim);
1691-
addOperand(operands, operandSegments, gangStatic);
1692-
addOperand(operands, operandSegments, workerNum);
1693-
addOperand(operands, operandSegments, vectorNum);
1744+
addOperands(operands, operandSegments, gangOperands);
1745+
addOperands(operands, operandSegments, workerNumOperands);
1746+
addOperands(operands, operandSegments, vectorOperands);
16941747
addOperands(operands, operandSegments, tileOperands);
16951748
addOperands(operands, operandSegments, cacheOperands);
16961749
addOperands(operands, operandSegments, privateOperands);
@@ -1708,12 +1761,42 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
17081761
builder, currentLocation, eval, operands, operandSegments,
17091762
/*outerCombined=*/false, retTy, yieldValue);
17101763

1711-
if (hasGang)
1712-
loopOp.setHasGangAttr(builder.getUnitAttr());
1713-
if (hasWorker)
1714-
loopOp.setHasWorkerAttr(builder.getUnitAttr());
1715-
if (hasVector)
1716-
loopOp.setHasVectorAttr(builder.getUnitAttr());
1764+
if (!gangDeviceTypes.empty())
1765+
loopOp.setGangAttr(builder.getArrayAttr(gangDeviceTypes));
1766+
if (!gangArgTypes.empty())
1767+
loopOp.setGangOperandsArgTypeAttr(builder.getArrayAttr(gangArgTypes));
1768+
if (!gangOperandsSegments.empty())
1769+
loopOp.setGangOperandsSegmentsAttr(
1770+
builder.getDenseI32ArrayAttr(gangOperandsSegments));
1771+
if (!gangOperandsDeviceTypes.empty())
1772+
loopOp.setGangOperandsDeviceTypeAttr(
1773+
builder.getArrayAttr(gangOperandsDeviceTypes));
1774+
1775+
if (!workerNumDeviceTypes.empty())
1776+
loopOp.setWorkerAttr(builder.getArrayAttr(workerNumDeviceTypes));
1777+
if (!workerNumOperandsDeviceTypes.empty())
1778+
loopOp.setWorkerNumOperandsDeviceTypeAttr(
1779+
builder.getArrayAttr(workerNumOperandsDeviceTypes));
1780+
1781+
if (!vectorDeviceTypes.empty())
1782+
loopOp.setVectorAttr(builder.getArrayAttr(vectorDeviceTypes));
1783+
if (!vectorOperandsDeviceTypes.empty())
1784+
loopOp.setVectorOperandsDeviceTypeAttr(
1785+
builder.getArrayAttr(vectorOperandsDeviceTypes));
1786+
1787+
if (!tileOperandsDeviceTypes.empty())
1788+
loopOp.setTileOperandsDeviceTypeAttr(
1789+
builder.getArrayAttr(tileOperandsDeviceTypes));
1790+
if (!tileOperandsSegments.empty())
1791+
loopOp.setTileOperandsSegmentsAttr(
1792+
builder.getDenseI32ArrayAttr(tileOperandsSegments));
1793+
1794+
if (!seqDeviceTypes.empty())
1795+
loopOp.setSeqAttr(builder.getArrayAttr(seqDeviceTypes));
1796+
if (!independentDeviceTypes.empty())
1797+
loopOp.setIndependentAttr(builder.getArrayAttr(independentDeviceTypes));
1798+
if (!autoDeviceTypes.empty())
1799+
loopOp.setAuto_Attr(builder.getArrayAttr(autoDeviceTypes));
17171800

17181801
if (!privatizations.empty())
17191802
loopOp.setPrivatizationsAttr(
@@ -1723,33 +1806,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
17231806
loopOp.setReductionRecipesAttr(
17241807
mlir::ArrayAttr::get(builder.getContext(), reductionRecipes));
17251808

1726-
// Lower clauses mapped to attributes
1727-
for (const Fortran::parser::AccClause &clause : accClauseList.v) {
1728-
mlir::Location clauseLocation = converter.genLocation(clause.source);
1729-
if (const auto *collapseClause =
1730-
std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u)) {
1731-
const Fortran::parser::AccCollapseArg &arg = collapseClause->v;
1732-
const auto &force = std::get<bool>(arg.t);
1733-
if (force)
1734-
TODO(clauseLocation, "OpenACC collapse force modifier");
1735-
const auto &intExpr =
1736-
std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t);
1737-
const auto *expr = Fortran::semantics::GetExpr(intExpr);
1738-
const std::optional<int64_t> collapseValue =
1739-
Fortran::evaluate::ToInt64(*expr);
1740-
if (collapseValue) {
1741-
loopOp.setCollapseAttr(builder.getI64IntegerAttr(*collapseValue));
1742-
}
1743-
} else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u)) {
1744-
loopOp.setSeqAttr(builder.getUnitAttr());
1745-
} else if (std::get_if<Fortran::parser::AccClause::Independent>(
1746-
&clause.u)) {
1747-
loopOp.setIndependentAttr(builder.getUnitAttr());
1748-
} else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u)) {
1749-
loopOp->setAttr(mlir::acc::LoopOp::getAutoAttrStrName(),
1750-
builder.getUnitAttr());
1751-
}
1752-
}
1809+
if (!collapseValues.empty())
1810+
loopOp.setCollapseAttr(builder.getI64ArrayAttr(collapseValues));
1811+
if (!collapseDeviceTypes.empty())
1812+
loopOp.setCollapseDeviceTypeAttr(builder.getArrayAttr(collapseDeviceTypes));
1813+
17531814
return loopOp;
17541815
}
17551816

0 commit comments

Comments
 (0)