Skip to content

[mlir][tosa] Canonicalize slice over overlapped or inside a pad. #138900

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 137 additions & 2 deletions mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -731,6 +731,141 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};

struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;

LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
PatternRewriter &rewriter) const override {
Value sliceInput = sliceOp.getInput1();

// Check if producer is a PadOp
auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
if (!padOp)
return rewriter.notifyMatchFailure(sliceOp,
"slice input must be a pad operation");

// Check PadOp has a single consumer
if (!padOp->hasOneUse())
return rewriter.notifyMatchFailure(sliceOp,
"pad shall have a single consumer");

// Check input is statically ranked
auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
if (!inputTy || !padTy || !inputTy.hasRank())
return rewriter.notifyMatchFailure(sliceOp,
"slice input must be a ranked tensor");

// Validate and extract tosa::PadOp padding
DenseIntElementsAttr paddingElems;
if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
return rewriter.notifyMatchFailure(
sliceOp,
"`padding` input specified on the tosa::PadOp must be constant.");
}
llvm::SmallVector<int64_t> padPaddings =
llvm::to_vector(paddingElems.getValues<int64_t>());

// Extract slice parameters
DenseElementsAttr startElems;
if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
return rewriter.notifyMatchFailure(
sliceOp, "start of slice must be a static ranked shape");
llvm::SmallVector<int64_t> sliceStarts =
llvm::to_vector(startElems.getValues<int64_t>());

DenseElementsAttr sizeElems;
if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
return rewriter.notifyMatchFailure(
sliceOp, "size of slice must be a static ranked shape");
llvm::SmallVector<int64_t> sliceSizes =
llvm::to_vector(sizeElems.getValues<int64_t>());

// Check if dynamic dimensions are sliced
const int64_t rank = inputTy.getRank();
if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
const bool isDimDynamic = inputTy.isDynamicDim(i);
const bool isDimSliced =
(sliceStarts[i] != 0) || (sliceSizes[i] != -1);

return isDimDynamic && isDimSliced;
})) {
return rewriter.notifyMatchFailure(
sliceOp, "axis that are sliced shall be statically known.");
}

// Update the parameters
llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
bool updated = false;

for (int64_t i = 0; i < rank; ++i) {
const int64_t padLo = padPaddings[i * 2];
const int64_t padHi = padPaddings[i * 2 + 1];
const int64_t sliceStart = sliceStarts[i];
const int64_t sliceSize = sliceSizes[i];
const int64_t sliceEnd = sliceStart + sliceSize;

// If dimension is dynamic pass-through
if (inputTy.isDynamicDim(i)) {
newPadPaddings[i * 2] = padLo;
newPadPaddings[i * 2 + 1] = padHi;
newSliceStarts[i] = sliceStart;
continue;
}

// Handle static dimensions
const int64_t dimSize = inputTy.getShape()[i];
const int64_t dimTotal = padLo + dimSize + padHi;

// Check slice within bounds
if (sliceStart < 0 || sliceEnd > dimTotal)
return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");

// Compute updated slice start parameter
const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
newSliceStarts[i] = newSliceStart;
updated |= newSliceStart != sliceStart;

// Compute updated pad parameters
const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
const int64_t newPadHi =
std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
newPadPaddings[i * 2] = newPadLo;
newPadPaddings[i * 2 + 1] = newPadHi;
updated |= (newPadLo != padLo) || (newPadHi != padHi);

// Calculate new pad output shape
newPadShape[i] =
newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
}

// Check that we actually need to proceed with the rewrite
if (!updated)
return rewriter.notifyMatchFailure(
sliceOp, "terminate condition; nothing to rewrite");

// Create a PadOp with updated padding
auto newPaddingsOp =
getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
auto newPadTy =
RankedTensorType::get(newPadShape, inputTy.getElementType());
auto newPadOp = rewriter.create<tosa::PadOp>(
padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
padOp.getPadConst());

// Update SliceOp and point to new PadOp
auto newStartOp =
getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
newPadOp.getResult(), newStartOp,
sliceOp.getSize());

return success();
}
};

// Update size operand of tosa.slice if size has dynamic dims but corresponding
// output dim is static
struct SliceDynamicSizeCanonicalization
Expand Down Expand Up @@ -779,8 +914,8 @@ struct SliceDynamicSizeCanonicalization

void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
context);
results.add<ConcatSliceOptimization, PadSliceOptimization,
SliceDynamicSizeCanonicalization>(context);
}

//===----------------------------------------------------------------------===//
Expand Down
72 changes: 72 additions & 0 deletions mlir/test/Dialect/Tosa/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -985,6 +985,78 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf

// -----

// CHECK-LABEL: @canonicalize_pad_slice_overlap
// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>}
// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
func.func @canonicalize_pad_slice_overlap(%arg0: tensor<?x16x16x3xf32>) -> tensor<?x14x18x3xf32> {
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<?x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x16x20x3xf32>
%start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%size = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
%sliced = tosa.slice %padded, %start, %size : (tensor<?x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x14x18x3xf32>
return %sliced : tensor<?x14x18x3xf32>
}

// -----

// CHECK-LABEL: @canonicalize_pad_slice_inside
// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
// CHECK-NOT: tosa.pad
// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
%start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
return %sliced : tensor<1x14x14x3xf32>
}

// -----

// CHECK-LABEL: func @canonicalize_pad_slice_exact
// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>}
// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
func.func @canonicalize_pad_slice_exact(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x20x2xf32> {
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
%start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%size = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x20x2xf32>
return %sliced : tensor<1x16x20x2xf32>
}

// -----

// CHECK-LABEL: func @canonicalize_pad_slice_dynamic_noupdate
// CHECK-DAG: tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
// CHECK-DAG: tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>}
// CHECK: tosa.pad
// CHECK: tosa.slice
func.func @canonicalize_pad_slice_dynamic_noupdate(%arg0: tensor<1x16x?x3xf32>) -> tensor<1x16x?x2xf32> {
%pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
%padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
%padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x?x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x?x3xf32>
%start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
%size = tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
%sliced = tosa.slice %padded, %start, %size : (tensor<1x16x?x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x?x2xf32>
return %sliced : tensor<1x16x?x2xf32>
}

// -----

// CHECK-LABEL: @fold_log_exp
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
Expand Down