|
| 1 | +//===- EliminateConstantWeightPack.cpp - Eliminate Const Weight *-- C++-*-===// |
| 2 | +// |
| 3 | +// This file is only temporarily used to extend upstream or upcoming utility in |
| 4 | +// TilingInterface, which finally aims for upstream. |
| 5 | +// |
| 6 | +//===----------------------------------------------------------------------===// |
| 7 | + |
| 8 | +#include <numeric> |
| 9 | + |
| 10 | +#include "mlir/Dialect/Func/IR/FuncOps.h" |
| 11 | +#include "mlir/Dialect/Linalg/IR/Linalg.h" |
| 12 | +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" |
| 13 | +#include "mlir/Dialect/Tensor/IR/Tensor.h" |
| 14 | +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" |
| 15 | +#include "mlir/Dialect/Utils/StaticValueUtils.h" |
| 16 | +#include "mlir/IR/PatternMatch.h" |
| 17 | +#include "mlir/Transforms/DialectConversion.h" |
| 18 | +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" |
| 19 | + |
| 20 | +#include "gc/Dialect/Linalgx/Utils.h" |
| 21 | +#include "gc/Transforms/Passes.h" |
| 22 | +#include "gc/Transforms/Transforms.h" |
| 23 | + |
| 24 | +namespace mlir { |
| 25 | +namespace gc { |
| 26 | +#define GEN_PASS_DEF_ELIMINATECONSTANTWEIGHTPACK |
| 27 | +#include "gc/Transforms/Passes.h.inc" |
| 28 | + |
| 29 | +using namespace mlir; |
| 30 | + |
| 31 | +class EliminateConstantWeightPack |
| 32 | + : public impl::EliminateConstantWeightPackBase< |
| 33 | + EliminateConstantWeightPack> { |
| 34 | +public: |
| 35 | + using impl::EliminateConstantWeightPackBase< |
| 36 | + EliminateConstantWeightPack>::EliminateConstantWeightPackBase; |
| 37 | + void runOnOperation() final; |
| 38 | +}; |
| 39 | + |
| 40 | +void EliminateConstantWeightPack::runOnOperation() { |
| 41 | + MLIRContext *ctx = &getContext(); |
| 42 | + IRRewriter rewriter(ctx); |
| 43 | + mlir::Operation *graph = getOperation(); |
| 44 | + ValueTypeRange<Block::BlockArgListType> finalArgTypes = |
| 45 | + graph->getBlock()->getArgumentTypes(); |
| 46 | + bool updated = false; |
| 47 | + graph->walk([&](Operation *op) { |
| 48 | + if (auto packedMatmul = dyn_cast<linalg::GenericOp>(op)) { |
| 49 | + if (linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(), |
| 50 | + linalgx::PackingType::MM2D4D) || |
| 51 | + linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(), |
| 52 | + linalgx::PackingType::MM4D) || |
| 53 | + linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(), |
| 54 | + linalgx::PackingType::VNNI_MM2D) || |
| 55 | + linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(), |
| 56 | + linalgx::PackingType::VNNI_MM4D)) { |
| 57 | + auto srcVal = packedMatmul.getDpsInputOperands()[1]->get(); |
| 58 | + mlir::Operation *argPack = nullptr; |
| 59 | + while (auto pack = srcVal.getDefiningOp<tensor::PackOp>()) { |
| 60 | + srcVal = pack.getSource(); |
| 61 | + argPack = pack; |
| 62 | + } |
| 63 | + if (!isa<BlockArgument>(srcVal) || !argPack) |
| 64 | + return WalkResult::skip(); |
| 65 | + // querying the block |
| 66 | + auto parentBlock = packedMatmul.getOperation()->getBlock(); |
| 67 | + auto blockArgs = parentBlock->getArguments(); |
| 68 | + auto found = std::find(blockArgs.begin(), blockArgs.end(), srcVal); |
| 69 | + assert(found != blockArgs.end()); |
| 70 | + size_t idx = std::distance(blockArgs.begin(), found); |
| 71 | + assert(idx < blockArgs.size() && "Within index."); |
| 72 | + |
| 73 | + auto ty = dyn_cast<TensorType>(srcVal.getType()); |
| 74 | + auto newArgTy = dyn_cast<TensorType>( |
| 75 | + packedMatmul.getDpsInputOperands()[1]->get().getType()); |
| 76 | + OpBuilder::InsertionGuard guard(rewriter); |
| 77 | + rewriter.setInsertionPoint(argPack); |
| 78 | + Value argReplace = rewriter.create<tensor::EmptyOp>( |
| 79 | + argPack->getLoc(), ty.getShape(), ty.getElementType()); |
| 80 | + rewriter.replaceAllUsesWith(srcVal, argReplace); |
| 81 | + parentBlock->eraseArgument(idx); |
| 82 | + parentBlock->addArgument(newArgTy, argPack->getLoc()); |
| 83 | + Value newPackedArg = parentBlock->getArguments().back(); |
| 84 | + rewriter.replaceAllUsesWith( |
| 85 | + packedMatmul.getDpsInputOperands()[1]->get(), newPackedArg); |
| 86 | + updated = true; |
| 87 | + finalArgTypes = parentBlock->getArgumentTypes(); |
| 88 | + } |
| 89 | + } |
| 90 | + return WalkResult::advance(); |
| 91 | + }); |
| 92 | + // Get funcOp |
| 93 | + if (updated) { |
| 94 | + func::FuncOp func = getOperation(); |
| 95 | + FunctionType computeFuncType = func.getFunctionType(); |
| 96 | + func.setType( |
| 97 | + FunctionType::get(ctx, finalArgTypes, computeFuncType.getResults())); |
| 98 | + } |
| 99 | +} |
| 100 | + |
| 101 | +} // namespace gc |
| 102 | +} // namespace mlir |
0 commit comments