@@ -59,6 +59,37 @@ vectorizeConvolution(RewriterBase &rewriter, LinalgOp convOp,
59
59
ArrayRef<bool > inputVecScalableFlags = {},
60
60
bool flatten1DDepthwiseConv = false );
61
61
62
+ // / Vectorize tensor::InsertSliceOp with:
63
+ // / * vector::TransferReadOp + vector::TransferWriteOp
64
+ // / The vector sizes are either:
65
+ // / * user-provided in `inputVectorSizes`, or
66
+ // / * inferred from the static dims in the input and output tensors.
67
+ // / Bails out if:
68
+ // / * vector sizes are not user-provided, and
69
+ // / * at least one dim is dynamic (in both the input and output tensors).
70
+ // /
71
+ // / Before:
72
+ // / !t_in_type = tensor<1x2x3xf32>
73
+ // / !t_out_type = tensor<9x8x7x1x2x3xf32>
74
+ // / !v_type = vector<1x2x3xf32>
75
+ // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
76
+ // / into !t_out_type
77
+ // / After:
78
+ // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
79
+ // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
80
+ static LogicalResult
81
+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
82
+ ArrayRef<int64_t > inputVectorSizes,
83
+ SmallVectorImpl<Value> &newResults);
84
+
85
+ // / Returns the effective Pad value for the input op, provided it's a scalar.
86
+ // /
87
+ // / Many Ops exhibit pad-like behaviour, but this isn't always explicit. If
88
+ // / this Op performs padding, retrieve the padding value provided that it's
89
+ // / a scalar and static/fixed for all the padded values. Returns an empty value
90
+ // / otherwise.
91
+ static Value getStaticPadVal (Operation *op);
92
+
62
93
// / Return the unique instance of OpType in `block` if it is indeed unique.
63
94
// / Return null if none or more than 1 instances exist.
64
95
template <typename OpType>
@@ -1557,6 +1588,7 @@ static LogicalResult
1557
1588
vectorizeAsTensorPackOp (RewriterBase &rewriter, tensor::PackOp packOp,
1558
1589
ArrayRef<int64_t > inputVectorSizes,
1559
1590
SmallVectorImpl<Value> &newResults) {
1591
+ // TODO: Introduce a parent class that will handle the insertion point update.
1560
1592
OpBuilder::InsertionGuard g (rewriter);
1561
1593
rewriter.setInsertionPoint (packOp);
1562
1594
@@ -1633,6 +1665,7 @@ vectorizeAsTensorUnpackOp(RewriterBase &rewriter, tensor::UnPackOp unpackOp,
1633
1665
ArrayRef<int64_t > inputVectorSizes,
1634
1666
SmallVectorImpl<Value> &newResults) {
1635
1667
1668
+ // TODO: Introduce a parent class that will handle the insertion point update.
1636
1669
OpBuilder::InsertionGuard g (rewriter);
1637
1670
rewriter.setInsertionPoint (unpackOp);
1638
1671
@@ -1763,7 +1796,7 @@ vectorizeAsTensorPadOp(RewriterBase &rewriter, tensor::PadOp padOp,
1763
1796
auto padValue = padOp.getConstantPaddingValue ();
1764
1797
Location loc = padOp.getLoc ();
1765
1798
1766
- // transfer_write_in_bounds(transfer_read_masked(pad_source, pad_value))
1799
+ // TODO: Introduce a parent class that will handle the insertion point update.
1767
1800
OpBuilder::InsertionGuard g (rewriter);
1768
1801
rewriter.setInsertionPoint (padOp);
1769
1802
@@ -1874,6 +1907,38 @@ vectorizeUnPackOpPrecondition(tensor::UnPackOp unpackOp,
1874
1907
return success ();
1875
1908
}
1876
1909
1910
+ static LogicalResult
1911
+ vectorizeInsertSliceOpPrecondition (tensor::InsertSliceOp sliceOp,
1912
+ ArrayRef<int64_t > inputVectorSizes) {
1913
+
1914
+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
1915
+ auto sourceType = source.getType ();
1916
+ if (!VectorType::isValidElementType (sourceType.getElementType ()))
1917
+ return failure ();
1918
+
1919
+ // Get the pad value.
1920
+ // TransferReadOp (which is used to vectorize InsertSliceOp), requires a
1921
+ // scalar padding value. Note that:
1922
+ // * for in-bounds accesses,
1923
+ // the value is actually irrelevant. There are 2 cases in which xfer.read
1924
+ // accesses are known to be in-bounds:
1925
+ // 1. The source shape is static (output vector sizes would be based on
1926
+ // the source shape and hence all memory accesses would be in-bounds),
1927
+ // 2. Masking is used, i.e. the output vector sizes are user-provided. In
1928
+ // this case it is safe to assume that all memory accesses are in-bounds.
1929
+ //
1930
+ // When the value is not known and not needed, use 0. Otherwise, bail out.
1931
+ Value padValue = getStaticPadVal (sliceOp);
1932
+ bool isOutOfBoundsRead =
1933
+ !sourceType.hasStaticShape () && inputVectorSizes.empty ();
1934
+
1935
+ if (!padValue && isOutOfBoundsRead) {
1936
+ LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
1937
+ return failure ();
1938
+ }
1939
+ return success ();
1940
+ }
1941
+
1877
1942
static LogicalResult vectorizeLinalgOpPrecondition (
1878
1943
LinalgOp linalgOp, ArrayRef<int64_t > inputVectorSizes,
1879
1944
bool vectorizeNDExtract, bool flatten1DDepthwiseConv) {
@@ -2144,6 +2209,9 @@ LogicalResult mlir::linalg::vectorizeOpPrecondition(
2144
2209
.Case <tensor::UnPackOp>([&](auto unpackOp) {
2145
2210
return vectorizeUnPackOpPrecondition (unpackOp, inputVectorSizes);
2146
2211
})
2212
+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2213
+ return vectorizeInsertSliceOpPrecondition (sliceOp, inputVectorSizes);
2214
+ })
2147
2215
.Default ([](auto ) { return failure (); });
2148
2216
}
2149
2217
@@ -2163,8 +2231,8 @@ static void convertAffineApply(RewriterBase &rewriter, LinalgOp linalgOp) {
2163
2231
}
2164
2232
2165
2233
bool mlir::linalg::hasVectorizationImpl (Operation *op) {
2166
- return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp>(
2167
- op);
2234
+ return isa<linalg::LinalgOp, tensor::PadOp, tensor::PackOp, tensor::UnPackOp,
2235
+ tensor::InsertSliceOp>( op);
2168
2236
}
2169
2237
2170
2238
// / Emit a suitable vector form for an operation. If provided,
@@ -2244,6 +2312,10 @@ LogicalResult mlir::linalg::vectorize(RewriterBase &rewriter, Operation *op,
2244
2312
return vectorizeAsTensorPackOp (rewriter, packOp, inputVectorSizes,
2245
2313
results);
2246
2314
})
2315
+ .Case <tensor::InsertSliceOp>([&](auto sliceOp) {
2316
+ return vectorizeAsInsertSliceOp (rewriter, sliceOp, inputVectorSizes,
2317
+ results);
2318
+ })
2247
2319
.Case <tensor::UnPackOp>([&](auto unpackOp) {
2248
2320
return vectorizeAsTensorUnpackOp (rewriter, unpackOp,
2249
2321
inputVectorSizes, results);
@@ -2540,6 +2612,9 @@ struct PadOpVectorizationWithTransferWritePattern
2540
2612
// / this Op performs padding, retrieve the padding value provided that it's
2541
2613
// / a scalar and static/fixed for all the padded values. Returns an empty value
2542
2614
// / otherwise.
2615
+ // /
2616
+ // / TODO: This is used twice (when checking vectorization pre-conditions and
2617
+ // / when vectorizing). Cache results instead of re-running.
2543
2618
static Value getStaticPadVal (Operation *op) {
2544
2619
if (!op)
2545
2620
return {};
@@ -2583,113 +2658,118 @@ static Value getStaticPadVal(Operation *op) {
2583
2658
return {};
2584
2659
}
2585
2660
2586
- // / Rewrite tensor.insert.slice as a vector.transfer_read +
2587
- // / vector.transfer_write pair. The vector size is inferred from the static
2588
- // / dims in the input and output tensors. If a dim is dynamic in both the input
2589
- // / and output tensors, bails out.
2590
- // /
2591
- // / Before:
2592
- // / !t_in_type = tensor<1x2x3xf32>
2593
- // / !t_out_type = tensor<9x8x7x1x2x3xf32>
2594
- // / !v_type = vector<1x2x3xf32>
2595
- // / %inserted_slice = tensor.insert_slice %src into %dest ... : !t_in_type
2596
- // / into !t_out_type
2597
- // / After:
2598
- // / %read = vector.transfer_read %src[...], %pad ... : !t_in_type, !v_type
2599
- // / %write = vector.transfer_write %read, %dest ... : !v_type, !t_out_type
2600
- // /
2601
- // / TODO: Support masking
2602
- struct InsertSliceVectorizePattern
2603
- : public OpRewritePattern<tensor::InsertSliceOp> {
2604
- using OpRewritePattern<tensor::InsertSliceOp>::OpRewritePattern;
2661
+ static LogicalResult
2662
+ vectorizeAsInsertSliceOp (RewriterBase &rewriter, tensor::InsertSliceOp sliceOp,
2663
+ ArrayRef<int64_t > inputVectorSizes,
2664
+ SmallVectorImpl<Value> &newResults) {
2665
+ // TODO: Introduce a parent class that will handle the insertion point update.
2666
+ OpBuilder::InsertionGuard g (rewriter);
2667
+ rewriter.setInsertionPoint (sliceOp);
2605
2668
2606
- LogicalResult matchAndRewrite (tensor::InsertSliceOp sliceOp,
2607
- PatternRewriter &rewriter) const final {
2608
- auto sourceType = sliceOp.getSource ().getType ();
2609
- if (!VectorType::isValidElementType (sourceType.getElementType ()))
2610
- return failure ();
2669
+ TypedValue<RankedTensorType> source = sliceOp.getSource ();
2670
+ auto sourceType = source.getType ();
2671
+ auto resultType = sliceOp.getResultType ();
2611
2672
2612
- auto resultType = sliceOp.getResultType ();
2613
-
2614
- // 1. Get the pad value.
2615
- // TransferReadOp requires a scalar padding value. Note that:
2616
- // * for in-bounds access, the value is actually irrelevant.
2617
- // There are 2 cases in which xfer.read accesses are known to be in-bounds:
2618
- // 1. The source shape is static (output vector sizes would be based on
2619
- // the source shape and hence all memory accesses would be in-bounds),
2620
- // 2. Masking is used (output vector sizes would be user-provided, in which
2621
- // case it is assumed that all memory accesses are in-bounds). This
2622
- // remains a TODO.
2623
- //
2624
- // When the value is not known and not needed, use 0. Otherwise, bail out.
2625
- Value padValue = getStaticPadVal (sliceOp);
2626
- bool isOutOfBoundsRead = !sourceType.hasStaticShape ();
2627
-
2628
- if (!padValue && isOutOfBoundsRead) {
2629
- LDBG (" Failed to get a pad value for out-of-bounds read access\n " );
2673
+ Value padValue = getStaticPadVal (sliceOp);
2674
+
2675
+ if (!padValue) {
2676
+ auto elemType = sourceType.getElementType ();
2677
+ padValue = rewriter.create <arith::ConstantOp>(
2678
+ sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2679
+ }
2680
+
2681
+ // 2. Get the vector shape and in-bounds attributes
2682
+ SmallVector<int64_t > vecShape;
2683
+ SmallVector<bool > readInBounds;
2684
+ SmallVector<bool > writeInBounds;
2685
+ size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2686
+ for (int64_t i = 0 , end = sourceType.getRank (); i < end; ++i) {
2687
+ if (!inputVectorSizes.empty ()) {
2688
+ vecShape.push_back (inputVectorSizes[i]);
2689
+ readInBounds.push_back (false );
2690
+ writeInBounds.push_back (false );
2691
+ } else if (!sourceType.isDynamicDim (i)) {
2692
+ vecShape.push_back (sourceType.getDimSize (i));
2693
+ // Source shape is statically known: Neither read nor write are
2694
+ // out-of-bounds.
2695
+ readInBounds.push_back (true );
2696
+ writeInBounds.push_back (true );
2697
+ } else if (!resultType.isDynamicDim (i)) {
2698
+ // Source shape is not statically known, but result shape is.
2699
+ // Vectorize with size of result shape. This may be larger than the
2700
+ // source size.
2701
+ // FIXME: Using rankDiff implies that the source tensor is inserted at
2702
+ // the end of the destination tensor. However, that's not required.
2703
+ vecShape.push_back (resultType.getDimSize (rankDiff + i));
2704
+ // Read may be out-of-bounds because the result size could be larger
2705
+ // than the source size.
2706
+ readInBounds.push_back (false );
2707
+ // Write will be in-bounds provided that the corresponding write idx is 0.
2708
+ // To keep this logic simple, conservatively mark as out-of-bounds.
2709
+ writeInBounds.push_back (false );
2710
+ } else {
2711
+ // Neither source nor result dim of padOp is static. Cannot vectorize
2712
+ // the copy.
2713
+ // TODO: Add support for masking
2630
2714
return failure ();
2631
2715
}
2716
+ }
2717
+ auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2632
2718
2633
- if (!padValue) {
2634
- auto elemType = sourceType.getElementType ();
2635
- padValue = rewriter.create <arith::ConstantOp>(
2636
- sliceOp.getLoc (), elemType, rewriter.getZeroAttr (elemType));
2637
- }
2719
+ // 3. Generate TransferReadOp.
2720
+ SmallVector<Value> readIndices (
2721
+ vecType.getRank (),
2722
+ rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2723
+ Operation *read = rewriter.create <vector::TransferReadOp>(
2724
+ sliceOp.getLoc (), vecType, source, readIndices, padValue,
2725
+ ArrayRef<bool >{readInBounds});
2638
2726
2639
- // 2. Get the vector shape and in-bounds attributes
2640
- SmallVector<int64_t > vecShape;
2641
- SmallVector<bool > readInBounds;
2642
- SmallVector<bool > writeInBounds;
2643
- size_t rankDiff = resultType.getRank () - sourceType.getRank ();
2644
- for (unsigned i = 0 ; i < sourceType.getRank (); ++i) {
2645
- if (!sourceType.isDynamicDim (i)) {
2646
- vecShape.push_back (sourceType.getDimSize (i));
2647
- // Source shape is statically known: Neither read nor write are
2648
- // out-of-bounds.
2649
- readInBounds.push_back (true );
2650
- writeInBounds.push_back (true );
2651
- } else if (!resultType.isDynamicDim (i)) {
2652
- // Source shape is not statically known, but result shape is.
2653
- // Vectorize with size of result shape. This may be larger than the
2654
- // source size.
2655
- // FIXME: Using rankDiff implies that the source tensor is inserted at
2656
- // the end of the destination tensor. However, that's not required.
2657
- vecShape.push_back (resultType.getDimSize (rankDiff + i));
2658
- // Read may be out-of-bounds because the result size could be larger
2659
- // than the source size.
2660
- readInBounds.push_back (false );
2661
- // Write will in-bounds provided that the corresponding write idx is 0.
2662
- // To keep this logic simple, conservatively mark as out-of-bounds.
2663
- writeInBounds.push_back (false );
2664
- } else {
2665
- // Neither source nor result dim of padOp is static. Cannot vectorize
2666
- // the copy.
2667
- // TODO: Add support for masking
2668
- return failure ();
2669
- }
2727
+ // If vector sizes are user provided, make sure to mask xfer_read.
2728
+ if (!inputVectorSizes.empty ()) {
2729
+ auto *srcDefOp = source.getDefiningOp ();
2730
+ if (!srcDefOp) {
2731
+ LDBG (" Unable to get the defining Op of " << sliceOp);
2732
+ return failure ();
2670
2733
}
2671
- auto vecType = VectorType::get (vecShape, sourceType.getElementType ());
2672
2734
2673
- // 3. Generate TransferReadOp.
2674
- SmallVector<Value> readIndices (
2675
- vecType.getRank (),
2676
- rewriter.create <arith::ConstantIndexOp>(sliceOp.getLoc (), 0 ));
2677
- auto read = rewriter.create <vector::TransferReadOp>(
2678
- sliceOp.getLoc (), vecType, sliceOp.getSource (), readIndices, padValue,
2679
- ArrayRef<bool >{readInBounds});
2735
+ ReifiedRankedShapedTypeDims reifiedSrcSizes;
2736
+ LogicalResult status =
2737
+ cast<ReifyRankedShapedTypeOpInterface>(srcDefOp).reifyResultShapes (
2738
+ rewriter, reifiedSrcSizes);
2739
+ if (status.failed ()) {
2740
+ LDBG (" Unable to reify result shapes of " << sliceOp);
2741
+ return failure ();
2742
+ }
2680
2743
2681
- // 4. Generate TransferWriteOp.
2682
- auto writeIndices = getValueOrCreateConstantIndexOp (
2683
- rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2744
+ // Create the mask
2745
+ SmallVector<int64_t > readMaskShape (
2746
+ sliceOp.getSource ().getType ().getShape ());
2747
+ auto readMaskType = VectorType::get (inputVectorSizes, rewriter.getI1Type ());
2748
+ Value maskOp = rewriter.create <vector::CreateMaskOp>(
2749
+ sliceOp.getLoc (), readMaskType, reifiedSrcSizes[0 ]);
2684
2750
2685
- // 5. Finalize
2686
- rewriter.replaceOpWithNewOp <vector::TransferWriteOp>(
2687
- sliceOp, read, sliceOp.getDest (), writeIndices,
2688
- ArrayRef<bool >{writeInBounds});
2751
+ // Mask the xfer_read Op
2752
+ read = mlir::vector::maskOperation (rewriter, read, maskOp);
2753
+ }
2689
2754
2690
- return success ();
2755
+ // 4. Generate TransferWriteOp.
2756
+ if (!inputVectorSizes.empty () &&
2757
+ ShapedType::isDynamicShape (resultType.getShape ())) {
2758
+ LDBG (" TODO: Masking of xfer_write when vectorising " << sliceOp);
2759
+ return failure ();
2691
2760
}
2692
- };
2761
+
2762
+ auto writeIndices = getValueOrCreateConstantIndexOp (
2763
+ rewriter, sliceOp.getLoc (), sliceOp.getMixedOffsets ());
2764
+
2765
+ // 5. Finalize
2766
+ Operation *write = rewriter.create <vector::TransferWriteOp>(
2767
+ sliceOp.getLoc (), read->getResult (0 ), sliceOp.getDest (), writeIndices,
2768
+ ArrayRef<bool >{writeInBounds});
2769
+ newResults.push_back (write->getResult (0 ));
2770
+
2771
+ return success ();
2772
+ }
2693
2773
2694
2774
// / Rewrite use of tensor::PadOp result in InsertSliceOp. E.g.:
2695
2775
// / ```
@@ -2778,11 +2858,6 @@ struct PadOpVectorizationWithInsertSlicePattern
2778
2858
}
2779
2859
};
2780
2860
2781
- void mlir::linalg::populateInsertSliceVectorizationPatterns (
2782
- RewritePatternSet &patterns) {
2783
- patterns.add <InsertSliceVectorizePattern>(patterns.getContext ());
2784
- }
2785
-
2786
2861
void mlir::linalg::populatePadOpVectorizationPatterns (
2787
2862
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
2788
2863
patterns.add <PadOpVectorizationWithTransferReadPattern,
0 commit comments