Skip to content

Commit 4619e21

Browse files
authored
[mlir][memref] Transpose: allow affine map layouts in result, extend folder (llvm#76294)
Currently, the `memref.transpose` verifier forces the result type of the Op to have an explicit `StridedLayoutAttr` via the method `inferTransposeResultType`. This means that the example Op given in the documentation is actually invalid because it uses an `AffineMap` to specify the layout. It also means that we can't "un-transpose" a transposed memref back to the implicit layout form, because the verifier will always enforce the explicit strided layout. This patch makes the following changes: 1. The verifier checks whether the canonicalized strided layout of the result Type is identitcal to the canonicalized infered result type layout. This way, it's only important that the two Types have the same strided layout, not necessarily the same representation of it. 2. The folder is extended to support folding away the trivial case of identity permutation and to fold one transposition into another by composing the permutation maps.
1 parent 061b777 commit 4619e21

File tree

4 files changed

+67
-15
lines changed

4 files changed

+67
-15
lines changed

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
31483148
setNameFn(getResult(), "transpose");
31493149
}
31503150

3151-
/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
3151+
/// Build a strided memref type by applying `permutationMap` to `memRefType`.
31523152
static MemRefType inferTransposeResultType(MemRefType memRefType,
31533153
AffineMap permutationMap) {
31543154
auto rank = memRefType.getRank();
@@ -3157,13 +3157,8 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
31573157
assert(originalStrides.size() == static_cast<unsigned>(rank));
31583158

31593159
// Compute permuted sizes and strides.
3160-
SmallVector<int64_t> sizes(rank, 0);
3161-
SmallVector<int64_t> strides(rank, 1);
3162-
for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
3163-
unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
3164-
sizes[en.index()] = originalSizes[position];
3165-
strides[en.index()] = originalStrides[position];
3166-
}
3160+
auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
3161+
auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);
31673162

31683163
return MemRefType::Builder(memRefType)
31693164
.setShape(sizes)
@@ -3216,18 +3211,34 @@ LogicalResult TransposeOp::verify() {
32163211
return emitOpError("expected a permutation map of same rank as the input");
32173212

32183213
auto srcType = llvm::cast<MemRefType>(getIn().getType());
3219-
auto dstType = llvm::cast<MemRefType>(getType());
3220-
auto transposedType = inferTransposeResultType(srcType, getPermutation());
3221-
if (dstType != transposedType)
3222-
return emitOpError("output type ")
3223-
<< dstType << " does not match transposed input type " << srcType
3224-
<< ", " << transposedType;
3214+
auto resultType = llvm::cast<MemRefType>(getType());
3215+
auto canonicalResultType = canonicalizeStridedLayout(
3216+
inferTransposeResultType(srcType, getPermutation()));
3217+
3218+
if (canonicalizeStridedLayout(resultType) != canonicalResultType)
3219+
return emitOpError("result type ")
3220+
<< resultType
3221+
<< " is not equivalent to the canonical transposed input type "
3222+
<< canonicalResultType;
32253223
return success();
32263224
}
32273225

32283226
OpFoldResult TransposeOp::fold(FoldAdaptor) {
3227+
// First check for identity permutation, we can fold it away if input and
3228+
// result types are identical already.
3229+
if (getPermutation().isIdentity() && getType() == getIn().getType())
3230+
return getIn();
32293231
if (succeeded(foldMemRefCast(*this)))
32303232
return getResult();
3233+
// Fold two consecutive memref.transpose Ops into one by composing their
3234+
// permutation maps.
3235+
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
3236+
AffineMap composedPermutation =
3237+
getPermutation().compose(otherTransposeOp.getPermutation());
3238+
getInMutable().assign(otherTransposeOp.getIn());
3239+
setPermutation(composedPermutation);
3240+
return getResult();
3241+
}
32313242
return {};
32323243
}
32333244

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,3 +988,38 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
988988
// CHECK: return %[[cast]]
989989
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
990990
}
991+
992+
// -----
993+
994+
// CHECK-LABEL: func @fold_double_transpose(
995+
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
996+
func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
997+
// CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
998+
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
999+
%1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1000+
// CHECK: return %[[ONETRANSPOSE]]
1001+
return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1002+
}
1003+
1004+
// -----
1005+
1006+
// CHECK-LABEL: func @fold_double_transpose2(
1007+
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
1008+
func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
1009+
// CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
1010+
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
1011+
%1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1012+
// CHECK: return %[[ONETRANSPOSE]]
1013+
return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
1014+
}
1015+
1016+
// -----
1017+
1018+
// CHECK-LABEL: func @fold_identity_transpose(
1019+
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
1020+
func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
1021+
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
1022+
%1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
1023+
// CHECK: return %[[arg0]]
1024+
return %1 : memref<1x2x3x4x5xf32>
1025+
}

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
142142
// -----
143143

144144
func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
145-
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
145+
// expected-error @+1 {{result type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' is not equivalent to the canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
146146
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
147147
}
148148

mlir/test/Dialect/MemRef/ops.mlir

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
378378
%dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
379379
return %dst : memref<?xf32, 1>
380380
}
381+
382+
// CHECK-LABEL: func @memref_transpose_map
383+
func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
384+
%dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
385+
return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
386+
}

0 commit comments

Comments
 (0)