Skip to content

Commit 54d81e4

Browse files
committed
[mlir] Allow negative strides and offset in StridedLayoutAttr
Negative strides are useful for creating reverse-view of array. We don't have specific example for negative offset yet but will add it for consistency. Differential Revision: https://reviews.llvm.org/D134147
1 parent 8a774c3 commit 54d81e4

File tree

6 files changed

+79
-22
lines changed

6 files changed

+79
-22
lines changed

mlir/lib/AsmParser/AttributeParser.cpp

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1174,17 +1174,23 @@ Attribute Parser::parseStridedLayoutAttr() {
11741174

11751175
SMLoc loc = getToken().getLoc();
11761176
auto emitWrongTokenError = [&] {
1177-
emitError(loc, "expected a non-negative 64-bit signed integer or '?'");
1177+
emitError(loc, "expected a 64-bit signed integer or '?'");
11781178
return llvm::None;
11791179
};
11801180

1181+
bool negative = consumeIf(Token::minus);
1182+
11811183
if (getToken().is(Token::integer)) {
11821184
Optional<uint64_t> value = getToken().getUInt64IntegerValue();
11831185
if (!value ||
11841186
*value > static_cast<uint64_t>(std::numeric_limits<int64_t>::max()))
11851187
return emitWrongTokenError();
11861188
consumeToken();
1187-
return static_cast<int64_t>(*value);
1189+
auto result = static_cast<int64_t>(*value);
1190+
if (negative)
1191+
result = -result;
1192+
1193+
return result;
11881194
}
11891195

11901196
return emitWrongTokenError();

mlir/lib/IR/BuiltinAttributes.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,9 @@ AffineMap StridedLayoutAttr::getAffineMap() const {
269269
LogicalResult
270270
StridedLayoutAttr::verify(function_ref<InFlightDiagnostic()> emitError,
271271
int64_t offset, ArrayRef<int64_t> strides) {
272-
if (offset < 0 && offset != ShapedType::kDynamicStrideOrOffset)
273-
return emitError() << "offset must be non-negative or dynamic";
272+
if (llvm::any_of(strides, [&](int64_t stride) { return stride == 0; }))
273+
return emitError() << "strides must not be zero";
274274

275-
if (llvm::any_of(strides, [&](int64_t stride) {
276-
return stride <= 0 && stride != ShapedType::kDynamicStrideOrOffset;
277-
})) {
278-
return emitError() << "strides must be positive or dynamic";
279-
}
280275
return success();
281276
}
282277

mlir/test/Conversion/MemRefToLLVM/memref-to-llvm.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,31 @@ func.func @subview_rank_reducing_leading_operands(%0 : memref<5x3xf32>) {
540540

541541
// -----
542542

543+
// CHECK-LABEL: func @subview_negative_stride
544+
// CHECK-SAME: (%[[ARG:.*]]: memref<7xf32>)
545+
func.func @subview_negative_stride(%arg0 : memref<7xf32>) -> memref<7xf32, strided<[-1], offset: 6>> {
546+
// CHECK: %[[ORIG:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : memref<7xf32> to !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
547+
// CHECK: %[[NEW1:.*]] = llvm.mlir.undef : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
548+
// CHECK: %[[PTR1:.*]] = llvm.extractvalue %[[ORIG]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
549+
// CHECK: %[[PTR2:.*]] = llvm.bitcast %[[PTR1]] : !llvm.ptr<f32> to !llvm.ptr<f32>
550+
// CHECK: %[[NEW2:.*]] = llvm.insertvalue %[[PTR2]], %[[NEW1]][0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
551+
// CHECK: %[[PTR3:.*]] = llvm.extractvalue %[[ORIG]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
552+
// CHECK: %[[PTR4:.*]] = llvm.bitcast %[[PTR3]] : !llvm.ptr<f32> to !llvm.ptr<f32>
553+
// CHECK: %[[NEW3:.*]] = llvm.insertvalue %[[PTR4]], %[[NEW2]][1] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
554+
// CHECK: %[[OFFSET:.*]] = llvm.mlir.constant(6 : index) : i64
555+
// CHECK: %[[NEW4:.*]] = llvm.insertvalue %[[OFFSET]], %[[NEW3]][2] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
556+
// CHECK: %[[SIZE:.*]] = llvm.mlir.constant(7 : i64) : i64
557+
// CHECK: %[[STRIDE:.*]] = llvm.mlir.constant(-1 : i64) : i64
558+
// CHECK: %[[NEW5:.*]] = llvm.insertvalue %[[SIZE]], %[[NEW4]][3, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
559+
// CHECK: %[[NEW6:.*]] = llvm.insertvalue %[[STRIDE]], %[[NEW5]][4, 0] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)>
560+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[NEW6]] : !llvm.struct<(ptr<f32>, ptr<f32>, i64, array<1 x i64>, array<1 x i64>)> to memref<7xf32, strided<[-1], offset: 6>>
561+
// CHECK: return %[[RES]] : memref<7xf32, strided<[-1], offset: 6>>
562+
%0 = memref.subview %arg0[6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
563+
return %0 : memref<7xf32, strided<[-1], offset: 6>>
564+
}
565+
566+
// -----
567+
543568
// CHECK-LABEL: func @assume_alignment
544569
func.func @assume_alignment(%0 : memref<4x4xf16>) {
545570
// CHECK: %[[PTR:.*]] = llvm.extractvalue %[[MEMREF:.*]][1] : !llvm.struct<(ptr<f16>, ptr<f16>, i64, array<2 x i64>, array<2 x i64>)>

mlir/test/Dialect/Builtin/types.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,7 @@ func.func private @f6() -> memref<?x?xf32, strided<[42, 1], offset: 0>>
1616
func.func private @f7() -> memref<f32, strided<[]>>
1717
// CHECK: memref<f32, strided<[], offset: ?>>
1818
func.func private @f8() -> memref<f32, strided<[], offset: ?>>
19+
// CHECK: memref<?xf32, strided<[-1], offset: ?>>
20+
func.func private @f9() -> memref<?xf32, strided<[-1], offset: ?>>
21+
// CHECK: memref<f32, strided<[], offset: -1>>
22+
func.func private @f10() -> memref<f32, strided<[], offset: -1>>

mlir/test/Dialect/MemRef/canonicalize.mlir

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,43 @@ func.func @multiple_reducing_dims_all_dynamic(%arg0 : memref<?x?x?xf32, strided<
127127
// CHECK: %[[REDUCED2:.+]] = memref.subview %[[REDUCED1]][0, 0] [1, %{{.+}}] [1, 1]
128128
// CHECK-SAME: : memref<1x?xf32, strided<[?, ?], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
129129

130+
// -----
131+
132+
func.func @subview_negative_stride1(%arg0 : memref<?xf32>) -> memref<?xf32, strided<[?], offset: ?>>
133+
{
134+
%c0 = arith.constant 0 : index
135+
%c1 = arith.constant -1 : index
136+
%1 = memref.dim %arg0, %c0 : memref<?xf32>
137+
%2 = arith.addi %1, %c1 : index
138+
%3 = memref.subview %arg0[%2] [%1] [%c1] : memref<?xf32> to memref<?xf32, strided<[?], offset: ?>>
139+
return %3 : memref<?xf32, strided<[?], offset: ?>>
140+
}
141+
// CHECK: func @subview_negative_stride1
142+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
143+
// CHECK: %[[C1:.*]] = arith.constant 0
144+
// CHECK: %[[C2:.*]] = arith.constant -1
145+
// CHECK: %[[DIM1:.*]] = memref.dim %[[ARG0]], %[[C1]] : memref<?xf32>
146+
// CHECK: %[[DIM2:.*]] = arith.addi %[[DIM1]], %[[C2]] : index
147+
// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][%[[DIM2]]] [%[[DIM1]]] [-1] : memref<?xf32> to memref<?xf32, strided<[-1], offset: ?>>
148+
// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref<?xf32, strided<[-1], offset: ?>> to memref<?xf32, strided<[?], offset: ?>>
149+
// CHECK: return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
150+
151+
// -----
152+
153+
func.func @subview_negative_stride2(%arg0 : memref<7xf32>) -> memref<?xf32, strided<[?], offset: ?>>
154+
{
155+
%c0 = arith.constant 0 : index
156+
%c1 = arith.constant -1 : index
157+
%1 = memref.dim %arg0, %c0 : memref<7xf32>
158+
%2 = arith.addi %1, %c1 : index
159+
%3 = memref.subview %arg0[%2] [%1] [%c1] : memref<7xf32> to memref<?xf32, strided<[?], offset: ?>>
160+
return %3 : memref<?xf32, strided<[?], offset: ?>>
161+
}
162+
// CHECK: func @subview_negative_stride2
163+
// CHECK-SAME: (%[[ARG0:.*]]: memref<7xf32>)
164+
// CHECK: %[[RES1:.*]] = memref.subview %[[ARG0]][6] [7] [-1] : memref<7xf32> to memref<7xf32, strided<[-1], offset: 6>>
165+
// CHECK: %[[RES2:.*]] = memref.cast %[[RES1]] : memref<7xf32, strided<[-1], offset: 6>> to memref<?xf32, strided<[?], offset: ?>>
166+
// CHECK: return %[[RES2]] : memref<?xf32, strided<[?], offset: ?>>
130167

131168
// -----
132169

mlir/test/IR/invalid-builtin-types.mlir

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ func.func private @memref_unfinished_strided() -> memref<?x?xf32, strided<>>
7474

7575
// -----
7676

77-
// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
77+
// expected-error @below {{expected a 64-bit signed integer or '?'}}
7878
func.func private @memref_unfinished_stride_list() -> memref<?x?xf32, strided<[>>
7979

8080
// -----
@@ -89,7 +89,7 @@ func.func private @memref_missing_offset_colon() -> memref<?x?xf32, strided<[],
8989

9090
// -----
9191

92-
// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
92+
// expected-error @below {{expected a 64-bit signed integer or '?'}}
9393
func.func private @memref_missing_offset_value() -> memref<?x?xf32, strided<[], offset: >>
9494

9595
// -----
@@ -99,21 +99,11 @@ func.func private @memref_incorrect_strided_ending() -> memref<?x?xf32, strided<
9999

100100
// -----
101101

102-
// expected-error @below {{strides must be positive or dynamic}}
102+
// expected-error @below {{strides must not be zero}}
103103
func.func private @memref_zero_stride() -> memref<?x?xf32, strided<[0, 0]>>
104104

105105
// -----
106106

107-
// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
108-
func.func private @memref_negative_stride() -> memref<?x?xf32, strided<[-2, -2]>>
109-
110-
// -----
111-
112-
// expected-error @below {{expected a non-negative 64-bit signed integer or '?'}}
113-
func.func private @memref_negative_offset() -> memref<?x?xf32, strided<[2, 1], offset: -2>>
114-
115-
// -----
116-
117107
// expected-error @below {{expected the number of strides to match the rank}}
118108
func.func private @memref_strided_rank_mismatch() -> memref<?x?xf32, strided<[1]>>
119109

0 commit comments

Comments
 (0)