Skip to content

Commit 9d0acb8

Browse files
authored
[mlir][Vector] Add support for trunci to narrow type emulation (#82565)
This PR add support for `arith.trunci` to vector narrow type emulation for iX -> i4 truncations, for X >= 8. For now, the pattern only works for 1D vectors and is based on `vector.shuffle` ops. We would need `vector.deinterleave` to add n-D vector support.
1 parent a3748d6 commit 9d0acb8

File tree

2 files changed

+163
-5
lines changed

2 files changed

+163
-5
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 121 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
729729

730730
// TODO: consider relaxing this restriction in the future if we find ways
731731
// to really work with subbyte elements across the MLIR/LLVM boundary.
732-
unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
733-
if (resultBitwidth % 8 != 0)
732+
unsigned bitwidth = preconditionType.getElementTypeBitWidth();
733+
if (bitwidth % 8 != 0)
734734
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
735735

736736
return success();
@@ -768,6 +768,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
768768
(dstElemBitwidth % srcElemBitwidth) != 0)
769769
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
770770

771+
if ((srcType.getShape().back() % 2) != 0)
772+
return rewriter.notifyMatchFailure(
773+
op, "Not an even number of i4 elements in trailing dim");
774+
771775
return success();
772776
}
773777

@@ -876,6 +880,58 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
876880
return rewriter.create<vector::InterleaveOp>(loc, low, high);
877881
}
878882

883+
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
884+
/// that take advantage of high-level information to avoid leaving LLVM to
885+
/// scramble with peephole optimizations.
886+
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
887+
Value srcValue) {
888+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
889+
assert(srcVecType.getElementType().isSignlessInteger(8) &&
890+
"Expected i8 type");
891+
892+
// 1. De-interleave low and high i8 elements.
893+
int64_t vecDimSize = srcVecType.getShape().back();
894+
SmallVector<int64_t> deinterleaveLowMaskValues;
895+
SmallVector<int64_t> deinterleaveHighMaskValues;
896+
assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
897+
deinterleaveLowMaskValues.reserve(vecDimSize / 2);
898+
deinterleaveHighMaskValues.reserve(vecDimSize / 2);
899+
for (int i = 0, end = vecDimSize; i < end; i += 2) {
900+
deinterleaveLowMaskValues.push_back(i);
901+
deinterleaveHighMaskValues.push_back(i + 1);
902+
}
903+
904+
auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
905+
loc, srcValue, srcValue,
906+
rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
907+
auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
908+
loc, srcValue, srcValue,
909+
rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
910+
911+
// 2. Zero out the upper side of each low i8 element.
912+
constexpr int8_t i8LowBitMask = 0x0F;
913+
Value zeroOutMask = rewriter.create<arith::ConstantOp>(
914+
loc,
915+
DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
916+
Value zeroOutLow =
917+
rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
918+
919+
// 3. Move high i4 values to upper side of the byte.
920+
constexpr int8_t bitsToShift = 4;
921+
VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
922+
auto shiftValues = rewriter.create<arith::ConstantOp>(
923+
loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
924+
Value shlHigh =
925+
rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
926+
927+
// 4. Merge high and low i4 values.
928+
auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
929+
930+
// 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
931+
auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
932+
return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
933+
}
934+
879935
namespace {
880936
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
881937
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1075,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10191075

10201076
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
10211077
PatternRewriter &rewriter) const override {
1022-
// Set up the BitCastRewriter and verify the preconditions.
1078+
// Verify the preconditions.
10231079
Value srcValue = conversionOp.getIn();
10241080
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
10251081
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1043,6 +1099,65 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
10431099
}
10441100
};
10451101

1102+
/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
1103+
/// bitwise ops that take advantage of high-level information to avoid leaving
1104+
/// LLVM to scramble with peephole optimizations.
1105+
///
1106+
/// For example:
1107+
/// arith.trunci %in : vector<8xi32> to vector<8xi4>
1108+
/// is rewriten as
1109+
///
1110+
/// %cst = arith.constant dense<15> : vector<4xi8>
1111+
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
1112+
/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
1113+
/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
1114+
/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
1115+
/// %3 = arith.andi %1, %cst : vector<4xi8>
1116+
/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
1117+
/// %5 = arith.ori %3, %4 : vector<4xi8>
1118+
/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
1119+
///
1120+
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1121+
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
1122+
1123+
LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
1124+
PatternRewriter &rewriter) const override {
1125+
// Verify the preconditions.
1126+
Value srcValue = truncOp.getIn();
1127+
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
1128+
auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
1129+
if (!srcVecType || !dstVecType)
1130+
return failure();
1131+
1132+
// Only single dim vectors are supported until we have
1133+
// `vector.deinterleave`.
1134+
if (srcVecType.getRank() != 1)
1135+
return failure();
1136+
1137+
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
1138+
return failure();
1139+
1140+
// Check general alignment preconditions. We invert the src/dst type order
1141+
// to reuse the existing precondition logic.
1142+
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1143+
truncOp)))
1144+
return failure();
1145+
1146+
// Create a new iX -> i8 truncation op.
1147+
Location loc = truncOp.getLoc();
1148+
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
1149+
Value i8TruncVal =
1150+
rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
1151+
1152+
// Rewrite the i8 -> i4 truncation part.
1153+
Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
1154+
1155+
// Finalize the rewrite.
1156+
rewriter.replaceOp(truncOp, subByteTrunc);
1157+
return success();
1158+
}
1159+
};
1160+
10461161
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
10471162
/// perform the transpose on wider (byte) element types.
10481163
/// For example:
@@ -1115,8 +1230,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
11151230
// Patterns for aligned cases. We set higher priority as they are expected to
11161231
// generate better performance for aligned cases.
11171232
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
1118-
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
1119-
patterns.getContext(), benefit.getBenefit() + 1);
1233+
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
1234+
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
1235+
benefit.getBenefit() + 1);
11201236
}
11211237

11221238
void vector::populateVectorTransposeNarrowTypeRewritePatterns(

mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
262262
return %0 : vector<8x32xf32>
263263
}
264264

265+
// CHECK-LABEL: func.func @aligned_trunci(
266+
func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
267+
// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
268+
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
269+
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
270+
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
271+
// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
272+
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
273+
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
274+
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
275+
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
276+
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
277+
%0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
278+
return %0 : vector<8xi4>
279+
}
280+
281+
// CHECK-LABEL: func.func @aligned_trunci_base_case(
282+
func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
283+
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
284+
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
285+
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
286+
// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
287+
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
288+
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
289+
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
290+
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
291+
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
292+
%0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
293+
return %0 : vector<8xi4>
294+
}
295+
296+
// CHECK-LABEL: func.func @aligned_trunci_2d(
297+
func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
298+
// CHECK-NOT: vector.shuffle
299+
// CHECK-NOT: vector.andi
300+
// CHECK-NOT: vector.shli
301+
// CHECK-NOT: vector.ori
302+
// CHECK: arith.trunci
303+
%0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
304+
return %0 : vector<8x32xi4>
305+
}
306+
265307
// CHECK-LABEL: func.func @i4_transpose(
266308
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
267309
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {

0 commit comments

Comments
 (0)