@@ -1593,67 +1593,89 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1593
1593
const Fortran::parser::AccClauseList &accClauseList,
1594
1594
bool needEarlyReturnHandling = false ) {
1595
1595
fir::FirOpBuilder &builder = converter.getFirOpBuilder ();
1596
-
1597
- mlir::Value workerNum;
1598
- mlir::Value vectorNum;
1599
- mlir::Value gangNum;
1600
- mlir::Value gangDim;
1601
- mlir::Value gangStatic;
1602
1596
llvm::SmallVector<mlir::Value> tileOperands, privateOperands,
1603
- reductionOperands, cacheOperands;
1597
+ reductionOperands, cacheOperands, vectorOperands, workerNumOperands,
1598
+ gangOperands;
1604
1599
llvm::SmallVector<mlir::Attribute> privatizations, reductionRecipes;
1605
- bool hasGang = false , hasVector = false , hasWorker = false ;
1600
+ llvm::SmallVector<int32_t > tileOperandsSegments, gangOperandsSegments;
1601
+ llvm::SmallVector<int64_t > collapseValues;
1602
+
1603
+ llvm::SmallVector<mlir::Attribute> gangArgTypes;
1604
+ llvm::SmallVector<mlir::Attribute> seqDeviceTypes, independentDeviceTypes,
1605
+ autoDeviceTypes, vectorOperandsDeviceTypes, workerNumOperandsDeviceTypes,
1606
+ vectorDeviceTypes, workerNumDeviceTypes, tileOperandsDeviceTypes,
1607
+ collapseDeviceTypes, gangDeviceTypes, gangOperandsDeviceTypes;
1608
+
1609
+ // device_type attribute is set to `none` until a device_type clause is
1610
+ // encountered.
1611
+ auto crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1612
+ builder.getContext (), mlir::acc::DeviceType::None);
1606
1613
1607
1614
for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1608
1615
mlir::Location clauseLocation = converter.genLocation (clause.source );
1609
1616
if (const auto *gangClause =
1610
1617
std::get_if<Fortran::parser::AccClause::Gang>(&clause.u )) {
1611
1618
if (gangClause->v ) {
1619
+ auto crtGangOperands = gangOperands.size ();
1612
1620
const Fortran::parser::AccGangArgList &x = *gangClause->v ;
1613
1621
for (const Fortran::parser::AccGangArg &gangArg : x.v ) {
1614
1622
if (const auto *num =
1615
1623
std::get_if<Fortran::parser::AccGangArg::Num>(&gangArg.u )) {
1616
- gangNum = fir::getBase (converter.genExprValue (
1617
- *Fortran::semantics::GetExpr (num->v ), stmtCtx));
1624
+ gangOperands.push_back (fir::getBase (converter.genExprValue (
1625
+ *Fortran::semantics::GetExpr (num->v ), stmtCtx)));
1626
+ gangArgTypes.push_back (mlir::acc::GangArgTypeAttr::get (
1627
+ builder.getContext (), mlir::acc::GangArgType::Num));
1618
1628
} else if (const auto *staticArg =
1619
1629
std::get_if<Fortran::parser::AccGangArg::Static>(
1620
1630
&gangArg.u )) {
1621
1631
const Fortran::parser::AccSizeExpr &sizeExpr = staticArg->v ;
1622
1632
if (sizeExpr.v ) {
1623
- gangStatic = fir::getBase (converter.genExprValue (
1624
- *Fortran::semantics::GetExpr (*sizeExpr.v ), stmtCtx));
1633
+ gangOperands. push_back ( fir::getBase (converter.genExprValue (
1634
+ *Fortran::semantics::GetExpr (*sizeExpr.v ), stmtCtx))) ;
1625
1635
} else {
1626
1636
// * was passed as value and will be represented as a special
1627
1637
// constant.
1628
- gangStatic = builder.createIntegerConstant (
1629
- clauseLocation, builder.getIndexType (), starCst);
1638
+ gangOperands. push_back ( builder.createIntegerConstant (
1639
+ clauseLocation, builder.getIndexType (), starCst)) ;
1630
1640
}
1641
+ gangArgTypes.push_back (mlir::acc::GangArgTypeAttr::get (
1642
+ builder.getContext (), mlir::acc::GangArgType::Static));
1631
1643
} else if (const auto *dim =
1632
1644
std::get_if<Fortran::parser::AccGangArg::Dim>(
1633
1645
&gangArg.u )) {
1634
- gangDim = fir::getBase (converter.genExprValue (
1635
- *Fortran::semantics::GetExpr (dim->v ), stmtCtx));
1646
+ gangOperands.push_back (fir::getBase (converter.genExprValue (
1647
+ *Fortran::semantics::GetExpr (dim->v ), stmtCtx)));
1648
+ gangArgTypes.push_back (mlir::acc::GangArgTypeAttr::get (
1649
+ builder.getContext (), mlir::acc::GangArgType::Dim));
1636
1650
}
1637
1651
}
1652
+ gangOperandsSegments.push_back (gangOperands.size () - crtGangOperands);
1653
+ gangOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1654
+ } else {
1655
+ gangDeviceTypes.push_back (crtDeviceTypeAttr);
1638
1656
}
1639
- hasGang = true ;
1640
1657
} else if (const auto *workerClause =
1641
1658
std::get_if<Fortran::parser::AccClause::Worker>(&clause.u )) {
1642
1659
if (workerClause->v ) {
1643
- workerNum = fir::getBase (converter.genExprValue (
1644
- *Fortran::semantics::GetExpr (*workerClause->v ), stmtCtx));
1660
+ workerNumOperands.push_back (fir::getBase (converter.genExprValue (
1661
+ *Fortran::semantics::GetExpr (*workerClause->v ), stmtCtx)));
1662
+ workerNumOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1663
+ } else {
1664
+ workerNumDeviceTypes.push_back (crtDeviceTypeAttr);
1645
1665
}
1646
- hasWorker = true ;
1647
1666
} else if (const auto *vectorClause =
1648
1667
std::get_if<Fortran::parser::AccClause::Vector>(&clause.u )) {
1649
1668
if (vectorClause->v ) {
1650
- vectorNum = fir::getBase (converter.genExprValue (
1651
- *Fortran::semantics::GetExpr (*vectorClause->v ), stmtCtx));
1669
+ vectorOperands.push_back (fir::getBase (converter.genExprValue (
1670
+ *Fortran::semantics::GetExpr (*vectorClause->v ), stmtCtx)));
1671
+ vectorOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1672
+ } else {
1673
+ vectorDeviceTypes.push_back (crtDeviceTypeAttr);
1652
1674
}
1653
- hasVector = true ;
1654
1675
} else if (const auto *tileClause =
1655
1676
std::get_if<Fortran::parser::AccClause::Tile>(&clause.u )) {
1656
1677
const Fortran::parser::AccTileExprList &accTileExprList = tileClause->v ;
1678
+ auto crtTileOperands = tileOperands.size ();
1657
1679
for (const auto &accTileExpr : accTileExprList.v ) {
1658
1680
const auto &expr =
1659
1681
std::get<std::optional<Fortran::parser::ScalarIntConstantExpr>>(
@@ -1669,6 +1691,8 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1669
1691
tileOperands.push_back (tileStar);
1670
1692
}
1671
1693
}
1694
+ tileOperandsDeviceTypes.push_back (crtDeviceTypeAttr);
1695
+ tileOperandsSegments.push_back (tileOperands.size () - crtTileOperands);
1672
1696
} else if (const auto *privateClause =
1673
1697
std::get_if<Fortran::parser::AccClause::Private>(
1674
1698
&clause.u )) {
@@ -1680,17 +1704,46 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1680
1704
&clause.u )) {
1681
1705
genReductions (reductionClause->v , converter, semanticsContext, stmtCtx,
1682
1706
reductionOperands, reductionRecipes);
1707
+ } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u )) {
1708
+ seqDeviceTypes.push_back (crtDeviceTypeAttr);
1709
+ } else if (std::get_if<Fortran::parser::AccClause::Independent>(
1710
+ &clause.u )) {
1711
+ independentDeviceTypes.push_back (crtDeviceTypeAttr);
1712
+ } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u )) {
1713
+ autoDeviceTypes.push_back (crtDeviceTypeAttr);
1714
+ } else if (const auto *deviceTypeClause =
1715
+ std::get_if<Fortran::parser::AccClause::DeviceType>(
1716
+ &clause.u )) {
1717
+ const Fortran::parser::AccDeviceTypeExprList &deviceTypeExprList =
1718
+ deviceTypeClause->v ;
1719
+ assert (deviceTypeExprList.v .size () == 1 &&
1720
+ " expect only one device_type expr" );
1721
+ crtDeviceTypeAttr = mlir::acc::DeviceTypeAttr::get (
1722
+ builder.getContext (), getDeviceType (deviceTypeExprList.v .front ().v ));
1723
+ } else if (const auto *collapseClause =
1724
+ std::get_if<Fortran::parser::AccClause::Collapse>(
1725
+ &clause.u )) {
1726
+ const Fortran::parser::AccCollapseArg &arg = collapseClause->v ;
1727
+ const auto &force = std::get<bool >(arg.t );
1728
+ if (force)
1729
+ TODO (clauseLocation, " OpenACC collapse force modifier" );
1730
+ const auto &intExpr =
1731
+ std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t );
1732
+ const auto *expr = Fortran::semantics::GetExpr (intExpr);
1733
+ const std::optional<int64_t > collapseValue =
1734
+ Fortran::evaluate::ToInt64 (*expr);
1735
+ assert (collapseValue && " expect integer value for the collapse clause" );
1736
+ collapseValues.push_back (*collapseValue);
1737
+ collapseDeviceTypes.push_back (crtDeviceTypeAttr);
1683
1738
}
1684
1739
}
1685
1740
1686
1741
// Prepare the operand segment size attribute and the operands value range.
1687
1742
llvm::SmallVector<mlir::Value> operands;
1688
1743
llvm::SmallVector<int32_t > operandSegments;
1689
- addOperand (operands, operandSegments, gangNum);
1690
- addOperand (operands, operandSegments, gangDim);
1691
- addOperand (operands, operandSegments, gangStatic);
1692
- addOperand (operands, operandSegments, workerNum);
1693
- addOperand (operands, operandSegments, vectorNum);
1744
+ addOperands (operands, operandSegments, gangOperands);
1745
+ addOperands (operands, operandSegments, workerNumOperands);
1746
+ addOperands (operands, operandSegments, vectorOperands);
1694
1747
addOperands (operands, operandSegments, tileOperands);
1695
1748
addOperands (operands, operandSegments, cacheOperands);
1696
1749
addOperands (operands, operandSegments, privateOperands);
@@ -1708,12 +1761,42 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1708
1761
builder, currentLocation, eval, operands, operandSegments,
1709
1762
/* outerCombined=*/ false , retTy, yieldValue);
1710
1763
1711
- if (hasGang)
1712
- loopOp.setHasGangAttr (builder.getUnitAttr ());
1713
- if (hasWorker)
1714
- loopOp.setHasWorkerAttr (builder.getUnitAttr ());
1715
- if (hasVector)
1716
- loopOp.setHasVectorAttr (builder.getUnitAttr ());
1764
+ if (!gangDeviceTypes.empty ())
1765
+ loopOp.setGangAttr (builder.getArrayAttr (gangDeviceTypes));
1766
+ if (!gangArgTypes.empty ())
1767
+ loopOp.setGangOperandsArgTypeAttr (builder.getArrayAttr (gangArgTypes));
1768
+ if (!gangOperandsSegments.empty ())
1769
+ loopOp.setGangOperandsSegmentsAttr (
1770
+ builder.getDenseI32ArrayAttr (gangOperandsSegments));
1771
+ if (!gangOperandsDeviceTypes.empty ())
1772
+ loopOp.setGangOperandsDeviceTypeAttr (
1773
+ builder.getArrayAttr (gangOperandsDeviceTypes));
1774
+
1775
+ if (!workerNumDeviceTypes.empty ())
1776
+ loopOp.setWorkerAttr (builder.getArrayAttr (workerNumDeviceTypes));
1777
+ if (!workerNumOperandsDeviceTypes.empty ())
1778
+ loopOp.setWorkerNumOperandsDeviceTypeAttr (
1779
+ builder.getArrayAttr (workerNumOperandsDeviceTypes));
1780
+
1781
+ if (!vectorDeviceTypes.empty ())
1782
+ loopOp.setVectorAttr (builder.getArrayAttr (vectorDeviceTypes));
1783
+ if (!vectorOperandsDeviceTypes.empty ())
1784
+ loopOp.setVectorOperandsDeviceTypeAttr (
1785
+ builder.getArrayAttr (vectorOperandsDeviceTypes));
1786
+
1787
+ if (!tileOperandsDeviceTypes.empty ())
1788
+ loopOp.setTileOperandsDeviceTypeAttr (
1789
+ builder.getArrayAttr (tileOperandsDeviceTypes));
1790
+ if (!tileOperandsSegments.empty ())
1791
+ loopOp.setTileOperandsSegmentsAttr (
1792
+ builder.getDenseI32ArrayAttr (tileOperandsSegments));
1793
+
1794
+ if (!seqDeviceTypes.empty ())
1795
+ loopOp.setSeqAttr (builder.getArrayAttr (seqDeviceTypes));
1796
+ if (!independentDeviceTypes.empty ())
1797
+ loopOp.setIndependentAttr (builder.getArrayAttr (independentDeviceTypes));
1798
+ if (!autoDeviceTypes.empty ())
1799
+ loopOp.setAuto_Attr (builder.getArrayAttr (autoDeviceTypes));
1717
1800
1718
1801
if (!privatizations.empty ())
1719
1802
loopOp.setPrivatizationsAttr (
@@ -1723,33 +1806,11 @@ createLoopOp(Fortran::lower::AbstractConverter &converter,
1723
1806
loopOp.setReductionRecipesAttr (
1724
1807
mlir::ArrayAttr::get (builder.getContext (), reductionRecipes));
1725
1808
1726
- // Lower clauses mapped to attributes
1727
- for (const Fortran::parser::AccClause &clause : accClauseList.v ) {
1728
- mlir::Location clauseLocation = converter.genLocation (clause.source );
1729
- if (const auto *collapseClause =
1730
- std::get_if<Fortran::parser::AccClause::Collapse>(&clause.u )) {
1731
- const Fortran::parser::AccCollapseArg &arg = collapseClause->v ;
1732
- const auto &force = std::get<bool >(arg.t );
1733
- if (force)
1734
- TODO (clauseLocation, " OpenACC collapse force modifier" );
1735
- const auto &intExpr =
1736
- std::get<Fortran::parser::ScalarIntConstantExpr>(arg.t );
1737
- const auto *expr = Fortran::semantics::GetExpr (intExpr);
1738
- const std::optional<int64_t > collapseValue =
1739
- Fortran::evaluate::ToInt64 (*expr);
1740
- if (collapseValue) {
1741
- loopOp.setCollapseAttr (builder.getI64IntegerAttr (*collapseValue));
1742
- }
1743
- } else if (std::get_if<Fortran::parser::AccClause::Seq>(&clause.u )) {
1744
- loopOp.setSeqAttr (builder.getUnitAttr ());
1745
- } else if (std::get_if<Fortran::parser::AccClause::Independent>(
1746
- &clause.u )) {
1747
- loopOp.setIndependentAttr (builder.getUnitAttr ());
1748
- } else if (std::get_if<Fortran::parser::AccClause::Auto>(&clause.u )) {
1749
- loopOp->setAttr (mlir::acc::LoopOp::getAutoAttrStrName (),
1750
- builder.getUnitAttr ());
1751
- }
1752
- }
1809
+ if (!collapseValues.empty ())
1810
+ loopOp.setCollapseAttr (builder.getI64ArrayAttr (collapseValues));
1811
+ if (!collapseDeviceTypes.empty ())
1812
+ loopOp.setCollapseDeviceTypeAttr (builder.getArrayAttr (collapseDeviceTypes));
1813
+
1753
1814
return loopOp;
1754
1815
}
1755
1816
0 commit comments