@@ -470,6 +470,88 @@ struct BubbleUpPackOpThroughGenericOpPattern
470
470
ControlPropagationFn controlFn;
471
471
};
472
472
473
+ // / Propagate a tensor.pack operation up through a tensor.pad. The idea is to
474
+ // / add as many zero padding dimensions in `high` and `low` based on the number
475
+ // / of point loops.
476
+ class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
477
+ public:
478
+ BubbleUpPackThroughPadOp (MLIRContext *context, ControlPropagationFn fun)
479
+ : OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
480
+
481
+ LogicalResult matchAndRewrite (tensor::PackOp packOp,
482
+ PatternRewriter &rewriter) const override {
483
+ auto padOp = packOp.getSource ().getDefiningOp <tensor::PadOp>();
484
+ if (!padOp)
485
+ return failure ();
486
+
487
+ // User controlled propagation function.
488
+ if (!controlFn (padOp))
489
+ return failure ();
490
+
491
+ if (!padOp.getResult ().hasOneUse ())
492
+ return failure ();
493
+
494
+ // TODO: Enable padding when the padding values are the same.
495
+ if (packOp.getPaddingValue ())
496
+ return failure ();
497
+
498
+ // Fail for non-constant padding values. The body of the pad could
499
+ // depend on the padding indices and/or properties of the padded
500
+ // tensor so for now we fail.
501
+ // TODO: Support non-constant padding values.
502
+ Value paddingVal = padOp.getConstantPaddingValue ();
503
+ if (!paddingVal)
504
+ return failure ();
505
+
506
+ if (!packOp.getDest ().getDefiningOp <tensor::EmptyOp>())
507
+ return failure ();
508
+
509
+ ArrayRef<int64_t > innerDimsPos = packOp.getInnerDimsPos ();
510
+ ArrayRef<int64_t > outerDimsPerm = packOp.getOuterDimsPerm ();
511
+
512
+ // Bail out if one of the padded dimension is a tiled one.
513
+ llvm::SmallBitVector paddedDims = padOp.getPaddedDims ();
514
+ llvm::SmallBitVector innerDims (paddedDims.size ());
515
+ for (int64_t dim : innerDimsPos)
516
+ innerDims.flip (dim);
517
+ if (paddedDims.anyCommon (innerDims))
518
+ return failure ();
519
+
520
+ Location loc = padOp->getLoc ();
521
+ OpBuilder::InsertionGuard guard (rewriter);
522
+ rewriter.setInsertionPoint (padOp);
523
+
524
+ auto empty = tensor::PackOp::createDestinationTensor (
525
+ rewriter, loc, padOp.getSource (), packOp.getMixedTiles (), innerDimsPos,
526
+ outerDimsPerm);
527
+ Value packedSource = rewriter.create <tensor::PackOp>(
528
+ loc, padOp.getSource (), empty, innerDimsPos, packOp.getMixedTiles (),
529
+ /* padding=*/ std::nullopt, outerDimsPerm);
530
+
531
+ // If we have `outer_dims_perms` we need to adjust the padded dimensions.
532
+ SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad ();
533
+ SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad ();
534
+ if (!outerDimsPerm.empty ()) {
535
+ applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
536
+ applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
537
+ }
538
+ // The tiled dimensions were verified to be unpadded above, so here we
539
+ // just append 0 for the inner tile dimensions.
540
+ size_t pointLoopsSize = innerDimsPos.size ();
541
+ lowPad.append (pointLoopsSize, rewriter.getIndexAttr (0 ));
542
+ highPad.append (pointLoopsSize, rewriter.getIndexAttr (0 ));
543
+
544
+ auto newPadOp = rewriter.create <tensor::PadOp>(
545
+ loc, /* result=*/ Type (), packedSource, lowPad, highPad, paddingVal,
546
+ padOp.getNofold ());
547
+ rewriter.replaceOp (packOp, newPadOp.getResult ());
548
+ return success ();
549
+ }
550
+
551
+ private:
552
+ ControlPropagationFn controlFn;
553
+ };
554
+
473
555
// TODO: Relax this restriction. We should unpack a generic op also
474
556
// in the presence of multiple unpack ops as producers.
475
557
// / Return the unpacked operand, if present, for the current generic op.
@@ -690,7 +772,8 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
690
772
void mlir::linalg::populateDataLayoutPropagationPatterns (
691
773
RewritePatternSet &patterns,
692
774
const ControlPropagationFn &controlPackUnPackPropagation) {
693
- patterns.insert <BubbleUpPackOpThroughGenericOpPattern,
694
- PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
695
- patterns.getContext (), controlPackUnPackPropagation);
775
+ patterns
776
+ .insert <BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
777
+ PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
778
+ patterns.getContext (), controlPackUnPackPropagation);
696
779
}
0 commit comments