Skip to content

Commit 46b90a7

Browse files
committed
[mlir] make remaining memref dialect ops produce strided layouts
The three following ops in the memref dialect: transpose, expand_shape, collapse_shape, have been originally designed to operate on memrefs with strided layouts but had to go through the affine map representation as the type did not support anything else. Make these ops produce memref values with StridedLayoutAttr instead now that it is available. Depends On D133938 Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D133947
1 parent 2791162 commit 46b90a7

File tree

13 files changed

+98
-143
lines changed

13 files changed

+98
-143
lines changed

mlir/include/mlir/IR/BuiltinTypes.h

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -466,11 +466,6 @@ bool isStrided(MemRefType t);
466466
/// Return null if the layout is not compatible with a strided layout.
467467
AffineMap getStridedLinearLayoutMap(MemRefType t);
468468

469-
/// Helper determining if a memref is static-shape and contiguous-row-major
470-
/// layout, while still allowing for an arbitrary offset (any static or
471-
/// dynamic value).
472-
bool isStaticShapeAndContiguousRowMajor(MemRefType memrefType);
473-
474469
} // namespace mlir
475470

476471
#endif // MLIR_IR_BUILTINTYPES_H

mlir/lib/Conversion/MemRefToLLVM/MemRefToLLVM.cpp

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -961,7 +961,25 @@ struct MemRefCopyOpLowering : public ConvertOpToLLVMPattern<memref::CopyOp> {
961961
auto srcType = op.getSource().getType().cast<BaseMemRefType>();
962962
auto targetType = op.getTarget().getType().cast<BaseMemRefType>();
963963

964-
auto isContiguousMemrefType = [](BaseMemRefType type) {
964+
auto isStaticShapeAndContiguousRowMajor = [](MemRefType type) {
965+
if (!type.hasStaticShape())
966+
return false;
967+
968+
SmallVector<int64_t> strides;
969+
int64_t offset;
970+
if (failed(getStridesAndOffset(type, strides, offset)))
971+
return false;
972+
973+
int64_t runningStride = 1;
974+
for (unsigned i = strides.size(); i > 0; --i) {
975+
if (strides[i - 1] != runningStride)
976+
return false;
977+
runningStride *= type.getDimSize(i - 1);
978+
}
979+
return true;
980+
};
981+
982+
auto isContiguousMemrefType = [&](BaseMemRefType type) {
965983
auto memrefType = type.dyn_cast<mlir::MemRefType>();
966984
// We can use memcpy for memrefs if they have an identity layout or are
967985
// contiguous with an arbitrary offset. Ignore empty memrefs, which is a

mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp

Lines changed: 34 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1761,7 +1761,7 @@ SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
17611761

17621762
/// Compute the layout map after expanding a given source MemRef type with the
17631763
/// specified reassociation indices.
1764-
static FailureOr<AffineMap>
1764+
static FailureOr<StridedLayoutAttr>
17651765
computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
17661766
ArrayRef<ReassociationIndices> reassociation) {
17671767
int64_t srcOffset;
@@ -1798,8 +1798,7 @@ computeExpandedLayoutMap(MemRefType srcType, ArrayRef<int64_t> resultShape,
17981798
}
17991799
auto resultStrides = llvm::to_vector<8>(llvm::reverse(reverseResultStrides));
18001800
resultStrides.resize(resultShape.size(), 1);
1801-
return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1802-
srcType.getContext());
1801+
return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
18031802
}
18041803

18051804
static FailureOr<MemRefType>
@@ -1814,14 +1813,12 @@ computeExpandedType(MemRefType srcType, ArrayRef<int64_t> resultShape,
18141813
}
18151814

18161815
// Source may not be contiguous. Compute the layout map.
1817-
FailureOr<AffineMap> computedLayout =
1816+
FailureOr<StridedLayoutAttr> computedLayout =
18181817
computeExpandedLayoutMap(srcType, resultShape, reassociation);
18191818
if (failed(computedLayout))
18201819
return failure();
1821-
auto computedType =
1822-
MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1823-
srcType.getMemorySpaceAsInt());
1824-
return canonicalizeStridedLayout(computedType);
1820+
return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1821+
srcType.getMemorySpace());
18251822
}
18261823

18271824
void ExpandShapeOp::build(OpBuilder &builder, OperationState &result,
@@ -1855,10 +1852,9 @@ LogicalResult ExpandShapeOp::verify() {
18551852
return emitOpError("invalid source layout map");
18561853

18571854
// Check actual result type.
1858-
auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
1859-
if (*expectedResultType != canonicalizedResultType)
1855+
if (*expectedResultType != resultType)
18601856
return emitOpError("expected expanded type to be ")
1861-
<< *expectedResultType << " but found " << canonicalizedResultType;
1857+
<< *expectedResultType << " but found " << resultType;
18621858

18631859
return success();
18641860
}
@@ -1877,7 +1873,7 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
18771873
/// not possible to check this by inspecting a MemRefType in the general case.
18781874
/// If non-contiguity cannot be checked statically, the collapse is assumed to
18791875
/// be valid (and thus accepted by this function) unless `strict = true`.
1880-
static FailureOr<AffineMap>
1876+
static FailureOr<StridedLayoutAttr>
18811877
computeCollapsedLayoutMap(MemRefType srcType,
18821878
ArrayRef<ReassociationIndices> reassociation,
18831879
bool strict = false) {
@@ -1940,13 +1936,12 @@ computeCollapsedLayoutMap(MemRefType srcType,
19401936
return failure();
19411937
}
19421938
}
1943-
return makeStridedLinearLayoutMap(resultStrides, srcOffset,
1944-
srcType.getContext());
1939+
return StridedLayoutAttr::get(srcType.getContext(), srcOffset, resultStrides);
19451940
}
19461941

19471942
bool CollapseShapeOp::isGuaranteedCollapsible(
19481943
MemRefType srcType, ArrayRef<ReassociationIndices> reassociation) {
1949-
// MemRefs with standard layout are always collapsible.
1944+
// MemRefs with identity layout are always collapsible.
19501945
if (srcType.getLayout().isIdentity())
19511946
return true;
19521947

@@ -1978,14 +1973,12 @@ computeCollapsedType(MemRefType srcType,
19781973
// Source may not be fully contiguous. Compute the layout map.
19791974
// Note: Dimensions that are collapsed into a single dim are assumed to be
19801975
// contiguous.
1981-
FailureOr<AffineMap> computedLayout =
1976+
FailureOr<StridedLayoutAttr> computedLayout =
19821977
computeCollapsedLayoutMap(srcType, reassociation);
19831978
assert(succeeded(computedLayout) &&
19841979
"invalid source layout map or collapsing non-contiguous dims");
1985-
auto computedType =
1986-
MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1987-
srcType.getMemorySpaceAsInt());
1988-
return canonicalizeStridedLayout(computedType);
1980+
return MemRefType::get(resultShape, srcType.getElementType(), *computedLayout,
1981+
srcType.getMemorySpace());
19891982
}
19901983

19911984
void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
@@ -2021,21 +2014,19 @@ LogicalResult CollapseShapeOp::verify() {
20212014
// Source may not be fully contiguous. Compute the layout map.
20222015
// Note: Dimensions that are collapsed into a single dim are assumed to be
20232016
// contiguous.
2024-
FailureOr<AffineMap> computedLayout =
2017+
FailureOr<StridedLayoutAttr> computedLayout =
20252018
computeCollapsedLayoutMap(srcType, getReassociationIndices());
20262019
if (failed(computedLayout))
20272020
return emitOpError(
20282021
"invalid source layout map or collapsing non-contiguous dims");
2029-
auto computedType =
2022+
expectedResultType =
20302023
MemRefType::get(resultType.getShape(), srcType.getElementType(),
2031-
*computedLayout, srcType.getMemorySpaceAsInt());
2032-
expectedResultType = canonicalizeStridedLayout(computedType);
2024+
*computedLayout, srcType.getMemorySpace());
20332025
}
20342026

2035-
auto canonicalizedResultType = canonicalizeStridedLayout(resultType);
2036-
if (expectedResultType != canonicalizedResultType)
2027+
if (expectedResultType != resultType)
20372028
return emitOpError("expected collapsed type to be ")
2038-
<< expectedResultType << " but found " << canonicalizedResultType;
2029+
<< expectedResultType << " but found " << resultType;
20392030

20402031
return success();
20412032
}
@@ -2709,24 +2700,26 @@ static MemRefType inferTransposeResultType(MemRefType memRefType,
27092700
AffineMap permutationMap) {
27102701
auto rank = memRefType.getRank();
27112702
auto originalSizes = memRefType.getShape();
2712-
// Compute permuted sizes.
2713-
SmallVector<int64_t, 4> sizes(rank, 0);
2714-
for (const auto &en : llvm::enumerate(permutationMap.getResults()))
2715-
sizes[en.index()] =
2716-
originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2717-
2718-
// Compute permuted strides.
27192703
int64_t offset;
2720-
SmallVector<int64_t, 4> strides;
2721-
auto res = getStridesAndOffset(memRefType, strides, offset);
2722-
assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2704+
SmallVector<int64_t, 4> originalStrides;
2705+
auto res = getStridesAndOffset(memRefType, originalStrides, offset);
2706+
assert(succeeded(res) &&
2707+
originalStrides.size() == static_cast<unsigned>(rank));
27232708
(void)res;
2724-
auto map =
2725-
makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2726-
map = permutationMap ? map.compose(permutationMap) : map;
2709+
2710+
// Compute permuted sizes and strides.
2711+
SmallVector<int64_t> sizes(rank, 0);
2712+
SmallVector<int64_t> strides(rank, 1);
2713+
for (const auto &en : llvm::enumerate(permutationMap.getResults())) {
2714+
unsigned position = en.value().cast<AffineDimExpr>().getPosition();
2715+
sizes[en.index()] = originalSizes[position];
2716+
strides[en.index()] = originalStrides[position];
2717+
}
2718+
27272719
return MemRefType::Builder(memRefType)
27282720
.setShape(sizes)
2729-
.setLayout(AffineMapAttr::get(map));
2721+
.setLayout(
2722+
StridedLayoutAttr::get(memRefType.getContext(), offset, strides));
27302723
}
27312724

27322725
void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,

mlir/lib/Dialect/Tensor/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,11 +136,10 @@ struct CollapseShapeOpInterface
136136
int64_t offset;
137137
if (failed(getStridesAndOffset(bufferType, strides, offset)))
138138
return failure();
139-
AffineMap resultLayout =
140-
makeStridedLinearLayoutMap({}, offset, op->getContext());
141-
resultType =
142-
MemRefType::get({}, tensorResultType.getElementType(), resultLayout,
143-
bufferType.getMemorySpaceAsInt());
139+
resultType = MemRefType::get(
140+
{}, tensorResultType.getElementType(),
141+
StridedLayoutAttr::get(op->getContext(), offset, {}),
142+
bufferType.getMemorySpace());
144143
}
145144

146145
replaceOpWithNewBufferizedOp<memref::CollapseShapeOp>(

mlir/lib/IR/AsmPrinter.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2250,7 +2250,8 @@ void AsmPrinter::Impl::printType(Type type) {
22502250
os << 'x';
22512251
}
22522252
printType(memrefTy.getElementType());
2253-
if (!memrefTy.getLayout().isIdentity()) {
2253+
MemRefLayoutAttrInterface layout = memrefTy.getLayout();
2254+
if (!layout.isa<AffineMapAttr>() || !layout.isIdentity()) {
22542255
os << ", ";
22552256
printAttribute(memrefTy.getLayout(), AttrTypeElision::May);
22562257
}

mlir/lib/IR/BuiltinTypes.cpp

Lines changed: 0 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1027,40 +1027,3 @@ AffineMap mlir::getStridedLinearLayoutMap(MemRefType t) {
10271027
return AffineMap();
10281028
return makeStridedLinearLayoutMap(strides, offset, t.getContext());
10291029
}
1030-
1031-
/// Return the AffineExpr representation of the offset, assuming `memRefType`
1032-
/// is a strided memref.
1033-
static AffineExpr getOffsetExpr(MemRefType memrefType) {
1034-
SmallVector<AffineExpr> strides;
1035-
AffineExpr offset;
1036-
if (failed(getStridesAndOffset(memrefType, strides, offset)))
1037-
assert(false && "expected strided memref");
1038-
return offset;
1039-
}
1040-
1041-
/// Helper to construct a contiguous MemRefType of `shape`, `elementType` and
1042-
/// `offset` AffineExpr.
1043-
static MemRefType makeContiguousRowMajorMemRefType(MLIRContext *context,
1044-
ArrayRef<int64_t> shape,
1045-
Type elementType,
1046-
AffineExpr offset) {
1047-
AffineExpr canonical = makeCanonicalStridedLayoutExpr(shape, context);
1048-
AffineExpr contiguousRowMajor = canonical + offset;
1049-
AffineMap contiguousRowMajorMap =
1050-
AffineMap::inferFromExprList({contiguousRowMajor})[0];
1051-
return MemRefType::get(shape, elementType, contiguousRowMajorMap);
1052-
}
1053-
1054-
/// Helper determining if a memref is static-shape and contiguous-row-major
1055-
/// layout, while still allowing for an arbitrary offset (any static or
1056-
/// dynamic value).
1057-
bool mlir::isStaticShapeAndContiguousRowMajor(MemRefType memrefType) {
1058-
if (!memrefType.hasStaticShape())
1059-
return false;
1060-
AffineExpr offset = getOffsetExpr(memrefType);
1061-
MemRefType contiguousRowMajorMemRefType = makeContiguousRowMajorMemRefType(
1062-
memrefType.getContext(), memrefType.getShape(),
1063-
memrefType.getElementType(), offset);
1064-
return canonicalizeStridedLayout(memrefType) ==
1065-
canonicalizeStridedLayout(contiguousRowMajorMemRefType);
1066-
}

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -609,7 +609,7 @@ func.func @address_space(%arg0 : memref<32xf32, affine_map<(d0) -> (d0)>, 7>) {
609609
// CHECK: llvm.extractvalue {{.*}}[3, 2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
610610
// CHECK: llvm.insertvalue {{.*}}[3, 1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
611611
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
612-
%0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d0 * s2 + d1)>>
612+
%0 = memref.transpose %arg0 (i, j, k) -> (k, i, j) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
613613
return
614614
}
615615

@@ -725,12 +725,12 @@ func.func @collapse_shape_static(%arg0: memref<1x3x4x1x5xf32>) -> memref<3x4x5xf
725725
// -----
726726

727727
func.func @collapse_shape_dynamic_with_non_identity_layout(
728-
%arg0 : memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>>) ->
729-
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> {
728+
%arg0 : memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>>) ->
729+
memref<4x?xf32, strided<[?, ?], offset: ?>> {
730730
%0 = memref.collapse_shape %arg0 [[0], [1, 2]]:
731-
memref<4x?x?xf32, affine_map<(d0, d1, d2)[s0, s1] -> (d0 * s1 + s0 + d1 * 4 + d2)>> into
732-
memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
733-
return %0 : memref<4x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>
731+
memref<4x?x?xf32, strided<[?, 4, 1], offset: ?>> into
732+
memref<4x?xf32, strided<[?, ?], offset: ?>>
733+
return %0 : memref<4x?xf32, strided<[?, ?], offset: ?>>
734734
}
735735
// CHECK-LABEL: func @collapse_shape_dynamic_with_non_identity_layout(
736736
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<2 x i64>, array<2 x i64>)>
@@ -898,12 +898,12 @@ func.func @expand_shape_dynamic(%arg0 : memref<1x?xf32>) -> memref<1x2x?xf32> {
898898
// -----
899899

900900
func.func @expand_shape_dynamic_with_non_identity_layout(
901-
%arg0 : memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>>) ->
902-
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>> {
901+
%arg0 : memref<1x?xf32, strided<[?, ?], offset: ?>>) ->
902+
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>> {
903903
%0 = memref.expand_shape %arg0 [[0], [1, 2]]:
904-
memref<1x?xf32, affine_map<(d0, d1)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2)>> into
905-
memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
906-
return %0 : memref<1x2x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3)>>
904+
memref<1x?xf32, strided<[?, ?], offset: ?>> into
905+
memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
906+
return %0 : memref<1x2x?xf32, strided<[?, ?, ?], offset: ?>>
907907
}
908908
// CHECK-LABEL: func @expand_shape_dynamic_with_non_identity_layout(
909909
// CHECK: llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<3 x i64>, array<3 x i64>)>
@@ -982,10 +982,10 @@ func.func @atomic_rmw(%I : memref<10xi32>, %ival : i32, %F : memref<10xf32>, %fv
982982
// -----
983983

984984
// CHECK-LABEL: func @collapse_static_shape_with_non_identity_layout
985-
func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>>) -> memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>> {
985+
func.func @collapse_static_shape_with_non_identity_layout(%arg: memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>>) -> memref<64xf32, strided<[1], offset: ?>> {
986986
// CHECK-NOT: memref.collapse_shape
987-
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, affine_map<(d0, d1, d2, d3)[s0] -> (d0 * 64 + s0 + d1 * 64 + d2 * 8 + d3)>> into memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
988-
return %1 : memref<64xf32, affine_map<(d0)[s0] -> (d0 + s0)>>
987+
%1 = memref.collapse_shape %arg [[0, 1, 2, 3]] : memref<1x1x8x8xf32, strided<[64, 64, 8, 1], offset: ?>> into memref<64xf32, strided<[1], offset: ?>>
988+
return %1 : memref<64xf32, strided<[1], offset: ?>>
989989
}
990990

991991
// -----
@@ -1069,13 +1069,11 @@ func.func @memref_copy_contiguous(%in: memref<16x2xi32>, %offset: index) {
10691069
// -----
10701070

10711071
// CHECK-LABEL: func @memref_copy_0d_offset
1072-
#map0 = affine_map<(d0) -> (d0 + 1)>
1073-
#map1 = affine_map<() -> (1)>
10741072
func.func @memref_copy_0d_offset(%in: memref<2xi32>) {
10751073
%buf = memref.alloc() : memref<i32>
1076-
%sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, #map0>
1077-
%scalar = memref.collapse_shape %sub [] : memref<1xi32, #map0> into memref<i32, #map1>
1078-
memref.copy %scalar, %buf : memref<i32, #map1> to memref<i32>
1074+
%sub = memref.subview %in[1] [1] [1] : memref<2xi32> to memref<1xi32, strided<[1], offset: 1>>
1075+
%scalar = memref.collapse_shape %sub [] : memref<1xi32, strided<[1], offset: 1>> into memref<i32, strided<[], offset: 1>>
1076+
memref.copy %scalar, %buf : memref<i32, strided<[], offset: 1>> to memref<i32>
10791077
// CHECK: llvm.intr.memcpy
10801078
return
10811079
}

mlir/test/Dialect/Bufferization/Transforms/one-shot-bufferize-alloc-tensor-elimination.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@ func.func @buffer_forwarding_conflict(
2323
%f = linalg.fill ins(%f0 : f32) outs(%a : tensor<?xf32>) -> tensor<?xf32>
2424

2525
// CHECK: memref.copy %[[FUNC_ARG]], %[[ALLOC]] : memref<?xf32> to memref<?xf32>
26-
// CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32>
27-
// CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32>
26+
// CHECK: %[[SV0_ALLOC:.*]] = memref.subview %[[ALLOC]][0] [%[[sz]]] [1] : memref<?xf32> to memref<?xf32, strided<[1]>>
27+
// CHECK: memref.copy %[[EXTRACT_SLICE_ALLOC]], %[[SV0_ALLOC]] : memref<?xf32> to memref<?xf32, strided<[1]>>
2828
%r0 = tensor.insert_slice %f into %t[0][%sz][1]: tensor<?xf32> into tensor<?xf32>
2929

3030
// CHECK: %[[T_SUBVIEW:.*]] = memref.subview %[[FUNC_ARG]][42] [%[[sz]]] [1]

mlir/test/Dialect/Linalg/roundtrip.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
// Test that we can lower all the way to LLVM without crashing, don't check results here.
77
// DISABLED: mlir-opt %s --convert-linalg-to-llvm -o=/dev/null 2>&1
88

9-
// CHECK: #[[$strided3DT:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>
10-
119
func.func @views(%arg0: index) {
1210
%c0 = arith.constant 0 : index
1311
%0 = arith.muli %arg0, %arg0 : index
@@ -70,12 +68,12 @@ func.func @fill_view(%arg0: memref<?xf32, strided<[1], offset: ?>>, %arg1: f32)
7068
// -----
7169

7270
func.func @transpose(%arg0: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>>) {
73-
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, affine_map<(d0, d1, d2)[s0, s1, s2] -> (d2 * s1 + s0 + d1 * s2 + d0)>>
71+
%0 = memref.transpose %arg0 (i, j, k) -> (k, j, i) : memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
7472
return
7573
}
7674
// CHECK-LABEL: func @transpose
7775
// CHECK: memref.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) :
78-
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, #[[$strided3DT]]>
76+
// CHECK-SAME: memref<?x?x?xf32, strided<[?, ?, 1], offset: ?>> to memref<?x?x?xf32, strided<[1, ?, ?], offset: ?>>
7977

8078
// -----
8179

mlir/test/Dialect/MemRef/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ func.func @expand_shape_to_smaller_rank(%arg0: memref<1xf32>) {
424424

425425
func.func @expand_shape_invalid_result_layout(
426426
%arg0: memref<30x20xf32, strided<[4000, 2], offset: 100>>) {
427-
// expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 60000 + d1 * 4000 + d2 * 2 + 100)>>' but found 'memref<2x15x20xf32, affine_map<(d0, d1, d2) -> (d0 * 5000 + d1 * 4000 + d2 * 2 + 100)>>'}}
427+
// expected-error @+1 {{expected expanded type to be 'memref<2x15x20xf32, strided<[60000, 4000, 2], offset: 100>>' but found 'memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>'}}
428428
%0 = memref.expand_shape %arg0 [[0, 1], [2]] :
429429
memref<30x20xf32, strided<[4000, 2], offset: 100>>
430430
into memref<2x15x20xf32, strided<[5000, 4000, 2], offset: 100>>

0 commit comments

Comments
 (0)