Skip to content

Commit 34d8275

Browse files
authored
[mlir][tensor] add tensor insert/extract op folders (#142458)
Adds a few canonicalizers, folders, and rewrite patterns to tensor ops: * tensor.insert folder: insert into a constant is replaced with a new constant * tensor.extract folder: extract from a parent tensor that was inserted at the same indices is folded into the inserted value * rewrite pattern added that replaces an extract of a collapse shape with an extract of the source tensor (requires static source dimensions) Signed-off-by: Asra Ali <[email protected]>
1 parent b9dec5a commit 34d8275

File tree

6 files changed

+240
-3
lines changed

6 files changed

+240
-3
lines changed

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ void populateFoldConstantExtractSlicePatterns(
176176
return false;
177177
});
178178

179+
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
180+
/// source tensor.
181+
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);
182+
179183
} // namespace tensor
180184
} // namespace mlir
181185

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -827,6 +827,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
827827

828828
let hasFolder = 1;
829829
let hasVerifier = 1;
830+
let hasCanonicalizer = 1;
830831
}
831832

832833
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "mlir/IR/IRMapping.h"
2323
#include "mlir/IR/Matchers.h"
2424
#include "mlir/IR/OpDefinition.h"
25+
#include "mlir/IR/PatternMatch.h"
2526
#include "mlir/IR/TypeUtilities.h"
2627
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2728
#include "mlir/Interfaces/InferIntRangeInterface.h"
@@ -33,10 +34,12 @@
3334
#include "llvm/ADT/STLExtras.h"
3435
#include "llvm/ADT/SmallBitVector.h"
3536
#include "llvm/ADT/StringRef.h"
37+
#include "llvm/Support/Casting.h"
3638
#include "llvm/Support/LogicalResult.h"
3739
#include "llvm/Support/MathExtras.h"
3840
#include <algorithm>
3941
#include <optional>
42+
#include <vector>
4043

4144
using namespace mlir;
4245
using namespace mlir::tensor;
@@ -1288,6 +1291,68 @@ struct ExtractFromTensorCast : public OpRewritePattern<tensor::ExtractOp> {
12881291
}
12891292
};
12901293

1294+
/// Canonicalizes the pattern of the form
1295+
///
1296+
/// %val = tensor.collapse_shape %src[[0, 1]] : tensor<3x4xf64> into
1297+
/// tensor<12xf64>
1298+
/// %extracted_element = tensor.extract %val[%c10] :
1299+
/// tensor<12xf64>
1300+
///
1301+
/// to
1302+
///
1303+
/// %extracted_element = tensor.extract %src[%c2, %c2] : tensor<3x4xf64>
1304+
struct ExtractFromCollapseShape : public OpRewritePattern<tensor::ExtractOp> {
1305+
using OpRewritePattern<tensor::ExtractOp>::OpRewritePattern;
1306+
1307+
LogicalResult matchAndRewrite(tensor::ExtractOp extractOp,
1308+
PatternRewriter &rewriter) const final {
1309+
auto collapseOp =
1310+
extractOp.getTensor().getDefiningOp<tensor::CollapseShapeOp>();
1311+
if (!collapseOp)
1312+
return failure();
1313+
if (!collapseOp.getSrcType().hasStaticShape())
1314+
return failure();
1315+
1316+
auto sourceSizes = collapseOp.getSrcType().getShape();
1317+
1318+
SmallVector<Value> indices(extractOp.getIndices().begin(),
1319+
extractOp.getIndices().end());
1320+
SmallVector<Value> sourceIndices;
1321+
for (auto [index, group] :
1322+
llvm::zip(indices, collapseOp.getReassociationIndices())) {
1323+
assert(!group.empty() && "association indices groups cannot be empty");
1324+
auto groupSize = group.size();
1325+
1326+
if (groupSize == 1) {
1327+
sourceIndices.push_back(index);
1328+
continue;
1329+
}
1330+
1331+
SmallVector<int64_t> basis =
1332+
llvm::map_to_vector(group, [&](int64_t d) { return sourceSizes[d]; });
1333+
auto delinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
1334+
extractOp.getLoc(), index, basis, /*hasOuterBound=*/true);
1335+
llvm::append_range(sourceIndices, delinearize.getResults());
1336+
}
1337+
if (collapseOp.getReassociationIndices().empty()) {
1338+
auto zeroAffineMap = rewriter.getConstantAffineMap(0);
1339+
int64_t srcRank =
1340+
cast<RankedTensorType>(collapseOp.getSrcType()).getRank();
1341+
OpFoldResult ofr = affine::makeComposedFoldedAffineApply(
1342+
rewriter, extractOp.getLoc(), zeroAffineMap,
1343+
ArrayRef<OpFoldResult>{});
1344+
for (int64_t i = 0; i < srcRank; i++) {
1345+
sourceIndices.push_back(
1346+
getValueOrCreateConstantIndexOp(rewriter, extractOp.getLoc(), ofr));
1347+
}
1348+
}
1349+
1350+
rewriter.replaceOpWithNewOp<tensor::ExtractOp>(
1351+
extractOp, collapseOp.getSrc(), sourceIndices);
1352+
return success();
1353+
}
1354+
};
1355+
12911356
} // namespace
12921357

12931358
void ExtractOp::getAsmResultNames(
@@ -1303,6 +1368,23 @@ LogicalResult ExtractOp::verify() {
13031368
return success();
13041369
}
13051370

1371+
/// If we have an ExtractOp consuming an InsertOp with the same
1372+
/// indices, we can return the InsertOp's scalar directly.
1373+
// TODO: This only checks the immediate producer; extend to go up the
1374+
// insert/extract chain if the slices are disjoint.
1375+
static Value foldExtractAfterInsert(ExtractOp extractOp) {
1376+
auto insertOp = extractOp.getTensor().getDefiningOp<InsertOp>();
1377+
1378+
auto isSame = [](Value a, Value b) {
1379+
return getAsOpFoldResult(a) == getAsOpFoldResult(b);
1380+
};
1381+
if (insertOp && insertOp.getScalar().getType() == extractOp.getType() &&
1382+
llvm::equal(insertOp.getIndices(), extractOp.getIndices(), isSame))
1383+
return insertOp.getScalar();
1384+
1385+
return {};
1386+
}
1387+
13061388
OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
13071389
if (Attribute tensor = adaptor.getTensor()) {
13081390
// If this is a splat elements attribute, simply return the value.
@@ -1350,6 +1432,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
13501432
return elementsAttr.getValues<Attribute>()[indices];
13511433
}
13521434

1435+
if (Value result = foldExtractAfterInsert(*this))
1436+
return result;
1437+
13531438
return {};
13541439
}
13551440

@@ -1358,6 +1443,11 @@ void ExtractOp::getCanonicalizationPatterns(RewritePatternSet &results,
13581443
results.add<ExtractFromTensorCast>(context);
13591444
}
13601445

1446+
void mlir::tensor::populateFoldCollapseExtractPatterns(
1447+
RewritePatternSet &patterns) {
1448+
patterns.add<ExtractFromCollapseShape>(patterns.getContext());
1449+
}
1450+
13611451
//===----------------------------------------------------------------------===//
13621452
// FromElementsOp
13631453
//===----------------------------------------------------------------------===//
@@ -1534,6 +1624,76 @@ OpFoldResult GatherOp::fold(FoldAdaptor adaptor) {
15341624
// InsertOp
15351625
//===----------------------------------------------------------------------===//
15361626

1627+
namespace {
1628+
1629+
/// Pattern to fold an insert op of a constant destination and scalar to a new
1630+
/// constant.
1631+
///
1632+
/// Example:
1633+
/// ```
1634+
/// %0 = arith.constant dense<[1.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1635+
/// %c0 = arith.constant 0 : index
1636+
/// %c4_f32 = arith.constant 4.0 : f32
1637+
/// %1 = tensor.insert %c4_f32 into %0[%c0] : tensor<4xf32>
1638+
/// ```
1639+
/// is rewritten into:
1640+
/// ```
1641+
/// %1 = arith.constant dense<[4.0, 2.0, 3.0, 4.0]> : tensor<4xf32>
1642+
/// ```
1643+
class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
1644+
public:
1645+
using OpRewritePattern<InsertOp>::OpRewritePattern;
1646+
1647+
LogicalResult matchAndRewrite(InsertOp insertOp,
1648+
PatternRewriter &rewriter) const override {
1649+
// Requires a ranked tensor type.
1650+
auto destType =
1651+
llvm::dyn_cast<RankedTensorType>(insertOp.getDest().getType());
1652+
if (!destType)
1653+
return failure();
1654+
1655+
// Pattern requires constant indices
1656+
SmallVector<uint64_t, 8> indices;
1657+
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
1658+
auto indiceAttr = dyn_cast<Attribute>(indice);
1659+
if (!indiceAttr)
1660+
return failure();
1661+
indices.push_back(llvm::cast<IntegerAttr>(indiceAttr).getInt());
1662+
}
1663+
1664+
// Requires a constant scalar to insert
1665+
OpFoldResult scalar = getAsOpFoldResult(insertOp.getScalar());
1666+
Attribute scalarAttr = dyn_cast<Attribute>(scalar);
1667+
if (!scalarAttr)
1668+
return failure();
1669+
1670+
if (auto constantOp = dyn_cast_or_null<arith::ConstantOp>(
1671+
insertOp.getDest().getDefiningOp())) {
1672+
if (auto sourceAttr =
1673+
llvm::dyn_cast<ElementsAttr>(constantOp.getValue())) {
1674+
// Update the attribute at the inserted index.
1675+
auto sourceValues = sourceAttr.getValues<Attribute>();
1676+
auto flattenedIndex = sourceAttr.getFlattenedIndex(indices);
1677+
std::vector<Attribute> updatedValues;
1678+
updatedValues.reserve(sourceAttr.getNumElements());
1679+
for (auto i = 0; i < sourceAttr.getNumElements(); ++i) {
1680+
updatedValues.push_back(i == flattenedIndex ? scalarAttr
1681+
: sourceValues[i]);
1682+
}
1683+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
1684+
insertOp, sourceAttr.getType(),
1685+
DenseElementsAttr::get(cast<ShapedType>(sourceAttr.getType()),
1686+
updatedValues));
1687+
return success();
1688+
}
1689+
}
1690+
1691+
return failure();
1692+
}
1693+
};
1694+
1695+
} // namespace
1696+
15371697
void InsertOp::getAsmResultNames(
15381698
function_ref<void(Value, StringRef)> setNameFn) {
15391699
setNameFn(getResult(), "inserted");
@@ -1557,6 +1717,11 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
15571717
return {};
15581718
}
15591719

1720+
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1721+
MLIRContext *context) {
1722+
results.add<InsertOpConstantFold>(context);
1723+
}
1724+
15601725
//===----------------------------------------------------------------------===//
15611726
// GenerateOp
15621727
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ func.func @infer_concat_return_type(%arg0: tensor<5x12xi32>, %arg1: tensor<?x12x
163163
// -----
164164

165165
// CHECK-LABEL: func @fold_extract
166-
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
166+
func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>, i32) {
167167
%const_0 = arith.constant 0 : index
168168
%const_1 = arith.constant 1 : index
169169
%const_3 = arith.constant 3 : index
@@ -193,8 +193,15 @@ func.func @fold_extract(%arg0 : index) -> (f32, f16, f16, i32, complex<f32>) {
193193
%4 = arith.constant dense<(1.2, 2.3)> : tensor<complex<f32>>
194194
%ext_5 = tensor.extract %4[] : tensor<complex<f32>>
195195

196-
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]]
197-
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5 : f32, f16, f16, i32, complex<f32>
196+
// Fold an extract after an insert.
197+
// CHECK-DAG: [[C6:%.+]] = arith.constant 4 : i32
198+
%c4_i32 = arith.constant 4 : i32
199+
%5 = arith.constant dense<[[1, 3], [0, 2]]> : tensor<2x2xi32>
200+
%inserted = tensor.insert %c4_i32 into %5[%const_1, %const_0] : tensor<2x2xi32>
201+
%ext_6 = tensor.extract %inserted[%const_1, %const_0] : tensor<2x2xi32>
202+
203+
// CHECK-NEXT: return [[C4]], [[CM2]], [[C0]], [[C64]], [[C5]], [[C6]]
204+
return %ext_1, %ext_2, %ext_3, %ext_4, %ext_5, %ext_6 : f32, f16, f16, i32, complex<f32>, i32
198205
}
199206

200207
// -----
@@ -224,6 +231,22 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
224231
return %ins_1 : tensor<4xf32>
225232
}
226233

234+
235+
// -----
236+
237+
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
238+
// Fold an insert into a splat.
239+
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
240+
// CHECK-LITERAL:
241+
// CHECK-NEXT: return %[[C4]]
242+
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
243+
%c0 = arith.constant 0 : index
244+
%c1 = arith.constant 1 : index
245+
%c4_i32 = arith.constant 4 : i32
246+
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
247+
return %inserted : tensor<2x2xi32>
248+
}
249+
227250
// -----
228251

229252
// CHECK-LABEL: func @extract_from_tensor.cast
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-extract-from-collapse-shape %s | FileCheck %s
2+
3+
// CHECK-LABEL: @extract_from_collapse_shape
4+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<1x1x8xi8>)
5+
func.func @extract_from_collapse_shape(%arg0: tensor<1x1x8xi8>) -> (i8, i8) {
6+
%c1 = arith.constant 1 : index
7+
%c0 = arith.constant 0 : index
8+
%collapsed = tensor.collapse_shape %arg0 [[0, 1, 2]] : tensor<1x1x8xi8> into tensor<8xi8>
9+
%extracted = tensor.extract %collapsed[%c0] : tensor<8xi8>
10+
%extracted_0 = tensor.extract %collapsed[%c1] : tensor<8xi8>
11+
func.return %extracted, %extracted_0 : i8, i8
12+
}
13+
14+
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
15+
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
16+
// CHECK-DAG: %[[RESULT0:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]] : tensor<1x1x8xi8>
17+
// CHECK-DAG: %[[RESULT1:.*]] = tensor.extract %[[ARG0]][%[[C0]], %[[C0]], %[[C1]]] : tensor<1x1x8xi8>
18+
// CHECK-NEXT: return %[[RESULT0]], %[[RESULT1]] : i8, i8
19+
20+
// -----
21+
22+
// CHECK-LABEL: @extract_from_static_shape
23+
// CHECK-SAME: (%[[ARG0:.*]]: tensor<2x6x32xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index)
24+
func.func @extract_from_static_shape(%arg0 : tensor<2x6x32xf32>, %arg1 : index, %arg2 : index) -> f32 {
25+
%0 = tensor.collapse_shape %arg0 [[0, 1], [2]] : tensor<2x6x32xf32> into tensor<12x32xf32>
26+
%1 = tensor.extract %0[%arg1, %arg2] : tensor<12x32xf32>
27+
return %1 : f32
28+
}
29+
// CHECK-NEXT: %[[MODIFIED_INDEXES:.*]]:2 = affine.delinearize_index %[[ARG1]] into (2, 6)
30+
// CHECK-NEXT: %[[RESULT:.*]] = tensor.extract %[[ARG0]][%[[MODIFIED_INDEXES]]#0, %[[MODIFIED_INDEXES]]#1, %[[ARG2]]] : tensor<2x6x32xf32>
31+
// CHECK-NEXT: return %[[RESULT]] : f32

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ struct TestTensorTransforms
7777
llvm::cl::desc("Test folding of expand_shape/collapse_shape"),
7878
llvm::cl::init(false)};
7979

80+
Option<bool> testFoldExtractFromCollapseShape{
81+
*this, "test-fold-extract-from-collapse-shape",
82+
llvm::cl::desc("Test folding of extract from collapse_shape"),
83+
llvm::cl::init(false)};
84+
8085
Option<bool> useForeach{
8186
*this, "use-foreach",
8287
llvm::cl::desc(
@@ -132,6 +137,12 @@ applyDropRedundantInsertSliceRankExpansionPatterns(Operation *rootOp) {
132137
(void)applyPatternsGreedily(rootOp, std::move(patterns));
133138
}
134139

140+
static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
141+
RewritePatternSet patterns(rootOp->getContext());
142+
tensor::populateFoldCollapseExtractPatterns(patterns);
143+
(void)applyPatternsGreedily(rootOp, std::move(patterns));
144+
}
145+
135146
namespace {
136147
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
137148
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -380,6 +391,8 @@ void TestTensorTransforms::runOnOperation() {
380391
applyRewriteExtractFromCollapseShapePatterns(rootOp, useForeach)))
381392
return signalPassFailure();
382393
}
394+
if (testFoldExtractFromCollapseShape)
395+
applyFoldExtractFromCollapseShapePatterns(rootOp);
383396
if (testTrackingListener)
384397
if (failed(testTrackingListenerReplacements(rootOp)))
385398
return signalPassFailure();

0 commit comments

Comments
 (0)