Skip to content

Commit b67677f

Browse files
committed
[mlir][tensor] move tensor insert canonicalization to pattern
Signed-off-by: Asra Ali <[email protected]>
1 parent d204aa9 commit b67677f

File tree

6 files changed

+33
-21
lines changed

6 files changed

+33
-21
lines changed

mlir/include/mlir/Dialect/Tensor/IR/Tensor.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,9 @@ void populateFoldConstantExtractSlicePatterns(
176176
return false;
177177
});
178178

179+
/// Patterns to fold inserts into a constant into a new constant.
180+
void populateFoldInsertAfterConstant(RewritePatternSet &patterns);
181+
179182
/// Patterns to fold extracts of a collapse_shaped tensor to an extract of the
180183
/// source tensor.
181184
void populateFoldCollapseExtractPatterns(RewritePatternSet &patterns);

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,6 @@ def Tensor_InsertOp : Tensor_Op<"insert", [
827827

828828
let hasFolder = 1;
829829
let hasVerifier = 1;
830-
let hasCanonicalizer = 1;
831830
}
832831

833832
//===----------------------------------------------------------------------===//

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1653,7 +1653,7 @@ class InsertOpConstantFold final : public OpRewritePattern<InsertOp> {
16531653
return failure();
16541654

16551655
// Pattern requires constant indices
1656-
SmallVector<uint64_t, 8> indices;
1656+
SmallVector<uint64_t> indices;
16571657
for (OpFoldResult indice : getAsOpFoldResult(insertOp.getIndices())) {
16581658
auto indiceAttr = dyn_cast<Attribute>(indice);
16591659
if (!indiceAttr)
@@ -1717,9 +1717,8 @@ OpFoldResult InsertOp::fold(FoldAdaptor adaptor) {
17171717
return {};
17181718
}
17191719

1720-
void InsertOp::getCanonicalizationPatterns(RewritePatternSet &results,
1721-
MLIRContext *context) {
1722-
results.add<InsertOpConstantFold>(context);
1720+
void populateFoldInsertAfterConstant(RewritePatternSet &patterns) {
1721+
patterns.add<InsertOpConstantFold>(patterns.getContext());
17231722
}
17241723

17251724
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -231,22 +231,6 @@ func.func @fold_insert(%arg0 : index) -> (tensor<4xf32>) {
231231
return %ins_1 : tensor<4xf32>
232232
}
233233

234-
235-
// -----
236-
237-
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
238-
// Fold an insert into a splat.
239-
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
240-
// CHECK-LITERAL:
241-
// CHECK-NEXT: return %[[C4]]
242-
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
243-
%c0 = arith.constant 0 : index
244-
%c1 = arith.constant 1 : index
245-
%c4_i32 = arith.constant 4 : i32
246-
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
247-
return %inserted : tensor<2x2xi32>
248-
}
249-
250234
// -----
251235

252236
// CHECK-LABEL: func @extract_from_tensor.cast
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-fold-insert-after-constant %s | FileCheck %s
2+
3+
func.func @canonicalize_insert_after_constant() -> (tensor<2x2xi32>) {
4+
// Fold an insert into a splat.
5+
// CHECK: %[[C4:.+]] = arith.constant dense<{{\[\[}}1, 2], [4, 4]]> : tensor<2x2xi32>
6+
// CHECK-LITERAL:
7+
// CHECK-NEXT: return %[[C4]]
8+
%cst = arith.constant dense<[[1, 2], [3, 4]]> : tensor<2x2xi32>
9+
%c0 = arith.constant 0 : index
10+
%c1 = arith.constant 1 : index
11+
%c4_i32 = arith.constant 4 : i32
12+
%inserted = tensor.insert %c4_i32 into %cst[%c1, %c0] : tensor<2x2xi32>
13+
return %inserted : tensor<2x2xi32>
14+
}

mlir/test/lib/Dialect/Tensor/TestTensorTransforms.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,11 @@ struct TestTensorTransforms
8282
llvm::cl::desc("Test folding of extract from collapse_shape"),
8383
llvm::cl::init(false)};
8484

85+
Option<bool> testFoldInsertAfterConstant{
86+
*this, "test-fold-insert-after-constant",
87+
llvm::cl::desc("Test folding of insert of a constant"),
88+
llvm::cl::init(false)};
89+
8590
Option<bool> useForeach{
8691
*this, "use-foreach",
8792
llvm::cl::desc(
@@ -143,6 +148,12 @@ static void applyFoldExtractFromCollapseShapePatterns(Operation *rootOp) {
143148
(void)applyPatternsGreedily(rootOp, std::move(patterns));
144149
}
145150

151+
static void applyFoldInsertAfterConstantPattern(Operation *rootOp) {
152+
RewritePatternSet patterns(rootOp->getContext());
153+
tensor::populateFoldInsertAfterConstant(patterns);
154+
(void)applyPatternsGreedily(rootOp, std::move(patterns));
155+
}
156+
146157
namespace {
147158
/// Base pattern to rewrite a `tensor.collapse_shape -> tensor.extract_slice`.
148159
/// The `tensor.extract_slice` is replaced by a loop or gather operation that
@@ -393,6 +404,8 @@ void TestTensorTransforms::runOnOperation() {
393404
}
394405
if (testFoldExtractFromCollapseShape)
395406
applyFoldExtractFromCollapseShapePatterns(rootOp);
407+
if (testFoldInsertAfterConstant)
408+
applyFoldInsertAfterConstantPattern(rootOp);
396409
if (testTrackingListener)
397410
if (failed(testTrackingListenerReplacements(rootOp)))
398411
return signalPassFailure();

0 commit comments

Comments
 (0)