-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Generalize/restrict GeneralizeOuterUnitDimsPackOpPattern
#114315
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
d4a892f
cb2f34f
5efa829
2d9bc62
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -1142,75 +1142,100 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite( | |
tensor::PackOp packOp, PatternRewriter &rewriter) const { | ||
// TODO: support the case that outer dimensions are not all 1s. A | ||
// tensor.expand_shape will be generated in this case. | ||
if (llvm::any_of(packOp.getTiledOuterDims(), | ||
if (llvm::any_of(packOp.getAllOuterDims(), | ||
[](int64_t dim) { return dim != 1; })) { | ||
return rewriter.notifyMatchFailure( | ||
packOp, "require the tiled outer dimensions of the result are all 1s"); | ||
packOp, "not all outer dimensions of the result are 1s"); | ||
} | ||
|
||
// 1. Use rank-reduced tensor.extract_slice op to extract the tile and untiled | ||
// outer dims. | ||
Attribute zeroIdxAttr = rewriter.getIndexAttr(0); | ||
Attribute oneIdxAttr = rewriter.getIndexAttr(1); | ||
Location loc = packOp.getLoc(); | ||
|
||
Value input = getPackOpSourceOrPaddedSource(rewriter, packOp); | ||
auto inputShape = packOp.getSourceType().getShape(); | ||
DenseMap<int64_t, OpFoldResult> dimAndTileMapping = | ||
packOp.getDimAndTileMapping(); | ||
Attribute zeroIdxAttr = rewriter.getIndexAttr(0); | ||
Attribute oneIdxAttr = rewriter.getIndexAttr(1); | ||
int64_t srcRank = packOp.getSourceRank(); | ||
|
||
int64_t destRank = packOp.getDestRank(); | ||
size_t numTiles = destRank - srcRank; | ||
|
||
// 1. Use rank-reduced tensor.extract_slice op to extract the tile: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you elaborate on the advantage of slicing out the inner tile here? We end up reintroducing the outer unit dims shortly after with the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am merely a messenger here ;-) (as in, To me, it makes more sense to only transpose the tile worth of data as that's the intent of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah really, that makes sense then. If it's preserving existing behavior leaving it is fine, but should look into removing it at some point. |
||
// %extracted_tile = tensor.extract_slice(%pack_op_input) | ||
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr); | ||
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr); | ||
SmallVector<OpFoldResult> readSizes; | ||
SmallVector<OpFoldResult> transShapeForEmpty; | ||
SmallVector<int64_t> readShapeForExtractSlice; | ||
|
||
// The sizes attribute for ExtractSliceOp. The leading sizes are set to 1 as | ||
// all outer dims are 1. | ||
SmallVector<OpFoldResult> extractSliceSizes(srcRank - numTiles, oneIdxAttr); | ||
// The shape of the output for ExtractSliceOp. All leading unit dims are | ||
// effectively rank-reduced, hence skipped. | ||
SmallVector<int64_t> outputShapeForExtractSlice; | ||
|
||
// Extract the trailing sizes and shape dims for ExtractSliceOp. These should | ||
// be equal to the inner tile sizes. | ||
for (auto i : llvm::seq<unsigned>(0, srcRank)) { | ||
if (dimAndTileMapping.count(i)) { | ||
readShapeForExtractSlice.push_back( | ||
getConstantIntValue(dimAndTileMapping[i]) | ||
.value_or(ShapedType::kDynamic)); | ||
readSizes.push_back(dimAndTileMapping[i]); | ||
transShapeForEmpty.push_back(dimAndTileMapping[i]); | ||
continue; | ||
} | ||
if (ShapedType::isDynamic(inputShape[i])) { | ||
readSizes.push_back( | ||
rewriter.create<tensor::DimOp>(loc, input, i).getResult()); | ||
} else { | ||
readSizes.push_back(rewriter.getIndexAttr(inputShape[i])); | ||
} | ||
if (inputShape[i] != 1) { | ||
readShapeForExtractSlice.push_back(inputShape[i]); | ||
transShapeForEmpty.push_back(rewriter.getIndexAttr(inputShape[i])); | ||
auto [tileSize, tileSizeOfr] = | ||
getSimplifiedOfrAndStaticSizePair(dimAndTileMapping[i], rewriter); | ||
extractSliceSizes.push_back(tileSizeOfr); | ||
outputShapeForExtractSlice.push_back(tileSize); | ||
} | ||
} | ||
|
||
Type elemType = packOp.getSourceType().getElementType(); | ||
auto readType = RankedTensorType::get(readShapeForExtractSlice, elemType); | ||
auto readType = RankedTensorType::get(outputShapeForExtractSlice, elemType); | ||
|
||
Value tile = rewriter.create<tensor::ExtractSliceOp>( | ||
loc, readType, input, readOffsets, readSizes, readStrides); | ||
loc, readType, input, readOffsets, extractSliceSizes, readStrides); | ||
|
||
// 2. Transpose the tile to match the inner tile order. | ||
// 2. Transpose the tile to match the inner tile order: | ||
// %init = tensor.empty() | ||
// %transposed_tile = linalg.transpose ins(%extracted_tile), outs(%init) | ||
// NOTE: Outer dims are 1 and hence effectively ignored. | ||
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm( | ||
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm()); | ||
|
||
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n"; | ||
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL();); | ||
|
||
applyPermutationToVector<OpFoldResult>(transShapeForEmpty, perm); | ||
// 2.1 Create tensor.empty (init value for TransposeOp) | ||
SmallVector<OpFoldResult> transShapeForEmptyOp; | ||
|
||
// Acquire tensor shape required to create EmptyOp. This will match the inner | ||
// tile sizes. | ||
size_t idx = numTiles; | ||
while (idx != 0) { | ||
transShapeForEmptyOp.push_back(extractSliceSizes[srcRank - idx]); | ||
idx--; | ||
} | ||
|
||
applyPermutationToVector<OpFoldResult>(transShapeForEmptyOp, perm); | ||
Value empty = | ||
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmpty, elemType); | ||
rewriter.create<tensor::EmptyOp>(loc, transShapeForEmptyOp, elemType); | ||
|
||
// 2.2 Create linalg.transpose | ||
auto transposedOp = | ||
rewriter.create<linalg::TransposeOp>(loc, tile, empty, perm); | ||
|
||
// 3. Insert the inner tile to the destination. | ||
int64_t destRank = packOp.getDestRank(); | ||
// 3. Insert the inner tile to the destination: | ||
// %inserted_tile = tensor.insert_slice(%transposed_tile) | ||
SmallVector<OpFoldResult> writeStrides(destRank, oneIdxAttr); | ||
SmallVector<OpFoldResult> writeOffsets(destRank, zeroIdxAttr); | ||
SmallVector<OpFoldResult> writeSizes = | ||
tensor::getMixedSizes(rewriter, loc, packOp.getDest()); | ||
// Outer dims are all 1s! | ||
SmallVector<OpFoldResult> writeSizes(destRank - dimAndTileMapping.size(), | ||
oneIdxAttr); | ||
SmallVector<int64_t> writeShape; | ||
|
||
for (auto tileSize : packOp.getMixedTiles()) { | ||
auto [tileSizeStatic, tileSizeOfr] = | ||
getSimplifiedOfrAndStaticSizePair(tileSize, rewriter); | ||
writeSizes.push_back(tileSizeOfr); | ||
writeShape.push_back(tileSizeStatic); | ||
} | ||
|
||
// 4. Replace tensor.packOp with tensor.insert_slice created above | ||
auto insert = rewriter.create<tensor::InsertSliceOp>( | ||
loc, transposedOp.getResult()[0], packOp.getDest(), writeOffsets, | ||
writeSizes, writeStrides); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This may not be in the scope of this PR to change this, but what exactly is this extract_slice for? It seems to me like it is just taking a full slice. Is the slice ever not full?