Skip to content

Commit 876455c

Browse files
yifeizh2zhczhong
authored andcommitted
temp remove const weight pack
1 parent 90cb1dd commit 876455c

File tree

4 files changed

+118
-0
lines changed

4 files changed

+118
-0
lines changed

include/gc/Transforms/Passes.td

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,18 @@ def PropagateLayoutOnNamedOps : Pass<"propagate-layout-on-named-ops"> {
181181
];
182182
}
183183

184+
def EliminateConstantWeightPack : Pass<"eliminate-constant-weight-pack", "func::FuncOp"> {
185+
let summary = "EliminateConstantWeightPack.";
186+
let description = [{
187+
EliminateConstantWeightPack.
188+
}];
189+
let dependentDialects = [
190+
"mlir::tensor::TensorDialect",
191+
"mlir::linalg::LinalgDialect",
192+
"mlir::func::FuncDialect"
193+
];
194+
}
195+
184196
def PostProcessPackUnpack : Pass<"post-process-pack-unpack"> {
185197
let summary = "Fold and simplify pack and unpack ops.";
186198
let description = [{

lib/gc/Transforms/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ gc_add_mlir_library(GcPasses
1515
OneDNNGraphToLinalg.cpp
1616
Pipeline.cpp
1717
IterativeTilingAndFusion.cpp
18+
EliminateConstantWeightPack.cpp
1819
TilingUsingInterfaceX.cpp
1920
VerifyTargetDescription.cpp
2021
PropagateLayout.cpp
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//===- EliminateConstantWeightPack.cpp - Eliminate Const Weight *-- C++-*-===//
2+
//
3+
// This file is only temporarily used to extend upstream or upcoming utility in
4+
// TilingInterface, which finally aims for upstream.
5+
//
6+
//===----------------------------------------------------------------------===//
7+
8+
#include <numeric>
9+
10+
#include "mlir/Dialect/Func/IR/FuncOps.h"
11+
#include "mlir/Dialect/Linalg/IR/Linalg.h"
12+
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
13+
#include "mlir/Dialect/Tensor/IR/Tensor.h"
14+
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
15+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
16+
#include "mlir/IR/PatternMatch.h"
17+
#include "mlir/Transforms/DialectConversion.h"
18+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
19+
20+
#include "gc/Dialect/Linalgx/Utils.h"
21+
#include "gc/Transforms/Passes.h"
22+
#include "gc/Transforms/Transforms.h"
23+
24+
namespace mlir {
25+
namespace gc {
26+
#define GEN_PASS_DEF_ELIMINATECONSTANTWEIGHTPACK
27+
#include "gc/Transforms/Passes.h.inc"
28+
29+
using namespace mlir;
30+
31+
class EliminateConstantWeightPack
32+
: public impl::EliminateConstantWeightPackBase<
33+
EliminateConstantWeightPack> {
34+
public:
35+
using impl::EliminateConstantWeightPackBase<
36+
EliminateConstantWeightPack>::EliminateConstantWeightPackBase;
37+
void runOnOperation() final;
38+
};
39+
40+
void EliminateConstantWeightPack::runOnOperation() {
41+
MLIRContext *ctx = &getContext();
42+
IRRewriter rewriter(ctx);
43+
mlir::Operation *graph = getOperation();
44+
ValueTypeRange<Block::BlockArgListType> finalArgTypes =
45+
graph->getBlock()->getArgumentTypes();
46+
bool updated = false;
47+
graph->walk([&](Operation *op) {
48+
if (auto packedMatmul = dyn_cast<linalg::GenericOp>(op)) {
49+
if (linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(),
50+
linalgx::PackingType::MM2D4D) ||
51+
linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(),
52+
linalgx::PackingType::MM4D) ||
53+
linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(),
54+
linalgx::PackingType::VNNI_MM2D) ||
55+
linalgx::isGenericPackedMatmulOp(packedMatmul.getOperation(),
56+
linalgx::PackingType::VNNI_MM4D)) {
57+
auto srcVal = packedMatmul.getDpsInputOperands()[1]->get();
58+
mlir::Operation *argPack = nullptr;
59+
while (auto pack = srcVal.getDefiningOp<tensor::PackOp>()) {
60+
srcVal = pack.getSource();
61+
argPack = pack;
62+
}
63+
if (!isa<BlockArgument>(srcVal) || !argPack)
64+
return WalkResult::skip();
65+
// querying the block
66+
auto parentBlock = packedMatmul.getOperation()->getBlock();
67+
auto blockArgs = parentBlock->getArguments();
68+
auto found = std::find(blockArgs.begin(), blockArgs.end(), srcVal);
69+
assert(found != blockArgs.end());
70+
size_t idx = std::distance(blockArgs.begin(), found);
71+
assert(idx < blockArgs.size() && "Within index.");
72+
73+
auto ty = dyn_cast<TensorType>(srcVal.getType());
74+
auto newArgTy = dyn_cast<TensorType>(
75+
packedMatmul.getDpsInputOperands()[1]->get().getType());
76+
OpBuilder::InsertionGuard guard(rewriter);
77+
rewriter.setInsertionPoint(argPack);
78+
Value argReplace = rewriter.create<tensor::EmptyOp>(
79+
argPack->getLoc(), ty.getShape(), ty.getElementType());
80+
rewriter.replaceAllUsesWith(srcVal, argReplace);
81+
parentBlock->eraseArgument(idx);
82+
parentBlock->addArgument(newArgTy, argPack->getLoc());
83+
Value newPackedArg = parentBlock->getArguments().back();
84+
rewriter.replaceAllUsesWith(
85+
packedMatmul.getDpsInputOperands()[1]->get(), newPackedArg);
86+
updated = true;
87+
finalArgTypes = parentBlock->getArgumentTypes();
88+
}
89+
}
90+
return WalkResult::advance();
91+
});
92+
// Get funcOp
93+
if (updated) {
94+
func::FuncOp func = getOperation();
95+
FunctionType computeFuncType = func.getFunctionType();
96+
func.setType(
97+
FunctionType::get(ctx, finalArgTypes, computeFuncType.getResults()));
98+
}
99+
}
100+
101+
} // namespace gc
102+
} // namespace mlir

lib/gc/Transforms/Pipeline.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ void populateTensorPasses(mlir::OpPassManager &pm) {
5555
// todo: layout propagation pass
5656
pm.addPass(createPropagateLayoutOnNamedOps());
5757
pm.addPass(createPostProcessPackUnpack());
58+
pm.addNestedPass<func::FuncOp>(createEliminateConstantWeightPack());
59+
populateCleanUpPasses(pm);
60+
pm.addPass(createPrintIRPass());
5861
// todo: tensor constant propagation pass
5962
// linalg.matmul lowering to (scf.loop + linalg.brgemm) pass
6063
pm.addNestedPass<func::FuncOp>(createDeepTileContractionOp());

0 commit comments

Comments
 (0)