Skip to content

Commit a2a5c93

Browse files
update test.
1 parent f6a113d commit a2a5c93

File tree

3 files changed

+57
-54
lines changed

3 files changed

+57
-54
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1986,17 +1986,20 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
19861986
return fromElementsOp.getElements()[flatIndex];
19871987
}
19881988

1989-
// If the dynamic operands of `extractOp` or `insertOp` is result of
1989+
// If the dynamic indices of `extractOp` or `insertOp` are result of
19901990
// `constantOp`, then fold it.
19911991
template <typename OpType, typename AdaptorType>
19921992
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
19931993
SmallVectorImpl<Value> &operands) {
1994-
auto staticPosition = op.getStaticPosition().vec();
1994+
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
19951995
OperandRange dynamicPosition = op.getDynamicPosition();
19961996
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();
1997+
19971998
// If the dynamic operands is empty, it is returned directly.
19981999
if (!dynamicPosition.size())
19992000
return {};
2001+
2002+
// `index` is used to iterate over the `dynamicPosition`.
20002003
unsigned index = 0;
20012004

20022005
// `opChange` is a flag. If it is true, it means to update `op` in place.

mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir

Lines changed: 43 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1473,6 +1473,26 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
14731473
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
14741474
// CHECK: return %[[T3]] : index
14751475

1476+
1477+
// -----
1478+
1479+
func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
1480+
%0 = arith.constant 0 : index
1481+
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
1482+
return %1 : i32
1483+
}
1484+
1485+
// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
1486+
// the constants are folded away, hence the lowering works.
1487+
1488+
// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
1489+
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
1490+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
1491+
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
1492+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
1493+
// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
1494+
// CHECK: return %[[RES]] : i32
1495+
14761496
// -----
14771497

14781498
//===----------------------------------------------------------------------===//
@@ -1726,6 +1746,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
17261746

17271747
// -----
17281748

1749+
func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
1750+
%0 = arith.constant 0 : index
1751+
%1 = arith.constant 1 : i32
1752+
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
1753+
return %res : vector<4x1xi32>
1754+
}
1755+
1756+
// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
1757+
// the constants are folded away, hence the lowering works.
1758+
1759+
// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
1760+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
1761+
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
1762+
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
1763+
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
1764+
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
1765+
// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
1766+
// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
1767+
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
1768+
// CHECK: return %[[RES]] : vector<4x1xi32>
1769+
1770+
// -----
1771+
17291772
//===----------------------------------------------------------------------===//
17301773
// vector.type_cast
17311774
//
@@ -4125,42 +4168,3 @@ func.func @step_scalable() -> vector<[4]xindex> {
41254168
%0 = vector.step : vector<[4]xindex>
41264169
return %0 : vector<[4]xindex>
41274170
}
4128-
4129-
// -----
4130-
4131-
// CHECK-LABEL: func @fold_extract_constant_indices
4132-
4133-
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
4134-
%0 = arith.constant 0 : index
4135-
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
4136-
return %1 : i32
4137-
}
4138-
4139-
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
4140-
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
4141-
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
4142-
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
4143-
// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
4144-
// CHECK: return %[[RES]] : i32
4145-
4146-
// -----
4147-
4148-
// CHECK-LABEL: func @fold_insert_constant_indices
4149-
4150-
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
4151-
%0 = arith.constant 0 : index
4152-
%1 = arith.constant 1 : i32
4153-
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
4154-
return %res : vector<4x1xi32>
4155-
}
4156-
4157-
4158-
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
4159-
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
4160-
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
4161-
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
4162-
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
4163-
// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
4164-
// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
4165-
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
4166-
// CHECK: return %[[RES]] : vector<4x1xi32>

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3045,30 +3045,26 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
30453045

30463046
// -----
30473047

3048-
// CHECK-LABEL: func @fold_extract_constant_indices
3049-
3048+
// CHECK-LABEL: @fold_extract_constant_indices
3049+
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
3050+
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
3051+
// CHECK: return %[[RES]] : i32
30503052
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
30513053
%0 = arith.constant 0 : index
30523054
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
30533055
return %1 : i32
30543056
}
30553057

3056-
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
3057-
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
3058-
// CHECK: return %[[RES]] : i32
3059-
30603058
// -----
30613059

3062-
// CHECK-LABEL: func @fold_insert_constant_indices
3063-
3060+
// CHECK-LABEL: @fold_insert_constant_indices
3061+
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
3062+
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
3063+
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
3064+
// CHECK: return %[[RES]] : vector<4x1xi32>
30643065
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
30653066
%0 = arith.constant 0 : index
30663067
%1 = arith.constant 1 : i32
30673068
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
30683069
return %res : vector<4x1xi32>
30693070
}
3070-
3071-
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
3072-
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
3073-
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
3074-
// CHECK: return %[[RES]] : vector<4x1xi32>

0 commit comments

Comments
 (0)