Skip to content

Commit 7371f69

Browse files
authored
[MLIR][Vector]: Generalize conversion of vector.insert to LLVM in line with vector.extract (#128915)
This is doing the same as #117731 did for `vector.extract`, but for `vector.insert`. It is a bit more complicated as the insertion destination may itself need to be extracted. As the test shows, this fixes two previously unsupported cases: - Dynamic indices - 0-D vectors. --------- Signed-off-by: Benoit Jacob <[email protected]>
1 parent eb84c11 commit 7371f69

File tree

2 files changed

+101
-55
lines changed

2 files changed

+101
-55
lines changed

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 69 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,6 @@
3535
using namespace mlir;
3636
using namespace mlir::vector;
3737

38-
// Helper to reduce vector type by *all* but one rank at back.
39-
static VectorType reducedVectorTypeBack(VectorType tp) {
40-
assert((tp.getRank() > 1) && "unlowerable vector type");
41-
return VectorType::get(tp.getShape().take_back(), tp.getElementType(),
42-
tp.getScalableDims().take_back());
43-
}
44-
4538
// Helper that picks the proper sequence for inserting.
4639
static Value insertOne(ConversionPatternRewriter &rewriter,
4740
const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
12231216
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
12241217
ConversionPatternRewriter &rewriter) const override {
12251218
auto loc = insertOp->getLoc();
1226-
auto sourceType = insertOp.getSourceType();
12271219
auto destVectorType = insertOp.getDestVectorType();
12281220
auto llvmResultType = typeConverter->convertType(destVectorType);
12291221
// Bail if result type cannot be lowered.
@@ -1233,53 +1225,81 @@ class VectorInsertOpConversion
12331225
SmallVector<OpFoldResult> positionVec = getMixedValues(
12341226
adaptor.getStaticPosition(), adaptor.getDynamicPosition(), rewriter);
12351227

1236-
// Overwrite entire vector with value. Should be handled by folder, but
1237-
// just to be safe.
1238-
ArrayRef<OpFoldResult> position(positionVec);
1239-
if (position.empty()) {
1240-
rewriter.replaceOp(insertOp, adaptor.getSource());
1241-
return success();
1242-
}
1243-
1244-
// One-shot insertion of a vector into an array (only requires insertvalue).
1245-
if (isa<VectorType>(sourceType)) {
1246-
if (insertOp.hasDynamicPosition())
1247-
return failure();
1248-
1249-
Value inserted = rewriter.create<LLVM::InsertValueOp>(
1250-
loc, adaptor.getDest(), adaptor.getSource(), getAsIntegers(position));
1251-
rewriter.replaceOp(insertOp, inserted);
1252-
return success();
1228+
// The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1229+
// its explanatory comment about how N-D vectors are converted as nested
1230+
// aggregates (llvm.array's) of 1D vectors.
1231+
//
1232+
// The innermost dimension of the destination vector, when converted to a
1233+
// nested aggregate form, will always be a 1D vector.
1234+
//
1235+
// * If the insertion is happening into the innermost dimension of the
1236+
// destination vector:
1237+
// - If the destination is a nested aggregate, extract a 1D vector out of
1238+
// the aggregate. This can be done using llvm.extractvalue. The
1239+
// destination is now guaranteed to be a 1D vector, to which we are
1240+
// inserting.
1241+
// - Do the insertion into the 1D destination vector, and make the result
1242+
// the new source nested aggregate. This can be done using
1243+
// llvm.insertelement.
1244+
// * Insert the source nested aggregate into the destination nested
1245+
// aggregate.
1246+
1247+
// Determine if we need to extract/insert a 1D vector out of the aggregate.
1248+
bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1249+
// Determine if we need to insert a scalar into the 1D vector.
1250+
bool insertIntoInnermostDim =
1251+
static_cast<int64_t>(positionVec.size()) == destVectorType.getRank();
1252+
1253+
ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate(
1254+
positionVec.begin(),
1255+
insertIntoInnermostDim ? positionVec.size() - 1 : positionVec.size());
1256+
OpFoldResult positionOfScalarWithin1DVector;
1257+
if (destVectorType.getRank() == 0) {
1258+
// Since the LLVM type converter converts 0D vectors to 1D vectors, we
1259+
// need to create a 0 here as the position into the 1D vector.
1260+
Type idxType = typeConverter->convertType(rewriter.getIndexType());
1261+
positionOfScalarWithin1DVector = rewriter.getZeroAttr(idxType);
1262+
} else if (insertIntoInnermostDim) {
1263+
positionOfScalarWithin1DVector = positionVec.back();
12531264
}
12541265

1255-
// Potential extraction of 1-D vector from array.
1256-
Value extracted = adaptor.getDest();
1257-
auto oneDVectorType = destVectorType;
1258-
if (position.size() > 1) {
1259-
if (insertOp.hasDynamicPosition())
1260-
return failure();
1261-
1262-
oneDVectorType = reducedVectorTypeBack(destVectorType);
1263-
extracted = rewriter.create<LLVM::ExtractValueOp>(
1264-
loc, extracted, getAsIntegers(position.drop_back()));
1266+
// We are going to mutate this 1D vector until it is either the final
1267+
// result (in the non-aggregate case) or the value that needs to be
1268+
// inserted into the aggregate result.
1269+
Value sourceAggregate = adaptor.getSource();
1270+
if (insertIntoInnermostDim) {
1271+
// Scalar-into-1D-vector case, so we know we will have to create a
1272+
// InsertElementOp. The question is into what destination.
1273+
if (isNestedAggregate) {
1274+
// Aggregate case: the destination for the InsertElementOp needs to be
1275+
// extracted from the aggregate.
1276+
if (!llvm::all_of(positionOf1DVectorWithinAggregate,
1277+
llvm::IsaPred<Attribute>)) {
1278+
// llvm.extractvalue does not support dynamic dimensions.
1279+
return failure();
1280+
}
1281+
sourceAggregate = rewriter.create<LLVM::ExtractValueOp>(
1282+
loc, adaptor.getDest(),
1283+
getAsIntegers(positionOf1DVectorWithinAggregate));
1284+
} else {
1285+
// No-aggregate case. The destination for the InsertElementOp is just
1286+
// the insertOp's destination.
1287+
sourceAggregate = adaptor.getDest();
1288+
}
1289+
// Insert the scalar into the 1D vector.
1290+
sourceAggregate = rewriter.create<LLVM::InsertElementOp>(
1291+
loc, sourceAggregate.getType(), sourceAggregate, adaptor.getSource(),
1292+
getAsLLVMValue(rewriter, loc, positionOfScalarWithin1DVector));
12651293
}
12661294

1267-
// Insertion of an element into a 1-D LLVM vector.
1268-
Value inserted = rewriter.create<LLVM::InsertElementOp>(
1269-
loc, typeConverter->convertType(oneDVectorType), extracted,
1270-
adaptor.getSource(), getAsLLVMValue(rewriter, loc, position.back()));
1271-
1272-
// Potential insertion of resulting 1-D vector into array.
1273-
if (position.size() > 1) {
1274-
if (insertOp.hasDynamicPosition())
1275-
return failure();
1276-
1277-
inserted = rewriter.create<LLVM::InsertValueOp>(
1278-
loc, adaptor.getDest(), inserted,
1279-
getAsIntegers(position.drop_back()));
1295+
Value result = sourceAggregate;
1296+
if (isNestedAggregate) {
1297+
result = rewriter.create<LLVM::InsertValueOp>(
1298+
loc, adaptor.getDest(), sourceAggregate,
1299+
getAsIntegers(positionOf1DVectorWithinAggregate));
12801300
}
12811301

1282-
rewriter.replaceOp(insertOp, inserted);
1302+
rewriter.replaceOp(insertOp, result);
12831303
return success();
12841304
}
12851305
};

mlir/test/Conversion/VectorToLLVM/vector-to-llvm-interface.mlir

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,16 @@ func.func @insertelement_into_vec_1d_f32_scalable_idx_as_index_scalable(%arg0: f
628628
// vector.insert
629629
//===----------------------------------------------------------------------===//
630630

631+
func.func @insert_scalar_into_vec_0d(%src: f32, %dst: vector<f32>) -> vector<f32> {
632+
%0 = vector.insert %src, %dst[] : f32 into vector<f32>
633+
return %0 : vector<f32>
634+
}
635+
636+
// CHECK-LABEL: @insert_scalar_into_vec_0d
637+
// CHECK: llvm.insertelement {{.*}} : vector<1xf32>
638+
639+
// -----
640+
631641
func.func @insert_scalar_into_vec_1d_f32(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
632642
%0 = vector.insert %arg0, %arg1[3] : f32 into vector<4xf32>
633643
return %0 : vector<4xf32>
@@ -780,10 +790,10 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx(%arg0: vector<1x16xf32>, %a
780790
return %0 : vector<1x16xf32>
781791
}
782792

783-
// Multi-dim vectors are not supported but this test shouldn't crash.
784-
785793
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx(
786-
// CHECK: vector.insert
794+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
795+
// CHECK: llvm.insertelement {{.*}} : vector<16xf32>
796+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<16xf32>>
787797

788798
// -----
789799

@@ -793,10 +803,26 @@ func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(%arg0: vector<1x[1
793803
return %0 : vector<1x[16]xf32>
794804
}
795805

796-
// Multi-dim vectors are not supported but this test shouldn't crash.
797-
798806
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_scalable(
799-
// CHECK: vector.insert
807+
// CHECK: llvm.extractvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
808+
// CHECK: llvm.insertelement {{.*}} : vector<[16]xf32>
809+
// CHECK: llvm.insertvalue {{.*}} : !llvm.array<1 x vector<[16]xf32>>
810+
811+
812+
// -----
813+
814+
func.func @insert_scalar_into_vec_2d_f32_dynamic_idx_fail(%arg0: vector<2x16xf32>, %arg1: f32, %idx: index)
815+
-> vector<2x16xf32> {
816+
%0 = vector.insert %arg1, %arg0[%idx, 0]: f32 into vector<2x16xf32>
817+
return %0 : vector<2x16xf32>
818+
}
819+
820+
// Currently fails to convert because of the dynamic index in non-innermost
821+
// dimension that converts to a llvm.array, as llvm.extractvalue does not
822+
// support dynamic dimensions
823+
824+
// CHECK-LABEL: @insert_scalar_into_vec_2d_f32_dynamic_idx_fail
825+
// CHECK: vector.insert
800826

801827
// -----
802828

0 commit comments

Comments
 (0)