@@ -731,9 +731,64 @@ struct ConcatSliceOptimization : public OpRewritePattern<tosa::SliceOp> {
731
731
}
732
732
};
733
733
734
+ // Update size operand of tosa.slice if size has dynamic dims but corresponding
735
+ // output dim is static
736
+ struct SliceDynamicSizeCanonicalization
737
+ : public OpRewritePattern<tosa::SliceOp> {
738
+ using OpRewritePattern<tosa::SliceOp>::OpRewritePattern;
739
+
740
+ LogicalResult matchAndRewrite (tosa::SliceOp sliceOp,
741
+ PatternRewriter &rewriter) const override {
742
+ ShapedType resultType = cast<ShapedType>(sliceOp.getType ());
743
+
744
+ ElementsAttr sizeElems;
745
+ if (!matchPattern (sliceOp.getSize (), m_Constant (&sizeElems))) {
746
+ return rewriter.notifyMatchFailure (
747
+ sliceOp, " size of slice must be a static ranked shape" );
748
+ }
749
+
750
+ llvm::SmallVector<int64_t > sliceSizes =
751
+ llvm::to_vector (sizeElems.getValues <int64_t >());
752
+
753
+ bool replaceSliceSize{false };
754
+ // if size op has -1 indicating dynamic shape but corresponding dim on the
755
+ // output is statically known, update size to match with known output dim
756
+ // shape
757
+ for (const auto i : llvm::enumerate (sliceSizes)) {
758
+ int64_t size = i.value ();
759
+ size_t index = i.index ();
760
+ if (size == -1 && !resultType.isDynamicDim (index)) {
761
+ sliceSizes[index] = resultType.getDimSize (index);
762
+ replaceSliceSize = true ;
763
+ }
764
+ }
765
+
766
+ if (!replaceSliceSize) {
767
+ return rewriter.notifyMatchFailure (
768
+ sliceOp, " no dimension of size of slice is dynamic that resolves "
769
+ " to static output shape" );
770
+ }
771
+
772
+ auto size_op = getTosaConstShape (rewriter, sliceOp.getLoc (), sliceSizes);
773
+ auto newSliceOp = rewriter.create <tosa::SliceOp>(
774
+ sliceOp.getLoc (), sliceOp.getType (), sliceOp.getInput1 (),
775
+ sliceOp.getStart (), size_op);
776
+
777
+ rewriter.replaceOp (sliceOp, newSliceOp.getResult ());
778
+
779
+ // Remove const_shape size op when it no longer has use point.
780
+ Operation *sizeConstShape = sliceOp.getSize ().getDefiningOp ();
781
+ if (sizeConstShape->getResult (0 ).hasOneUse ())
782
+ rewriter.eraseOp (sizeConstShape);
783
+
784
+ return success ();
785
+ }
786
+ };
787
+
734
788
void SliceOp::getCanonicalizationPatterns (RewritePatternSet &results,
735
789
MLIRContext *context) {
736
- results.add <ConcatSliceOptimization>(context);
790
+ results.add <ConcatSliceOptimization, SliceDynamicSizeCanonicalization>(
791
+ context);
737
792
}
738
793
739
794
// ===----------------------------------------------------------------------===//
0 commit comments