Skip to content

Commit 2721f5a

Browse files
authored
[mlir][vector] Prevent folding of OOB values in insert/extract (#135498)
Out of bound position values should not be folded in vector.extract and vector.insert operations, as only in bounds constants and -1 are valid. Fixes #134516
1 parent 64de852 commit 2721f5a

File tree

2 files changed

+50
-3
lines changed

2 files changed

+50
-3
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1997,6 +1997,11 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
19971997
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
19981998
OperandRange dynamicPosition = op.getDynamicPosition();
19991999
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
2000+
ArrayRef<int64_t> vectorShape;
2001+
if constexpr (std::is_same_v<OpType, ExtractOp>)
2002+
vectorShape = op.getSourceVectorType().getShape();
2003+
else
2004+
vectorShape = op.getDestVectorType().getShape();
20002005

20012006
// If the dynamic operands is empty, it is returned directly.
20022007
if (!dynamicPosition.size())
@@ -2013,9 +2018,13 @@ static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
20132018
Attribute positionAttr = dynamicPositionAttr[index];
20142019
Value position = dynamicPosition[index++];
20152020
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
2016-
staticPosition[i] = attr.getInt();
2017-
opChange = true;
2018-
continue;
2021+
int64_t value = attr.getInt();
2022+
// Do not fold if the value is out of bounds.
2023+
if (value >= 0 && value < vectorShape[i]) {
2024+
staticPosition[i] = attr.getInt();
2025+
opChange = true;
2026+
continue;
2027+
}
20192028
}
20202029
operands.push_back(position);
20212030
}

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3311,3 +3311,41 @@ func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi3
33113311
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
33123312
return %res : vector<4x1xi32>
33133313
}
3314+
3315+
// -----
3316+
3317+
// Check that out of bounds indices are not folded for vector.insert.
3318+
3319+
// CHECK-LABEL: @fold_insert_oob
3320+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> vector<4x1x2xi32> {
3321+
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
3322+
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
3323+
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
3324+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, %[[OOB1]], %[[OOB2]]] : i32 into vector<4x1x2xi32>
3325+
// CHECK: return %[[RES]] : vector<4x1x2xi32>
3326+
func.func @fold_insert_oob(%arg : vector<4x1x2xi32>) -> vector<4x1x2xi32> {
3327+
%c0 = arith.constant 0 : index
3328+
%c-2 = arith.constant -2 : index
3329+
%c2 = arith.constant 2 : index
3330+
%c1 = arith.constant 1 : i32
3331+
%res = vector.insert %c1, %arg[%c0, %c-2, %c2] : i32 into vector<4x1x2xi32>
3332+
return %res : vector<4x1x2xi32>
3333+
}
3334+
3335+
// -----
3336+
3337+
// Check that out of bounds indices are not folded for vector.extract.
3338+
3339+
// CHECK-LABEL: @fold_extract_oob
3340+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1x2xi32>) -> i32 {
3341+
// CHECK: %[[OOB1:.*]] = arith.constant -2 : index
3342+
// CHECK: %[[OOB2:.*]] = arith.constant 2 : index
3343+
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, %[[OOB1]], %[[OOB2]]] : i32 from vector<4x1x2xi32>
3344+
// CHECK: return %[[RES]] : i32
3345+
func.func @fold_extract_oob(%arg : vector<4x1x2xi32>) -> i32 {
3346+
%c0 = arith.constant 0 : index
3347+
%c-2 = arith.constant -2 : index
3348+
%c2 = arith.constant 2 : index
3349+
%res = vector.extract %arg[%c0, %c-2, %c2] : i32 from vector<4x1x2xi32>
3350+
return %res : i32
3351+
}

0 commit comments

Comments
 (0)