13
13
#include < type_traits>
14
14
15
15
#include " mlir/Dialect/AffineOps/AffineOps.h"
16
+ #include " mlir/Dialect/StandardOps/Ops.h"
16
17
#include " mlir/Dialect/VectorOps/VectorOps.h"
17
18
#include " mlir/Dialect/VectorOps/VectorTransforms.h"
18
19
#include " mlir/Dialect/VectorOps/VectorUtils.h"
28
29
#include " mlir/IR/PatternMatch.h"
29
30
#include " mlir/IR/Types.h"
30
31
#include " mlir/Support/Functional.h"
32
+ #include " mlir/Support/MathExtras.h"
31
33
#include " mlir/Support/STLExtras.h"
32
34
33
35
#include " llvm/Support/CommandLine.h"
@@ -657,6 +659,131 @@ struct TupleGetFolderOp : public OpRewritePattern<vector::TupleGetOp> {
657
659
}
658
660
};
659
661
662
+ // / Progressive lowering of ExtractSlicesOp to tuple of StridedSliceOp.
663
+ // / One:
664
+ // / %x = vector.extract_slices %0
665
+ // / is replaced by:
666
+ // / %a = vector.strided_slice %0
667
+ // / %b = vector.strided_slice %0
668
+ // / ..
669
+ // / %x = vector.tuple %a, %b, ..
670
+ class ExtractSlicesOpLowering
671
+ : public OpRewritePattern<vector::ExtractSlicesOp> {
672
+ public:
673
+ using OpRewritePattern<vector::ExtractSlicesOp>::OpRewritePattern;
674
+
675
+ // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
676
+ PatternMatchResult matchAndRewrite (vector::ExtractSlicesOp op,
677
+ PatternRewriter &rewriter) const override {
678
+ auto loc = op.getLoc ();
679
+
680
+ VectorType vectorType = op.getSourceVectorType ();
681
+ int64_t rank = vectorType.getRank ();
682
+ auto shape = vectorType.getShape ();
683
+
684
+ SmallVector<int64_t , 4 > sizes;
685
+ op.getSizes (sizes);
686
+ SmallVector<int64_t , 4 > strides;
687
+ op.getStrides (strides); // all-ones at the moment
688
+
689
+ // Compute the number of slices in each dimension.
690
+ SmallVector<int64_t , 4 > sliceDimCounts (rank);
691
+ for (int64_t r = 0 ; r < rank; ++r)
692
+ sliceDimCounts[r] = ceilDiv (shape[r], sizes[r]);
693
+
694
+ // For each element in the tuple, generate the proper strided slice.
695
+ auto basis = computeStrides (sliceDimCounts);
696
+ TupleType tupleType = op.getResultTupleType ();
697
+ int64_t tupleSize = tupleType.size ();
698
+ SmallVector<Value, 4 > tupleValues (tupleSize);
699
+ for (int64_t i = 0 ; i < tupleSize; ++i) {
700
+ // De-linearize w.r.t. 'basis'.
701
+ auto vectorOffsets = delinearize (i, basis);
702
+ // Convert from unrolled vector-space offsets to element-space offsets.
703
+ auto elementOffsets = mlir::functional::zipMap (
704
+ [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
705
+ // Compute the size of each slice.
706
+ SmallVector<int64_t , 4 > sliceSizes (rank);
707
+ for (int64_t r = 0 ; r < rank; ++r)
708
+ sliceSizes[r] = std::min (sizes[r], shape[r] - elementOffsets[r]);
709
+ // Insert in tuple.
710
+ tupleValues[i] = rewriter.create <vector::StridedSliceOp>(
711
+ loc, op.vector (), elementOffsets, sliceSizes, strides);
712
+ }
713
+
714
+ rewriter.replaceOpWithNewOp <vector::TupleOp>(op, tupleType, tupleValues);
715
+ return matchSuccess ();
716
+ }
717
+ };
718
+
719
+ // / Progressive lowering of InsertSlicesOp to series of InsertStridedSliceOp.
720
+ // / One:
721
+ // / %x = vector.insert_slices %0
722
+ // / is replaced by:
723
+ // / %r0 = vector.splat 0
724
+ // %t1 = vector.tuple_get %0, 0
725
+ // / %r1 = vector.insert_strided_slice %r0, %t1
726
+ // %t2 = vector.tuple_get %0, 1
727
+ // / %r2 = vector.insert_strided_slice %r1, %t2
728
+ // / ..
729
+ // / %x = ..
730
+ class InsertSlicesOpLowering : public OpRewritePattern <vector::InsertSlicesOp> {
731
+ public:
732
+ using OpRewritePattern<vector::InsertSlicesOp>::OpRewritePattern;
733
+
734
+ // TODO(ajcbik): refactor slice utilities out into VectorUtils.h
735
+ PatternMatchResult matchAndRewrite (vector::InsertSlicesOp op,
736
+ PatternRewriter &rewriter) const override {
737
+ auto loc = op.getLoc ();
738
+
739
+ VectorType vectorType = op.getResultVectorType ();
740
+ int64_t rank = vectorType.getRank ();
741
+ auto shape = vectorType.getShape ();
742
+
743
+ SmallVector<int64_t , 4 > sizes;
744
+ op.getSizes (sizes);
745
+ SmallVector<int64_t , 4 > strides;
746
+ op.getStrides (strides); // all-ones at the moment
747
+
748
+ // Compute the number of slices in each dimension.
749
+ SmallVector<int64_t , 4 > sliceDimCounts (rank);
750
+ for (int64_t r = 0 ; r < rank; ++r)
751
+ sliceDimCounts[r] = ceilDiv (shape[r], sizes[r]);
752
+
753
+ // Prepare result.
754
+ auto elemType = vectorType.getElementType ();
755
+ Value zero = rewriter.create <ConstantOp>(loc, elemType,
756
+ rewriter.getZeroAttr (elemType));
757
+ Value result = rewriter.create <SplatOp>(loc, vectorType, zero);
758
+
759
+ // For each element in the tuple, extract the proper strided slice.
760
+ auto basis = computeStrides (sliceDimCounts);
761
+ TupleType tupleType = op.getSourceTupleType ();
762
+ int64_t tupleSize = tupleType.size ();
763
+ SmallVector<Value, 4 > tupleValues (tupleSize);
764
+ for (int64_t i = 0 ; i < tupleSize; ++i) {
765
+ // De-linearize w.r.t. 'basis'.
766
+ auto vectorOffsets = delinearize (i, basis);
767
+ // Convert from unrolled vector-space offsets to element-space offsets.
768
+ auto elementOffsets = mlir::functional::zipMap (
769
+ [](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, sizes);
770
+ // Compute the size of each slice.
771
+ SmallVector<int64_t , 4 > sliceSizes (rank);
772
+ for (int64_t r = 0 ; r < rank; ++r)
773
+ sliceSizes[r] = std::min (sizes[r], shape[r] - elementOffsets[r]);
774
+ // Extract from tuple into the result.
775
+ auto index = rewriter.getI64IntegerAttr (i);
776
+ auto tupleGet = rewriter.create <vector::TupleGetOp>(
777
+ loc, tupleType.getType (i), op.getOperand (), index);
778
+ result = rewriter.create <vector::InsertStridedSliceOp>(
779
+ loc, tupleGet, result, elementOffsets, strides);
780
+ }
781
+
782
+ rewriter.replaceOp (op, result);
783
+ return matchSuccess ();
784
+ }
785
+ };
786
+
660
787
} // namespace
661
788
662
789
// TODO(andydavis) Add pattern to rewrite ExtractSlices(ConstantMaskOp).
@@ -666,3 +793,8 @@ void mlir::vector::populateVectorToVectorTransformationPatterns(
666
793
patterns.insert <SplitTransferReadOp, SplitTransferWriteOp, TupleGetFolderOp>(
667
794
context);
668
795
}
796
+
797
+ void mlir::vector::populateVectorSlicesLoweringPatterns (
798
+ OwningRewritePatternList &patterns, MLIRContext *context) {
799
+ patterns.insert <ExtractSlicesOpLowering, InsertSlicesOpLowering>(context);
800
+ }
0 commit comments