Skip to content

Commit 11193c5

Browse files
committed
Add support for folding tosa.slice with tosa.slice
1 parent 61a20b8 commit 11193c5

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1463,6 +1463,30 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
14631463
}
14641464

14651465
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1466+
const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) {
1467+
auto precedingSliceOp = getInput1().getDefiningOp<SliceOp>();
1468+
if (!precedingSliceOp)
1469+
return failure();
1470+
const auto precedingSliceStart = precedingSliceOp.getStart();
1471+
const auto thisSliceStart = getStart();
1472+
SmallVector<int64_t> newSliceStart;
1473+
newSliceStart.reserve(precedingSliceStart.size());
1474+
for (auto [startPreceding, startThis] :
1475+
llvm::zip_equal(precedingSliceStart, thisSliceStart)) {
1476+
newSliceStart.push_back(startPreceding + startThis);
1477+
}
1478+
setOperand(precedingSliceOp->getOperand(0));
1479+
setStart(newSliceStart);
1480+
getOperation()->setLoc(
1481+
FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()}));
1482+
return success();
1483+
};
1484+
1485+
// First try folding the preceding slice, this also works if the shapes are
1486+
// dynamic
1487+
if (succeeded(tryFoldWithPrecedingSlice(adaptor)))
1488+
return getResult();
1489+
14661490
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
14671491
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
14681492

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -690,6 +690,81 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
690690

691691
// -----
692692

693+
// CHECK-LABEL: @slice_fuse
694+
func.func @slice_fuse(%arg0: tensor<3x4xf32>) -> tensor<1x2xf32> {
695+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> {
696+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 2>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x2xf32>
697+
// CHECK: return [[VAR_0_]] : tensor<1x2xf32>
698+
%0 = tosa.slice %arg0 { size = array<i64: 2, 3>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<2x3xf32>
699+
%1 = tosa.slice %0 { size = array<i64: 1, 2>, start = array<i64: 0, 0>}: (tensor<2x3xf32>) -> tensor<1x2xf32>
700+
return %1 : tensor<1x2xf32>
701+
}
702+
703+
// -----
704+
705+
// CHECK-LABEL: @slice_fuse_different_step
706+
func.func @slice_fuse_different_step(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> {
707+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
708+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
709+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
710+
%0 = tosa.slice %arg0 { size = array<i64: 1, 3>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<1x3xf32>
711+
%1 = tosa.slice %0 { size = array<i64: 1, 1>, start = array<i64: 0, 0>}: (tensor<1x3xf32>) -> tensor<1x1xf32>
712+
return %1 : tensor<1x1xf32>
713+
}
714+
715+
// -----
716+
717+
// CHECK-LABEL: @slice_fuse_different_start
718+
func.func @slice_fuse_different_start(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> {
719+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
720+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 1, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
721+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
722+
%0 = tosa.slice %arg0 { size = array<i64: 1, 3>, start = array<i64: 1, 0>}: (tensor<3x4xf32>) -> tensor<1x3xf32>
723+
%1 = tosa.slice %0 { size = array<i64: 1, 1>, start = array<i64: 0, 0>}: (tensor<1x3xf32>) -> tensor<1x1xf32>
724+
return %1 : tensor<1x1xf32>
725+
}
726+
727+
// -----
728+
729+
// CHECK-LABEL: @slice_fuse_different_start_2
730+
func.func @slice_fuse_different_start_2(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> {
731+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
732+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
733+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
734+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 0>}: (tensor<10x10xf32>) -> tensor<5x5xf32>
735+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<5x5xf32>) -> tensor<3x3xf32>
736+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<3x3xf32>) -> tensor<1x1xf32>
737+
return %2 : tensor<1x1xf32>
738+
}
739+
740+
// -----
741+
742+
// CHECK-LABEL: @slice_fuse_different_start_3
743+
func.func @slice_fuse_different_start_3(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> {
744+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
745+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 2>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
746+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
747+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 1>}: (tensor<10x10xf32>) -> tensor<5x5xf32>
748+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<5x5xf32>) -> tensor<3x3xf32>
749+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<3x3xf32>) -> tensor<1x1xf32>
750+
return %2 : tensor<1x1xf32>
751+
}
752+
753+
// -----
754+
755+
// CHECK-LABEL: func.func @slice_fuse_different_start_dynamic
756+
func.func @slice_fuse_different_start_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
757+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
758+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<*xf32>) -> tensor<*xf32>
759+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
760+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 0>}: (tensor<*xf32>) -> tensor<*xf32>
761+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<*xf32>) -> tensor<*xf32>
762+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<*xf32>) -> tensor<*xf32>
763+
return %2 : tensor<*xf32>
764+
}
765+
766+
// -----
767+
693768
// CHECK-LABEL: @tile_fold
694769
func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
695770
// CHECK: return %arg0

0 commit comments

Comments
 (0)