Skip to content

[mlir][memref] Transpose: allow affine map layouts in result, extend folder #76294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 11, 2024

Conversation

ubfx
Copy link
Member

@ubfx ubfx commented Dec 23, 2023

Currently, the memref.transpose verifier forces the result type of the Op to have an explicit StridedLayoutAttr via the method inferTransposeResultType. This means that things like the example Op given in the documentation (https://mlir.llvm.org/docs/Dialects/MemRef/#memreftranspose-memreftransposeop) is actually invalid because it uses an AffineMap to specify the layout:

%1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>

It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout.

This patch makes the following changes:

  1. inferTransposeResultType() returns a MemRefType with canonicalized strided layout, i.e the strides are turned into a linearizing affine expression.
  2. The verifier checks whether the canonicalized strided layout of the result Type is identitcal to the infered (also canonical) result type layout. This way, it's only important that the two Types have the same strided layout, not necessarily the same representation of it.
  3. The folder is extended to support folding away the trivial case of identity permutation and to fold one transposition into another by composing the permutation maps.

@llvmbot
Copy link
Member

llvmbot commented Dec 23, 2023

@llvm/pr-subscribers-mlir-memref

@llvm/pr-subscribers-mlir

Author: Felix Schneider (ubfx)

Changes

Currently, the memref.transpose verifier forces the result type of the Op to have an explicit StridedLayoutAttr via the method inferTransposeResultType. This means that things like the example Op given in the documentation (https://mlir.llvm.org/docs/Dialects/MemRef/#memreftranspose-memreftransposeop) is actually invalid because it uses an AffineMap to specify the layout:

%1 = memref.transpose %0 (i, j) -&gt; (j, i) : memref&lt;?x?xf32&gt; to memref&lt;?x?xf32, affine_map&lt;(d0, d1)[s0] -&gt; (d1 * s0 + d0)&gt;&gt;

It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout.

This patch makes the following changes:

  1. inferTransposeResultType() returns a MemRefType with canonicalized strided layout, i.e the strides are turned into a linearizing affine expression.
  2. The verifier checks whether the canonicalized strided layout of the result Type is identitcal to the infered (also canonical) result type layout. This way, it's only important that the two Types have the same strided layout, not necessarily the same representation of it.
  3. The folder is extended to support folding away the trivial case of identity permutation and to fold one transposition into another by composing the permutation maps.

Full diff: https://github.com/llvm/llvm-project/pull/76294.diff

4 Files Affected:

  • (modified) mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp (+30-18)
  • (modified) mlir/test/Dialect/MemRef/canonicalize.mlir (+23)
  • (modified) mlir/test/Dialect/MemRef/invalid.mlir (+1-1)
  • (modified) mlir/test/Dialect/MemRef/ops.mlir (+6)
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..8d7cb6e1cc92cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
   setNameFn(getResult(), "transpose");
 }
 
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
 static MemRefType inferTransposeResultType(MemRefType memRefType,
                                            AffineMap permutationMap) {
   auto rank = memRefType.getRank();
@@ -3157,18 +3157,14 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
   assert(originalStrides.size() == static_cast<unsigned>(rank));
 
   // Compute permuted sizes and strides.
-  SmallVector<int64_t> sizes(rank, 0);
-  SmallVector<int64_t> strides(rank, 1);
-  for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
-    unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
-    sizes[en.index()] = originalSizes[position];
-    strides[en.index()] = originalStrides[position];
-  }
+  auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+  auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
 
-  return MemRefType::Builder(memRefType)
-      .setShape(sizes)
-      .setLayout(
-          StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
+  auto stridedTy = MemRefType::Builder(memRefType)
+                       .setShape(sizes)
+                       .setLayout(StridedLayoutAttr::get(
+                           memRefType.getContext(), offset, strides));
+  return canonicalizeStridedLayout(stridedTy);
 }
 
 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3216,18 +3212,34 @@ LogicalResult TransposeOp::verify() {
     return emitOpError("expected a permutation map of same rank as the input");
 
   auto srcType = llvm::cast<MemRefType>(getIn().getType());
-  auto dstType = llvm::cast<MemRefType>(getType());
-  auto transposedType = inferTransposeResultType(srcType, getPermutation());
-  if (dstType != transposedType)
-    return emitOpError("output type ")
-           << dstType << " does not match transposed input type " << srcType
-           << ", " << transposedType;
+  auto canonicalDstType =
+      canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
+  auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
+
+  if (canonicalDstType != inferedDstType)
+    return emitOpError("canonicalized output type ")
+           << canonicalDstType
+           << " does not match canonical transposed input type " << srcType
+           << ", " << inferedDstType;
   return success();
 }
 
 OpFoldResult TransposeOp::fold(FoldAdaptor) {
+  // First check for identity permutation, we can fold it away if input and
+  // result types are identical already.
+  if (getPermutation().isIdentity() && getType() == getIn().getType())
+    return getIn();
   if (succeeded(foldMemRefCast(*this)))
     return getResult();
+  // Fold two consecutive memref.transpose Ops into one by composing their
+  // permutation maps.
+  if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+    AffineMap composedPermutation =
+        otherTransposeOp.getPermutation().compose(getPermutation());
+    getInMutable().assign(otherTransposeOp.getIn());
+    setPermutation(composedPermutation);
+    return getResult();
+  }
   return {};
 }
 
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..3471a1f912e7ea 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,26 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
   // CHECK: return %[[cast]]
   return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
 }
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+  // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+  // CHECK: return %[[ONETRANSPOSE]]
+  return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+//  CHECK-SAME:     %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+  %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+  %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+  // CHECK: return %[[arg0]]
+  return %1 : memref<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..25e08eda8f4dac 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
 // -----
 
 func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
-  // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+  // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
   memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
 }
 
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..a7730b71a0eacf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
   %dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
   return %dst : memref<?xf32, 1>
 }
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+  %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+  return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}
\ No newline at end of file

@ubfx ubfx requested a review from lipracer January 3, 2024 11:54
Copy link
Member

@lipracer lipracer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Please be patient and wait for other reviewers to have any other comments.

@ubfx ubfx force-pushed the memref-transpose-fold branch from fff9104 to 5067ba1 Compare January 10, 2024 19:21
@ubfx
Copy link
Member Author

ubfx commented Jan 11, 2024

I changed the patch so that inferTransposedResultType() returns a MemRefType with StridedLayoutAttr, which is the same as it was before this patch. I moved the change to the verifier, where now, we just make sure that the canonicalized Version of the Op's result Ttpe equals the canonicalized version of the inferred result type. This way, all existing code that creates memref::TransposeOp with infered result Type should behave the same after the patch, but we still allow the Op to be used (and folded) with Affine Map layouts.

Copy link
Member

@ftynse ftynse left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this was just overlooked when we transitioned to the first-class support for strided layouts and the intention would have been to have the map result type have the strided layout if the input had.

@ubfx ubfx merged commit 4619e21 into llvm:main Jan 11, 2024
@ubfx ubfx deleted the memref-transpose-fold branch January 11, 2024 18:54
justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…folder (llvm#76294)

Currently, the `memref.transpose` verifier forces the result type of the
Op to have an explicit `StridedLayoutAttr` via the method
`inferTransposeResultType`. This means that the example Op
given in the documentation is actually invalid because it uses an `AffineMap`
to specify the layout.
It also means that we can't "un-transpose" a transposed memref back to
the implicit layout form, because the verifier will always enforce the
explicit strided layout.

This patch makes the following changes:

1. The verifier checks whether the canonicalized strided layout of the
result Type is identitcal to the canonicalized infered result type
layout. This way, it's only important that the two Types have the same
strided layout, not necessarily the same representation of it.
2. The folder is extended to support folding away the trivial case of
identity permutation and to fold one transposition into another by
composing the permutation maps.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants