|
19 | 19 | #include "mlir/IR/PatternMatch.h"
|
20 | 20 | #include "mlir/Transforms/DialectConversion.h"
|
21 | 21 | #include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
| 22 | +#include "llvm/ADT/DenseMap.h" |
22 | 23 |
|
23 | 24 | #include "gc/Dialect/Linalgx/LinalgxDialect.h"
|
24 | 25 | #include "gc/Dialect/Linalgx/LinalgxOps.h"
|
@@ -495,6 +496,83 @@ struct PackVNNI<linalg::GenericOp>
|
495 | 496 | }
|
496 | 497 | };
|
497 | 498 |
|
| 499 | +/* |
| 500 | +Match patterns like broadcast + pack, uplift pack |
| 501 | +*/ |
| 502 | +struct UpliftPackOverBroadcast : public OpRewritePattern<tensor::PackOp> { |
| 503 | + UpliftPackOverBroadcast(MLIRContext *context, PatternBenefit benefit = 1) |
| 504 | + : OpRewritePattern<tensor::PackOp>(context, benefit) {} |
| 505 | + LogicalResult matchAndRewrite(tensor::PackOp pack, |
| 506 | + PatternRewriter &rewriter) const override { |
| 507 | + auto broadcastOp = pack.getSource().getDefiningOp<linalg::BroadcastOp>(); |
| 508 | + if (!broadcastOp || !broadcastOp.getResult()[0].hasOneUse()) { |
| 509 | + return failure(); |
| 510 | + } |
| 511 | + SmallVector<int64_t> innerTileSizes = pack.getStaticTiles(); |
| 512 | + SmallVector<int64_t> innerDimsPos(pack.getInnerDimsPos()); |
| 513 | + SmallVector<int64_t> outerDimsPerm(pack.getOuterDimsPerm()); |
| 514 | + int64_t rank = |
| 515 | + cast<ShapedType>(pack.getSource().getType()).getShape().size(); |
| 516 | + if (outerDimsPerm.empty()) { |
| 517 | + outerDimsPerm.resize(rank); |
| 518 | + std::iota(outerDimsPerm.begin(), outerDimsPerm.end(), 0); |
| 519 | + } |
| 520 | + ArrayRef<int64_t> broadcastAxis = broadcastOp.getDimensions(); |
| 521 | + SmallVector<int64_t> newInnerDimsPos, newOuterDimsPerm, packedBroadcastAxis; |
| 522 | + SmallVector<OpFoldResult> newInnerTileSizes; |
| 523 | + llvm::SmallDenseMap<int64_t, int64_t> axisMapping; |
| 524 | + int64_t axisCounter = 0; |
| 525 | + for (int64_t axis = 0; axis < rank; ++axis) { |
| 526 | + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == |
| 527 | + broadcastAxis.end()) { |
| 528 | + // if the axis is not broadcasted, keep it |
| 529 | + axisMapping[axis] = axisCounter++; |
| 530 | + } |
| 531 | + } |
| 532 | + // update broadcast dims |
| 533 | + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { |
| 534 | + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != |
| 535 | + broadcastAxis.end()) { |
| 536 | + packedBroadcastAxis.push_back(index); |
| 537 | + } |
| 538 | + } |
| 539 | + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { |
| 540 | + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) != |
| 541 | + broadcastAxis.end()) { |
| 542 | + packedBroadcastAxis.push_back(index + rank); |
| 543 | + } |
| 544 | + } |
| 545 | + // update packing axis |
| 546 | + for (auto [index, axis] : llvm::enumerate(outerDimsPerm)) { |
| 547 | + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == |
| 548 | + broadcastAxis.end()) { |
| 549 | + newOuterDimsPerm.push_back(axisMapping[axis]); |
| 550 | + } |
| 551 | + } |
| 552 | + for (auto [index, axis] : llvm::enumerate(innerDimsPos)) { |
| 553 | + if (std::find(broadcastAxis.begin(), broadcastAxis.end(), axis) == |
| 554 | + broadcastAxis.end()) { |
| 555 | + newInnerDimsPos.push_back(axisMapping[axis]); |
| 556 | + newInnerTileSizes.push_back( |
| 557 | + rewriter.getIndexAttr(innerTileSizes[index])); |
| 558 | + } |
| 559 | + } |
| 560 | + // replace ops |
| 561 | + auto loc = broadcastOp.getLoc(); |
| 562 | + auto dest = tensor::PackOp::createDestinationTensor( |
| 563 | + rewriter, loc, broadcastOp.getDpsInputs()[0], newInnerTileSizes, |
| 564 | + newInnerDimsPos, newOuterDimsPerm); |
| 565 | + Value packedSource = rewriter.create<tensor::PackOp>( |
| 566 | + loc, broadcastOp.getDpsInputs()[0], dest, newInnerDimsPos, |
| 567 | + newInnerTileSizes, |
| 568 | + /*padding=*/std::nullopt, newOuterDimsPerm); |
| 569 | + auto newBroadcastOp = rewriter.create<linalg::BroadcastOp>( |
| 570 | + loc, packedSource, pack.getDest(), packedBroadcastAxis); |
| 571 | + rewriter.replaceOp(pack, newBroadcastOp.getResults()); |
| 572 | + return success(); |
| 573 | + } |
| 574 | +}; |
| 575 | + |
498 | 576 | void PropagateLayoutOnNamedOps::runOnOperation() {
|
499 | 577 | MLIRContext *ctx = &getContext();
|
500 | 578 | mlir::Operation *graph = getOperation();
|
@@ -541,6 +619,12 @@ void PropagateLayoutOnNamedOps::runOnOperation() {
|
541 | 619 | };
|
542 | 620 | if (failed(namedOpLayoutPropagation(ctx, graph, layoutControlFn)))
|
543 | 621 | return signalPassFailure();
|
| 622 | + |
| 623 | + // stage4: uplift pack through broadcast |
| 624 | + RewritePatternSet upliftPatterns(&getContext()); |
| 625 | + upliftPatterns.add<UpliftPackOverBroadcast>(ctx); |
| 626 | + if (failed(applyPatternsAndFoldGreedily(graph, std::move(upliftPatterns)))) |
| 627 | + return signalPassFailure(); |
544 | 628 | }
|
545 | 629 |
|
546 | 630 | } // namespace gc
|
|
0 commit comments