Skip to content

[mlir][vector]add extractInsertFoldConstantOp fold function and apply it to extractOp and insertOp. #124399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
45 changes: 45 additions & 0 deletions mlir/lib/Dialect/Vector/IR/VectorOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1986,6 +1986,45 @@ static Value foldScalarExtractFromFromElements(ExtractOp extractOp) {
return fromElementsOp.getElements()[flatIndex];
}

// If the dynamic indices of `extractOp` or `insertOp` are result of
// `constantOp`, then fold it.
template <typename OpType, typename AdaptorType>
static Value extractInsertFoldConstantOp(OpType op, AdaptorType adaptor,
SmallVectorImpl<Value> &operands) {
std::vector<int64_t> staticPosition = op.getStaticPosition().vec();
OperandRange dynamicPosition = op.getDynamicPosition();
ArrayRef<Attribute> dynamicPositionAttr = adaptor.getDynamicPosition();

// If the dynamic operands is empty, it is returned directly.
if (!dynamicPosition.size())
return {};

// `index` is used to iterate over the `dynamicPosition`.
unsigned index = 0;

// `opChange` is a flag. If it is true, it means to update `op` in place.
bool opChange = false;
for (unsigned i = 0, e = staticPosition.size(); i < e; ++i) {
if (!ShapedType::isDynamic(staticPosition[i]))
continue;
Attribute positionAttr = dynamicPositionAttr[index];
Value position = dynamicPosition[index++];
if (auto attr = mlir::dyn_cast_if_present<IntegerAttr>(positionAttr)) {
staticPosition[i] = attr.getInt();
opChange = true;
continue;
}
operands.push_back(position);
}

if (opChange) {
op.setStaticPosition(staticPosition);
op.getOperation()->setOperands(operands);
return op.getResult();
}
return {};
}

/// Fold an insert or extract operation into an poison value when a poison index
/// is found at any dimension of the static position.
static ub::PoisonAttr
Expand Down Expand Up @@ -2022,6 +2061,9 @@ OpFoldResult ExtractOp::fold(FoldAdaptor adaptor) {
return val;
if (auto val = foldScalarExtractFromFromElements(*this))
return val;
SmallVector<Value> operands = {getVector()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
return OpFoldResult();
}

Expand Down Expand Up @@ -3069,6 +3111,9 @@ OpFoldResult vector::InsertOp::fold(FoldAdaptor adaptor) {
// (type mismatch).
if (getNumIndices() == 0 && getSourceType() == getType())
return getSource();
SmallVector<Value> operands = {getSource(), getDest()};
if (auto val = extractInsertFoldConstantOp(*this, adaptor, operands))
return val;
if (auto res = foldPoisonIndexInsertExtractOp(
getContext(), adaptor.getStaticPosition(), kPoisonIndex))
return res;
Expand Down
43 changes: 43 additions & 0 deletions mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1473,6 +1473,26 @@ func.func @extract_scalar_from_vec_0d_index(%arg0: vector<index>) -> index {
// CHECK: %[[T3:.*]] = builtin.unrealized_conversion_cast %[[T2]] : i64 to index
// CHECK: return %[[T3]] : index


// -----

func.func @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<32x1xi32>) -> i32 {
%0 = arith.constant 0 : index
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
return %1 : i32
}

// Compile-time if the indices of extractOp if constants, the constants will be collapsed,
// the constants are folded away, hence the lowering works.

// CHECK-LABEL: @extract_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<32x1xi32> to !llvm.array<32 x vector<1xi32>>
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<32 x vector<1xi32>>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[RES:.*]] = llvm.extractelement %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
// CHECK: return %[[RES]] : i32

// -----

//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1726,6 +1746,29 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1

// -----

func.func @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}

// Compile-time if the indices of insertOp if constants, the constants will be collapsed,
// the constants are folded away, hence the lowering works.

// CHECK-LABEL: @insert_scalar_from_vec_2d_f32_dynamic_idxs_compile_time_const
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
// CHECK: %[[CAST:.*]] = builtin.unrealized_conversion_cast %[[ARG]] : vector<4x1xi32> to !llvm.array<4 x vector<1xi32>>
// CHECK: %[[C1:.*]] = arith.constant 1 : i32
// CHECK: %[[VEC_0:.*]] = llvm.extractvalue %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
// CHECK: %[[C0:.*]] = llvm.mlir.constant(0 : i64) : i64
// CHECK: %[[VEC_1:.*]] = llvm.insertelement %[[C1]], %[[VEC_0]]{{\[}}%[[C0]] : i64] : vector<1xi32>
// CHECK: %[[VEC_2:.*]] = llvm.insertvalue %[[VEC_1]], %[[CAST]][0] : !llvm.array<4 x vector<1xi32>>
// CHECK: %[[RES:.*]] = builtin.unrealized_conversion_cast %[[VEC_2]] : !llvm.array<4 x vector<1xi32>> to vector<4x1xi32>
// CHECK: return %[[RES]] : vector<4x1xi32>

// -----

//===----------------------------------------------------------------------===//
// vector.type_cast
//
Expand Down
26 changes: 26 additions & 0 deletions mlir/test/Dialect/Vector/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3042,3 +3042,29 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
return
}

// -----

// CHECK-LABEL: @fold_extract_constant_indices
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>
// CHECK: return %[[RES]] : i32
func.func @fold_extract_constant_indices(%arg : vector<32x1xi32>) -> i32 {
%0 = arith.constant 0 : index
%1 = vector.extract %arg[%0, %0] : i32 from vector<32x1xi32>
return %1 : i32
}

// -----

// CHECK-LABEL: @fold_insert_constant_indices
// CHECK-SAME: %[[ARG:.*]]: vector<4x1xi32>) -> vector<4x1xi32> {
// CHECK: %[[VAL:.*]] = arith.constant 1 : i32
// CHECK: %[[RES:.*]] = vector.insert %[[VAL]], %[[ARG]] [0, 0] : i32 into vector<4x1xi32>
// CHECK: return %[[RES]] : vector<4x1xi32>
func.func @fold_insert_constant_indices(%arg : vector<4x1xi32>) -> vector<4x1xi32> {
%0 = arith.constant 0 : index
%1 = arith.constant 1 : i32
%res = vector.insert %1, %arg[%0, %0] : i32 into vector<4x1xi32>
return %res : vector<4x1xi32>
}
3 changes: 1 addition & 2 deletions mlir/test/Dialect/Vector/vector-warp-distribute.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -778,12 +778,11 @@ func.func @warp_constant(%laneid: index) -> (vector<1xf32>) {

// CHECK-PROP-LABEL: func.func @vector_extract_1d(
// CHECK-PROP-DAG: %[[C5_I32:.*]] = arith.constant 5 : i32
// CHECK-PROP-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-PROP: %[[R:.*]] = gpu.warp_execute_on_lane_0(%{{.*}})[32] -> (vector<2xf32>) {
// CHECK-PROP: %[[V:.*]] = "some_def"() : () -> vector<64xf32>
// CHECK-PROP: gpu.yield %[[V]] : vector<64xf32>
// CHECK-PROP: }
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][%[[C1]]] : f32 from vector<2xf32>
// CHECK-PROP: %[[E:.*]] = vector.extract %[[R]][1] : f32 from vector<2xf32>
// CHECK-PROP: %[[SHUFFLED:.*]], %{{.*}} = gpu.shuffle idx %[[E]], %[[C5_I32]]
// CHECK-PROP: return %[[SHUFFLED]] : f32
func.func @vector_extract_1d(%laneid: index) -> (f32) {
Expand Down