Skip to content

Commit c1826ae

Browse files
authored
[mlir][tensor] Add new helper hooks for RelayoutOp (#109642)
Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims` and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp an UnPackOp inherit from). This improves code re-use and also clarifies the meaning of "outer dims" and "tiled outer dims".
1 parent 029b9b6 commit c1826ae

File tree

3 files changed

+61
-14
lines changed

3 files changed

+61
-14
lines changed

mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
18141814
}
18151815

18161816
//===----------------------------------------------------------------------===//
1817-
// PackOp
1817+
// RelayoutOp
18181818
//===----------------------------------------------------------------------===//
18191819

18201820
class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,27 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
18511851
/// a sentinel `kDynamic` is introduced at that position in
18521852
/// the returned vector.
18531853
SmallVector<int64_t> getStaticTiles();
1854+
1855+
/// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
1856+
/// dims excluding the trailing dims corresponding to `innerTiles`. Note
1857+
/// that this will include both tiled and non-tiled dimensions. The order
1858+
/// of the output dimensions is consistent with the shape of the packed
1859+
/// tensor.
1860+
ArrayRef<int64_t> getAllOuterDims();
1861+
1862+
/// Similar to `getAllOuterDims`, but only retrieve the outer dims that
1863+
/// have been tiled. Also, the order of the output dimensions is consistent
1864+
/// with `inner_dims_pos` rather than the packed tensor.
1865+
SmallVector<int64_t> getTiledOuterDims();
18541866
}];
18551867

18561868
let hasVerifier = 1;
18571869
}
18581870

1871+
//===----------------------------------------------------------------------===//
1872+
// PackOp
1873+
//===----------------------------------------------------------------------===//
1874+
18591875
def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
18601876
AttrSizedOperandSegments]> {
18611877
let summary = "tensor pack operation";

mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
10301030
return input;
10311031
}
10321032

1033+
assert(llvm::all_of(packOp.getAllOuterDims(),
1034+
[](int64_t val) { return val == 1; }) &&
1035+
"some outer dims are != 1");
1036+
10331037
Location loc = packOp.getLoc();
10341038
ShapedType inputType = packOp.getSourceType();
10351039
int64_t inputRank = inputType.getRank();
1036-
assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
1037-
[](int64_t val) { return val == 1; }));
10381040

10391041
SmallVector<int64_t> paddedShape;
10401042
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11261128

11271129
// TODO: support the case that outer dimensions are not all 1s. A
11281130
// tensor.expand_shape will be generated in this case.
1129-
auto innerDimsPos = packOp.getInnerDimsPos();
1130-
int64_t srcRank = packOp.getSourceRank();
1131-
auto destShape = packOp.getDestType().getShape();
1132-
if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
1133-
return destShape[index] != 1;
1134-
})) {
1131+
if (llvm::any_of(packOp.getTiledOuterDims(),
1132+
[](int64_t dim) { return dim != 1; })) {
11351133
return rewriter.notifyMatchFailure(
11361134
packOp, "require the tiled outer dimensions of the result are all 1s");
11371135
}
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11451143
packOp.getDimAndTileMapping();
11461144
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
11471145
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
1146+
int64_t srcRank = packOp.getSourceRank();
11481147
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
11491148
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
11501149
SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
11731172
loc, readType, input, readOffsets, readSizes, readStrides);
11741173

11751174
// 2. Transpose the tile to match the inner tile order.
1176-
11771175
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
1178-
inputShape, innerDimsPos, packOp.getOuterDimsPerm());
1176+
inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
11791177

11801178
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
11811179
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
12081206
int64_t destRank = unpackOp.getDestRank();
12091207
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
12101208
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
1211-
if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
1212-
return srcShape[index] != 1;
1213-
})) {
1209+
if (llvm::any_of(unpackOp.getTiledOuterDims(),
1210+
[](int64_t dim) { return dim != 1; })) {
12141211
return rewriter.notifyMatchFailure(
12151212
unpackOp,
12161213
"require the tiled outer dimensions of the result are all 1s");

mlir/lib/Dialect/Tensor/IR/TensorOps.cpp

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3987,6 +3987,23 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
39873987
return getStaticTilesImpl(*this);
39883988
}
39893989

3990+
ArrayRef<int64_t> PackOp::getAllOuterDims() {
3991+
ShapedType inputType = getSourceType();
3992+
int64_t inputRank = inputType.getRank();
3993+
return getDestType().getShape().take_front(inputRank);
3994+
}
3995+
3996+
SmallVector<int64_t> PackOp::getTiledOuterDims() {
3997+
auto innerDimsPos = getInnerDimsPos();
3998+
auto packedShape = getDestType().getShape();
3999+
SmallVector<int64_t> res;
4000+
4001+
for (auto index : innerDimsPos)
4002+
res.push_back(packedShape[index]);
4003+
4004+
return res;
4005+
}
4006+
39904007
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
39914008
ArrayRef<int64_t> innerDimsPos,
39924009
ArrayRef<int64_t> outputShape,
@@ -4411,6 +4428,23 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
44114428
return getStaticTilesImpl(*this);
44124429
}
44134430

4431+
ArrayRef<int64_t> UnPackOp::getAllOuterDims() {
4432+
ShapedType destType = getDestType();
4433+
int64_t destRank = destType.getRank();
4434+
return getSourceType().getShape().take_front(destRank);
4435+
}
4436+
4437+
SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
4438+
auto innerDimsPos = getInnerDimsPos();
4439+
auto packedShape = getSourceType().getShape();
4440+
SmallVector<int64_t> res;
4441+
4442+
for (auto index : innerDimsPos)
4443+
res.push_back(packedShape[index]);
4444+
4445+
return res;
4446+
}
4447+
44144448
LogicalResult UnPackOp::verify() {
44154449
return commonVerifierPackAndUnPackOp(*this);
44164450
}

0 commit comments

Comments
 (0)