Skip to content

[mlir] [IR] Allow zero strides in StridedLayoutAttr #116463

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions mlir/lib/IR/BuiltinAttributes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,6 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
LogicalResult
StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
int64_t offset, ArrayRef<int64_t> strides) {
if (llvm::is_contained(strides, 0))
return emitError() << "strides must not be zero";

return success();
}

Expand Down Expand Up @@ -1815,7 +1812,6 @@ AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
for (const auto &en : llvm::enumerate(strides)) {
auto dim = en.index();
auto stride = en.value();
assert(stride != 0 && "Invalid stride specification");
auto d = getAffineDimExpr(dim, context);
AffineExpr mult;
// Static case.
Expand Down
14 changes: 0 additions & 14 deletions mlir/lib/IR/BuiltinTypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -803,20 +803,6 @@ static LogicalResult getStridesAndOffset(MemRefType t,
for (auto &stride : strides)
stride = simplifyAffineExpr(stride, numDims, numSymbols);

// In practice, a strided memref must be internally non-aliasing. Test
// against 0 as a proxy.
// TODO: static cases can have more advanced checks.
// TODO: dynamic cases would require a way to compare symbolic
// expressions and would probably need an affine set context propagated
// everywhere.
if (llvm::any_of(strides, [](AffineExpr e) {
return e == getAffineConstantExpr(0, e.getContext());
})) {
offset = AffineExpr();
strides.clear();
return failure();
}

return success();
}

Expand Down
4 changes: 2 additions & 2 deletions mlir/test/Dialect/Affine/memref-stride-calculation.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ func.func @f(%0: index) {
%26 = memref.alloc(%0)[] : memref<?xf32, affine_map<(i)[M]->(i)>>
// CHECK: MemRefType offset: 0 strides: 1
%27 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(M)>>
// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (s0)>> cannot be converted to strided form
// CHECK: MemRefType offset: ? strides: 0
%28 = memref.alloc()[%0] : memref<5xf32, affine_map<(i)[M]->(123)>>
// CHECK: MemRefType memref<5xf32, affine_map<(d0)[s0] -> (123)>> cannot be converted to strided form
// CHECK: MemRefType offset: 123 strides: 0
%29 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(M)>>
// CHECK: MemRefType offset: ? strides:
%30 = memref.alloc()[%0] : memref<f32, affine_map<()[M]->(123)>>
Expand Down
10 changes: 0 additions & 10 deletions mlir/test/Dialect/MemRef/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -245,16 +245,6 @@ func.func @memref_reinterpret_cast_no_map_but_strides(%in: memref<?x?xf32>) {

// -----

func.func @memref_reinterpret_cast_non_strided_layout(%in: memref<?x?xf32>) {
// expected-error @+1 {{expected result type to have strided layout but found 'memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>}}
%out = memref.reinterpret_cast %in to
offset: [0], sizes: [9, 10], strides: [42, 1]
: memref<?x?xf32> to memref<9x10xf32, affine_map<(d0, d1) -> (d0)>>
return
}

// -----

func.func @memref_reshape_element_type_mismatch(
%buf: memref<*xf32>, %shape: memref<1xi32>) {
// expected-error @+1 {{element types of source and destination memref types should be the same}}
Expand Down
5 changes: 0 additions & 5 deletions mlir/test/IR/invalid-builtin-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,6 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<

// -----

// expected-error @below {{strides must not be zero}}
func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>

// -----

// expected-error @below {{expected the number of strides to match the rank}}
func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>

Expand Down
Loading