35
35
using namespace mlir ;
36
36
using namespace mlir ::vector;
37
37
38
- // Helper to reduce vector type by *all* but one rank at back.
39
- static VectorType reducedVectorTypeBack (VectorType tp) {
40
- assert ((tp.getRank () > 1 ) && " unlowerable vector type" );
41
- return VectorType::get (tp.getShape ().take_back (), tp.getElementType (),
42
- tp.getScalableDims ().take_back ());
43
- }
44
-
45
38
// Helper that picks the proper sequence for inserting.
46
39
static Value insertOne (ConversionPatternRewriter &rewriter,
47
40
const LLVMTypeConverter &typeConverter, Location loc,
@@ -1223,7 +1216,6 @@ class VectorInsertOpConversion
1223
1216
matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
1224
1217
ConversionPatternRewriter &rewriter) const override {
1225
1218
auto loc = insertOp->getLoc ();
1226
- auto sourceType = insertOp.getSourceType ();
1227
1219
auto destVectorType = insertOp.getDestVectorType ();
1228
1220
auto llvmResultType = typeConverter->convertType (destVectorType);
1229
1221
// Bail if result type cannot be lowered.
@@ -1233,53 +1225,81 @@ class VectorInsertOpConversion
1233
1225
SmallVector<OpFoldResult> positionVec = getMixedValues (
1234
1226
adaptor.getStaticPosition (), adaptor.getDynamicPosition (), rewriter);
1235
1227
1236
- // Overwrite entire vector with value. Should be handled by folder, but
1237
- // just to be safe.
1238
- ArrayRef<OpFoldResult> position (positionVec);
1239
- if (position.empty ()) {
1240
- rewriter.replaceOp (insertOp, adaptor.getSource ());
1241
- return success ();
1242
- }
1243
-
1244
- // One-shot insertion of a vector into an array (only requires insertvalue).
1245
- if (isa<VectorType>(sourceType)) {
1246
- if (insertOp.hasDynamicPosition ())
1247
- return failure ();
1248
-
1249
- Value inserted = rewriter.create <LLVM::InsertValueOp>(
1250
- loc, adaptor.getDest (), adaptor.getSource (), getAsIntegers (position));
1251
- rewriter.replaceOp (insertOp, inserted);
1252
- return success ();
1228
+ // The logic in this pattern mirrors VectorExtractOpConversion. Refer to
1229
+ // its explanatory comment about how N-D vectors are converted as nested
1230
+ // aggregates (llvm.array's) of 1D vectors.
1231
+ //
1232
+ // The innermost dimension of the destination vector, when converted to a
1233
+ // nested aggregate form, will always be a 1D vector.
1234
+ //
1235
+ // * If the insertion is happening into the innermost dimension of the
1236
+ // destination vector:
1237
+ // - If the destination is a nested aggregate, extract a 1D vector out of
1238
+ // the aggregate. This can be done using llvm.extractvalue. The
1239
+ // destination is now guaranteed to be a 1D vector, to which we are
1240
+ // inserting.
1241
+ // - Do the insertion into the 1D destination vector, and make the result
1242
+ // the new source nested aggregate. This can be done using
1243
+ // llvm.insertelement.
1244
+ // * Insert the source nested aggregate into the destination nested
1245
+ // aggregate.
1246
+
1247
+ // Determine if we need to extract/insert a 1D vector out of the aggregate.
1248
+ bool isNestedAggregate = isa<LLVM::LLVMArrayType>(llvmResultType);
1249
+ // Determine if we need to insert a scalar into the 1D vector.
1250
+ bool insertIntoInnermostDim =
1251
+ static_cast <int64_t >(positionVec.size ()) == destVectorType.getRank ();
1252
+
1253
+ ArrayRef<OpFoldResult> positionOf1DVectorWithinAggregate (
1254
+ positionVec.begin (),
1255
+ insertIntoInnermostDim ? positionVec.size () - 1 : positionVec.size ());
1256
+ OpFoldResult positionOfScalarWithin1DVector;
1257
+ if (destVectorType.getRank () == 0 ) {
1258
+ // Since the LLVM type converter converts 0D vectors to 1D vectors, we
1259
+ // need to create a 0 here as the position into the 1D vector.
1260
+ Type idxType = typeConverter->convertType (rewriter.getIndexType ());
1261
+ positionOfScalarWithin1DVector = rewriter.getZeroAttr (idxType);
1262
+ } else if (insertIntoInnermostDim) {
1263
+ positionOfScalarWithin1DVector = positionVec.back ();
1253
1264
}
1254
1265
1255
- // Potential extraction of 1-D vector from array.
1256
- Value extracted = adaptor.getDest ();
1257
- auto oneDVectorType = destVectorType;
1258
- if (position.size () > 1 ) {
1259
- if (insertOp.hasDynamicPosition ())
1260
- return failure ();
1261
-
1262
- oneDVectorType = reducedVectorTypeBack (destVectorType);
1263
- extracted = rewriter.create <LLVM::ExtractValueOp>(
1264
- loc, extracted, getAsIntegers (position.drop_back ()));
1266
+ // We are going to mutate this 1D vector until it is either the final
1267
+ // result (in the non-aggregate case) or the value that needs to be
1268
+ // inserted into the aggregate result.
1269
+ Value sourceAggregate = adaptor.getSource ();
1270
+ if (insertIntoInnermostDim) {
1271
+ // Scalar-into-1D-vector case, so we know we will have to create a
1272
+ // InsertElementOp. The question is into what destination.
1273
+ if (isNestedAggregate) {
1274
+ // Aggregate case: the destination for the InsertElementOp needs to be
1275
+ // extracted from the aggregate.
1276
+ if (!llvm::all_of (positionOf1DVectorWithinAggregate,
1277
+ llvm::IsaPred<Attribute>)) {
1278
+ // llvm.extractvalue does not support dynamic dimensions.
1279
+ return failure ();
1280
+ }
1281
+ sourceAggregate = rewriter.create <LLVM::ExtractValueOp>(
1282
+ loc, adaptor.getDest (),
1283
+ getAsIntegers (positionOf1DVectorWithinAggregate));
1284
+ } else {
1285
+ // No-aggregate case. The destination for the InsertElementOp is just
1286
+ // the insertOp's destination.
1287
+ sourceAggregate = adaptor.getDest ();
1288
+ }
1289
+ // Insert the scalar into the 1D vector.
1290
+ sourceAggregate = rewriter.create <LLVM::InsertElementOp>(
1291
+ loc, sourceAggregate.getType (), sourceAggregate, adaptor.getSource (),
1292
+ getAsLLVMValue (rewriter, loc, positionOfScalarWithin1DVector));
1265
1293
}
1266
1294
1267
- // Insertion of an element into a 1-D LLVM vector.
1268
- Value inserted = rewriter.create <LLVM::InsertElementOp>(
1269
- loc, typeConverter->convertType (oneDVectorType), extracted,
1270
- adaptor.getSource (), getAsLLVMValue (rewriter, loc, position.back ()));
1271
-
1272
- // Potential insertion of resulting 1-D vector into array.
1273
- if (position.size () > 1 ) {
1274
- if (insertOp.hasDynamicPosition ())
1275
- return failure ();
1276
-
1277
- inserted = rewriter.create <LLVM::InsertValueOp>(
1278
- loc, adaptor.getDest (), inserted,
1279
- getAsIntegers (position.drop_back ()));
1295
+ Value result = sourceAggregate;
1296
+ if (isNestedAggregate) {
1297
+ result = rewriter.create <LLVM::InsertValueOp>(
1298
+ loc, adaptor.getDest (), sourceAggregate,
1299
+ getAsIntegers (positionOf1DVectorWithinAggregate));
1280
1300
}
1281
1301
1282
- rewriter.replaceOp (insertOp, inserted );
1302
+ rewriter.replaceOp (insertOp, result );
1283
1303
return success ();
1284
1304
}
1285
1305
};
0 commit comments