Skip to content

Commit 57e4f22

Browse files
committed
avoid double propagation
1 parent d6d6e7e commit 57e4f22

File tree

2 files changed

+68
-32
lines changed

2 files changed

+68
-32
lines changed

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 18 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -22,27 +22,15 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
2222
SmallVector<int64_t> innerAxis = layoutCache.getInnerAxis();
2323
SmallVector<OpFoldResult> tileSizes = layoutCache.getTileSizes();
2424
ss << "[";
25-
for (size_t i = 0; i < outerAxis.size(); ++i) {
26-
if (i != 0) {
27-
ss << ", ";
28-
}
29-
ss << outerAxis[i];
30-
}
31-
for (size_t i = 0; i < innerAxis.size(); ++i) {
32-
ss << (i == 0 ? "; " : ", ");
33-
ss << innerAxis[i];
25+
llvm::interleaveComma(outerAxis, ss);
26+
if (!innerAxis.empty()) {
27+
ss << "; ";
28+
llvm::interleaveComma(innerAxis, ss);
3429
}
3530
ss << "]";
3631
if (!tileSizes.empty()) {
3732
ss << "; {";
38-
for (size_t i = 0; i < tileSizes.size(); ++i) {
39-
if (i != 0) {
40-
ss << ", ";
41-
}
42-
if (getConstantIntValue(tileSizes[i]).has_value()) {
43-
ss << *getConstantIntValue(tileSizes[i]);
44-
}
45-
}
33+
llvm::interleaveComma(tileSizes, ss);
4634
ss << "}";
4735
}
4836
return ss;
@@ -58,11 +46,11 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
5846
const OperatorLayout &opLayout) {
5947
for (auto &&[idx, layoutCache] :
6048
llvm::enumerate(opLayout.getSupportedInputLayouts())) {
61-
ss << "input " << idx << "'s layoutCache: " << layoutCache << "\n";
49+
ss << "input " << idx << "'s layout: " << layoutCache << "\n";
6250
}
6351
for (auto &&[idx, layoutCache] :
6452
llvm::enumerate(opLayout.getSupportedOutputLayouts())) {
65-
ss << "output " << idx << "'s layoutCache: " << layoutCache << "\n";
53+
ss << "output " << idx << "'s layout: " << layoutCache << "\n";
6654
}
6755
return ss;
6856
}
@@ -156,15 +144,15 @@ inferTargetLayout(TensorLayout layoutBase,
156144
}
157145

158146
static size_t getTargetInputIdx(ArrayRef<TensorLayout> curInputLayouts) {
159-
for (auto i = 0; i < curInputLayouts.size(); ++i) {
147+
for (size_t i = 0; i < curInputLayouts.size(); ++i) {
160148
if (!curInputLayouts[i].isPlainLayout()) {
161149
return i;
162150
}
163151
}
164152
return 0;
165153
}
166154

167-
static bool supportedContractionOpList(linalg::LinalgOp &linalgOp) {
155+
static bool supportedContractionNamedOpList(linalg::LinalgOp &linalgOp) {
168156
if (isa<linalg::MatmulOp, linalg::MatmulTransposeAOp,
169157
linalg::MatmulTransposeBOp, linalg::BatchMatmulOp,
170158
linalg::BatchMatmulTransposeAOp, linalg::BatchMatmulTransposeBOp>(
@@ -211,7 +199,7 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
211199
// ------ Get Current Op's Suggested Layout & Do Propagation ------
212200
IRRewriter rewriter(linalgOp);
213201
// TODO: extend to packed/vnni matmul ops
214-
if (supportedContractionOpList(linalgOp)) {
202+
if (supportedContractionNamedOpList(linalgOp)) {
215203
// get input and output rank
216204
auto ARank = cast<ShapedType>(linalgOp.getDpsInputs()[0].getType())
217205
.getShape()
@@ -253,7 +241,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
253241
rewriter.getIndexAttr(iin)});
254242
OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout});
255243
layoutCache[linalgOp] = suggestedLayout;
256-
} else if (!mlir::linalg::isaContractionOpInterface(linalgOp)) {
244+
} else if (!mlir::linalg::isaContractionOpInterface(linalgOp) &&
245+
!supportedContractionNamedOpList(linalgOp)) {
257246
SmallVector<TensorLayout> inputLayouts, outputLayouts;
258247
size_t targetIdx = getTargetInputIdx(curInputLayouts);
259248
// TODO(yifei): wisely choose the input format basis
@@ -345,11 +334,12 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
345334

346335
namespace utils {
347336
bool isPackableNamedOp(Operation *op) {
348-
if ((isa<linalg::LinalgOp>(op) &&
349-
!mlir::linalg::isaContractionOpInterface(
350-
dyn_cast<linalg::LinalgOp>(op)) &&
351-
!isa<linalgx::Mm4DVnniOp>(op)) ||
352-
isa<tensor::ExpandShapeOp>(op) || isa<tensor::PadOp>(op))
337+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
338+
if (!supportedContractionNamedOpList(linalgOp)) {
339+
return true;
340+
}
341+
} else if (isa<tensor::ExpandShapeOp, tensor::CollapseShapeOp, tensor::PadOp>(
342+
op))
353343
return true;
354344
return false;
355345
}

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 50 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9-
#include <iostream>
109
#include <numeric>
1110

1211
#include "gc/Transforms/Transforms.h"
@@ -209,8 +208,34 @@ FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
209208
}
210209

211210
// check whether the op is already packed or not
212-
static bool checkPacked(Operation *op, const OperatorLayout &layout) {
211+
static bool checkPacked(Operation *op, const OperatorLayout &opLayout) {
213212
// check whether rank match
213+
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
214+
assert(linalgOp.getDpsInits().size() ==
215+
opLayout.getSupportedOutputLayouts().size() &&
216+
linalgOp.getDpsInputs().size() ==
217+
opLayout.getSupportedInputLayouts().size());
218+
for (auto [index, layout] :
219+
llvm::enumerate(opLayout.getSupportedInputLayouts())) {
220+
// if dimension mismatch, then the op itself is already packed
221+
if (layout.getOuterAxis().size() !=
222+
cast<RankedTensorType>(linalgOp.getDpsInputs()[index].getType())
223+
.getShape()
224+
.size())
225+
return true;
226+
}
227+
for (auto [index, layout] :
228+
llvm::enumerate(opLayout.getSupportedOutputLayouts())) {
229+
// if dimension mismatch, then the op itself is already packed
230+
if (layout.getOuterAxis().size() !=
231+
cast<RankedTensorType>(linalgOp.getDpsInits()[index].getType())
232+
.getShape()
233+
.size())
234+
return true;
235+
}
236+
} else {
237+
assert(op->getNumOperands() == 1 && op->getNumResults() == 1);
238+
}
214239
return false;
215240
}
216241

@@ -225,11 +250,28 @@ class PropagateLayoutOnNamedOps
225250
void runOnOperation() final;
226251
};
227252

253+
LogicalResult graphAlreadyPacked(MLIRContext *ctx, mlir::Operation *graph) {
254+
IRRewriter rewriter(ctx);
255+
auto walk = graph->walk([&](Operation *op) {
256+
if (mlir::gc::utils::isPackableNamedOp(op) && op->hasAttr("packed")) {
257+
LLVM_DEBUG(llvm::dbgs()
258+
<< "Graph already packed. Stop layout propagation.\n");
259+
return WalkResult::interrupt();
260+
}
261+
return WalkResult::advance();
262+
});
263+
if (walk.wasInterrupted()) {
264+
return failure();
265+
}
266+
return success();
267+
}
268+
228269
LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
229270
ControlPackNamedOpsFn controlFn) {
230271
IRRewriter rewriter(ctx);
231272
auto walk = graph->walk([&](Operation *op) {
232273
if (mlir::gc::utils::isPackableNamedOp(op)) {
274+
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n");
233275
FailureOr<OperatorLayout> opLayout = controlFn(op);
234276
if (failed(opLayout)) {
235277
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName()
@@ -258,6 +300,9 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
258300
packNamedOp(rewriter, linalgOp, *opLayout);
259301
if (failed(packedOp)) {
260302
return WalkResult::skip();
303+
} else {
304+
packedOp->packedLinalgOp->setAttr("packed",
305+
rewriter.getBoolAttr(true));
261306
}
262307
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
263308
// Location loc = expandShapeOp->getLoc();
@@ -291,8 +336,6 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
291336
}
292337
return WalkResult::advance();
293338
});
294-
if (walk.wasSkipped())
295-
return failure();
296339
return success();
297340
}
298341

@@ -576,6 +619,9 @@ struct UpliftPackOverBroadcast : public OpRewritePattern<tensor::PackOp> {
576619
void PropagateLayoutOnNamedOps::runOnOperation() {
577620
MLIRContext *ctx = &getContext();
578621
mlir::Operation *graph = getOperation();
622+
// stage0: check if the graph has been packed
623+
if (failed(graphAlreadyPacked(ctx, graph)))
624+
return;
579625
// stage1: pack matmul
580626
RewritePatternSet packMatmulPatterns(&getContext());
581627
mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn =

0 commit comments

Comments
 (0)