-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][memref] Transpose: allow affine map layouts in result, extend folder #76294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-memref @llvm/pr-subscribers-mlir Author: Felix Schneider (ubfx) ChangesCurrently, the %1 = memref.transpose %0 (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout. This patch makes the following changes:
Full diff: https://github.com/llvm/llvm-project/pull/76294.diff 4 Files Affected:
diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
index a332fe253ba645..8d7cb6e1cc92cc 100644
--- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
+++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}
-/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
+/// Build a strided memref type by applying `permutationMap` to `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
@@ -3157,18 +3157,14 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
assert(originalStrides.size() == static_cast<unsigned>(rank));
// Compute permuted sizes and strides.
- SmallVector<int64_t> sizes(rank, 0);
- SmallVector<int64_t> strides(rank, 1);
- for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
- unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
- sizes[en.index()] = originalSizes[position];
- strides[en.index()] = originalStrides[position];
- }
+ auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
+ auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
- return MemRefType::Builder(memRefType)
- .setShape(sizes)
- .setLayout(
- StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
+ auto stridedTy = MemRefType::Builder(memRefType)
+ .setShape(sizes)
+ .setLayout(StridedLayoutAttr::get(
+ memRefType.getContext(), offset, strides));
+ return canonicalizeStridedLayout(stridedTy);
}
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
@@ -3216,18 +3212,34 @@ LogicalResult TransposeOp::verify() {
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = llvm::cast<MemRefType>(getIn().getType());
- auto dstType = llvm::cast<MemRefType>(getType());
- auto transposedType = inferTransposeResultType(srcType, getPermutation());
- if (dstType != transposedType)
- return emitOpError("output type ")
- << dstType << " does not match transposed input type " << srcType
- << ", " << transposedType;
+ auto canonicalDstType =
+ canonicalizeStridedLayout(llvm::cast<MemRefType>(getType()));
+ auto inferedDstType = inferTransposeResultType(srcType, getPermutation());
+
+ if (canonicalDstType != inferedDstType)
+ return emitOpError("canonicalized output type ")
+ << canonicalDstType
+ << " does not match canonical transposed input type " << srcType
+ << ", " << inferedDstType;
return success();
}
OpFoldResult TransposeOp::fold(FoldAdaptor) {
+ // First check for identity permutation, we can fold it away if input and
+ // result types are identical already.
+ if (getPermutation().isIdentity() && getType() == getIn().getType())
+ return getIn();
if (succeeded(foldMemRefCast(*this)))
return getResult();
+ // Fold two consecutive memref.transpose Ops into one by composing their
+ // permutation maps.
+ if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
+ AffineMap composedPermutation =
+ otherTransposeOp.getPermutation().compose(getPermutation());
+ getInMutable().assign(otherTransposeOp.getIn());
+ setPermutation(composedPermutation);
+ return getResult();
+ }
return {};
}
diff --git a/mlir/test/Dialect/MemRef/canonicalize.mlir b/mlir/test/Dialect/MemRef/canonicalize.mlir
index d3406c630f6dd7..3471a1f912e7ea 100644
--- a/mlir/test/Dialect/MemRef/canonicalize.mlir
+++ b/mlir/test/Dialect/MemRef/canonicalize.mlir
@@ -988,3 +988,26 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}
+
+// -----
+
+// CHECK-LABEL: func @fold_double_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
+ // CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+ // CHECK: return %[[ONETRANSPOSE]]
+ return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
+}
+
+// -----
+
+// CHECK-LABEL: func @fold_identity_transpose(
+// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
+func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
+ %0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
+ %1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
+ // CHECK: return %[[arg0]]
+ return %1 : memref<1x2x3x4x5xf32>
+}
diff --git a/mlir/test/Dialect/MemRef/invalid.mlir b/mlir/test/Dialect/MemRef/invalid.mlir
index f9b870f77266e1..25e08eda8f4dac 100644
--- a/mlir/test/Dialect/MemRef/invalid.mlir
+++ b/mlir/test/Dialect/MemRef/invalid.mlir
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
// -----
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
- // expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
+ // expected-error @+1 {{canonicalized output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>', 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}
diff --git a/mlir/test/Dialect/MemRef/ops.mlir b/mlir/test/Dialect/MemRef/ops.mlir
index 7e2018ca58dc4a..a7730b71a0eacf 100644
--- a/mlir/test/Dialect/MemRef/ops.mlir
+++ b/mlir/test/Dialect/MemRef/ops.mlir
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
%dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
return %dst : memref<?xf32, 1>
}
+
+// CHECK-LABEL: func @memref_transpose_map
+func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
+ %dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+ return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
+}
\ No newline at end of file
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. Please be patient and wait for other reviewers to have any other comments.
fff9104
to
5067ba1
Compare
I changed the patch so that |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this was just overlooked when we transitioned to the first-class support for strided layouts and the intention would have been to have the map result type have the strided layout if the input had.
…folder (llvm#76294) Currently, the `memref.transpose` verifier forces the result type of the Op to have an explicit `StridedLayoutAttr` via the method `inferTransposeResultType`. This means that the example Op given in the documentation is actually invalid because it uses an `AffineMap` to specify the layout. It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout. This patch makes the following changes: 1. The verifier checks whether the canonicalized strided layout of the result Type is identitcal to the canonicalized infered result type layout. This way, it's only important that the two Types have the same strided layout, not necessarily the same representation of it. 2. The folder is extended to support folding away the trivial case of identity permutation and to fold one transposition into another by composing the permutation maps.
Currently, the
memref.transpose
verifier forces the result type of the Op to have an explicitStridedLayoutAttr
via the methodinferTransposeResultType
. This means that things like the example Op given in the documentation (https://mlir.llvm.org/docs/Dialects/MemRef/#memreftranspose-memreftransposeop) is actually invalid because it uses anAffineMap
to specify the layout:It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout.
This patch makes the following changes:
inferTransposeResultType()
returns aMemRefType
with canonicalized strided layout, i.e the strides are turned into a linearizing affine expression.