Skip to content

[tosa]: canonicalize dynamic size of tosa.slice to static output shape #135429

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 12, 2025

Conversation

sahas3
Copy link
Member

@sahas3 sahas3 commented Apr 11, 2025

Addresses #135389

@llvmbot
Copy link
Member

llvmbot commented Apr 11, 2025

@llvm/pr-subscribers-mlir-tosa

@llvm/pr-subscribers-mlir

Author: Sayan Saha (sahas3)

Changes

Addresses #135389


Full diff: https://github.com/llvm/llvm-project/pull/135429.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+53-1)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+15)
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index c4ef7d0bb9ff5..67d8baf32539f 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -731,9 +731,61 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
   }
 };
 
+// Update size operand of tosa.slice if size has dynamic dims but corresponding
+// output dim is static
+struct SliceDynamicSizeCanonicalization : public OpRewritePattern<tosa::SliceOp> {
+  using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(tosa::SliceOp sliceOp,
+                                PatternRewriter &rewriter) const override {
+      ShapedType resultType = cast<ShapedType>(sliceOp.getType());
+
+      ElementsAttr sizeElems;
+      if (!matchPattern(sliceOp.getSize(), m_Constant(&sizeElems))) {
+        return rewriter.notifyMatchFailure(
+            sliceOp, "size of slice must be a static ranked shape");
+      }
+
+      llvm::SmallVector<int64_t> sliceSizes =
+          llvm::to_vector(sizeElems.getValues<int64_t>());
+
+      bool replaceSliceSize{false};
+      // if size op has -1 indicating dynamic shape but corresponding dim on the
+      // output is statically known, update size to match with known output dim shape
+      for (const auto i : llvm::enumerate(sliceSizes)) {
+        int64_t size = i.value();
+        size_t index = i.index();
+        if (size == -1 && !resultType.isDynamicDim(index)) {
+          sliceSizes[index] = resultType.getDimSize(index);
+          replaceSliceSize = true;
+        }
+      }
+
+      if (!replaceSliceSize) {
+        return rewriter.notifyMatchFailure(
+            sliceOp, "no dimension of size of slice is dynamic that resolves "
+                     "to static output shape");
+      }
+      
+      auto size_op = getTosaConstShape(rewriter, sliceOp.getLoc(), sliceSizes);
+      auto newSliceOp = rewriter.create<tosa::SliceOp>(
+          sliceOp.getLoc(), sliceOp.getType(), sliceOp.getInput1(),
+          sliceOp.getStart(), size_op);
+
+      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
+
+      // Remove const_shape size op when it no longer has use point.
+      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
+      if (sizeConstShape->getResult(0).hasOneUse())
+        rewriter.eraseOp(sizeConstShape);
+
+      return success();
+  }
+};
+
 void SliceOp::getCanonicalizationPatterns(RewritePatternSet &results,
                                           MLIRContext *context) {
-  results.add<ConcatSliceOptimization>(context);
+  results.add<ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(context);
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index b366b4f1e4fd4..a754a46be603f 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -1212,3 +1212,18 @@ func.func @do_not_fold_intdiv_division_by_0() -> tensor<1x24x2xi32> {
   %16 = tosa.intdiv %4, %1 : (tensor<1x24x2xi32>, tensor<1x24x2xi32>) -> tensor<1x24x2xi32>
   return %16 : tensor<1x24x2xi32>
 }
+
+
+// ----
+// CHECK-LABEL:   func.func @slice_dynamic_size_static_output_canonicalize(
+// CHECK-SAME:                     %[[ARG0:.*]]: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+// CHECK:           %[[START:.*]] = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[SIZE:.*]] = tosa.const_shape  {values = dense<[2, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+// CHECK:           %[[SLICE:.*]] = tosa.slice %[[ARG0]], %[[START]], %[[SIZE]] : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+// CHECK:           return %[[SLICE]]
+func.func @slice_dynamic_size_static_output_canonicalize(%arg0: tensor<2x60x59x?xf32>) -> tensor<2x60x58x?xf32> {
+    %0 = tosa.const_shape  {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4>
+    %1 = tosa.const_shape  {values = dense<[-1, 60, 58, -1]> : tensor<4xindex>} : () -> !tosa.shape<4>
+    %2 = tosa.slice %arg0, %0, %1 : (tensor<2x60x59x?xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x60x58x?xf32>
+    return %2 : tensor<2x60x58x?xf32>
+  }

@sahas3 sahas3 requested a review from sjarus April 11, 2025 19:57
@sahas3
Copy link
Member Author

sahas3 commented Apr 11, 2025

Hi @Tai78641 couldn't find you in reviewer list, so tagging you here. Thanks!

Copy link

github-actions bot commented Apr 11, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@sahas3 sahas3 force-pushed the sliceShapeCanonicalize branch from 43099d9 to 6dc0911 Compare April 11, 2025 20:00
Copy link
Contributor

@Tai78641 Tai78641 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
just one nit

@sahas3 sahas3 force-pushed the sliceShapeCanonicalize branch from d4cc270 to 2dcd893 Compare April 12, 2025 20:43
Copy link
Contributor

@Tai78641 Tai78641 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM
thanks for the fix!

@sahas3 sahas3 merged commit 60b1d44 into llvm:main Apr 12, 2025
11 checks passed
@thurstond
Copy link
Contributor

Dialect/Tosa/canonicalize.mlir is failing on an asan bot (https://lab.llvm.org/buildbot/#/builders/169/builds/10405)

@thurstond
Copy link
Contributor

thurstond commented Apr 13, 2025

Based on the ASan output, I think after the replaceOp on line 775, it's no longer valid to do getSize() on sliceOp:

   775      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
   776
   777      // Remove const_shape size op when it no longer has use point.
   778      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();

From mlir/lib/IR/PatternMatch.cpp:

/// This method replaces the results of the operation with the specified list of
/// values. The number of provided values must match the number of results of
/// the operation. The replaced op is erased.
void RewriterBase::replaceOp(Operation *op, ValueRange newValues) {
  assert(op->getNumResults() == newValues.size() &&
         "incorrect # of replacement values");

  // Replace all result uses. Also notifies the listener of modifications.
  replaceAllOpUsesWith(op, newValues);

  // Erase op and notify listener.
  eraseOp(op);
}

ASan log:

==mlir-opt==1182057==ERROR: AddressSanitizer: heap-use-after-free on address 0x76a7f48bbc7c at pc 0x6010214a6254 bp 0x7ffc067d9790 sp 0x7ffc067d9788
READ of size 4 at 0x76a7f48bbc7c thread T0
    #0 0x6010214a6253 in getOpOperands /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/include/mlir/IR/Operation.h:384:12
    #1 0x6010214a6253 in getOperands /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/include/mlir/IR/Operation.h:379:43
    #2 0x6010214a6253 in operand_begin /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/include/mlir/IR/Operation.h:374:45
    #3 0x6010214a6253 in getODSOperands /home/b/sanitizer-x86_64-linux-fast/build/llvm_build_asan_ubsan/tools/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h.inc:14332:39
    #4 0x6010214a6253 in mlir::tosa::SliceOp::getSize() /home/b/sanitizer-x86_64-linux-fast/build/llvm_build_asan_ubsan/tools/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.h.inc:14345:69
    #5 0x60102176a3ff in SliceDynamicSizeCanonicalization::matchAndRewrite(mlir::tosa::SliceOp, mlir::PatternRewriter&) const /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp:778:41
...

freed by thread T0 here:
    #0 0x60101ac8aeb6 in free /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/compiler-rt/lib/asan/asan_malloc_linux.cpp:51:3
    #1 0x601022f7a71d in erase /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/llvm/include/llvm/ADT/ilist.h:205:5
    #2 0x601022f7a71d in erase /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/llvm/include/llvm/ADT/ilist.h:209:39
    #3 0x601022f7a71d in mlir::Operation::erase() /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/Operation.cpp:541:29
    #4 0x601022fb2d6a in operator() /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/PatternMatch.cpp:184:9
    #5 0x601022fb2d6a in operator() /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/PatternMatch.cpp:223:5
    #6 0x601022fb2d6a in __invoke<(lambda at /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/PatternMatch.cpp:190:48) &, mlir::Operation *> /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__type_traits/invoke.h:179:25
    #7 0x601022fb2d6a in __call<(lambda at /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/PatternMatch.cpp:190:48) &, mlir::Operation *> /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__type_traits/invoke.h:251:5
    #8 0x601022fb2d6a in __invoke_r<void, (lambda at /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/IR/PatternMatch.cpp:190:48) &, mlir::Operation *> /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__type_traits/invoke.h:273:10
    #9 0x601022fb2d6a in operator() /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__functional/function.h:167:12
    #10 0x601022fb2d6a in std::__1::__function::__func<mlir::RewriterBase::eraseOp(mlir::Operation*)::$_0, std::__1::allocator<mlir::RewriterBase::eraseOp(mlir::Operation*)::$_0>, void (mlir::Operation*)>::operator()(mlir::Operation*&&) /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__functional/function.h:319:10
    #11 0x601022fae6a3 in operator() /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__functional/function.h:436:12
    #12 0x601022fae6a3 in operator() /home/b/sanitizer-x86_64-linux-fast/build/libcxx_install_asan_ubsan/include/c++/v1/__functional/function.h:995:10
      #13 0x5c8e0439e5a4 in executeAction<(anonymous namespace)::GreedyPatternRewriteIteration, long &> /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/include/mlir/IR/MLIRContext.h:280:7
    #14 0x5c8e0439e5a4 in simplify /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:872:10
    #15 0x5c8e0439e5a4 in mlir::applyPatternsGreedily(mlir::Region&, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp:919:47
    #16 0x5c8e0001d2ad in mlir::applyPatternsGreedily(mlir::Operation*, mlir::FrozenRewritePatternSet const&, mlir::GreedyRewriteConfig, bool*) /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h:174:15
    #17 0x5c8e0430af42 in (anonymous namespace)::Canonicalizer::runOnOperation() /home/b/sanitizer-x86_64-linux-fast/build/llvm-project/mlir/lib/Transforms/Canonicalizer.cpp:64:9
...

thurstond added a commit that referenced this pull request Apr 13, 2025
…put shape" (#135525)

Reverts #135429 due buildbot breakage:
https://lab.llvm.org/buildbot/#/builders/169/builds/10405

Based on the ASan output, I think after the replaceOp on line 775, it's
no longer valid to do getSize() on sliceOp:
```
   775      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
   776
   777      // Remove const_shape size op when it no longer has use point.
   778      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
```
@thurstond
Copy link
Contributor

I reverted this in #135525 to hopefully bring the buildbots back to green.

Please reland with a fix at your convenience. Thanks!

llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Apr 13, 2025
… static output shape" (#135525)

Reverts llvm/llvm-project#135429 due buildbot breakage:
https://lab.llvm.org/buildbot/#/builders/169/builds/10405

Based on the ASan output, I think after the replaceOp on line 775, it's
no longer valid to do getSize() on sliceOp:
```
   775      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
   776
   777      // Remove const_shape size op when it no longer has use point.
   778      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
```
sahas3 added a commit to sahas3/llvm-project that referenced this pull request Apr 13, 2025
@sahas3
Copy link
Member Author

sahas3 commented Apr 13, 2025

Thanks for looking into the ASAN build failure @thurstond. I've created #135560 to re-enable this patch and verified ASAN works fine locally.

@thurstond
Copy link
Contributor

Thanks for looking into the ASAN build failure @thurstond. I've created #135560 to re-enable this patch and verified ASAN works fine locally.

Thank you! :-)

sahas3 added a commit to sahas3/llvm-project that referenced this pull request Apr 13, 2025
sahas3 added a commit that referenced this pull request Apr 13, 2025
Removed the calls to `sizeOp` after replacing `SliceOp`:

```
// Remove const_shape size op when it no longer has use point.
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
```

Turns out as part of canonicalization, trivially dead ops are removed
anyway, so the above piece of code isn't actually needed.
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
…put shape" (llvm#135525)

Reverts llvm#135429 due buildbot breakage:
https://lab.llvm.org/buildbot/#/builders/169/builds/10405

Based on the ASan output, I think after the replaceOp on line 775, it's
no longer valid to do getSize() on sliceOp:
```
   775      rewriter.replaceOp(sliceOp, newSliceOp.getResult());
   776
   777      // Remove const_shape size op when it no longer has use point.
   778      Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
```
var-const pushed a commit to ldionne/llvm-project that referenced this pull request Apr 17, 2025
Removed the calls to `sizeOp` after replacing `SliceOp`:

```
// Remove const_shape size op when it no longer has use point.
Operation *sizeConstShape = sliceOp.getSize().getDefiningOp();
```

Turns out as part of canonicalization, trivially dead ops are removed
anyway, so the above piece of code isn't actually needed.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants