-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tosa] Canonicalize slice over overlapped or inside a pad. #138900
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Update the paddings and/or the slice parameters when a `tosa.slice` after a `tosa.pad` is accessing only an overlapping or not region of the padded tensor. Signed-off-by: Georgios Pinitas <[email protected]>
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Georgios Pinitas (GeorgeARM) ChangesUpdate the paddings and/or the slice parameters when a Full diff: https://github.com/llvm/llvm-project/pull/138900.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 47368532df169..eeb7d3e4a27b7 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,6 +731,141 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+struct PadSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ Value sliceInput = sliceOp.getInput1();
+
+ // Check if producer is a PadOp
+ auto padOp = sliceInput.getDefiningOp<tosa::PadOp>();
+ if (!padOp)
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a pad operation");
+
+ // Check PadOp has a single consumer
+ if (!padOp->hasOneUse())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "pad shall have a single consumer");
+
+ // Check input is statically ranked
+ auto inputTy = dyn_cast<RankedTensorType>(padOp.getInput1().getType());
+ auto padTy = dyn_cast<RankedTensorType>(padOp.getType());
+ if (!inputTy || !padTy || !inputTy.hasRank())
+ return rewriter.notifyMatchFailure(sliceOp,
+ "slice input must be a ranked tensor");
+
+ // Validate and extract tosa::PadOp padding
+ DenseIntElementsAttr paddingElems;
+ if (!matchPattern(padOp.getPadding(), m_Constant(&paddingElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp,
+ "`padding` input specified on the tosa::PadOp must be constant.");
+ }
+ llvm::SmallVector<int64_t> padPaddings =
+ llvm::to_vector(paddingElems.getValues<int64_t>());
+
+ // Extract slice parameters
+ DenseElementsAttr startElems;
+ if (!matchPattern(sliceOp.getStart(), m_Constant(&startElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "start of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceStarts =
+ llvm::to_vector(startElems.getValues<int64_t>());
+
+ DenseElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems)))
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ // Check if dynamic dimensions are sliced
+ const int64_t rank = inputTy.getRank();
+ if (llvm::any_of(llvm::seq<int64_t>(0, rank), [&](int64_t i) {
+ const bool isDimDynamic = inputTy.isDynamicDim(i);
+ const bool isDimSliced =
+ (sliceStarts[i] != 0) || (sliceSizes[i] != -1);
+
+ return isDimDynamic && isDimSliced;
+ })) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "axis that are sliced shall be statically known.");
+ }
+
+ // Update the parameters
+ llvm::SmallVector<int64_t> newSliceStarts(rank, 0);
+ llvm::SmallVector<int64_t> newPadPaddings(2 * rank, 0);
+ llvm::SmallVector<int64_t> newPadShape(rank, ShapedType::kDynamic);
+ bool updated = false;
+
+ for (int64_t i = 0; i < rank; ++i) {
+ const int64_t padLo = padPaddings[i * 2];
+ const int64_t padHi = padPaddings[i * 2 + 1];
+ const int64_t sliceStart = sliceStarts[i];
+ const int64_t sliceSize = sliceSizes[i];
+ const int64_t sliceEnd = sliceStart + sliceSize;
+
+ // If dimension is dynamic pass-through
+ if (inputTy.isDynamicDim(i)) {
+ newPadPaddings[i * 2] = padLo;
+ newPadPaddings[i * 2 + 1] = padHi;
+ newSliceStarts[i] = sliceStart;
+ continue;
+ }
+
+ // Handle static dimensions
+ const int64_t dimSize = inputTy.getShape()[i];
+ const int64_t dimTotal = padLo + dimSize + padHi;
+
+ // Check slice within bounds
+ if (sliceStart < 0 || sliceEnd > dimTotal)
+ return rewriter.notifyMatchFailure(sliceOp, "slice is out-of-bounds");
+
+ // Compute updated slice start parameter
+ const int64_t newSliceStart = std::max<int64_t>(sliceStart - padLo, 0);
+ newSliceStarts[i] = newSliceStart;
+ updated |= newSliceStart != sliceStart;
+
+ // Compute updated pad parameters
+ const int64_t newPadLo = std::max<int64_t>(padLo - sliceStart, 0);
+ const int64_t newPadHi =
+ std::max<int64_t>(sliceEnd - (padLo + dimSize), 0);
+ newPadPaddings[i * 2] = newPadLo;
+ newPadPaddings[i * 2 + 1] = newPadHi;
+ updated |= (newPadLo != padLo) || (newPadHi != padHi);
+
+ // Calculate new pad output shape
+ newPadShape[i] =
+ newPadPaddings[i * 2] + dimSize + newPadPaddings[i * 2 + 1];
+ }
+
+ // Check that we actually need to proceed with the rewrite
+ if (!updated)
+ return rewriter.notifyMatchFailure(
+ sliceOp, "terminate condition; nothing to rewrite");
+
+ // Create a PadOp with updated padding
+ auto newPaddingsOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newPadPaddings);
+ auto newPadTy =
+ RankedTensorType::get(newPadShape, inputTy.getElementType());
+ auto newPadOp = rewriter.create<tosa::PadOp>(
+ padOp.getLoc(), newPadTy, padOp.getInput1(), newPaddingsOp,
+ padOp.getPadConst());
+
+ // Update SliceOp and point to new PadOp
+ auto newStartOp =
+ getTosaConstShape(rewriter, sliceOp.getLoc(), newSliceStarts);
+ rewriter.replaceOpWithNewOp<tosa::SliceOp>(sliceOp, sliceOp.getType(),
+ newPadOp.getResult(), newStartOp,
+ sliceOp.getSize());
+
+ return success();
+ }
+};
+
// Update size operand of tosa.slice if size has dynamic dims but corresponding
// output dim is static
struct SliceDynamicSizeCanonicalization
@@ -779,8 +914,8 @@ struct SliceDynamicSizeCanonicalization
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
- context);
+ results.add<ConcatSliceOptimization, PadSliceOptimization,
+ SliceDynamicSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 59fd490330691..c37a0a3d1fb0c 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -985,6 +985,78 @@ func.func @canonicalize_concat_slice_on_non_concat_axis(%arg0 : tensor<1x12x12xf
// -----
+// CHECK-LABEL: @canonicalize_pad_slice_overlap
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 1, 1, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_overlap(%arg0: tensor<?x16x16x3xf32>) -> tensor<?x14x18x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<?x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<?x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 1, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[-1, 14, 18, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<?x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<?x14x18x3xf32>
+ return %sliced : tensor<?x14x18x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: @canonicalize_pad_slice_inside
+// CHECK-DAG: %[[SLICE_START:.*]] = tosa.const_shape {values = dense<[0, 1, 2, 0]> : tensor<4xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>}
+// CHECK-NOT: tosa.pad
+// CHECK: %[[SLICED:.*]] = tosa.slice %arg0, %[[SLICE_START]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_inside(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x14x14x3xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 1, 4, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 14, 10, 3]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x14x14x3xf32>
+ return %sliced : tensor<1x14x14x3xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_pad_slice_exact
+// CHECK-DAG: %[[PAD_CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+// CHECK-DAG: %[[ZERO:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK-DAG: %[[PADDING:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: %[[SLICE_SIZE:.*]] = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>}
+// CHECK: %[[PADDED:.*]] = tosa.pad %arg0, %[[PADDING]], %[[PAD_CONST]]
+// CHECK: %[[SLICED:.*]] = tosa.slice %[[PADDED]], %[[ZERO]], %[[SLICE_SIZE]]
+func.func @canonicalize_pad_slice_exact(%arg0: tensor<1x16x16x3xf32>) -> tensor<1x16x20x2xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x16x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x20x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 16, 20, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x20x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x20x2xf32>
+ return %sliced : tensor<1x16x20x2xf32>
+}
+
+// -----
+
+// CHECK-LABEL: func @canonicalize_pad_slice_dynamic_noupdate
+// CHECK-DAG: tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>}
+// CHECK-DAG: tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>}
+// CHECK: tosa.pad
+// CHECK: tosa.slice
+func.func @canonicalize_pad_slice_dynamic_noupdate(%arg0: tensor<1x16x?x3xf32>) -> tensor<1x16x?x2xf32> {
+ %pad_const = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32>
+ %padding = tosa.const_shape {values = dense<[0, 0, 0, 0, 2, 2, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8>
+ %padded = tosa.pad %arg0, %padding, %pad_const : (tensor<1x16x?x3xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x16x?x3xf32>
+ %start = tosa.const_shape {values = dense<[0, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %size = tosa.const_shape {values = dense<[1, 16, 15, 2]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %sliced = tosa.slice %padded, %start, %size : (tensor<1x16x?x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x16x?x2xf32>
+ return %sliced : tensor<1x16x?x2xf32>
+}
+
+// -----
+
// CHECK-LABEL: @fold_log_exp
func.func @fold_log_exp(%arg0: tensor<?x1xf32>) -> tensor<?x1xf32> {
// CHECK: return %arg{{.*}} : tensor<?x1xf32>
|
For future reference, original review was completed in #138270 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @GeorgeARM - looks great!
Update the paddings and/or the slice parameters when a
tosa.slice
after atosa.pad
is accessing only an overlapping or not region of the padded tensor.