Skip to content

Commit 692c2e4

Browse files
authored
Merge pull request #527 from Xilinx/jrickert.slice_folding
Add support for folding tosa.slice with tosa.slice
2 parents 1134785 + 0f19730 commit 692c2e4

File tree

2 files changed

+99
-0
lines changed

2 files changed

+99
-0
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,6 +1390,30 @@ OpFoldResult ReverseOp::fold(FoldAdaptor adaptor) {
13901390
}
13911391

13921392
OpFoldResult SliceOp::fold(FoldAdaptor adaptor) {
1393+
const auto tryFoldWithPrecedingSlice = [this](FoldAdaptor adaptor) {
1394+
auto precedingSliceOp = getInput1().getDefiningOp<SliceOp>();
1395+
if (!precedingSliceOp)
1396+
return failure();
1397+
const auto precedingSliceStart = precedingSliceOp.getStart();
1398+
const auto thisSliceStart = getStart();
1399+
SmallVector<int64_t> newSliceStart;
1400+
newSliceStart.reserve(precedingSliceStart.size());
1401+
for (auto [startPreceding, startThis] :
1402+
llvm::zip_equal(precedingSliceStart, thisSliceStart)) {
1403+
newSliceStart.push_back(startPreceding + startThis);
1404+
}
1405+
setOperand(precedingSliceOp->getOperand(0));
1406+
setStart(newSliceStart);
1407+
getOperation()->setLoc(
1408+
FusedLoc::get(getContext(), {precedingSliceOp->getLoc(), getLoc()}));
1409+
return success();
1410+
};
1411+
1412+
// First try folding the preceding slice, this also works if the shapes are
1413+
// dynamic
1414+
if (succeeded(tryFoldWithPrecedingSlice(adaptor)))
1415+
return getResult();
1416+
13931417
auto inputTy = llvm::dyn_cast<RankedTensorType>(getInput1().getType());
13941418
auto outputTy = llvm::dyn_cast<RankedTensorType>(getType());
13951419

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -693,6 +693,81 @@ func.func @slice_nofold(%arg0: tensor<?x4xf32>) -> tensor<?x4xf32> {
693693

694694
// -----
695695

696+
// CHECK-LABEL: @slice_fuse
697+
func.func @slice_fuse(%arg0: tensor<3x4xf32>) -> tensor<1x2xf32> {
698+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x2xf32> {
699+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 2>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x2xf32>
700+
// CHECK: return [[VAR_0_]] : tensor<1x2xf32>
701+
%0 = tosa.slice %arg0 { size = array<i64: 2, 3>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<2x3xf32>
702+
%1 = tosa.slice %0 { size = array<i64: 1, 2>, start = array<i64: 0, 0>}: (tensor<2x3xf32>) -> tensor<1x2xf32>
703+
return %1 : tensor<1x2xf32>
704+
}
705+
706+
// -----
707+
708+
// CHECK-LABEL: @slice_fuse_different_step
709+
func.func @slice_fuse_different_step(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> {
710+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
711+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 0, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
712+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
713+
%0 = tosa.slice %arg0 { size = array<i64: 1, 3>, start = array<i64: 0, 0>}: (tensor<3x4xf32>) -> tensor<1x3xf32>
714+
%1 = tosa.slice %0 { size = array<i64: 1, 1>, start = array<i64: 0, 0>}: (tensor<1x3xf32>) -> tensor<1x1xf32>
715+
return %1 : tensor<1x1xf32>
716+
}
717+
718+
// -----
719+
720+
// CHECK-LABEL: @slice_fuse_different_start
721+
func.func @slice_fuse_different_start(%arg0: tensor<3x4xf32>) -> tensor<1x1xf32> {
722+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<3x4xf32>) -> tensor<1x1xf32> {
723+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 1, 0>} : (tensor<3x4xf32>) -> tensor<1x1xf32>
724+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
725+
%0 = tosa.slice %arg0 { size = array<i64: 1, 3>, start = array<i64: 1, 0>}: (tensor<3x4xf32>) -> tensor<1x3xf32>
726+
%1 = tosa.slice %0 { size = array<i64: 1, 1>, start = array<i64: 0, 0>}: (tensor<1x3xf32>) -> tensor<1x1xf32>
727+
return %1 : tensor<1x1xf32>
728+
}
729+
730+
// -----
731+
732+
// CHECK-LABEL: @slice_fuse_different_start_2
733+
func.func @slice_fuse_different_start_2(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> {
734+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
735+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
736+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
737+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 0>}: (tensor<10x10xf32>) -> tensor<5x5xf32>
738+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<5x5xf32>) -> tensor<3x3xf32>
739+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<3x3xf32>) -> tensor<1x1xf32>
740+
return %2 : tensor<1x1xf32>
741+
}
742+
743+
// -----
744+
745+
// CHECK-LABEL: @slice_fuse_different_start_3
746+
func.func @slice_fuse_different_start_3(%arg0: tensor<10x10xf32>) -> tensor<1x1xf32> {
747+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<10x10xf32>) -> tensor<1x1xf32> {
748+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 2>} : (tensor<10x10xf32>) -> tensor<1x1xf32>
749+
// CHECK: return [[VAR_0_]] : tensor<1x1xf32>
750+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 1>}: (tensor<10x10xf32>) -> tensor<5x5xf32>
751+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<5x5xf32>) -> tensor<3x3xf32>
752+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<3x3xf32>) -> tensor<1x1xf32>
753+
return %2 : tensor<1x1xf32>
754+
}
755+
756+
// -----
757+
758+
// CHECK-LABEL: func.func @slice_fuse_different_start_dynamic
759+
func.func @slice_fuse_different_start_dynamic(%arg0: tensor<*xf32>) -> tensor<*xf32> {
760+
// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<*xf32>) -> tensor<*xf32> {
761+
// CHECK: [[VAR_0_:%.+]] = tosa.slice [[PARAM_0_]] {size = array<i64: 1, 1>, start = array<i64: 4, 1>} : (tensor<*xf32>) -> tensor<*xf32>
762+
// CHECK: return [[VAR_0_]] : tensor<*xf32>
763+
%0 = tosa.slice %arg0 { size = array<i64: 5, 5>, start = array<i64: 4, 0>}: (tensor<*xf32>) -> tensor<*xf32>
764+
%1 = tosa.slice %0 { size = array<i64: 3, 3>, start = array<i64: 0, 0>}: (tensor<*xf32>) -> tensor<*xf32>
765+
%2 = tosa.slice %1 { size = array<i64: 1, 1>, start = array<i64: 0, 1>}: (tensor<*xf32>) -> tensor<*xf32>
766+
return %2 : tensor<*xf32>
767+
}
768+
769+
// -----
770+
696771
// CHECK-LABEL: @tile_fold
697772
func.func @tile_fold(%arg0: tensor<3x4xf32>) -> tensor<3x4xf32> {
698773
// CHECK: return %arg0

0 commit comments

Comments
 (0)