Skip to content

Commit 2a88feb

Browse files
authored
[mlir][tosa] Canonicalize slice over overlapped or inside a pad. (#138900)
Update the paddings and/or the slice parameters when a `tosa.slice` after a `tosa.pad` is accessing only an overlapping or not region of the padded tensor. Signed-off-by: Georgios Pinitas <[email protected]>
1 parent a3f58f3 commit 2a88feb

File tree

2 files changed

+209
-2
lines changed

2 files changed

+209
-2
lines changed

mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp

Lines changed: 137 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,141 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731731
}
732732
};
733733

734+
struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
735+
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
736+
737+
LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
738+
PatternRewriter &rewriter) const override {
739+
Value sliceInput = sliceOp.getInput1();
740+
741+
// Check if producer is a PadOp
742+
auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
743+
if (!padOp)
744+
return rewriter.notifyMatchFailure(sliceOp,
745+
"slice input must be a pad operation");
746+
747+
// Check PadOp has a single consumer
748+
if (!padOp->hasOneUse())
749+
return rewriter.notifyMatchFailure(sliceOp,
750+
"pad shall have a single consumer");
751+
752+
// Check input is statically ranked
753+
auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
754+
auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
755+
if (!inputTy || !padTy || !inputTy.hasRank())
756+
return rewriter.notifyMatchFailure(sliceOp,
757+
"slice input must be a ranked tensor");
758+
759+
// Validate and extract tosa::PadOp padding
760+
DenseIntElementsAttr paddingElems;
761+
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
762+
return rewriter.notifyMatchFailure(
763+
sliceOp,
764+
"`padding` input specified on the tosa::PadOp must be constant.");
765+
}
766+
llvm::SmallVector<int64_t> padPaddings =
767+
llvm::to_vector(paddingElems.getValues<int64_t>());
768+
769+
// Extract slice parameters
770+
DenseElementsAttr startElems;
771+
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
772+
return rewriter.notifyMatchFailure(
773+
sliceOp, "start of slice must be a static ranked shape");
774+
llvm::SmallVector<int64_t> sliceStarts =
775+
llvm::to_vector(startElems.getValues<int64_t>());
776+
777+
DenseElementsAttr sizeElems;
778+
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
779+
return rewriter.notifyMatchFailure(
780+
sliceOp, "size of slice must be a static ranked shape");
781+
llvm::SmallVector<int64_t> sliceSizes =
782+
llvm::to_vector(sizeElems.getValues<int64_t>());
783+
784+
// Check if dynamic dimensions are sliced
785+
const int64_t rank = inputTy.getRank();
786+
if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
787+
const bool isDimDynamic = inputTy.isDynamicDim(i);
788+
const bool isDimSliced =
789+
(sliceStarts[i] != 0) || (sliceSizes[i] != -1);
790+
791+
return isDimDynamic && isDimSliced;
792+
})) {
793+
return rewriter.notifyMatchFailure(
794+
sliceOp, "axis that are sliced shall be statically known.");
795+
}
796+
797+
// Update the parameters
798+
llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
799+
llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
800+
llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
801+
bool updated = false;
802+
803+
for (int64_t i = 0; i < rank; ++i) {
804+
const int64_t padLo = padPaddings[i * 2];
805+
const int64_t padHi = padPaddings[i * 2 + 1];
806+
const int64_t sliceStart = sliceStarts[i];
807+
const int64_t sliceSize = sliceSizes[i];
808+
const int64_t sliceEnd = sliceStart + sliceSize;
809+
810+
// If dimension is dynamic pass-through
811+
if (inputTy.isDynamicDim(i)) {
812+
newPadPaddings[i * 2] = padLo;
813+
newPadPaddings[i * 2 + 1] = padHi;
814+
newSliceStarts[i] = sliceStart;
815+
continue;
816+
}
817+
818+
// Handle static dimensions
819+
const int64_t dimSize = inputTy.getShape()[i];
820+
const int64_t dimTotal = padLo + dimSize + padHi;
821+
822+
// Check slice within bounds
823+
if (sliceStart < 0 || sliceEnd > dimTotal)
824+
return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
825+
826+
// Compute updated slice start parameter
827+
const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
828+
newSliceStarts[i] = newSliceStart;
829+
updated |= newSliceStart != sliceStart;
830+
831+
// Compute updated pad parameters
832+
const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
833+
const int64_t newPadHi =
834+
std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
835+
newPadPaddings[i * 2] = newPadLo;
836+
newPadPaddings[i * 2 + 1] = newPadHi;
837+
updated |= (newPadLo != padLo) || (newPadHi != padHi);
838+
839+
// Calculate new pad output shape
840+
newPadShape[i] =
841+
newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
842+
}
843+
844+
// Check that we actually need to proceed with the rewrite
845+
if (!updated)
846+
return rewriter.notifyMatchFailure(
847+
sliceOp, "terminate condition; nothing to rewrite");
848+
849+
// Create a PadOp with updated padding
850+
auto newPaddingsOp =
851+
getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
852+
auto newPadTy =
853+
RankedTensorType::get(newPadShape, inputTy.getElementType());
854+
auto newPadOp = rewriter.create<tosa::PadOp>(
855+
padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
856+
padOp.getPadConst());
857+
858+
// Update SliceOp and point to new PadOp
859+
auto newStartOp =
860+
getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
861+
rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
862+
newPadOp.getResult(), newStartOp,
863+
sliceOp.getSize());
864+
865+
return success();
866+
}
867+
};
868+
734869
// Update size operand of tosa.slice if size has dynamic dims but corresponding
735870
// output dim is static
736871
struct SliceDynamicSizeCanonicalization
@@ -779,8 +914,8 @@ struct SliceDynamicSizeCanonicalization
779914

780915
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
781916
MLIRContext *context) {
782-
results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
783-
context);
917+
results.add<ConcatSliceOptimization, PadSliceOptimization,
918+
SliceDynamicSizeCanonicalization>(context);
784919
}
785920

786921
//===----------------------------------------------------------------------===//

mlir/test/Dialect/Tosa/canonicalize.mlir

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,6 +985,78 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
985985

986986
// -----
987987

988+
// CHECK-LABEL: @canonicalize_pad_slice_overlap
989+
// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
990+
// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
991+
// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
992+
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>}
993+
// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
994+
// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
995+
func.func @canonicalize_pad_slice_overlap(%arg0: tensor<?x16x16x3xf32>) -> tensor<?x14x18x3xf32> {
996+
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
997+
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
998+
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<?x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x16x20x3xf32>
999+
%start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
1000+
%size = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
1001+
%sliced = tosa.slice %padded, %start, %size : (tensor<?x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x14x18x3xf32>
1002+
return %sliced : tensor<?x14x18x3xf32>
1003+
}
1004+
1005+
// -----
1006+
1007+
// CHECK-LABEL: @canonicalize_pad_slice_inside
1008+
// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
1009+
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
1010+
// CHECK-NOT: tosa.pad
1011+
// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
1012+
func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
1013+
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1014+
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
1015+
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
1016+
%start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
1017+
%size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
1018+
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
1019+
return %sliced : tensor<1x14x14x3xf32>
1020+
}
1021+
1022+
// -----
1023+
1024+
// CHECK-LABEL: func @canonicalize_pad_slice_exact
1025+
// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1026+
// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
1027+
// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
1028+
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>}
1029+
// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
1030+
// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
1031+
func.func @canonicalize_pad_slice_exact(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x20x2xf32> {
1032+
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1033+
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
1034+
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
1035+
%start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
1036+
%size = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
1037+
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x20x2xf32>
1038+
return %sliced : tensor<1x16x20x2xf32>
1039+
}
1040+
1041+
// -----
1042+
1043+
// CHECK-LABEL: func @canonicalize_pad_slice_dynamic_noupdate
1044+
// CHECK-DAG: tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
1045+
// CHECK-DAG: tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>}
1046+
// CHECK: tosa.pad
1047+
// CHECK: tosa.slice
1048+
func.func @canonicalize_pad_slice_dynamic_noupdate(%arg0: tensor<1x16x?x3xf32>) -> tensor<1x16x?x2xf32> {
1049+
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
1050+
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
1051+
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x?x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x?x3xf32>
1052+
%start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
1053+
%size = tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
1054+
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x?x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x?x2xf32>
1055+
return %sliced : tensor<1x16x?x2xf32>
1056+
}
1057+
1058+
// -----
1059+
9881060
// CHECK-LABEL: @fold_log_exp
9891061
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
9901062
// CHECK: return %arg{{.*}} : tensor<?x1xf32>

0 commit comments

Comments
 (0)