Skip to content

Commit 1f07bfb

Browse files
sabaumaSpenser Bauman
andauthored
[mlir][tensor] Implement folding logic for size 0 tensor and memref ops (#90814)
Implement folding and rewrite logic to eliminate no-op tensor and memref operations. This handles two specific cases: 1. tensor.insert_slice operations where the size of the inserted slice is known to be 0. 2. memref.copy operations where either the source or target memrefs are known to be emtpy. Co-authored-by: Spenser Bauman <sabauma@fastmail>
1 parent 250c39c commit 1f07bfb

File tree

4 files changed

+46
-1
lines changed

4 files changed

+46
-1
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -833,11 +833,31 @@ struct FoldSelfCopy : public OpRewritePattern<CopyOp> {
833833
return success();
834834
}
835835
};
836+
837+
struct FoldEmptyCopy final : public OpRewritePattern<CopyOp> {
838+
using OpRewritePattern<CopyOp>::OpRewritePattern;
839+
840+
static bool isEmptyMemRef(BaseMemRefType type) {
841+
return type.hasRank() &&
842+
llvm::any_of(type.getShape(), [](int64_t x) { return x == 0; });
843+
}
844+
845+
LogicalResult matchAndRewrite(CopyOp copyOp,
846+
PatternRewriter &rewriter) const override {
847+
if (isEmptyMemRef(copyOp.getSource().getType()) ||
848+
isEmptyMemRef(copyOp.getTarget().getType())) {
849+
rewriter.eraseOp(copyOp);
850+
return success();
851+
}
852+
853+
return failure();
854+
}
855+
};
836856
} // namespace
837857

838858
void CopyOp::getCanonicalizationPatterns(RewritePatternSet &results,
839859
MLIRContext *context) {
840-
results.add<FoldCopyOfCast, FoldSelfCopy>(context);
860+
results.add<FoldCopyOfCast, FoldEmptyCopy, FoldSelfCopy>(context);
841861
}
842862

843863
LogicalResult CopyOp::fold(FoldAdaptor adaptor,

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2609,6 +2609,9 @@ OpFoldResult InsertSliceOp::fold(FoldAdaptor) {
26092609
return getResult();
26102610
if (auto result = foldInsertAfterExtractSlice(*this))
26112611
return result;
2612+
if (llvm::any_of(getMixedSizes(),
2613+
[](OpFoldResult ofr) { return isConstantIntValue(ofr, 0); }))
2614+
return getDest();
26122615
return OpFoldResult();
26132616
}
26142617

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,16 @@ func.func @self_copy(%m1: memref<?xf32>) {
692692

693693
// -----
694694

695+
// CHECK-LABEL: func @empty_copy
696+
// CHECK-NEXT: return
697+
func.func @empty_copy(%m1: memref<0x10xf32>, %m2: memref<?x10xf32>) {
698+
memref.copy %m1, %m2 : memref<0x10xf32> to memref<?x10xf32>
699+
memref.copy %m2, %m1 : memref<?x10xf32> to memref<0x10xf32>
700+
return
701+
}
702+
703+
// -----
704+
695705
func.func @scopeMerge() {
696706
memref.alloca_scope {
697707
%cnt = "test.count"() : () -> index

mlir/test/Dialect/Tensor/canonicalize.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,6 +542,18 @@ func.func @trivial_insert_slice(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6
542542

543543
// -----
544544

545+
// CHECK-LABEL: func @empty_insert_slice
546+
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<0x2xi8>
547+
// CHECK-SAME: %[[ARG1:.[a-z0-9A-Z_]+]]: tensor<3x3xi8>
548+
// CHECK-NOT: tensor.extract_slice
549+
// CHECK: return %[[ARG1]] : tensor<3x3xi8>
550+
func.func @empty_insert_slice(%arg0 : tensor<0x2xi8>, %arg1 : tensor<3x3xi8>) -> tensor<3x3xi8> {
551+
%0 = tensor.insert_slice %arg0 into %arg1[0, 0] [0, 2] [1, 1] : tensor<0x2xi8> into tensor<3x3xi8>
552+
return %0 : tensor<3x3xi8>
553+
}
554+
555+
// -----
556+
545557
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
546558
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
547559
// CHECK: %[[S:.+]] = tensor.extract_slice %arg0[0, 1, 0, 0] [1, 1, 16, 32] [1, 1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>

0 commit comments

Comments
 (0)