@@ -59,6 +59,30 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
59
59
ArrayRef<bool > inputVecScalableFlags = {},
60
60
bool flatten1DDepthwiseConv = false );
61
61
62
+ // / Vectorize tensor::InsertSliceOp with:
63
+ // / * vector::TransferReadOp + vector::TransferWriteOp
64
+ // / The vector sizes are either:
65
+ // / * user-provided in `inputVectorSizes`, or
66
+ // / * inferred from the static dims in the input and output tensors.
67
+ // / Bails out if:
68
+ // / * vector sizes are not user-provided, and
69
+ // / * at least one dim is dynamic (in both the input and output tensors),
70
+ // / bails out.
71
+ // /
72
+ // / Before:
73
+ // / !t_in_type = tensor<1x2x3xf32>
74
+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
75
+ // / !v_type = vector<1x2x3xf32>
76
+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
77
+ // / into !t_out_type
78
+ // / After:
79
+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
80
+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
81
+ static LogicalResult
82
+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
83
+ ArrayRef<int64_t > inputVectorSizes,
84
+ SmallVectorImpl<Value> &newResults);
85
+
62
86
// / Return the unique instance of OpType in `block` if it is indeed unique.
63
87
// / Return null if none or more than 1 instances exist.
64
88
template <typename OpType>
@@ -1557,6 +1581,7 @@ static LogicalResult
1557
1581
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1558
1582
ArrayRef<int64_t > inputVectorSizes,
1559
1583
SmallVectorImpl<Value> &newResults) {
1584
+ // TODO: Introduce a parent class that will handle the insertion point update.
1560
1585
OpBuilder::InsertionGuard g (rewriter);
1561
1586
rewriter.setInsertionPoint (packOp);
1562
1587
@@ -1633,6 +1658,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1633
1658
ArrayRef<int64_t > inputVectorSizes,
1634
1659
SmallVectorImpl<Value> &newResults) {
1635
1660
1661
+ // TODO: Introduce a parent class that will handle the insertion point update.
1636
1662
OpBuilder::InsertionGuard g (rewriter);
1637
1663
rewriter.setInsertionPoint (unpackOp);
1638
1664
@@ -1763,7 +1789,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1763
1789
auto padValue = padOp.getConstantPaddingValue ();
1764
1790
Location loc = padOp.getLoc ();
1765
1791
1766
- // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1792
+ // TODO: Introduce a parent class that will handle the insertion point update.
1767
1793
OpBuilder::InsertionGuard g (rewriter);
1768
1794
rewriter.setInsertionPoint (padOp);
1769
1795
@@ -1874,6 +1900,15 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1874
1900
return success ();
1875
1901
}
1876
1902
1903
+ // / Need to check if the inner-tiles are static/constant.
1904
+ static LogicalResult
1905
+ vectorizeInsertSliceOpPrecondition (tensor::InsertSliceOp sliceOp,
1906
+ ArrayRef<int64_t > inputVectorSizes) {
1907
+
1908
+ // TODO: Move pre-conditions from the vectorization logic
1909
+ return success ();
1910
+ }
1911
+
1877
1912
static LogicalResult vectorizeLinalgOpPrecondition (
1878
1913
LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
1879
1914
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2179,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
2144
2179
.Case <tensor::UnPackOp>([&](auto unpackOp) {
2145
2180
return vectorizeUnPackOpPrecondition (unpackOp, inputVectorSizes);
2146
2181
})
2182
+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2183
+ return vectorizeInsertSliceOpPrecondition (sliceOp, inputVectorSizes);
2184
+ })
2147
2185
.Default ([](auto ) { return failure (); });
2148
2186
}
2149
2187
@@ -2163,8 +2201,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2163
2201
}
2164
2202
2165
2203
bool mlir::linalg::hasVectorizationImpl (Operation *op) {
2166
- return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167
- op);
2204
+ return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2205
+ tensor::InsertSliceOp>( op);
2168
2206
}
2169
2207
2170
2208
// / Emit a suitable vector form for an operation. If provided,
@@ -2178,6 +2216,7 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2178
2216
ArrayRef<bool > inputScalableVecDims,
2179
2217
bool vectorizeNDExtract,
2180
2218
bool flatten1DDepthwiseConv) {
2219
+ rewriter.getInsertionPoint ();
2181
2220
LDBG (" Attempting to vectorize:\n " << *op << " \n " );
2182
2221
LDBG (" Input vector sizes: " );
2183
2222
LLVM_DEBUG (llvm::interleaveComma (inputVectorSizes, llvm::dbgs ()));
@@ -2244,6 +2283,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2244
2283
return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
2245
2284
results);
2246
2285
})
2286
+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2287
+ return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
2288
+ results);
2289
+ })
2247
2290
.Case <tensor::UnPackOp>([&](auto unpackOp) {
2248
2291
return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2249
2292
inputVectorSizes, results);
@@ -2583,113 +2626,139 @@ static Value getStaticPadVal(Operation *op) {
2583
2626
return {};
2584
2627
}
2585
2628
2586
- // / Rewrite tensor.insert.slice as a vector.transfer_read +
2587
- // / vector.transfer_write pair. The vector size is inferred from the static
2588
- // / dims in the input and output tensors. If a dim is dynamic in both the input
2589
- // / and output tensors, bails out.
2590
- // /
2591
- // / Before:
2592
- // / !t_in_type = tensor<1x2x3xf32>
2593
- // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2594
- // / !v_type = vector<1x2x3xf32>
2595
- // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596
- // / into !t_out_type
2597
- // / After:
2598
- // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599
- // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600
- // /
2601
- // / TODO: Support masking
2602
- struct InsertSliceVectorizePattern
2603
- : public OpRewritePattern<tensor::InsertSliceOp> {
2604
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2629
+ static LogicalResult
2630
+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2631
+ ArrayRef<int64_t > inputVectorSizes,
2632
+ SmallVectorImpl<Value> &newResults) {
2633
+ // TODO: Introduce a parent class that will handle the insertion point update.
2634
+ OpBuilder::InsertionGuard g (rewriter);
2635
+ rewriter.setInsertionPoint (sliceOp);
2605
2636
2606
- LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2607
- PatternRewriter &rewriter) const final {
2608
- auto sourceType = sliceOp.getSource ().getType ();
2609
- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2610
- return failure ();
2637
+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
2638
+ auto sourceType = source.getType ();
2639
+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
2640
+ return failure ();
2611
2641
2612
- auto resultType = sliceOp.getResultType ();
2613
-
2614
- // 1. Get the pad value.
2615
- // TransferReadOp requires a scalar padding value. Note that:
2616
- // * for in-bounds access, the value is actually irrelevant.
2617
- // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618
- // 1. The source shape is static (output vector sizes would be based on
2619
- // the source shape and hence all memory accesses would be in-bounds),
2620
- // 2. Masking is used (output vector sizes would be user-provided, in which
2621
- // case it is assumed that all memory accesses are in-bounds). This
2622
- // remains a TODO.
2623
- //
2624
- // When the value is not known and not needed, use 0. Otherwise, bail out.
2625
- Value padValue = getStaticPadVal (sliceOp);
2626
- bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2627
-
2628
- if (!padValue && isOutOfBoundsRead) {
2629
- LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2642
+ auto resultType = sliceOp.getResultType ();
2643
+
2644
+ // 1. Get the pad value.
2645
+ // TransferReadOp requires a scalar padding value. Note that:
2646
+ // * for in-bounds access, the value is actually irrelevant.
2647
+ // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2648
+ // 1. The source shape is static (output vector sizes would be based on
2649
+ // the source shape and hence all memory accesses would be in-bounds),
2650
+ // 2. Masking is used (output vector sizes would be user-provided, in which
2651
+ // case it is assumed that all memory accesses are in-bounds). This
2652
+ // remains a TODO.
2653
+ //
2654
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
2655
+ Value padValue = getStaticPadVal (sliceOp);
2656
+ bool isOutOfBoundsRead =
2657
+ !sourceType.hasStaticShape () && inputVectorSizes.empty ();
2658
+
2659
+ if (!padValue && isOutOfBoundsRead) {
2660
+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2661
+ return failure ();
2662
+ }
2663
+
2664
+ if (!padValue) {
2665
+ auto elemType = sourceType.getElementType ();
2666
+ padValue = rewriter.create <arith::ConstantOp>(
2667
+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2668
+ }
2669
+
2670
+ // 2. Get the vector shape and in-bounds attributes
2671
+ SmallVector<int64_t > vecShape;
2672
+ SmallVector<bool > readInBounds;
2673
+ SmallVector<bool > writeInBounds;
2674
+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2675
+ for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2676
+ if (!inputVectorSizes.empty ()) {
2677
+ vecShape.push_back (inputVectorSizes[i]);
2678
+ readInBounds.push_back (false );
2679
+ writeInBounds.push_back (false );
2680
+ } else if (!sourceType.isDynamicDim (i)) {
2681
+ vecShape.push_back (sourceType.getDimSize (i));
2682
+ // Source shape is statically known: Neither read nor write are
2683
+ // out-of-bounds.
2684
+ readInBounds.push_back (true );
2685
+ writeInBounds.push_back (true );
2686
+ } else if (!resultType.isDynamicDim (i)) {
2687
+ // Source shape is not statically known, but result shape is.
2688
+ // Vectorize with size of result shape. This may be larger than the
2689
+ // source size.
2690
+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2691
+ // the end of the destination tensor. However, that's not required.
2692
+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2693
+ // Read may be out-of-bounds because the result size could be larger
2694
+ // than the source size.
2695
+ readInBounds.push_back (false );
2696
+ // Write will be in-bounds provided that the corresponding write idx is 0.
2697
+ // To keep this logic simple, conservatively mark as out-of-bounds.
2698
+ writeInBounds.push_back (false );
2699
+ } else {
2700
+ // Neither source nor result dim of padOp is static. Cannot vectorize
2701
+ // the copy.
2702
+ // TODO: Add support for masking
2630
2703
return failure ();
2631
2704
}
2705
+ }
2706
+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2632
2707
2633
- if (!padValue) {
2634
- auto elemType = sourceType.getElementType ();
2635
- padValue = rewriter.create <arith::ConstantOp>(
2636
- sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2637
- }
2708
+ // 3. Generate TransferReadOp.
2709
+ SmallVector<Value> readIndices (
2710
+ vecType.getRank (),
2711
+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2712
+ Operation *read = rewriter.create <vector::TransferReadOp>(
2713
+ sliceOp.getLoc (), vecType, source, readIndices, padValue,
2714
+ ArrayRef<bool >{readInBounds});
2638
2715
2639
- // 2. Get the vector shape and in-bounds attributes
2640
- SmallVector<int64_t > vecShape;
2641
- SmallVector<bool > readInBounds;
2642
- SmallVector<bool > writeInBounds;
2643
- size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2644
- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2645
- if (!sourceType.isDynamicDim (i)) {
2646
- vecShape.push_back (sourceType.getDimSize (i));
2647
- // Source shape is statically known: Neither read nor write are
2648
- // out-of-bounds.
2649
- readInBounds.push_back (true );
2650
- writeInBounds.push_back (true );
2651
- } else if (!resultType.isDynamicDim (i)) {
2652
- // Source shape is not statically known, but result shape is.
2653
- // Vectorize with size of result shape. This may be larger than the
2654
- // source size.
2655
- // FIXME: Using rankDiff implies that the source tensor is inserted at
2656
- // the end of the destination tensor. However, that's not required.
2657
- vecShape.push_back (resultType.getDimSize (rankDiff + i));
2658
- // Read may be out-of-bounds because the result size could be larger
2659
- // than the source size.
2660
- readInBounds.push_back (false );
2661
- // Write will in-bounds provided that the corresponding write idx is 0.
2662
- // To keep this logic simple, conservatively mark as out-of-bounds.
2663
- writeInBounds.push_back (false );
2664
- } else {
2665
- // Neither source nor result dim of padOp is static. Cannot vectorize
2666
- // the copy.
2667
- // TODO: Add support for masking
2668
- return failure ();
2669
- }
2716
+ // If vector sizes are user provided, make sure to mask xfer_read.
2717
+ if (!inputVectorSizes.empty ()) {
2718
+ auto *srcDefOp = source.getDefiningOp ();
2719
+ if (!srcDefOp) {
2720
+ LDBG (" Unable to get the defining Op of " << sliceOp);
2721
+ return failure ();
2670
2722
}
2671
- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2672
2723
2673
- // 3. Generate TransferReadOp.
2674
- SmallVector<Value> readIndices (
2675
- vecType.getRank (),
2676
- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2677
- auto read = rewriter.create <vector::TransferReadOp>(
2678
- sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2679
- ArrayRef<bool >{readInBounds});
2724
+ ReifiedRankedShapedTypeDims reifiedSrcSizes;
2725
+ LogicalResult status =
2726
+ cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2727
+ rewriter, reifiedSrcSizes);
2728
+ if (status.failed ()) {
2729
+ LDBG (" Unable to reify result shapes of " << sliceOp);
2730
+ return failure ();
2731
+ }
2680
2732
2681
- // 4. Generate TransferWriteOp.
2682
- auto writeIndices = getValueOrCreateConstantIndexOp (
2683
- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2733
+ // Create the mask
2734
+ SmallVector<int64_t > readMaskShape (
2735
+ sliceOp.getSource ().getType ().getShape ());
2736
+ auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2737
+ Value maskOp = rewriter.create <vector::CreateMaskOp>(
2738
+ sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2684
2739
2685
- // 5. Finalize
2686
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2687
- sliceOp, read, sliceOp.getDest (), writeIndices,
2688
- ArrayRef<bool >{writeInBounds});
2740
+ // Mask the xfer_read Op
2741
+ read = mlir::vector::maskOperation (rewriter, read, maskOp);
2742
+ }
2689
2743
2690
- return success ();
2744
+ // 4. Generate TransferWriteOp.
2745
+ if (!inputVectorSizes.empty () &&
2746
+ ShapedType::isDynamicShape (resultType.getShape ())) {
2747
+ LDBG (" TODO: Masking of xfer_write when vectorising " << sliceOp);
2748
+ return failure ();
2691
2749
}
2692
- };
2750
+
2751
+ auto writeIndices = getValueOrCreateConstantIndexOp (
2752
+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2753
+
2754
+ // 5. Finalize
2755
+ Operation *write = rewriter.create <vector::TransferWriteOp>(
2756
+ sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2757
+ ArrayRef<bool >{writeInBounds});
2758
+ newResults.push_back (write->getResult (0 ));
2759
+
2760
+ return success ();
2761
+ }
2693
2762
2694
2763
// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2695
2764
// / ```
@@ -2778,11 +2847,6 @@ struct PadOpVectorizationWithInsertSlicePattern
2778
2847
}
2779
2848
};
2780
2849
2781
- void mlir::linalg::populateInsertSliceVectorizationPatterns (
2782
- RewritePatternSet &patterns) {
2783
- patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2784
- }
2785
-
2786
2850
void mlir::linalg::populatePadOpVectorizationPatterns (
2787
2851
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2788
2852
patterns.add <PadOpVectorizationWithTransferReadPattern,
0 commit comments