@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
729
729
730
730
// TODO: consider relaxing this restriction in the future if we find ways
731
731
// to really work with subbyte elements across the MLIR/LLVM boundary.
732
- unsigned resultBitwidth = preconditionType.getElementTypeBitWidth ();
733
- if (resultBitwidth % 8 != 0 )
732
+ unsigned bitwidth = preconditionType.getElementTypeBitWidth ();
733
+ if (bitwidth % 8 != 0 )
734
734
return rewriter.notifyMatchFailure (op, " bitwidth is not k * 8" );
735
735
736
736
return success ();
@@ -768,6 +768,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
768
768
(dstElemBitwidth % srcElemBitwidth) != 0 )
769
769
return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
770
770
771
+ if ((srcType.getShape ().back () % 2 ) != 0 )
772
+ return rewriter.notifyMatchFailure (
773
+ op, " Not an even number of i4 elements in trailing dim" );
774
+
771
775
return success ();
772
776
}
773
777
@@ -876,6 +880,58 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
876
880
return rewriter.create <vector::InterleaveOp>(loc, low, high);
877
881
}
878
882
883
+ // / Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884
+ // / that take advantage of high-level information to avoid leaving LLVM to
885
+ // / scramble with peephole optimizations.
886
+ static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
887
+ Value srcValue) {
888
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
889
+ assert (srcVecType.getElementType ().isSignlessInteger (8 ) &&
890
+ " Expected i8 type" );
891
+
892
+ // 1. De-interleave low and high i8 elements.
893
+ int64_t vecDimSize = srcVecType.getShape ().back ();
894
+ SmallVector<int64_t > deinterleaveLowMaskValues;
895
+ SmallVector<int64_t > deinterleaveHighMaskValues;
896
+ assert ((vecDimSize % 2 ) == 0 && " Odd number of i4 elements" );
897
+ deinterleaveLowMaskValues.reserve (vecDimSize / 2 );
898
+ deinterleaveHighMaskValues.reserve (vecDimSize / 2 );
899
+ for (int i = 0 , end = vecDimSize; i < end; i += 2 ) {
900
+ deinterleaveLowMaskValues.push_back (i);
901
+ deinterleaveHighMaskValues.push_back (i + 1 );
902
+ }
903
+
904
+ auto lowShuffleOp = rewriter.create <vector::ShuffleOp>(
905
+ loc, srcValue, srcValue,
906
+ rewriter.getI64ArrayAttr (deinterleaveLowMaskValues));
907
+ auto highShuffleOp = rewriter.create <vector::ShuffleOp>(
908
+ loc, srcValue, srcValue,
909
+ rewriter.getI64ArrayAttr (deinterleaveHighMaskValues));
910
+
911
+ // 2. Zero out the upper side of each low i8 element.
912
+ constexpr int8_t i8LowBitMask = 0x0F ;
913
+ Value zeroOutMask = rewriter.create <arith::ConstantOp>(
914
+ loc,
915
+ DenseElementsAttr::get (lowShuffleOp.getResultVectorType (), i8LowBitMask));
916
+ Value zeroOutLow =
917
+ rewriter.create <arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
918
+
919
+ // 3. Move high i4 values to upper side of the byte.
920
+ constexpr int8_t bitsToShift = 4 ;
921
+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType ();
922
+ auto shiftValues = rewriter.create <arith::ConstantOp>(
923
+ loc, DenseElementsAttr::get (deinterI8VecType, bitsToShift));
924
+ Value shlHigh =
925
+ rewriter.create <arith::ShLIOp>(loc, highShuffleOp, shiftValues);
926
+
927
+ // 4. Merge high and low i4 values.
928
+ auto mergedHiLowOp = rewriter.create <arith::OrIOp>(loc, zeroOutLow, shlHigh);
929
+
930
+ // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
931
+ auto i4VecType = srcVecType.cloneWith (std::nullopt, rewriter.getI4Type ());
932
+ return rewriter.create <vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
933
+ }
934
+
879
935
namespace {
880
936
// / Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
881
937
// / advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1075,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1019
1075
1020
1076
LogicalResult matchAndRewrite (ConversionOpType conversionOp,
1021
1077
PatternRewriter &rewriter) const override {
1022
- // Set up the BitCastRewriter and verify the preconditions.
1078
+ // Verify the preconditions.
1023
1079
Value srcValue = conversionOp.getIn ();
1024
1080
auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1025
1081
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType ());
@@ -1043,6 +1099,65 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
1043
1099
}
1044
1100
};
1045
1101
1102
+ // / Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
1103
+ // / bitwise ops that take advantage of high-level information to avoid leaving
1104
+ // / LLVM to scramble with peephole optimizations.
1105
+ // /
1106
+ // / For example:
1107
+ // / arith.trunci %in : vector<8xi32> to vector<8xi4>
1108
+ // / is rewriten as
1109
+ // /
1110
+ // / %cst = arith.constant dense<15> : vector<4xi8>
1111
+ // / %cst_0 = arith.constant dense<4> : vector<4xi8>
1112
+ // / %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1113
+ // / %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1114
+ // / %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1115
+ // / %3 = arith.andi %1, %cst : vector<4xi8>
1116
+ // / %4 = arith.shli %2, %cst_0 : vector<4xi8>
1117
+ // / %5 = arith.ori %3, %4 : vector<4xi8>
1118
+ // / %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
1119
+ // /
1120
+ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1121
+ using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
1122
+
1123
+ LogicalResult matchAndRewrite (arith::TruncIOp truncOp,
1124
+ PatternRewriter &rewriter) const override {
1125
+ // Verify the preconditions.
1126
+ Value srcValue = truncOp.getIn ();
1127
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType ());
1128
+ auto dstVecType = dyn_cast<VectorType>(truncOp.getType ());
1129
+ if (!srcVecType || !dstVecType)
1130
+ return failure ();
1131
+
1132
+ // Only single dim vectors are supported until we have
1133
+ // `vector.deinterleave`.
1134
+ if (srcVecType.getRank () != 1 )
1135
+ return failure ();
1136
+
1137
+ if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
1138
+ return failure ();
1139
+
1140
+ // Check general alignment preconditions. We invert the src/dst type order
1141
+ // to reuse the existing precondition logic.
1142
+ if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
1143
+ truncOp)))
1144
+ return failure ();
1145
+
1146
+ // Create a new iX -> i8 truncation op.
1147
+ Location loc = truncOp.getLoc ();
1148
+ auto i8VecType = srcVecType.cloneWith (std::nullopt, rewriter.getI8Type ());
1149
+ Value i8TruncVal =
1150
+ rewriter.create <arith::TruncIOp>(loc, i8VecType, srcValue);
1151
+
1152
+ // Rewrite the i8 -> i4 truncation part.
1153
+ Value subByteTrunc = rewriteI8ToI4Trunc (rewriter, loc, i8TruncVal);
1154
+
1155
+ // Finalize the rewrite.
1156
+ rewriter.replaceOp (truncOp, subByteTrunc);
1157
+ return success ();
1158
+ }
1159
+ };
1160
+
1046
1161
// / Rewrite a sub-byte vector transpose into a sequence of instructions that
1047
1162
// / perform the transpose on wider (byte) element types.
1048
1163
// / For example:
@@ -1115,8 +1230,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
1115
1230
// Patterns for aligned cases. We set higher priority as they are expected to
1116
1231
// generate better performance for aligned cases.
1117
1232
patterns.add <RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1118
- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1119
- patterns.getContext (), benefit.getBenefit () + 1 );
1233
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234
+ RewriteAlignedSubByteIntTrunc>(patterns.getContext (),
1235
+ benefit.getBenefit () + 1 );
1120
1236
}
1121
1237
1122
1238
void vector::populateVectorTransposeNarrowTypeRewritePatterns (
0 commit comments