@@ -642,19 +642,19 @@ struct BitCastRewriter {
642
642
643
643
BitCastRewriter (VectorType sourceVectorType, VectorType targetVectorType);
644
644
645
- // / Verify that the preconditions for the rewrite are met.
646
- LogicalResult precondition (PatternRewriter &rewriter,
647
- VectorType preconditionVectorType , Operation *op);
645
+ // / Verify that general preconditions for the rewrite are met.
646
+ LogicalResult commonPrecondition (PatternRewriter &rewriter,
647
+ VectorType preconditionType , Operation *op);
648
648
649
649
// / Precompute the metadata for the rewrite.
650
650
SmallVector<BitCastRewriter::Metadata>
651
651
precomputeMetadata (IntegerType shuffledElementType);
652
652
653
653
// / Rewrite one step of the sequence:
654
654
// / `(shuffle -> and -> shiftright -> shiftleft -> or)`.
655
- Value rewriteStep (PatternRewriter &rewriter, Location loc, Value initialValue ,
656
- Value runningResult,
657
- const BitCastRewriter::Metadata &metadata);
655
+ Value genericRewriteStep (PatternRewriter &rewriter, Location loc,
656
+ Value initialValue, Value runningResult,
657
+ const BitCastRewriter::Metadata &metadata);
658
658
659
659
private:
660
660
// / Underlying enumerator that encodes the provenance of the bits in the each
@@ -719,21 +719,57 @@ BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
719
719
LDBG (" \n " << enumerator.sourceElementRanges );
720
720
}
721
721
722
- LogicalResult BitCastRewriter::precondition (PatternRewriter &rewriter,
723
- VectorType precondition,
724
- Operation *op) {
725
- if (precondition.getRank () != 1 || precondition.isScalable ())
722
+ // / Verify that the precondition type meets the common preconditions for any
723
+ // / conversion.
724
+ static LogicalResult commonConversionPrecondition (PatternRewriter &rewriter,
725
+ VectorType preconditionType,
726
+ Operation *op) {
727
+ if (!preconditionType || preconditionType.getRank () != 1 ||
728
+ preconditionType.isScalable ())
726
729
return rewriter.notifyMatchFailure (op, " scalable or >1-D vector" );
727
730
728
731
// TODO: consider relaxing this restriction in the future if we find ways
729
732
// to really work with subbyte elements across the MLIR/LLVM boundary.
730
- int64_t resultBitwidth = precondition .getElementTypeBitWidth ();
733
+ unsigned resultBitwidth = preconditionType .getElementTypeBitWidth ();
731
734
if (resultBitwidth % 8 != 0 )
732
735
return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
733
736
734
737
return success ();
735
738
}
736
739
740
+ LogicalResult BitCastRewriter::commonPrecondition (PatternRewriter &rewriter,
741
+ VectorType preconditionType,
742
+ Operation *op) {
743
+ if (!enumerator.sourceVectorType || !enumerator.targetVectorType )
744
+ return rewriter.notifyMatchFailure (op, " types are not vector" );
745
+
746
+ return commonConversionPrecondition (rewriter, preconditionType, op);
747
+ }
748
+
749
+ // / Verify that source and destination element types meet the precondition for
750
+ // / the supported aligned conversion cases. Alignment means that the either the
751
+ // / source element type is multiple of the destination element type or the other
752
+ // / way around.
753
+ // /
754
+ // / NOTE: This method assumes that common conversion preconditions are met.
755
+ static LogicalResult alignedConversionPrecondition (PatternRewriter &rewriter,
756
+ VectorType srcType,
757
+ VectorType dstType,
758
+ Operation *op) {
759
+ if (!srcType || !dstType)
760
+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
761
+ unsigned srcElemBitwidth = srcType.getElementTypeBitWidth ();
762
+ unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
763
+ unsigned byteBitwidth = 8 ;
764
+
765
+ // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
766
+ if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
767
+ (dstElemBitwidth % srcElemBitwidth) != 0 )
768
+ return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
769
+
770
+ return success ();
771
+ }
772
+
737
773
SmallVector<BitCastRewriter::Metadata>
738
774
BitCastRewriter::precomputeMetadata (IntegerType shuffledElementType) {
739
775
SmallVector<BitCastRewriter::Metadata> result;
@@ -775,9 +811,9 @@ BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
775
811
return result;
776
812
}
777
813
778
- Value BitCastRewriter::rewriteStep (PatternRewriter &rewriter, Location loc,
779
- Value initialValue , Value runningResult ,
780
- const BitCastRewriter::Metadata &metadata) {
814
+ Value BitCastRewriter::genericRewriteStep (
815
+ PatternRewriter &rewriter, Location loc , Value initialValue ,
816
+ Value runningResult, const BitCastRewriter::Metadata &metadata) {
781
817
// Create vector.shuffle from the metadata.
782
818
auto shuffleOp = rewriter.create <vector::ShuffleOp>(
783
819
loc, initialValue, initialValue, metadata.shuffles );
@@ -810,6 +846,44 @@ Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
810
846
return runningResult;
811
847
}
812
848
849
+ // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
850
+ // / bitwise ops that take advantage of high-level information to avoid leaving
851
+ // / LLVM to scramble with peephole optimizations.
852
+ static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
853
+ Value srcValue) {
854
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
855
+ assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
856
+ " Expected i4 type" );
857
+
858
+ // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
859
+ int64_t vecDimSize = srcVecType.getShape ().back ();
860
+ SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
861
+ constexpr int64_t i4Toi8BitwidthFactor = 2 ;
862
+ i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
863
+ auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
864
+ Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
865
+
866
+ // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
867
+ // byte are place in one vector and the high i4 elements in another vector.
868
+ constexpr int8_t bitsToShift = 4 ;
869
+ auto shiftValues = rewriter.create <arith::ConstantOp>(
870
+ loc, DenseElementsAttr::get (i8VecType, bitsToShift));
871
+ Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
872
+ Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
873
+ Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
874
+
875
+ // 3. Interleave low and high i8 elements using a shuffle.
876
+ SmallVector<int64_t > interleaveMaskValues;
877
+ interleaveMaskValues.reserve (vecDimSize);
878
+ for (int i = 0 , end = vecDimSize / 2 ; i < end; ++i) {
879
+ interleaveMaskValues.push_back (i);
880
+ interleaveMaskValues.push_back (i + (vecDimSize / 2 ));
881
+ }
882
+
883
+ return rewriter.create <vector::ShuffleOp>(
884
+ loc, low, high, rewriter.getI64ArrayAttr (interleaveMaskValues));
885
+ }
886
+
813
887
namespace {
814
888
// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
815
889
// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -829,7 +903,7 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
829
903
VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
830
904
VectorType targetVectorType = bitCastOp.getResultVectorType ();
831
905
BitCastRewriter bcr (sourceVectorType, targetVectorType);
832
- if (failed (bcr.precondition (rewriter, targetVectorType, bitCastOp)))
906
+ if (failed (bcr.commonPrecondition (rewriter, targetVectorType, bitCastOp)))
833
907
return failure ();
834
908
835
909
// Perform the rewrite.
@@ -839,8 +913,8 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
839
913
Value runningResult;
840
914
for (const BitCastRewriter ::Metadata &metadata :
841
915
bcr.precomputeMetadata (shuffledElementType)) {
842
- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (), truncValue,
843
- runningResult, metadata);
916
+ runningResult = bcr.genericRewriteStep (
917
+ rewriter, bitCastOp-> getLoc (), truncValue, runningResult, metadata);
844
918
}
845
919
846
920
// Finalize the rewrite.
@@ -893,7 +967,7 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
893
967
VectorType sourceVectorType = bitCastOp.getSourceVectorType ();
894
968
VectorType targetVectorType = bitCastOp.getResultVectorType ();
895
969
BitCastRewriter bcr (sourceVectorType, targetVectorType);
896
- if (failed (bcr.precondition (
970
+ if (failed (bcr.commonPrecondition (
897
971
rewriter, cast<VectorType>(extOp.getOut ().getType ()), bitCastOp)))
898
972
return failure ();
899
973
@@ -904,8 +978,8 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
904
978
cast<IntegerType>(getElementTypeOrSelf (sourceValue.getType ()));
905
979
for (const BitCastRewriter::Metadata &metadata :
906
980
bcr.precomputeMetadata (shuffledElementType)) {
907
- runningResult = bcr.rewriteStep (rewriter, bitCastOp-> getLoc (),
908
- sourceValue, runningResult, metadata);
981
+ runningResult = bcr.genericRewriteStep (
982
+ rewriter, bitCastOp-> getLoc (), sourceValue, runningResult, metadata);
909
983
}
910
984
911
985
// Finalize the rewrite.
@@ -923,6 +997,62 @@ struct RewriteExtOfBitCast : OpRewritePattern<ExtOpType> {
923
997
return success ();
924
998
}
925
999
};
1000
+
1001
+ // / Rewrite the i4 -> i8 part of any conversion into a sequence of shuffles and
1002
+ // / bitwise ops that take advantage of high-level information to avoid leaving
1003
+ // / LLVM to scramble with peephole optimizations.
1004
+ // /
1005
+ // / For example:
1006
+ // / arith.extsi %in : vector<8xi4> to vector<8xi32>
1007
+ // / is rewriten as
1008
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1009
+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1010
+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1011
+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1012
+ // / %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1013
+ // / : vector<4xi8>, vector<4xi8>
1014
+ // / %5 = arith.extsi %4 : vector<8xi8> to vector<8xi32>
1015
+ // /
1016
+ // / arith.sitofp %in : vector<8xi4> to vector<8xf32>
1017
+ // / is rewriten as
1018
+ // / %0 = vector.bitcast %in : vector<8xi4> to vector<4xi8>
1019
+ // / %1 = arith.shli %0, 4 : vector<4xi8>
1020
+ // / %2 = arith.shrsi %1, 4 : vector<4xi8>
1021
+ // / %3 = arith.shrsi %0, 4 : vector<4xi8>
1022
+ // / %4 = vector.shuffle %2, %3 [0, 4, 1, 5, 2, 6, 3, 7]
1023
+ // / : vector<4xi8>, vector<4xi8>
1024
+ // / %5 = arith.sitofp %4 : vector<8xi8> to vector<8xf32>
1025
+ // /
1026
+ template <typename ConversionOpType>
1027
+ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1028
+ using OpRewritePattern<ConversionOpType>::OpRewritePattern;
1029
+
1030
+ LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1031
+ PatternRewriter &rewriter) const override {
1032
+ // Set up the BitCastRewriter and verify the preconditions.
1033
+ Value srcValue = conversionOp.getIn ();
1034
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1035
+ auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
1036
+ if (failed (
1037
+ commonConversionPrecondition (rewriter, dstVecType, conversionOp)))
1038
+ return failure ();
1039
+
1040
+ // Check general alignment preconditions.
1041
+ if (failed (alignedConversionPrecondition (rewriter, srcVecType, dstVecType,
1042
+ conversionOp)))
1043
+ return failure ();
1044
+
1045
+ // Perform the rewrite.
1046
+ Value subByteExt =
1047
+ rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1048
+
1049
+ // Finalize the rewrite.
1050
+ rewriter.replaceOpWithNewOp <ConversionOpType>(
1051
+ conversionOp, conversionOp.getType (), subByteExt);
1052
+ return success ();
1053
+ }
1054
+ };
1055
+
926
1056
} // namespace
927
1057
928
1058
// ===----------------------------------------------------------------------===//
@@ -944,4 +1074,10 @@ void vector::populateVectorNarrowTypeRewritePatterns(
944
1074
patterns.add <RewriteBitCastOfTruncI, RewriteExtOfBitCast<arith::ExtUIOp>,
945
1075
RewriteExtOfBitCast<arith::ExtSIOp>>(patterns.getContext (),
946
1076
benefit);
1077
+
1078
+ // Patterns for aligned cases. We set higher priority as they are expected to
1079
+ // generate better performance for aligned cases.
1080
+ patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1081
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1082
+ patterns.getContext (), benefit.getBenefit () + 1 );
947
1083
}
0 commit comments