Skip to content

[mlir][tensor] Add new helper hooks for RelayoutOp #109642

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Sep 24, 2024

Conversation

banach-space
Copy link
Contributor

Implements two helper hooks for PackOp and UnPackOP, getAllOuterDims
and getTiledOuterDims, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".

Implements two helper hooks for PackOp and UnPackOP, `getAllOuterDims`
and `getTiledOuterDims`, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".
@banach-space banach-space changed the title [mlir][tensor] Add new helper hooks to RelayoutOp [mlir][tensor] Add new helper hooks for RelayoutOp Sep 23, 2024
@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir-tensor

Author: Andrzej Warzyński (banach-space)

Changes

Implements two helper hooks for PackOp and UnPackOP, getAllOuterDims
and getTiledOuterDims, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".


Full diff: https://github.com/llvm/llvm-project/pull/109642.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+18-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+10-13)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+22)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
 }
 
 //===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
 //===----------------------------------------------------------------------===//
 
 class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// a sentinel `kDynamic` is introduced at that position in
     /// the returned vector.
     SmallVector<int64_t> getStaticTiles();
+
+    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+    /// that this will include both tiled and non-tiled dimensions.
+    ArrayRef<int64_t> getAllOuterDims() {
+      ShapedType inputType = getSourceType();
+      int64_t inputRank = inputType.getRank();
+      return getDestType().getShape().take_front(inputRank);
+    }
+
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims  that
+    /// have been tiled.
+    SmallVector<int64_t> getTiledOuterDims();
   }];
 
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
 def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     AttrSizedOperandSegments]> {
   let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
+  assert(llvm::all_of(packOp.getAllOuterDims(),
+                      [](int64_t val) { return val == 1; }) &&
+         "some outer dims are != 1");
+
   Location loc = packOp.getLoc();
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
-  assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
-                      [](int64_t val) { return val == 1; }));
 
   SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  int64_t srcRank = packOp.getSourceRank();
-  auto destShape = packOp.getDestType().getShape();
-  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
-        return destShape[index] != 1;
-      })) {
+  if (llvm::any_of(packOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+  int64_t srcRank = packOp.getSourceRank();
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
-  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
-        return srcShape[index] != 1;
-      })) {
+  if (llvm::any_of(unpackOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         unpackOp,
         "require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getDestType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
                                  ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getSourceType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 LogicalResult UnPackOp::verify() {
   return commonVerifierPackAndUnPackOp(*this);
 }

@llvmbot
Copy link
Member

llvmbot commented Sep 23, 2024

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Implements two helper hooks for PackOp and UnPackOP, getAllOuterDims
and getTiledOuterDims, and adds them to RelayoutOp (that both PackOp
an UnPackOp inherit from).

This improves code re-use and also clarifies the meaning of "outer dims"
and "tiled outer dims".


Full diff: https://github.com/llvm/llvm-project/pull/109642.diff

3 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td (+18-1)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp (+10-13)
  • (modified) mlir/lib/Dialect/Tensor/IR/TensorOps.cpp (+22)
diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
index cafc3d91fd1e9d..9fee75c6a2ca3d 100644
--- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
+++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td
@@ -1814,7 +1814,7 @@ def Tensor_SplatOp : Tensor_Op<"splat", [
 }
 
 //===----------------------------------------------------------------------===//
-// PackOp
+// RelayoutOp
 //===----------------------------------------------------------------------===//
 
 class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
@@ -1851,11 +1851,28 @@ class Tensor_RelayoutOp<string mnemonic, list<Trait> traits = []> :
     /// a sentinel `kDynamic` is introduced at that position in
     /// the returned vector.
     SmallVector<int64_t> getStaticTiles();
+
+    /// Retrieve all outer dims for this Pack/UnPack Op, i.e. all the leading
+    /// dims excluding the trailing dims corresponding to `innerTiles`. Note
+    /// that this will include both tiled and non-tiled dimensions.
+    ArrayRef<int64_t> getAllOuterDims() {
+      ShapedType inputType = getSourceType();
+      int64_t inputRank = inputType.getRank();
+      return getDestType().getShape().take_front(inputRank);
+    }
+
+    /// Similar to `getAllOuterDims`, but only retrieve the outer dims  that
+    /// have been tiled.
+    SmallVector<int64_t> getTiledOuterDims();
   }];
 
   let hasVerifier = 1;
 }
 
+//===----------------------------------------------------------------------===//
+// PackOp
+//===----------------------------------------------------------------------===//
+
 def Tensor_PackOp : Tensor_RelayoutOp<"pack", [
     AttrSizedOperandSegments]> {
   let summary = "tensor pack operation";
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
index 77f0ea9d2236ea..e0dea8e78d55c1 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/Transforms.cpp
@@ -1030,11 +1030,13 @@ static Value getPackOpSourceOrPaddedSource(OpBuilder &builder,
     return input;
   }
 
+  assert(llvm::all_of(packOp.getAllOuterDims(),
+                      [](int64_t val) { return val == 1; }) &&
+         "some outer dims are != 1");
+
   Location loc = packOp.getLoc();
   ShapedType inputType = packOp.getSourceType();
   int64_t inputRank = inputType.getRank();
-  assert(llvm::all_of(packOp.getDestType().getShape().take_front(inputRank),
-                      [](int64_t val) { return val == 1; }));
 
   SmallVector<int64_t> paddedShape;
   DenseMap<int64_t, OpFoldResult> tileAndPosMapping =
@@ -1126,12 +1128,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
 
   // TODO: support the case that outer dimensions are not all 1s. A
   // tensor.expand_shape will be generated in this case.
-  auto innerDimsPos = packOp.getInnerDimsPos();
-  int64_t srcRank = packOp.getSourceRank();
-  auto destShape = packOp.getDestType().getShape();
-  if (llvm::any_of(innerDimsPos, [destShape](int64_t index) {
-        return destShape[index] != 1;
-      })) {
+  if (llvm::any_of(packOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         packOp, "require the tiled outer dimensions of the result are all 1s");
   }
@@ -1145,6 +1143,7 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       packOp.getDimAndTileMapping();
   Attribute zeroIdxAttr = rewriter.getIndexAttr(0);
   Attribute oneIdxAttr = rewriter.getIndexAttr(1);
+  int64_t srcRank = packOp.getSourceRank();
   SmallVector<OpFoldResult> readOffsets(srcRank, zeroIdxAttr);
   SmallVector<OpFoldResult> readStrides(srcRank, oneIdxAttr);
   SmallVector<OpFoldResult> readSizes;
@@ -1173,9 +1172,8 @@ LogicalResult GeneralizeOuterUnitDimsPackOpPattern::matchAndRewrite(
       loc, readType, input, readOffsets, readSizes, readStrides);
 
   // 2. Transpose the tile to match the inner tile order.
-
   SmallVector<int64_t> perm = getPackUnpackRankReducedPerm(
-      inputShape, innerDimsPos, packOp.getOuterDimsPerm());
+      inputShape, packOp.getInnerDimsPos(), packOp.getOuterDimsPerm());
 
   LLVM_DEBUG(DBGS() << "Pack permutation: " << packOp << "\n";
              llvm::interleaveComma(perm, DBGS() << "perm: "); DBGSNL(););
@@ -1208,9 +1206,8 @@ LogicalResult GeneralizeOuterUnitDimsUnPackOpPattern::matchAndRewrite(
   int64_t destRank = unpackOp.getDestRank();
   ArrayRef<int64_t> srcShape = unpackOp.getSourceType().getShape();
   ArrayRef<int64_t> innerDimsPos = unpackOp.getInnerDimsPos();
-  if (llvm::any_of(innerDimsPos, [srcShape](int64_t index) {
-        return srcShape[index] != 1;
-      })) {
+  if (llvm::any_of(unpackOp.getTiledOuterDims(),
+                   [](int64_t dim) { return dim != 1; })) {
     return rewriter.notifyMatchFailure(
         unpackOp,
         "require the tiled outer dimensions of the result are all 1s");
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 47f540e092e990..bc7deb1614d18d 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -3987,6 +3987,17 @@ SmallVector<int64_t> PackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> PackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getDestType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 bool PackOp::requirePaddingValue(ArrayRef<int64_t> inputShape,
                                  ArrayRef<int64_t> innerDimsPos,
                                  ArrayRef<int64_t> outputShape,
@@ -4411,6 +4422,17 @@ SmallVector<int64_t> UnPackOp::getStaticTiles() {
   return getStaticTilesImpl(*this);
 }
 
+SmallVector<int64_t> UnPackOp::getTiledOuterDims() {
+  auto innerDimsPos = getInnerDimsPos();
+  auto destShape = getSourceType().getShape();
+  SmallVector<int64_t> res;
+
+  for (auto index : innerDimsPos)
+    res.push_back(destShape[index]);
+
+  return res;
+}
+
 LogicalResult UnPackOp::verify() {
   return commonVerifierPackAndUnPackOp(*this);
 }

Copy link
Contributor

@adam-smnk adam-smnk left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks!

@banach-space banach-space merged commit c1826ae into llvm:main Sep 24, 2024
5 of 7 checks passed
@banach-space banach-space deleted the andrzej/pack_op_add_helpers branch September 24, 2024 12:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants