@@ -44,6 +44,19 @@ static bool isLessThanTargetBitWidth(Operation *op, unsigned targetBitWidth) {
44
44
return true ;
45
45
}
46
46
47
+ static bool isLessThanOrEqualTargetBitWidth (Type t, unsigned targetBitWidth) {
48
+ VectorType vecType = dyn_cast<VectorType>(t);
49
+ // Reject index since getElementTypeBitWidth will abort for Index types.
50
+ if (!vecType || vecType.getElementType ().isIndex ())
51
+ return false ;
52
+ // There are no dimension to fold if it is a 0-D vector.
53
+ if (vecType.getRank () == 0 )
54
+ return false ;
55
+ unsigned trailingVecDimBitWidth =
56
+ vecType.getShape ().back () * vecType.getElementTypeBitWidth ();
57
+ return trailingVecDimBitWidth <= targetBitWidth;
58
+ }
59
+
47
60
namespace {
48
61
struct LinearizeConstant final : OpConversionPattern<arith::ConstantOp> {
49
62
using OpConversionPattern::OpConversionPattern;
@@ -355,6 +368,88 @@ struct LinearizeVectorExtract final
355
368
return success ();
356
369
}
357
370
371
+ private:
372
+ unsigned targetVectorBitWidth;
373
+ };
374
+
375
+ // / This pattern converts the InsertOp to a ShuffleOp that works on a
376
+ // / linearized vector.
377
+ // / Following,
378
+ // / vector.insert %source %destination [ position ]
379
+ // / is converted to :
380
+ // / %source_1d = vector.shape_cast %source
381
+ // / %destination_1d = vector.shape_cast %destination
382
+ // / %out_1d = vector.shuffle %destination_1d, %source_1d [ shuffle_indices_1d
383
+ // / ] %out_nd = vector.shape_cast %out_1d
384
+ // / `shuffle_indices_1d` is computed using the position of the original insert.
385
+ struct LinearizeVectorInsert final
386
+ : public OpConversionPattern<vector::InsertOp> {
387
+ using OpConversionPattern::OpConversionPattern;
388
+ LinearizeVectorInsert (
389
+ const TypeConverter &typeConverter, MLIRContext *context,
390
+ unsigned targetVectBitWidth = std::numeric_limits<unsigned >::max(),
391
+ PatternBenefit benefit = 1 )
392
+ : OpConversionPattern(typeConverter, context, benefit),
393
+ targetVectorBitWidth (targetVectBitWidth) {}
394
+ LogicalResult
395
+ matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
396
+ ConversionPatternRewriter &rewriter) const override {
397
+ Type dstTy = getTypeConverter ()->convertType (insertOp.getDestVectorType ());
398
+ assert (!(insertOp.getDestVectorType ().isScalable () ||
399
+ cast<VectorType>(dstTy).isScalable ()) &&
400
+ " scalable vectors are not supported." );
401
+
402
+ if (!isLessThanOrEqualTargetBitWidth (insertOp.getSourceType (),
403
+ targetVectorBitWidth))
404
+ return rewriter.notifyMatchFailure (
405
+ insertOp, " Can't flatten since targetBitWidth < OpSize" );
406
+
407
+ // dynamic position is not supported
408
+ if (insertOp.hasDynamicPosition ())
409
+ return rewriter.notifyMatchFailure (insertOp,
410
+ " dynamic position is not supported." );
411
+ auto srcTy = insertOp.getSourceType ();
412
+ auto srcAsVec = dyn_cast<VectorType>(srcTy);
413
+ uint64_t srcSize = 0 ;
414
+ if (srcAsVec) {
415
+ srcSize = srcAsVec.getNumElements ();
416
+ } else {
417
+ return rewriter.notifyMatchFailure (insertOp,
418
+ " scalars are not supported." );
419
+ }
420
+
421
+ auto dstShape = insertOp.getDestVectorType ().getShape ();
422
+ const auto dstSize = insertOp.getDestVectorType ().getNumElements ();
423
+ auto dstSizeForOffsets = dstSize;
424
+
425
+ // compute linearized offset
426
+ int64_t linearizedOffset = 0 ;
427
+ auto offsetsNd = insertOp.getStaticPosition ();
428
+ for (auto [dim, offset] : llvm::enumerate (offsetsNd)) {
429
+ dstSizeForOffsets /= dstShape[dim];
430
+ linearizedOffset += offset * dstSizeForOffsets;
431
+ }
432
+
433
+ llvm::SmallVector<int64_t , 2 > indices (dstSize);
434
+ auto origValsUntil = indices.begin ();
435
+ std::advance (origValsUntil, linearizedOffset);
436
+ std::iota (indices.begin (), origValsUntil,
437
+ 0 ); // original values that remain [0, offset)
438
+ auto newValsUntil = origValsUntil;
439
+ std::advance (newValsUntil, srcSize);
440
+ std::iota (origValsUntil, newValsUntil,
441
+ dstSize); // new values [offset, offset+srcNumElements)
442
+ std::iota (newValsUntil, indices.end (),
443
+ linearizedOffset + srcSize); // the rest of original values
444
+ // [offset+srcNumElements, end)
445
+
446
+ rewriter.replaceOpWithNewOp <vector::ShuffleOp>(
447
+ insertOp, dstTy, adaptor.getDest (), adaptor.getSource (),
448
+ rewriter.getI64ArrayAttr (indices));
449
+
450
+ return success ();
451
+ }
452
+
358
453
private:
359
454
unsigned targetVectorBitWidth;
360
455
};
@@ -410,6 +505,6 @@ void mlir::vector::populateVectorLinearizeShuffleLikeOpsPatterns(
410
505
: true ;
411
506
});
412
507
patterns.add <LinearizeVectorShuffle, LinearizeVectorExtract,
413
- LinearizeVectorExtractStridedSlice>(
508
+ LinearizeVectorInsert, LinearizeVectorExtractStridedSlice>(
414
509
typeConverter, patterns.getContext (), targetBitWidth);
415
510
}
0 commit comments