Skip to content

[mlir][memref] Transpose: allow affine map layouts in result, extend folder #76294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 25 additions & 14 deletions mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3148,7 +3148,7 @@ void TransposeOp::getAsmResultNames(
setNameFn(getResult(), "transpose");
}

/// Build a strided memref type by applying `permutationMap` tp `memRefType`.
/// Build a strided memref type by applying `permutationMap` to `memRefType`.
static MemRefType inferTransposeResultType(MemRefType memRefType,
AffineMap permutationMap) {
auto rank = memRefType.getRank();
Expand All @@ -3157,13 +3157,8 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
assert(originalStrides.size() == static_cast<unsigned>(rank));

// Compute permuted sizes and strides.
SmallVector<int64_t> sizes(rank, 0);
SmallVector<int64_t> strides(rank, 1);
for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
unsigned position = cast<AffineDimExpr>(en.value()).getPosition();
sizes[en.index()] = originalSizes[position];
strides[en.index()] = originalStrides[position];
}
auto sizes = applyPermutationMap<int64_t>(permutationMap, originalSizes);
auto strides = applyPermutationMap<int64_t>(permutationMap, originalStrides);

return MemRefType::Builder(memRefType)
.setShape(sizes)
Expand Down Expand Up @@ -3216,18 +3211,34 @@ LogicalResult TransposeOp::verify() {
return emitOpError("expected a permutation map of same rank as the input");

auto srcType = llvm::cast<MemRefType>(getIn().getType());
auto dstType = llvm::cast<MemRefType>(getType());
auto transposedType = inferTransposeResultType(srcType, getPermutation());
if (dstType != transposedType)
return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
<< ", " << transposedType;
auto resultType = llvm::cast<MemRefType>(getType());
auto canonicalResultType = canonicalizeStridedLayout(
inferTransposeResultType(srcType, getPermutation()));

if (canonicalizeStridedLayout(resultType) != canonicalResultType)
return emitOpError("result type ")
<< resultType
<< " is not equivalent to the canonical transposed input type "
<< canonicalResultType;
return success();
}

OpFoldResult TransposeOp::fold(FoldAdaptor) {
// First check for identity permutation, we can fold it away if input and
// result types are identical already.
if (getPermutation().isIdentity() && getType() == getIn().getType())
return getIn();
if (succeeded(foldMemRefCast(*this)))
return getResult();
// Fold two consecutive memref.transpose Ops into one by composing their
// permutation maps.
if (auto otherTransposeOp = getIn().getDefiningOp<memref::TransposeOp>()) {
AffineMap composedPermutation =
getPermutation().compose(otherTransposeOp.getPermutation());
getInMutable().assign(otherTransposeOp.getIn());
setPermutation(composedPermutation);
return getResult();
}
return {};
}

Expand Down
35 changes: 35 additions & 0 deletions mlir/test/Dialect/MemRef/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -988,3 +988,38 @@ func.func @subview_rank_reduction(%arg0: memref<1x384x384xf32>, %idx: index)
// CHECK: return %[[cast]]
return %0 : memref<?x?xf32, strided<[384, 1], offset: ?>>
}

// -----

// CHECK-LABEL: func @fold_double_transpose(
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
func.func @fold_double_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
// CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
%1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
// CHECK: return %[[ONETRANSPOSE]]
return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
}

// -----

// CHECK-LABEL: func @fold_double_transpose2(
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
func.func @fold_double_transpose2(%arg0: memref<1x2x3x4x5xf32>) -> memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>> {
// CHECK: %[[ONETRANSPOSE:.+]] = memref.transpose %[[arg0]] (d0, d1, d2, d3, d4) -> (d4, d2, d1, d3, d0)
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d0, d1, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>>
%1 = memref.transpose %0 (d0, d1, d4, d3, d2) -> (d4, d2, d1, d3, d0) : memref<1x2x5x4x3xf32, strided<[120, 60, 1, 5, 20]>> to memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
// CHECK: return %[[ONETRANSPOSE]]
return %1 : memref<5x3x2x4x1xf32, strided<[1, 20, 60, 5, 120]>>
}

// -----

// CHECK-LABEL: func @fold_identity_transpose(
// CHECK-SAME: %[[arg0:.*]]: memref<1x2x3x4x5xf32>
func.func @fold_identity_transpose(%arg0: memref<1x2x3x4x5xf32>) -> memref<1x2x3x4x5xf32> {
%0 = memref.transpose %arg0 (d0, d1, d2, d3, d4) -> (d1, d0, d4, d3, d2) : memref<1x2x3x4x5xf32> to memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>>
%1 = memref.transpose %0 (d1, d0, d4, d3, d2) -> (d0, d1, d2, d3, d4) : memref<2x1x5x4x3xf32, strided<[60, 120, 1, 5, 20]>> to memref<1x2x3x4x5xf32>
// CHECK: return %[[arg0]]
return %1 : memref<1x2x3x4x5xf32>
}
2 changes: 1 addition & 1 deletion mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ func.func @transpose_bad_rank(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(o
// -----

func.func @transpose_wrong_type(%v : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>) {
// expected-error @+1 {{output type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' does not match transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>'}}
// expected-error @+1 {{result type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1)>>' is not equivalent to the canonical transposed input type 'memref<?x?xf32, affine_map<(d0, d1)[s0, s1] -> (d0 + s0 + d1 * s1)>>'}}
memref.transpose %v (i, j) -> (j, i) : memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>> to memref<?x?xf32, affine_map<(i, j)[off, M]->(off + M * i + j)>>
}

Expand Down
6 changes: 6 additions & 0 deletions mlir/test/Dialect/MemRef/ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -378,3 +378,9 @@ func.func @memref_memory_space_cast(%src : memref<?xf32>) -> memref<?xf32, 1> {
%dst = memref.memory_space_cast %src : memref<?xf32> to memref<?xf32, 1>
return %dst : memref<?xf32, 1>
}

// CHECK-LABEL: func @memref_transpose_map
func.func @memref_transpose_map(%src : memref<?x?xf32>) -> memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>> {
%dst = memref.transpose %src (i, j) -> (j, i) : memref<?x?xf32> to memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
return %dst : memref<?x?xf32, affine_map<(d0, d1)[s0] -> (d1 * s0 + d0)>>
}