-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][tensor] Add new helper hooks for RelayoutOp #109642
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir][tensor] Add new helper hooks for RelayoutOp #109642
Conversation
Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims` and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp an UnPackOp inherit from). This improves code re-use and also clarifies the meaning of "outer dims" and "tiled outer dims".
@llvm/pr-subscribers-mlir-linalg @llvm/pr-subscribers-mlir-tensor Author: Andrzej Warzyński (banach-space) ChangesImplements two helper hooks for PackOp and UnPackOP, This improves code re-use and also clarifies the meaning of "outer dims" Full diff: https://github.com/llvm/llvm-project/pull/109642.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
}
//===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
//===----------------------------------------------------------------------===//
class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// a sentinel `kDynamic` is introduced at that position in
/// the returned vector.
SmallVector<int64_t> getStaticTiles();
+
+ /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+ /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+ /// that this will include both tiled and non-tiled dimensions.
+ ArrayRef<int64_t> getAllOuterDims() {
+ ShapedType inputType = getSourceType();
+ int64_t inputRank = inputType.getRank();
+ return getDestType().getShape().take_front(inputRank);
+ }
+
+ /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+ /// have been tiled.
+ SmallVector<int64_t> getTiledOuterDims();
}];
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
AttrSizedOperandSegments]> {
let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
return input;
}
+ assert(llvm::all_of(packOp.getAllOuterDims(),
+ [](int64_t val) { return val == 1; }) &&
+ "some outer dims are != 1");
+
Location loc = packOp.getLoc();
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
- [](int64_t val) { return val == 1; }));
SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
- auto innerDimsPos = packOp.getInnerDimsPos();
- int64_t srcRank = packOp.getSourceRank();
- auto destShape = packOp.getDestType().getShape();
- if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
- return destShape[index] != 1;
- })) {
+ if (llvm::any_of(packOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+ int64_t srcRank = packOp.getSourceRank();
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
loc, readType, input, readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the inner tile order.
-
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+ inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
- if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
- return srcShape[index] != 1;
- })) {
+ if (llvm::any_of(unpackOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
unpackOp,
"require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getDestType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getSourceType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
LogicalResult UnPackOp::verify() {
return commonVerifierPackAndUnPackOp(*this);
}
|
@llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesImplements two helper hooks for PackOp and UnPackOP, This improves code re-use and also clarifies the meaning of "outer dims" Full diff: https://github.com/llvm/llvm-project/pull/109642.diff 3 Files Affected:
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
}
//===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
//===----------------------------------------------------------------------===//
class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
/// a sentinel `kDynamic` is introduced at that position in
/// the returned vector.
SmallVector<int64_t> getStaticTiles();
+
+ /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+ /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+ /// that this will include both tiled and non-tiled dimensions.
+ ArrayRef<int64_t> getAllOuterDims() {
+ ShapedType inputType = getSourceType();
+ int64_t inputRank = inputType.getRank();
+ return getDestType().getShape().take_front(inputRank);
+ }
+
+ /// Similar to `getAllOuterDims`, but only retrieve the outer dims that
+ /// have been tiled.
+ SmallVector<int64_t> getTiledOuterDims();
}];
let hasVerifier = 1;
}
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
AttrSizedOperandSegments]> {
let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
return input;
}
+ assert(llvm::all_of(packOp.getAllOuterDims(),
+ [](int64_t val) { return val == 1; }) &&
+ "some outer dims are != 1");
+
Location loc = packOp.getLoc();
ShapedType inputType = packOp.getSourceType();
int64_t inputRank = inputType.getRank();
- assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
- [](int64_t val) { return val == 1; }));
SmallVector<int64_t> paddedShape;
DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
// TODO: support the case that outer dimensions are not all 1s. A
// tensor.expand_shape will be generated in this case.
- auto innerDimsPos = packOp.getInnerDimsPos();
- int64_t srcRank = packOp.getSourceRank();
- auto destShape = packOp.getDestType().getShape();
- if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
- return destShape[index] != 1;
- })) {
+ if (llvm::any_of(packOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
packOp, "require the tiled outer dimensions of the result are all 1s");
}
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
packOp.getDimAndTileMapping();
Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+ int64_t srcRank = packOp.getSourceRank();
SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
loc, readType, input, readOffsets, readSizes, readStrides);
// 2. Transpose the tile to match the inner tile order.
-
SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
- inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+ inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
int64_t destRank = unpackOp.getDestRank();
ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
- if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
- return srcShape[index] != 1;
- })) {
+ if (llvm::any_of(unpackOp.getTiledOuterDims(),
+ [](int64_t dim) { return dim != 1; })) {
return rewriter.notifyMatchFailure(
unpackOp,
"require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getDestType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
ArrayRef<int64_t> innerDimsPos,
ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
return getStaticTilesImpl(*this);
}
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+ auto innerDimsPos = getInnerDimsPos();
+ auto destShape = getSourceType().getShape();
+ SmallVector<int64_t> res;
+
+ for (auto index : innerDimsPos)
+ res.push_back(destShape[index]);
+
+ return res;
+}
+
LogicalResult UnPackOp::verify() {
return commonVerifierPackAndUnPackOp(*this);
}
|
Remove empty space
Add comments, specialize getAllOuterDims
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, thanks!
Use Adam's suggestion for var names
Implements two helper hooks for PackOp and UnPackOP,
getAllOuterDims
and
getTiledOuterDims
, and adds them to RelayoutOp (that both PackOpan UnPackOp inherit from).
This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".