Skip to content

Commit 0d83b20

Browse files
committed
Feedback
1 parent 4a9ae1c commit 0d83b20

File tree

6 files changed

+61
-13
lines changed

6 files changed

+61
-13
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -690,9 +690,10 @@ def Vector_ExtractOp :
690690
Takes an n-D vector and a k-D position and extracts the (n-k)-D vector at
691691
the proper position. Degenerates to an element type if n-k is zero.
692692

693-
Dynamic indices must be greater or equal to zero and less than the size of
694-
the corresponding dimension. The result is undefined if any index is
695-
out-of-bounds.
693+
Static and dynamic indices must be greater or equal to zero and less than
694+
the size of the corresponding dimension. The result is undefined if any
695+
index is out-of-bounds. The value `-1` represents a poison index, which
696+
specifies that the extracted element is poison.
696697

697698
Example:
698699

@@ -702,6 +703,7 @@ def Vector_ExtractOp :
702703
%3 = vector.extract %1[]: vector<f32> from vector<f32>
703704
%4 = vector.extract %0[%a, %b, %c]: f32 from vector<4x8x16xf32>
704705
%5 = vector.extract %0[2, %b]: vector<16xf32> from vector<4x8x16xf32>
706+
%6 = vector.extract %10[-1, %c]: f32 from vector<4x16xf32>
705707
```
706708
}];
707709

@@ -880,9 +882,10 @@ def Vector_InsertOp :
880882
and inserts the n-D source into the (n+k)-D destination at the proper
881883
position. Degenerates to a scalar or a 0-d vector source type when n = 0.
882884

883-
Dynamic indices must be greater or equal to zero and less than the size of
884-
the corresponding dimension. The result is undefined if any index is
885-
out-of-bounds.
885+
Static and dynamic indices must be greater or equal to zero and less than
886+
the size of the corresponding dimension. The result is undefined if any
887+
index is out-of-bounds. The value `-1` represents a poison index, which
888+
specifies that the resulting vector is poison.
886889

887890
Example:
888891

@@ -892,6 +895,7 @@ def Vector_InsertOp :
892895
%8 = vector.insert %6, %7[] : f32 into vector<f32>
893896
%11 = vector.insert %9, %10[%a, %b, %c] : vector<f32> into vector<4x8x16xf32>
894897
%12 = vector.insert %4, %10[2, %b] : vector<16xf32> into vector<4x8x16xf32>
898+
%13 = vector.insert %20, %1[-1, %c] : f32 into vector<4x16xf32>
895899
```
896900
}];
897901

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,7 +1368,7 @@ LogicalResult vector::ExtractOp::verify() {
13681368
return emitOpError("expected position attribute #")
13691369
<< (idx + 1)
13701370
<< " to be a non-negative integer smaller than the "
1371-
"corresponding vector dimension";
1371+
"corresponding vector dimension or poison (-1)";
13721372
}
13731373
}
13741374
}
@@ -2264,11 +2264,7 @@ LogicalResult foldExtractFromFromElements(ExtractOp extractOp,
22642264
template <typename OpTy>
22652265
LogicalResult foldPoisonIndexInsertExtractOp(OpTy op,
22662266
PatternRewriter &rewriter) {
2267-
auto hasPoisonIndex = [](int64_t index) {
2268-
return index == OpTy::kPoisonIndex;
2269-
};
2270-
2271-
if (llvm::none_of(op.getStaticPosition(), hasPoisonIndex))
2267+
if (!llvm::is_contained(op.getStaticPosition(), OpTy::kPoisonIndex))
22722268
return failure();
22732269

22742270
rewriter.replaceOpWithNewOp<ub::PoisonOp>(op, op.getResult().getType());

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1250,6 +1250,16 @@ func.func @extract_scalar_from_vec_1d_f32(%arg0: vector<16xf32>) -> f32 {
12501250

12511251
// -----
12521252

1253+
func.func @extract_poison_idx(%arg0: vector<16xf32>) -> f32 {
1254+
%0 = vector.extract %arg0[-1]: f32 from vector<16xf32>
1255+
return %0 : f32
1256+
}
1257+
// CHECK-LABEL: @extract_poison_idx
1258+
// CHECK: %[[IDX:.*]] = llvm.mlir.constant(-1 : i64) : i64
1259+
// CHECK: llvm.extractelement {{.*}}[%[[IDX]] : i64] : vector<16xf32>
1260+
1261+
// -----
1262+
12531263
func.func @extract_scalar_from_vec_1d_f32_scalable(%arg0: vector<[16]xf32>) -> f32 {
12541264
%0 = vector.extract %arg0[15]: f32 from vector<[16]xf32>
12551265
return %0 : f32

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ func.func @extract(%arg0 : vector<2xf32>) -> (vector<1xf32>, f32) {
175175

176176
// -----
177177

178+
func.func @extract_poison_idx(%arg0 : vector<4xf32>) -> f32 {
179+
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
180+
%0 = vector.extract %arg0[-1] : f32 from vector<4xf32>
181+
return %0: f32
182+
}
183+
184+
// -----
185+
178186
// CHECK-LABEL: @extract_size1_vector
179187
// CHECK-SAME: %[[ARG0:.+]]: vector<1xf32>
180188
// CHECK: %[[R:.+]] = builtin.unrealized_conversion_cast %[[ARG0]]
@@ -256,6 +264,14 @@ func.func @insert(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
256264

257265
// -----
258266

267+
func.func @insert_poison_idx(%arg0 : vector<4xf32>, %arg1: f32) -> vector<4xf32> {
268+
// expected-error@+1 {{index -1 out of bounds for 'vector<4xf32>'}}
269+
%1 = vector.insert %arg1, %arg0[-1] : f32 into vector<4xf32>
270+
return %1: vector<4xf32>
271+
}
272+
273+
// -----
274+
259275
// CHECK-LABEL: @insert_index_vector
260276
// CHECK: spirv.CompositeInsert %{{.+}}, %{{.+}}[2 : i32] : i32 into vector<4xi32>
261277
func.func @insert_index_vector(%arg0 : vector<4xindex>, %arg1: index) -> vector<4xindex> {

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,17 @@ func.func @extract_vector_poison_idx(%a: vector<4x5xf32>) -> vector<5xf32> {
152152

153153
// -----
154154

155+
// CHECK-LABEL: @extract_multiple_poison_idx
156+
func.func @extract_multiple_poison_idx(%a: vector<4x5x8xf32>)
157+
-> vector<8xf32> {
158+
// CHECK-NOT: vector.extract
159+
// CHECK-NEXT: ub.poison : vector<8xf32>
160+
%0 = vector.extract %a[-1, -1] : vector<8xf32> from vector<4x5x8xf32>
161+
return %0 : vector<8xf32>
162+
}
163+
164+
// -----
165+
155166
// CHECK-LABEL: extract_from_create_mask_dynamic_position_all_false
156167
// CHECK-SAME: %[[DIM0:.*]]: index, %[[INDEX:.*]]: index
157168
func.func @extract_from_create_mask_dynamic_position_all_false(%dim0: index, %index: index) -> vector<6xi1> {
@@ -2833,6 +2844,17 @@ func.func @insert_vector_poison_idx(%a: vector<4x5xf32>, %b: vector<5xf32>)
28332844

28342845
// -----
28352846

2847+
// CHECK-LABEL: @insert_multiple_poison_idx
2848+
func.func @insert_multiple_poison_idx(%a: vector<4x5x8xf32>, %b: vector<8xf32>)
2849+
-> vector<4x5x8xf32> {
2850+
// CHECK-NOT: vector.insert
2851+
// CHECK-NEXT: ub.poison : vector<4x5x8xf32>
2852+
%0 = vector.insert %b, %a[-1, -1] : vector<8xf32> into vector<4x5x8xf32>
2853+
return %0 : vector<4x5x8xf32>
2854+
}
2855+
2856+
// -----
2857+
28362858
// CHECK-LABEL: @contiguous_extract_strided_slices_to_extract
28372859
// CHECK: %[[EXTRACT:.+]] = vector.extract {{.*}}[0, 0, 0, 0, 0] : vector<4xi32> from vector<8x1x2x1x1x4xi32>
28382860
// CHECK-NEXT: return %[[EXTRACT]] : vector<4xi32>

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ func.func @extract_0d(%arg0: vector<f32>) {
186186
// -----
187187

188188
func.func @extract_position_overflow(%arg0: vector<4x8x16xf32>) {
189-
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension}}
189+
// expected-error@+1 {{expected position attribute #3 to be a non-negative integer smaller than the corresponding vector dimension or poison (-1)}}
190190
%1 = vector.extract %arg0[0, 0, -5] : f32 from vector<4x8x16xf32>
191191
}
192192

0 commit comments

Comments
 (0)