Skip to content

Commit 00779c1

Browse files
committed
fix layout propagation on expand shape
1 parent 474baa3 commit 00779c1

File tree

2 files changed

+131
-94
lines changed

2 files changed

+131
-94
lines changed

lib/gc/Analysis/GlobalAnalysis.cpp

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,16 @@ llvm::raw_ostream &operator<<(llvm::raw_ostream &ss,
5555
return ss;
5656
}
5757

58-
// inferring the relationship of two indexing map
59-
// j -> i, means j is represented as the same symbol as i
60-
// we don't allow duplicate in symbols
61-
// e.g. if 2 j corresponding to 1 i, then return failure
58+
// infer the relation between two indexing maps
59+
// returns target dim -> base dim, means target is the same as input
60+
// we don't allow duplication, e.g. 2 target corresponding to 1 base
6261
static FailureOr<DenseMap<int64_t, int64_t>>
6362
inferIndexingMapRelation(AffineMap indexingMapBase,
6463
AffineMap indexingMapTarget) {
64+
// symbols are not allowed to occur
65+
if (indexingMapBase.getNumSymbols() != 0 ||
66+
indexingMapTarget.getNumSymbols() != 0)
67+
return failure();
6568
DenseMap<int64_t, int64_t> res;
6669
ArrayRef<AffineExpr> resultsBase = indexingMapBase.getResults();
6770
ArrayRef<AffineExpr> resultsTarget = indexingMapTarget.getResults();
@@ -70,6 +73,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
7073
auto base = dyn_cast<AffineDimExpr>(resultsBase[i]);
7174
auto target = dyn_cast<AffineDimExpr>(resultsTarget[j]);
7275
if (base && target && base.getPosition() == target.getPosition()) {
76+
// dim j already mapped to certain i
7377
if (res.find(j) != res.end())
7478
return failure();
7579
res[j] = i;
@@ -91,7 +95,7 @@ inferIndexingMapRelation(AffineMap indexingMapBase,
9195
return res;
9296
}
9397

94-
// given j --> i and max rank of i, return i --> j
98+
// given target --> base and max rank of base, return base --> target
9599
static DenseMap<int64_t, int64_t>
96100
getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
97101
size_t maxRank) {
@@ -109,7 +113,7 @@ getReversedIndexMap(const DenseMap<int64_t, int64_t> &indexMap,
109113
return res;
110114
}
111115

112-
static FailureOr<TensorLayout>
116+
static TensorLayout
113117
inferTargetLayout(TensorLayout layoutBase,
114118
const DenseMap<int64_t, int64_t> &indexMap) {
115119
SmallVector<int64_t> baseOuterAxis = layoutBase.getOuterAxis();
@@ -177,6 +181,39 @@ getPackingAxis(int64_t numRank, bool transposed) {
177181
return std::make_pair(outerAxisPerm, innerAxisPos);
178182
}
179183

184+
// copied from mlir
185+
static SmallVector<int64_t>
186+
projectToInnerMostNonUnitDimsPos(ArrayRef<int64_t> dimsPos,
187+
ArrayRef<ReassociationIndices> reassocIndices,
188+
ArrayRef<int64_t> targetShape) {
189+
SmallVector<int64_t> projectedDimsPos;
190+
for (auto pos : dimsPos) {
191+
// In the case all dims are unit, this will return the inner-most one.
192+
int64_t projectedPos = reassocIndices[pos].back();
193+
for (auto i : llvm::reverse(reassocIndices[pos])) {
194+
int64_t dim = targetShape[i];
195+
if (dim > 1 || ShapedType::isDynamic(dim)) {
196+
projectedPos = i;
197+
break;
198+
}
199+
}
200+
projectedDimsPos.push_back(projectedPos);
201+
}
202+
return projectedDimsPos;
203+
}
204+
205+
/// Check if all dims in dimsPos are divisible by the corresponding tile sizes.
206+
static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
207+
ArrayRef<int64_t> shape,
208+
ArrayRef<int64_t> tileSizes) {
209+
for (auto [pos, tileSize] : llvm::zip_equal(dimsPos, tileSizes)) {
210+
int64_t dim = shape[pos];
211+
if (ShapedType::isDynamic(dim) || (dim % tileSize) != 0)
212+
return false;
213+
}
214+
return true;
215+
}
216+
180217
GlobalAnalysis::GlobalAnalysis(Operation *root) {
181218
root->walk([&](Operation *op) {
182219
if (auto linalgOp = dyn_cast<linalg::LinalgOp>(op)) {
@@ -198,9 +235,8 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
198235
}
199236
// ------ Get Current Op's Suggested Layout & Do Propagation ------
200237
IRRewriter rewriter(linalgOp);
201-
// TODO: extend to packed/vnni matmul ops
202238
if (supportedContractionNamedOpList(linalgOp)) {
203-
// get input and output rank
239+
// infer layout for linalg contraction named ops
204240
auto ARank = cast<ShapedType>(linalgOp.getDpsInputs()[0].getType())
205241
.getShape()
206242
.size();
@@ -242,29 +278,36 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
242278
OperatorLayout suggestedLayout({ALayout, BLayout}, {CLayout});
243279
layoutCache[linalgOp] = suggestedLayout;
244280
} else if (!mlir::linalg::isaContractionOpInterface(linalgOp) &&
281+
!mlir::linalg::isaConvolutionOpInterface(linalgOp) &&
245282
!supportedContractionNamedOpList(linalgOp)) {
283+
// infer layout for non-contraction/non-convolution linalg named ops
284+
// and linalg generic ops
246285
SmallVector<TensorLayout> inputLayouts, outputLayouts;
247286
size_t targetIdx = getTargetInputIdx(curInputLayouts);
248-
// TODO(yifei): wisely choose the input format basis
249-
// Let's only refer to input[0] for now
250287
for (size_t i = 0; i < curInputs.size(); ++i) {
251288
// getMatchingIndexingMap
252289
if (i != targetIdx) {
253-
auto res = inferIndexingMapRelation(
290+
auto indexRelation = inferIndexingMapRelation(
254291
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
255292
linalgOp.getMatchingIndexingMap(curInputs[i]));
293+
if (failed(indexRelation)) {
294+
return WalkResult::skip();
295+
}
256296
TensorLayout inputLayout =
257-
*inferTargetLayout(curInputLayouts[targetIdx], *res);
297+
inferTargetLayout(curInputLayouts[targetIdx], *indexRelation);
258298
inputLayouts.push_back(inputLayout);
259299
} else {
260300
inputLayouts.push_back(curInputLayouts[targetIdx]);
261301
}
262302
}
263-
auto res_out = inferIndexingMapRelation(
303+
auto indexRelation = inferIndexingMapRelation(
264304
linalgOp.getMatchingIndexingMap(curInputs[targetIdx]),
265305
linalgOp.getIndexingMapMatchingResult(curResults[0]));
306+
if (failed(indexRelation)) {
307+
return WalkResult::skip();
308+
}
266309
TensorLayout outputLayout =
267-
*inferTargetLayout(curInputLayouts[targetIdx], *res_out);
310+
inferTargetLayout(curInputLayouts[targetIdx], *indexRelation);
268311
outputLayouts.push_back(outputLayout);
269312
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
270313
layoutCache[linalgOp] = suggestedLayout;
@@ -283,52 +326,44 @@ GlobalAnalysis::GlobalAnalysis(Operation *root) {
283326
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
284327
layoutCache[padOp] = suggestedLayout;
285328
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
286-
auto reassociation = expandShapeOp.getReassociation();
329+
SmallVector<ReassociationIndices> reassocIndices =
330+
expandShapeOp.getReassociationIndices();
287331
auto staticOutputShape = expandShapeOp.getStaticOutputShape();
288332
auto parent = expandShapeOp.getSrc().getDefiningOp();
289333
auto inputShape = expandShapeOp.getSrcType().getShape();
290334
TensorLayout curInputLayout =
291335
layoutCache.find(parent) != layoutCache.end()
292336
? layoutCache[parent].getOutputLayout(0)
293337
: TensorLayout::createPlainLayout(inputShape.size());
294-
DenseMap<int64_t, int64_t> outputInputIdxMapping, inputOutputIndexMapping;
295-
int64_t accumulationOffset = 0;
296-
for (int64_t i = 0; i < static_cast<int64_t>(reassociation.size()); ++i) {
297-
auto subReassociation = llvm::cast<ArrayAttr>(reassociation[i]);
298-
for (int64_t j = 0; j < static_cast<int64_t>(subReassociation.size());
299-
++j) {
300-
if (staticOutputShape[accumulationOffset + j] == inputShape[i]) {
301-
outputInputIdxMapping[accumulationOffset + j] = i;
302-
inputOutputIndexMapping[i] = accumulationOffset + j;
303-
}
304-
}
305-
accumulationOffset += subReassociation.size();
338+
SmallVector<int64_t> innerTileSizes;
339+
auto intTileSizes = getConstantIntValues(curInputLayout.getTileSizes());
340+
if (intTileSizes) {
341+
innerTileSizes = *intTileSizes;
306342
}
307-
auto inputOuterAxis = curInputLayout.getOuterAxis();
308-
auto inputInnerAxis = curInputLayout.getInnerAxis();
309-
int64_t diffDifference = staticOutputShape.size() - inputShape.size();
310-
int64_t startIdx = 0;
311-
SmallVector<int64_t> outputOuterAxis, outputInnerAxis;
312-
for (int64_t i = 0; i < static_cast<int64_t>(staticOutputShape.size());
313-
++i) {
314-
if (outputInputIdxMapping.find(i) != outputInputIdxMapping.end()) {
315-
outputOuterAxis.push_back(inputOuterAxis[outputInputIdxMapping[i]] +
316-
diffDifference);
317-
} else {
318-
outputOuterAxis.push_back(startIdx++);
319-
}
343+
ArrayRef<int64_t> innerDimsPos = curInputLayout.getInnerAxis();
344+
ArrayRef<int64_t> outerDimsPerm = curInputLayout.getOuterAxis();
345+
SmallVector<int64_t> projectedInnerDimsPos =
346+
projectToInnerMostNonUnitDimsPos(curInputLayout.getInnerAxis(),
347+
reassocIndices, staticOutputShape);
348+
349+
if (!isDimsDivisibleByTileSizes(projectedInnerDimsPos, staticOutputShape,
350+
innerTileSizes)) {
351+
return WalkResult::skip();
320352
}
321-
for (int64_t i = 0; i < static_cast<int64_t>(inputInnerAxis.size());
322-
++i) {
323-
outputInnerAxis.push_back(inputOutputIndexMapping[inputInnerAxis[i]]);
353+
SmallVector<int64_t> newOuterDimsPerm;
354+
for (auto outerPos : outerDimsPerm) {
355+
newOuterDimsPerm.insert(newOuterDimsPerm.end(),
356+
reassocIndices[outerPos].begin(),
357+
reassocIndices[outerPos].end());
324358
}
325-
TensorLayout outputLayout(outputOuterAxis, outputInnerAxis,
359+
TensorLayout outputLayout(newOuterDimsPerm, projectedInnerDimsPos,
326360
curInputLayout.getTileSizes());
327361
SmallVector<TensorLayout> inputLayouts{curInputLayout},
328362
outputLayouts{outputLayout};
329363
OperatorLayout suggestedLayout(inputLayouts, outputLayouts);
330364
layoutCache[expandShapeOp] = suggestedLayout;
331365
}
366+
return WalkResult::advance();
332367
});
333368
}
334369

lib/gc/Transforms/PropagateLayout.cpp

Lines changed: 52 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ static SmallVector<int64_t> getPackedPermAxes(ArrayRef<int64_t> plainPermAxes,
7676
return result;
7777
}
7878

79+
static int64_t applyPermutationAndReindexReassoc(
80+
SmallVector<ReassociationIndices> &reassocIndices,
81+
ArrayRef<int64_t> permutation) {
82+
if (!permutation.empty())
83+
applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
84+
int64_t nextPos = 0;
85+
for (ReassociationIndices &indices : reassocIndices) {
86+
for (auto &index : indices) {
87+
index = nextPos;
88+
nextPos += 1;
89+
}
90+
}
91+
return nextPos;
92+
}
93+
7994
// extends linalg::pack(...) for named ops
8095
FailureOr<linalg::PackResult> packNamedOp(RewriterBase &rewriter,
8196
linalg::LinalgOp linalgOp,
@@ -250,26 +265,10 @@ class PropagateLayoutOnNamedOps
250265
void runOnOperation() final;
251266
};
252267

253-
LogicalResult graphAlreadyPacked(MLIRContext *ctx, mlir::Operation *graph) {
254-
IRRewriter rewriter(ctx);
255-
auto walk = graph->walk([&](Operation *op) {
256-
if (mlir::gc::utils::isPackableNamedOp(op) && op->hasAttr("packed")) {
257-
LLVM_DEBUG(llvm::dbgs()
258-
<< "Graph already packed. Stop layout propagation.\n");
259-
return WalkResult::interrupt();
260-
}
261-
return WalkResult::advance();
262-
});
263-
if (walk.wasInterrupted()) {
264-
return failure();
265-
}
266-
return success();
267-
}
268-
269268
LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
270269
ControlPackNamedOpsFn controlFn) {
271270
IRRewriter rewriter(ctx);
272-
auto walk = graph->walk([&](Operation *op) {
271+
graph->walk([&](Operation *op) {
273272
if (mlir::gc::utils::isPackableNamedOp(op)) {
274273
LLVM_DEBUG(llvm::dbgs() << "Op " << op->getName() << " visited.\n");
275274
FailureOr<OperatorLayout> opLayout = controlFn(op);
@@ -300,38 +299,44 @@ LogicalResult namedOpLayoutPropagation(MLIRContext *ctx, mlir::Operation *graph,
300299
packNamedOp(rewriter, linalgOp, *opLayout);
301300
if (failed(packedOp)) {
302301
return WalkResult::skip();
303-
} else {
304-
packedOp->packedLinalgOp->setAttr("packed",
305-
rewriter.getBoolAttr(true));
306302
}
307303
} else if (auto expandShapeOp = dyn_cast<tensor::ExpandShapeOp>(op)) {
308-
// Location loc = expandShapeOp->getLoc();
309-
// auto inputLayout = opLayout->getSupportedInputLayouts()[0];
310-
// auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
311-
// Value dest = tensor::PackOp::createDestinationTensor(
312-
// rewriter, loc, expandShapeOp.getSrc(),
313-
// inputLayout.getTileSizes(), inputLayout.getInnerAxis(),
314-
// inputLayout.getOuterAxis());
315-
// Value packedSource = rewriter.create<tensor::PackOp>(
316-
// loc, expandShapeOp.getSrc(), dest, inputLayout.getInnerAxis(),
317-
// inputLayout.getTileSizes(), std::nullopt,
318-
// inputLayout.getOuterAxis());
319-
// auto resultType = RankedTensorType::get(
320-
// expandShapeOp.getStaticOutputShape(),
321-
// expandShapeOp.getSrcType().getElementType());
322-
// RankedTensorType resultPackType = tensor::PackOp::inferPackedType(
323-
// resultType, vector::getAsIntegers(outputLayout.getTileSizes()),
324-
// outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
325-
// auto reassocExpand = getReassociationIndicesForReshape(
326-
// cast<ShapedType>(dest.getType()), resultPackType);
327-
// auto packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
328-
// loc, expandShapeOp.getSrcType().getElementType(), packedSource,
329-
// *reassocExpand);
330-
// Value result = rewriter.create<tensor::UnPackOp>(
331-
// packedExpandShape->getLoc(), packedExpandShape,
332-
// packedExpandShape, outputLayout.getInnerAxis(),
333-
// outputLayout.getTileSizes(), outputLayout.getOuterAxis());
334-
// rewriter.replaceOp(expandShapeOp, result);
304+
Location loc = expandShapeOp->getLoc();
305+
auto inputLayout = opLayout->getSupportedInputLayouts()[0];
306+
auto outputLayout = opLayout->getSupportedOutputLayouts()[0];
307+
LLVM_DEBUG(llvm::dbgs() << "Input layout: " << inputLayout << ".\n");
308+
LLVM_DEBUG(llvm::dbgs() << "Output layout: " << outputLayout << ".\n");
309+
Value curSrc = expandShapeOp.getSrc();
310+
Value curDst = expandShapeOp.getResult();
311+
Value dest = tensor::PackOp::createDestinationTensor(
312+
rewriter, loc, curSrc, inputLayout.getTileSizes(),
313+
inputLayout.getInnerAxis(), inputLayout.getOuterAxis());
314+
Value packedSource = rewriter.create<tensor::PackOp>(
315+
loc, curSrc, dest, inputLayout.getInnerAxis(),
316+
inputLayout.getTileSizes(), std::nullopt,
317+
inputLayout.getOuterAxis());
318+
SmallVector<ReassociationIndices> newReassocIndices =
319+
expandShapeOp.getReassociationIndices();
320+
int64_t nextPos = applyPermutationAndReindexReassoc(
321+
newReassocIndices, inputLayout.getOuterAxis());
322+
// Then add direct mapping for the inner tile dims.
323+
for (size_t i = 0; i < inputLayout.getInnerAxis().size(); ++i) {
324+
newReassocIndices.push_back({nextPos});
325+
nextPos += 1;
326+
}
327+
RankedTensorType newExpandType = tensor::PackOp::inferPackedType(
328+
dyn_cast<RankedTensorType>(curDst.getType()),
329+
*getConstantIntValues(outputLayout.getTileSizes()),
330+
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
331+
Value packedExpandShape = rewriter.create<tensor::ExpandShapeOp>(
332+
loc, newExpandType, packedSource, newReassocIndices);
333+
auto unpackDst = tensor::UnPackOp::createDestinationTensor(
334+
rewriter, loc, packedExpandShape, outputLayout.getTileSizes(),
335+
outputLayout.getInnerAxis(), outputLayout.getOuterAxis());
336+
auto newUnPackOp = rewriter.create<tensor::UnPackOp>(
337+
loc, packedExpandShape, unpackDst, outputLayout.getInnerAxis(),
338+
outputLayout.getTileSizes(), outputLayout.getOuterAxis());
339+
rewriter.replaceOp(expandShapeOp, newUnPackOp);
335340
}
336341
}
337342
return WalkResult::advance();
@@ -619,9 +624,6 @@ struct UpliftPackOverBroadcast : public OpRewritePattern<tensor::PackOp> {
619624
void PropagateLayoutOnNamedOps::runOnOperation() {
620625
MLIRContext *ctx = &getContext();
621626
mlir::Operation *graph = getOperation();
622-
// stage0: check if the graph has been packed
623-
if (failed(graphAlreadyPacked(ctx, graph)))
624-
return;
625627
// stage1: pack matmul
626628
RewritePatternSet packMatmulPatterns(&getContext());
627629
mlir::linalg::ControlBlockPackMatmulFn packMatmulControlFn =

0 commit comments

Comments
 (0)