Skip to content

Commit 886294a

Browse files
authored
[mlir][linalg] Add pattern to propagate pack up through tensor.pad (#82035)
This mirrors the existing pattern for pushing unpack down through padding, restricting to cases where the padded dimensions aren't tiled by the pack. Additionally reformats the propagation test to make it easier to read.
1 parent 767433b commit 886294a

File tree

2 files changed

+538
-389
lines changed

2 files changed

+538
-389
lines changed

mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp

Lines changed: 86 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,88 @@ struct BubbleUpPackOpThroughGenericOpPattern
470470
ControlPropagationFn controlFn;
471471
};
472472

473+
/// Propagate a tensor.pack operation up through a tensor.pad. The idea is to
474+
/// add as many zero padding dimensions in `high` and `low` based on the number
475+
/// of point loops.
476+
class BubbleUpPackThroughPadOp final : public OpRewritePattern<tensor::PackOp> {
477+
public:
478+
BubbleUpPackThroughPadOp(MLIRContext *context, ControlPropagationFn fun)
479+
: OpRewritePattern<tensor::PackOp>(context), controlFn(std::move(fun)) {}
480+
481+
LogicalResult matchAndRewrite(tensor::PackOp packOp,
482+
PatternRewriter &rewriter) const override {
483+
auto padOp = packOp.getSource().getDefiningOp<tensor::PadOp>();
484+
if (!padOp)
485+
return failure();
486+
487+
// User controlled propagation function.
488+
if (!controlFn(padOp))
489+
return failure();
490+
491+
if (!padOp.getResult().hasOneUse())
492+
return failure();
493+
494+
// TODO: Enable padding when the padding values are the same.
495+
if (packOp.getPaddingValue())
496+
return failure();
497+
498+
// Fail for non-constant padding values. The body of the pad could
499+
// depend on the padding indices and/or properties of the padded
500+
// tensor so for now we fail.
501+
// TODO: Support non-constant padding values.
502+
Value paddingVal = padOp.getConstantPaddingValue();
503+
if (!paddingVal)
504+
return failure();
505+
506+
if (!packOp.getDest().getDefiningOp<tensor::EmptyOp>())
507+
return failure();
508+
509+
ArrayRef<int64_t> innerDimsPos = packOp.getInnerDimsPos();
510+
ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
511+
512+
// Bail out if one of the padded dimension is a tiled one.
513+
llvm::SmallBitVector paddedDims = padOp.getPaddedDims();
514+
llvm::SmallBitVector innerDims(paddedDims.size());
515+
for (int64_t dim : innerDimsPos)
516+
innerDims.flip(dim);
517+
if (paddedDims.anyCommon(innerDims))
518+
return failure();
519+
520+
Location loc = padOp->getLoc();
521+
OpBuilder::InsertionGuard guard(rewriter);
522+
rewriter.setInsertionPoint(padOp);
523+
524+
auto empty = tensor::PackOp::createDestinationTensor(
525+
rewriter, loc, padOp.getSource(), packOp.getMixedTiles(), innerDimsPos,
526+
outerDimsPerm);
527+
Value packedSource = rewriter.create<tensor::PackOp>(
528+
loc, padOp.getSource(), empty, innerDimsPos, packOp.getMixedTiles(),
529+
/*padding=*/std::nullopt, outerDimsPerm);
530+
531+
// If we have `outer_dims_perms` we need to adjust the padded dimensions.
532+
SmallVector<OpFoldResult> lowPad = padOp.getMixedLowPad();
533+
SmallVector<OpFoldResult> highPad = padOp.getMixedHighPad();
534+
if (!outerDimsPerm.empty()) {
535+
applyPermutationToVector<OpFoldResult>(lowPad, outerDimsPerm);
536+
applyPermutationToVector<OpFoldResult>(highPad, outerDimsPerm);
537+
}
538+
// The tiled dimensions were verified to be unpadded above, so here we
539+
// just append 0 for the inner tile dimensions.
540+
size_t pointLoopsSize = innerDimsPos.size();
541+
lowPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
542+
highPad.append(pointLoopsSize, rewriter.getIndexAttr(0));
543+
544+
auto newPadOp = rewriter.create<tensor::PadOp>(
545+
loc, /*result=*/Type(), packedSource, lowPad, highPad, paddingVal,
546+
padOp.getNofold());
547+
rewriter.replaceOp(packOp, newPadOp.getResult());
548+
return success();
549+
}
550+
551+
private:
552+
ControlPropagationFn controlFn;
553+
};
554+
473555
// TODO: Relax this restriction. We should unpack a generic op also
474556
// in the presence of multiple unpack ops as producers.
475557
/// Return the unpacked operand, if present, for the current generic op.
@@ -690,7 +772,8 @@ struct PushDownUnPackThroughPadOp : public OpRewritePattern<tensor::PadOp> {
690772
void mlir::linalg::populateDataLayoutPropagationPatterns(
691773
RewritePatternSet &patterns,
692774
const ControlPropagationFn &controlPackUnPackPropagation) {
693-
patterns.insert<BubbleUpPackOpThroughGenericOpPattern,
694-
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
695-
patterns.getContext(), controlPackUnPackPropagation);
775+
patterns
776+
.insert<BubbleUpPackOpThroughGenericOpPattern, BubbleUpPackThroughPadOp,
777+
PushDownUnPackOpThroughGenericOp, PushDownUnPackThroughPadOp>(
778+
patterns.getContext(), controlPackUnPackPropagation);
696779
}

0 commit comments

Comments
 (0)