@@ -561,6 +561,126 @@ struct DropUnitDims : public OpRewritePattern<GenericOp> {
561
561
};
562
562
} // namespace
563
563
564
+ // ===---------------------------------------------------------------------===//
565
+ // Drop dimensions that are unit-extents within tensor operations.
566
+ // ===---------------------------------------------------------------------===//
567
+
568
+ namespace {
569
+ struct DropPadUnitDims : public OpRewritePattern <tensor::PadOp> {
570
+ DropPadUnitDims (MLIRContext *context, ControlDropUnitDims options = {},
571
+ PatternBenefit benefit = 1 )
572
+ : OpRewritePattern(context, benefit), options(std::move(options)) {}
573
+
574
+ LogicalResult matchAndRewrite (tensor::PadOp padOp,
575
+ PatternRewriter &rewriter) const override {
576
+ // 1a. Get the allowed list of dimensions to drop from the `options`.
577
+ SmallVector<unsigned > allowedUnitDims = options.controlFn (padOp);
578
+ if (allowedUnitDims.empty ()) {
579
+ return rewriter.notifyMatchFailure (
580
+ padOp, " control function returns no allowed unit dims to prune" );
581
+ }
582
+
583
+ if (padOp.getSourceType ().getEncoding ()) {
584
+ return rewriter.notifyMatchFailure (
585
+ padOp, " cannot collapse dims of tensor with encoding" );
586
+ }
587
+
588
+ // Fail for non-constant padding values. The body of the pad could
589
+ // depend on the padding indices and/or properties of the padded
590
+ // tensor so for now we fail.
591
+ // TODO: Support non-constant padding values.
592
+ Value paddingVal = padOp.getConstantPaddingValue ();
593
+ if (!paddingVal) {
594
+ return rewriter.notifyMatchFailure (
595
+ padOp, " unimplemented: non-constant padding value" );
596
+ }
597
+
598
+ ArrayRef<int64_t > sourceShape = padOp.getSourceType ().getShape ();
599
+ int64_t padRank = sourceShape.size ();
600
+
601
+ auto isStaticZero = [](OpFoldResult f) {
602
+ std::optional<int64_t > maybeInt = getConstantIntValue (f);
603
+ return maybeInt && *maybeInt == 0 ;
604
+ };
605
+
606
+ llvm::SmallDenseSet<unsigned > unitDimsFilter (allowedUnitDims.begin (),
607
+ allowedUnitDims.end ());
608
+ llvm::SmallDenseSet<unsigned > unitDims;
609
+ SmallVector<int64_t > newShape;
610
+ SmallVector<OpFoldResult> newLowPad;
611
+ SmallVector<OpFoldResult> newHighPad;
612
+ for (const auto [dim, size, low, high] :
613
+ zip_equal (llvm::seq (static_cast <int64_t >(0 ), padRank), sourceShape,
614
+ padOp.getMixedLowPad (), padOp.getMixedHighPad ())) {
615
+ if (unitDimsFilter.contains (dim) && size == 1 && isStaticZero (low) &&
616
+ isStaticZero (high)) {
617
+ unitDims.insert (dim);
618
+ } else {
619
+ newShape.push_back (size);
620
+ newLowPad.push_back (low);
621
+ newHighPad.push_back (high);
622
+ }
623
+ }
624
+
625
+ if (unitDims.empty ()) {
626
+ return rewriter.notifyMatchFailure (padOp, " no unit dims to collapse" );
627
+ }
628
+
629
+ ReassociationIndices reassociationGroup;
630
+ SmallVector<ReassociationIndices> reassociationMap;
631
+ int64_t dim = 0 ;
632
+ while (dim < padRank && unitDims.contains (dim))
633
+ reassociationGroup.push_back (dim++);
634
+ while (dim < padRank) {
635
+ assert (!unitDims.contains (dim) && " expected non unit-extent" );
636
+ reassociationGroup.push_back (dim);
637
+ dim++;
638
+ // Fold all following dimensions that are unit-extent.
639
+ while (dim < padRank && unitDims.contains (dim))
640
+ reassociationGroup.push_back (dim++);
641
+ reassociationMap.push_back (reassociationGroup);
642
+ reassociationGroup.clear ();
643
+ }
644
+
645
+ Value collapsedSource =
646
+ collapseValue (rewriter, padOp.getLoc (), padOp.getSource (), newShape,
647
+ reassociationMap, options.rankReductionStrategy );
648
+
649
+ auto newPadOp = rewriter.create <tensor::PadOp>(
650
+ padOp.getLoc (), /* result=*/ Type (), collapsedSource, newLowPad,
651
+ newHighPad, paddingVal, padOp.getNofold ());
652
+
653
+ Value dest = padOp.getResult ();
654
+ if (options.rankReductionStrategy ==
655
+ ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice) {
656
+ SmallVector<OpFoldResult> expandedSizes;
657
+ int64_t numUnitDims = 0 ;
658
+ for (auto dim : llvm::seq (static_cast <int64_t >(0 ), padRank)) {
659
+ if (unitDims.contains (dim)) {
660
+ expandedSizes.push_back (rewriter.getIndexAttr (1 ));
661
+ numUnitDims++;
662
+ continue ;
663
+ }
664
+ expandedSizes.push_back (tensor::getMixedSize (
665
+ rewriter, padOp.getLoc (), newPadOp, dim - numUnitDims));
666
+ }
667
+ dest = rewriter.create <tensor::EmptyOp>(
668
+ padOp.getLoc (), expandedSizes,
669
+ padOp.getResultType ().getElementType ());
670
+ }
671
+
672
+ Value expandedValue =
673
+ expandValue (rewriter, padOp.getLoc (), newPadOp.getResult (), dest,
674
+ reassociationMap, options.rankReductionStrategy );
675
+ rewriter.replaceOp (padOp, expandedValue);
676
+ return success ();
677
+ }
678
+
679
+ private:
680
+ ControlDropUnitDims options;
681
+ };
682
+ } // namespace
683
+
564
684
namespace {
565
685
// / Convert `extract_slice` operations to rank-reduced versions.
566
686
struct RankReducedExtractSliceOp
@@ -640,6 +760,7 @@ populateFoldUnitExtentDimsViaReshapesPatterns(RewritePatternSet &patterns,
640
760
ControlDropUnitDims &options) {
641
761
auto *context = patterns.getContext ();
642
762
patterns.add <DropUnitDims>(context, options);
763
+ patterns.add <DropPadUnitDims>(context, options);
643
764
// TODO: Patterns unrelated to unit dim folding should be factored out.
644
765
patterns.add <RankReducedExtractSliceOp,
645
766
RankReducedInsertSliceOp<tensor::InsertSliceOp>,
@@ -661,6 +782,7 @@ populateFoldUnitExtentDimsViaSlicesPatterns(RewritePatternSet &patterns,
661
782
options.rankReductionStrategy =
662
783
ControlDropUnitDims::RankReductionStrategy::ExtractInsertSlice;
663
784
patterns.add <DropUnitDims>(context, options);
785
+ patterns.add <DropPadUnitDims>(context, options);
664
786
// TODO: Patterns unrelated to unit dim folding should be factored out.
665
787
linalg::FillOp::getCanonicalizationPatterns (patterns, context);
666
788
tensor::EmptyOp::getCanonicalizationPatterns (patterns, context);
0 commit comments