-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[tosa]: canonicalize dynamic size of tosa.slice to static output shape #135429
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-tosa @llvm/pr-subscribers-mlir Author: Sayan Saha (sahas3) ChangesAddresses #135389 Full diff: https://github.com/llvm/llvm-project/pull/135429.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c4ef7d0bb9ff5..67d8baf32539f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,9 +731,61 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
}
};
+// Update size operand of tosa.slice if size has dynamic dims but corresponding
+// output dim is static
+struct SliceDynamicSizeCanonicalization : public OpRewritePattern<tosa::SliceOp> {
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+ PatternRewriter &rewriter) const override {
+ ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+
+ ElementsAttr sizeElems;
+ if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "size of slice must be a static ranked shape");
+ }
+
+ llvm::SmallVector<int64_t> sliceSizes =
+ llvm::to_vector(sizeElems.getValues<int64_t>());
+
+ bool replaceSliceSize{false};
+ // if size op has -1 indicating dynamic shape but corresponding dim on the
+ // output is statically known, update size to match with known output dim shape
+ for (const auto i : llvm::enumerate(sliceSizes)) {
+ int64_t size = i.value();
+ size_t index = i.index();
+ if (size == -1 && !resultType.isDynamicDim(index)) {
+ sliceSizes[index] = resultType.getDimSize(index);
+ replaceSliceSize = true;
+ }
+ }
+
+ if (!replaceSliceSize) {
+ return rewriter.notifyMatchFailure(
+ sliceOp, "no dimension of size of slice is dynamic that resolves "
+ "to static output shape");
+ }
+
+ auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
+ auto newSliceOp = rewriter.create<tosa::SliceOp>(
+ sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
+ sliceOp.getStart(), size_op);
+
+ rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+ // Remove const_shape size op when it no longer has use point.
+ Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
+ if (sizeConstShape->getResult(0).hasOneUse())
+ rewriter.eraseOp(sizeConstShape);
+
+ return success();
+ }
+};
+
void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
- results.add<ConcatSliceOptimization>(context);
+ results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(context);
}
//===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index b366b4f1e4fd4..a754a46be603f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> {
%16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
return %16 : tensor<1x24x2xi32>
}
+
+
+// ----
+// CHECK-LABEL: func.func @slice_dynamic_size_static_output_canonicalize(
+// CHECK-SAME: %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+// CHECK: %[[START:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[SIZE:.*]] = tosa.const_shape {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK: %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+// CHECK: return %[[SLICE]]
+func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+ %0 = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %1 = tosa.const_shape {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+ %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+ return %2 : tensor<2x60x58x?xf32>
+ }
|
Hi @Tai78641 couldn't find you in reviewer list, so tagging you here. Thanks! |
✅ With the latest revision this PR passed the C/C++ code formatter. |
43099d9
to
6dc0911
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
just one nit
d4cc270
to
2dcd893
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
thanks for the fix!
Dialect/Tosa/canonicalize.mlir is failing on an asan bot (https://lab.llvm.org/buildbot/#/builders/169/builds/10405) |
Based on the ASan output, I think after the replaceOp on line 775, it's no longer valid to do getSize() on sliceOp:
From mlir/lib/IR/PatternMatch.cpp:
ASan log:
|
…put shape" (#135525) Reverts #135429 due buildbot breakage: https://lab.llvm.org/buildbot/#/builders/169/builds/10405 Based on the ASan output, I think after the replaceOp on line 775, it's no longer valid to do getSize() on sliceOp: ``` 775 rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 776 777 // Remove const_shape size op when it no longer has use point. 778 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); ```
I reverted this in #135525 to hopefully bring the buildbots back to green. Please reland with a fix at your convenience. Thanks! |
… static output shape" (#135525) Reverts llvm/llvm-project#135429 due buildbot breakage: https://lab.llvm.org/buildbot/#/builders/169/builds/10405 Based on the ASan output, I think after the replaceOp on line 775, it's no longer valid to do getSize() on sliceOp: ``` 775 rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 776 777 // Remove const_shape size op when it no longer has use point. 778 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); ```
Thanks for looking into the ASAN build failure @thurstond. I've created #135560 to re-enable this patch and verified ASAN works fine locally. |
Thank you! :-) |
Removed the calls to `sizeOp` after replacing `SliceOp`: ``` // Remove const_shape size op when it no longer has use point. Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); ``` Turns out as part of canonicalization, trivially dead ops are removed anyway, so the above piece of code isn't actually needed.
…put shape" (llvm#135525) Reverts llvm#135429 due buildbot breakage: https://lab.llvm.org/buildbot/#/builders/169/builds/10405 Based on the ASan output, I think after the replaceOp on line 775, it's no longer valid to do getSize() on sliceOp: ``` 775 rewriter.replaceOp(sliceOp, newSliceOp.getResult()); 776 777 // Remove const_shape size op when it no longer has use point. 778 Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); ```
Removed the calls to `sizeOp` after replacing `SliceOp`: ``` // Remove const_shape size op when it no longer has use point. Operation *sizeConstShape = sliceOp.getSize().getDefiningOp(); ``` Turns out as part of canonicalization, trivially dead ops are removed anyway, so the above piece of code isn't actually needed.
Addresses #135389