@@ -1659,7 +1659,7 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
1659
1659
if (!matchPattern (input, m_Constant (&inputValues)))
1660
1660
return failure ();
1661
1661
1662
- // Only fold op with multiple users if foldSplatOrSingleUseOnly == true .
1662
+ // Only fold op with multiple users if foldSplatOrSingleUseOnly is false .
1663
1663
if (!llvm::hasSingleElement (input.getDefiningOp ()->getUsers ()) &&
1664
1664
foldSplatOrSingleUseOnly)
1665
1665
return failure ();
@@ -1684,6 +1684,116 @@ struct TosaFoldConstantPad : public TosaFoldConstantBase<tosa::PadOp> {
1684
1684
}
1685
1685
};
1686
1686
1687
+ template <typename BaseType, typename RangeT>
1688
+ void tileArray (ShapedType inputType, RangeT inputValues, ShapedType outputType,
1689
+ SmallVector<BaseType> &outputValues) {
1690
+
1691
+ auto inputShape = inputType.getShape ();
1692
+ auto outputShape = outputType.getShape ();
1693
+
1694
+ SmallVector<int64_t > indexInTarget (outputType.getRank ());
1695
+
1696
+ for (size_t outIndex = 0 , e = outputValues.size (); outIndex < e; ++outIndex) {
1697
+ auto index = offsetToIndex (outputShape, outIndex);
1698
+ for (auto i = 0 ; i < outputType.getRank (); ++i) {
1699
+ indexInTarget[i] = index[i] % inputShape[i];
1700
+ }
1701
+ auto inputIndexOffset = indexToOffset (inputShape, indexInTarget);
1702
+ BaseType value = inputValues[inputIndexOffset];
1703
+ outputValues[outIndex] = value;
1704
+ }
1705
+ }
1706
+
1707
+ template <typename BaseType>
1708
+ DenseElementsAttr tileTypeRaw (DenseElementsAttr attr, ShapedType inputType,
1709
+ ShapedType outputType) {
1710
+ ArrayRef<BaseType> inputValues =
1711
+ cast<DenseIntOrFPElementsAttr>(attr).getNonSplatRawData <BaseType>();
1712
+
1713
+ SmallVector<BaseType> outputValues;
1714
+ outputValues.resize_for_overwrite (outputType.getNumElements ());
1715
+ tileArray<BaseType>(inputType, inputValues, /* out*/ outputType, outputValues);
1716
+
1717
+ ArrayRef rawOutputValues (reinterpret_cast <const char *>(outputValues.data ()),
1718
+ outputValues.size () * sizeof (BaseType));
1719
+ return DenseElementsAttr::getFromRawBuffer (outputType, rawOutputValues);
1720
+ }
1721
+
1722
+ template <typename BaseType>
1723
+ DenseElementsAttr tileType (DenseElementsAttr attr, ShapedType inputType,
1724
+ ShapedType outputType) {
1725
+
1726
+ auto inputValues = attr.getValues <BaseType>();
1727
+ SmallVector<BaseType> outputValues (outputType.getNumElements (),
1728
+ *std::begin (inputValues));
1729
+ tileArray<BaseType>(inputType, inputValues, outputType, /* out*/ outputValues);
1730
+ return DenseElementsAttr::get (outputType,
1731
+ llvm::ArrayRef<BaseType>(outputValues));
1732
+ }
1733
+
1734
+ DenseElementsAttr tile (DenseElementsAttr inputValues, ShapedType outputType) {
1735
+
1736
+ auto inputType = inputValues.getType ();
1737
+ auto baseType = inputType.getElementType ();
1738
+
1739
+ // Handle possible integer types
1740
+ if (auto intType = dyn_cast<IntegerType>(baseType)) {
1741
+ switch (intType.getWidth ()) {
1742
+ case 1 :
1743
+ // i1 has special alignment which is not handled by transposeTypeRaw.
1744
+ return tileType<bool >(inputValues, inputType, outputType);
1745
+ case 8 :
1746
+ return tileTypeRaw<uint8_t >(inputValues, inputType, outputType);
1747
+ case 16 :
1748
+ return tileTypeRaw<uint16_t >(inputValues, inputType, outputType);
1749
+ case 32 :
1750
+ return tileTypeRaw<uint32_t >(inputValues, inputType, outputType);
1751
+ case 64 :
1752
+ return tileTypeRaw<uint64_t >(inputValues, inputType, outputType);
1753
+ default :
1754
+ return tileType<APInt>(inputValues, inputType, outputType);
1755
+ }
1756
+ }
1757
+
1758
+ // Handle possible float types
1759
+ if (baseType.isF32 ()) {
1760
+ return tileTypeRaw<uint32_t >(inputValues, inputType, outputType);
1761
+ }
1762
+ if (baseType.isF64 ()) {
1763
+ return tileTypeRaw<uint64_t >(inputValues, inputType, outputType);
1764
+ }
1765
+ if (baseType.isBF16 ()) {
1766
+ return tileTypeRaw<uint16_t >(inputValues, inputType, outputType);
1767
+ }
1768
+ return tileType<APFloat>(inputValues, inputType, outputType);
1769
+ }
1770
+
1771
+ struct TosaFoldConstantTile : public TosaFoldConstantBase <tosa::TileOp> {
1772
+ using TosaFoldConstantBase::TosaFoldConstantBase;
1773
+
1774
+ LogicalResult matchAndRewrite (tosa::TileOp op,
1775
+ PatternRewriter &rewriter) const override {
1776
+ auto outputType = cast<ShapedType>(op.getType ());
1777
+ // TOSA doesn't support quantized types.
1778
+ if (!outputType.getElementType ().isIntOrIndexOrFloat ())
1779
+ return failure ();
1780
+
1781
+ auto input = op.getInput1 ();
1782
+ DenseElementsAttr inputValues;
1783
+ if (!matchPattern (input, m_Constant (&inputValues)))
1784
+ return failure ();
1785
+
1786
+ // Only fold op with multiple users if foldSplatOrSingleUseOnly is false.
1787
+ if (!llvm::hasSingleElement (input.getDefiningOp ()->getUsers ()) &&
1788
+ foldSplatOrSingleUseOnly)
1789
+ return failure ();
1790
+
1791
+ rewriter.replaceOpWithNewOp <tosa::ConstOp>(op, outputType,
1792
+ tile (inputValues, outputType));
1793
+ return success ();
1794
+ }
1795
+ };
1796
+
1687
1797
// / Getting the axes position of the element which is located
1688
1798
// / in the tensor at the counter index
1689
1799
@@ -1836,41 +1946,42 @@ struct ReduceConstantOptimization : public OpRewritePattern<OperationType> {
1836
1946
1837
1947
void mlir::tosa::populateTosaFoldConstantPatterns (
1838
1948
MLIRContext *ctx, RewritePatternSet &patterns,
1839
- bool foldSplatOrSingleUseOnly,
1840
- bool enableIntCastFolding) {
1841
-
1842
- patterns.add <TosaFoldConstantTranspose>(ctx, foldSplatOrSingleUseOnly);
1843
- patterns.add <TosaFoldConstantReciprocal>(ctx, foldSplatOrSingleUseOnly);
1844
- patterns.add <TosaFoldConstantReshape>(ctx, foldSplatOrSingleUseOnly);
1845
- patterns.add <TosaFoldConstantRSQRT>(ctx, foldSplatOrSingleUseOnly);
1846
- patterns.add <TosaFoldConstantLogicalNot>(ctx, foldSplatOrSingleUseOnly);
1847
- patterns.add <TosaFoldConstantPow>(ctx, foldSplatOrSingleUseOnly);
1848
- patterns.add <TosaFoldConstantMul>(ctx, foldSplatOrSingleUseOnly);
1849
- patterns.add <TosaFoldConstantClamp>(ctx, foldSplatOrSingleUseOnly);
1850
- if (enableIntCastFolding) {
1851
- patterns.add <TosaFoldConstantCast>(ctx, foldSplatOrSingleUseOnly);
1949
+ const TosaLayerwiseConstantFoldPassOptions &options) {
1950
+
1951
+ patterns.add <TosaFoldConstantTranspose>(ctx, options.foldSplatOrSingleUseOnly );
1952
+ patterns.add <TosaFoldConstantReciprocal>(ctx, options.foldSplatOrSingleUseOnly );
1953
+ patterns.add <TosaFoldConstantReshape>(ctx, options.foldSplatOrSingleUseOnly );
1954
+ patterns.add <TosaFoldConstantRSQRT>(ctx, options.foldSplatOrSingleUseOnly );
1955
+ patterns.add <TosaFoldConstantLogicalNot>(ctx, options.foldSplatOrSingleUseOnly );
1956
+ patterns.add <TosaFoldConstantPow>(ctx, options.foldSplatOrSingleUseOnly );
1957
+ patterns.add <TosaFoldConstantMul>(ctx, options.foldSplatOrSingleUseOnly );
1958
+ patterns.add <TosaFoldConstantClamp>(ctx, options.foldSplatOrSingleUseOnly );
1959
+ if (options.enableIntCastFolding ) {
1960
+ patterns.add <TosaFoldConstantCast>(ctx, options.foldSplatOrSingleUseOnly );
1852
1961
} else {
1853
- patterns.add <TosaFoldConstantFloatCasts>(ctx, foldSplatOrSingleUseOnly);
1854
- }
1855
- patterns.add <TosaFoldConstantAdd>(ctx, foldSplatOrSingleUseOnly);
1856
- patterns.add <TosaFoldConstantSub>(ctx, foldSplatOrSingleUseOnly);
1857
- patterns.add <TosaFoldConstantGreater>(ctx, foldSplatOrSingleUseOnly);
1858
- patterns.add <TosaFoldConstantBitwiseNot>(ctx, foldSplatOrSingleUseOnly);
1859
- patterns.add <TosaFoldConstantFloor>(ctx, foldSplatOrSingleUseOnly);
1860
- patterns.add <TosaFoldConstantCeil>(ctx, foldSplatOrSingleUseOnly);
1861
- patterns.add <TosaFoldConstantErf>(ctx, foldSplatOrSingleUseOnly);
1862
- patterns.add <TosaFoldConstantExp>(ctx, foldSplatOrSingleUseOnly);
1863
- patterns.add <TosaFoldConstantLog>(ctx, foldSplatOrSingleUseOnly);
1864
- patterns.add <TosaFoldConstantCos>(ctx, foldSplatOrSingleUseOnly);
1865
- patterns.add <TosaFoldConstantSin>(ctx, foldSplatOrSingleUseOnly);
1866
- patterns.add <TosaFoldConstantBitwiseAnd>(ctx, foldSplatOrSingleUseOnly);
1867
- patterns.add <TosaFoldConstantBitwiseOr>(ctx, foldSplatOrSingleUseOnly);
1868
- patterns.add <TosaFoldConstantGreaterEqual>(ctx, foldSplatOrSingleUseOnly);
1869
- patterns.add <TosaFoldConstantEqual>(ctx, foldSplatOrSingleUseOnly);
1870
- patterns.add <TosaFoldConstantMinimum>(ctx, foldSplatOrSingleUseOnly);
1871
- patterns.add <TosaFoldConstantMaximum>(ctx, foldSplatOrSingleUseOnly);
1872
- patterns.add <TosaFoldConstantPad>(ctx, foldSplatOrSingleUseOnly);
1873
- patterns.add <TosaFoldConstantMatMul>(ctx, foldSplatOrSingleUseOnly);
1962
+ patterns.add <TosaFoldConstantFloatCasts>(ctx, options.foldSplatOrSingleUseOnly );
1963
+ }
1964
+ patterns.add <TosaFoldConstantAdd>(ctx, options.foldSplatOrSingleUseOnly );
1965
+ patterns.add <TosaFoldConstantSub>(ctx, options.foldSplatOrSingleUseOnly );
1966
+ patterns.add <TosaFoldConstantGreater>(ctx, options.foldSplatOrSingleUseOnly );
1967
+ patterns.add <TosaFoldConstantBitwiseNot>(ctx, options.foldSplatOrSingleUseOnly );
1968
+ patterns.add <TosaFoldConstantFloor>(ctx, options.foldSplatOrSingleUseOnly );
1969
+ patterns.add <TosaFoldConstantCeil>(ctx, options.foldSplatOrSingleUseOnly );
1970
+ patterns.add <TosaFoldConstantErf>(ctx, options.foldSplatOrSingleUseOnly );
1971
+ patterns.add <TosaFoldConstantExp>(ctx, options.foldSplatOrSingleUseOnly );
1972
+ patterns.add <TosaFoldConstantLog>(ctx, options.foldSplatOrSingleUseOnly );
1973
+ patterns.add <TosaFoldConstantCos>(ctx, options.foldSplatOrSingleUseOnly );
1974
+ patterns.add <TosaFoldConstantSin>(ctx, options.foldSplatOrSingleUseOnly );
1975
+ patterns.add <TosaFoldConstantBitwiseAnd>(ctx, options.foldSplatOrSingleUseOnly );
1976
+ patterns.add <TosaFoldConstantBitwiseOr>(ctx, options.foldSplatOrSingleUseOnly );
1977
+ patterns.add <TosaFoldConstantGreaterEqual>(ctx, options.foldSplatOrSingleUseOnly );
1978
+ patterns.add <TosaFoldConstantEqual>(ctx, options.foldSplatOrSingleUseOnly );
1979
+ patterns.add <TosaFoldConstantMinimum>(ctx, options.foldSplatOrSingleUseOnly );
1980
+ patterns.add <TosaFoldConstantMaximum>(ctx, options.foldSplatOrSingleUseOnly );
1981
+ patterns.add <TosaFoldConstantPad>(ctx, options.foldSplatOrSingleUseOnly );
1982
+ patterns.add <TosaFoldConstantMatMul>(ctx, options.foldSplatOrSingleUseOnly );
1983
+ if (options.enableTileFolding )
1984
+ patterns.add <TosaFoldConstantTile>(ctx, options.foldSplatOrSingleUseOnly );
1874
1985
}
1875
1986
1876
1987
void mlir::tosa::populateTosaConstantReduction (MLIRContext *ctx,
0 commit comments