22
22
#include " mlir/IR/IRMapping.h"
23
23
#include " mlir/IR/Matchers.h"
24
24
#include " mlir/IR/OpDefinition.h"
25
+ #include " mlir/IR/PatternMatch.h"
25
26
#include " mlir/IR/TypeUtilities.h"
26
27
#include " mlir/Interfaces/DestinationStyleOpInterface.h"
27
28
#include " mlir/Interfaces/InferIntRangeInterface.h"
33
34
#include " llvm/ADT/STLExtras.h"
34
35
#include " llvm/ADT/SmallBitVector.h"
35
36
#include " llvm/ADT/StringRef.h"
37
+ #include " llvm/Support/Casting.h"
36
38
#include " llvm/Support/LogicalResult.h"
37
39
#include " llvm/Support/MathExtras.h"
38
40
#include < algorithm>
39
41
#include < optional>
42
+ #include < vector>
40
43
41
44
using namespace mlir ;
42
45
using namespace mlir ::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
1288
1291
}
1289
1292
};
1290
1293
1294
+ // / Canonicalizes the pattern of the form
1295
+ // /
1296
+ // / %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1297
+ // / tensor<12xf64>
1298
+ // / %extracted_element = tensor.extract %val[%c10] :
1299
+ // / tensor<12xf64>
1300
+ // /
1301
+ // / to
1302
+ // /
1303
+ // / %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1304
+ struct ExtractFromCollapseShape : public OpRewritePattern <tensor::ExtractOp> {
1305
+ using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1306
+
1307
+ LogicalResult matchAndRewrite (tensor::ExtractOp extractOp,
1308
+ PatternRewriter &rewriter) const final {
1309
+ auto collapseOp =
1310
+ extractOp.getTensor ().getDefiningOp <tensor::CollapseShapeOp>();
1311
+ if (!collapseOp)
1312
+ return failure ();
1313
+ if (!collapseOp.getSrcType ().hasStaticShape ())
1314
+ return failure ();
1315
+
1316
+ auto sourceSizes = collapseOp.getSrcType ().getShape ();
1317
+
1318
+ SmallVector<Value> indices (extractOp.getIndices ().begin (),
1319
+ extractOp.getIndices ().end ());
1320
+ SmallVector<Value> sourceIndices;
1321
+ for (auto [index, group] :
1322
+ llvm::zip (indices, collapseOp.getReassociationIndices ())) {
1323
+ assert (!group.empty () && " association indices groups cannot be empty" );
1324
+ auto groupSize = group.size ();
1325
+
1326
+ if (groupSize == 1 ) {
1327
+ sourceIndices.push_back (index);
1328
+ continue ;
1329
+ }
1330
+
1331
+ SmallVector<int64_t > basis =
1332
+ llvm::map_to_vector (group, [&](int64_t d) { return sourceSizes[d]; });
1333
+ auto delinearize = rewriter.create <affine::AffineDelinearizeIndexOp>(
1334
+ extractOp.getLoc (), index, basis, /* hasOuterBound=*/ true );
1335
+ llvm::append_range (sourceIndices, delinearize.getResults ());
1336
+ }
1337
+ if (collapseOp.getReassociationIndices ().empty ()) {
1338
+ auto zeroAffineMap = rewriter.getConstantAffineMap (0 );
1339
+ int64_t srcRank =
1340
+ cast<RankedTensorType>(collapseOp.getSrcType ()).getRank ();
1341
+ OpFoldResult ofr = affine::makeComposedFoldedAffineApply (
1342
+ rewriter, extractOp.getLoc (), zeroAffineMap,
1343
+ ArrayRef<OpFoldResult>{});
1344
+ for (int64_t i = 0 ; i < srcRank; i++) {
1345
+ sourceIndices.push_back (
1346
+ getValueOrCreateConstantIndexOp (rewriter, extractOp.getLoc (), ofr));
1347
+ }
1348
+ }
1349
+
1350
+ rewriter.replaceOpWithNewOp <tensor::ExtractOp>(
1351
+ extractOp, collapseOp.getSrc (), sourceIndices);
1352
+ return success ();
1353
+ }
1354
+ };
1355
+
1291
1356
} // namespace
1292
1357
1293
1358
void ExtractOp::getAsmResultNames (
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
1303
1368
return success ();
1304
1369
}
1305
1370
1371
+ // / If we have an ExtractOp consuming an InsertOp with the same
1372
+ // / indices, we can return the InsertOp's scalar directly.
1373
+ // TODO: This only checks the immediate producer; extend to go up the
1374
+ // insert/extract chain if the slices are disjoint.
1375
+ static Value foldExtractAfterInsert (ExtractOp extractOp) {
1376
+ auto insertOp = extractOp.getTensor ().getDefiningOp <InsertOp>();
1377
+
1378
+ auto isSame = [](Value a, Value b) {
1379
+ return getAsOpFoldResult (a) == getAsOpFoldResult (b);
1380
+ };
1381
+ if (insertOp && insertOp.getScalar ().getType () == extractOp.getType () &&
1382
+ llvm::equal (insertOp.getIndices (), extractOp.getIndices (), isSame))
1383
+ return insertOp.getScalar ();
1384
+
1385
+ return {};
1386
+ }
1387
+
1306
1388
OpFoldResult ExtractOp::fold (FoldAdaptor adaptor) {
1307
1389
if (Attribute tensor = adaptor.getTensor ()) {
1308
1390
// If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
1350
1432
return elementsAttr.getValues <Attribute>()[indices];
1351
1433
}
1352
1434
1435
+ if (Value result = foldExtractAfterInsert (*this ))
1436
+ return result;
1437
+
1353
1438
return {};
1354
1439
}
1355
1440
@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
1358
1443
results.add <ExtractFromTensorCast>(context);
1359
1444
}
1360
1445
1446
+ void mlir::tensor::populateFoldCollapseExtractPatterns (
1447
+ RewritePatternSet &patterns) {
1448
+ patterns.add <ExtractFromCollapseShape>(patterns.getContext ());
1449
+ }
1450
+
1361
1451
// ===----------------------------------------------------------------------===//
1362
1452
// FromElementsOp
1363
1453
// ===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
1534
1624
// InsertOp
1535
1625
// ===----------------------------------------------------------------------===//
1536
1626
1627
+ namespace {
1628
+
1629
+ // / Pattern to fold an insert op of a constant destination and scalar to a new
1630
+ // / constant.
1631
+ // /
1632
+ // / Example:
1633
+ // / ```
1634
+ // / %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635
+ // / %c0 = arith.constant 0 : index
1636
+ // / %c4_f32 = arith.constant 4.0 : f32
1637
+ // / %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638
+ // / ```
1639
+ // / is rewritten into:
1640
+ // / ```
1641
+ // / %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642
+ // / ```
1643
+ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
1644
+ public:
1645
+ using OpRewritePattern<InsertOp>::OpRewritePattern;
1646
+
1647
+ LogicalResult matchAndRewrite (InsertOp insertOp,
1648
+ PatternRewriter &rewriter) const override {
1649
+ // Requires a ranked tensor type.
1650
+ auto destType =
1651
+ llvm::dyn_cast<RankedTensorType>(insertOp.getDest ().getType ());
1652
+ if (!destType)
1653
+ return failure ();
1654
+
1655
+ // Pattern requires constant indices
1656
+ SmallVector<uint64_t , 8 > indices;
1657
+ for (OpFoldResult indice : getAsOpFoldResult (insertOp.getIndices ())) {
1658
+ auto indiceAttr = dyn_cast<Attribute>(indice);
1659
+ if (!indiceAttr)
1660
+ return failure ();
1661
+ indices.push_back (llvm::cast<IntegerAttr>(indiceAttr).getInt ());
1662
+ }
1663
+
1664
+ // Requires a constant scalar to insert
1665
+ OpFoldResult scalar = getAsOpFoldResult (insertOp.getScalar ());
1666
+ Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667
+ if (!scalarAttr)
1668
+ return failure ();
1669
+
1670
+ if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671
+ insertOp.getDest ().getDefiningOp ())) {
1672
+ if (auto sourceAttr =
1673
+ llvm::dyn_cast<ElementsAttr>(constantOp.getValue ())) {
1674
+ // Update the attribute at the inserted index.
1675
+ auto sourceValues = sourceAttr.getValues <Attribute>();
1676
+ auto flattenedIndex = sourceAttr.getFlattenedIndex (indices);
1677
+ std::vector<Attribute> updatedValues;
1678
+ updatedValues.reserve (sourceAttr.getNumElements ());
1679
+ for (auto i = 0 ; i < sourceAttr.getNumElements (); ++i) {
1680
+ updatedValues.push_back (i == flattenedIndex ? scalarAttr
1681
+ : sourceValues[i]);
1682
+ }
1683
+ rewriter.replaceOpWithNewOp <arith::ConstantOp>(
1684
+ insertOp, sourceAttr.getType (),
1685
+ DenseElementsAttr::get (cast<ShapedType>(sourceAttr.getType ()),
1686
+ updatedValues));
1687
+ return success ();
1688
+ }
1689
+ }
1690
+
1691
+ return failure ();
1692
+ }
1693
+ };
1694
+
1695
+ } // namespace
1696
+
1537
1697
void InsertOp::getAsmResultNames (
1538
1698
function_ref<void (Value, StringRef)> setNameFn) {
1539
1699
setNameFn (getResult (), " inserted" );
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
1557
1717
return {};
1558
1718
}
1559
1719
1720
+ void InsertOp::getCanonicalizationPatterns (RewritePatternSet &results,
1721
+ MLIRContext *context) {
1722
+ results.add <InsertOpConstantFold>(context);
1723
+ }
1724
+
1560
1725
// ===----------------------------------------------------------------------===//
1561
1726
// GenerateOp
1562
1727
// ===----------------------------------------------------------------------===//
0 commit comments