Skip to content

[mlir][Vector] Add support for trunci to narrow type emulation #82565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 27, 2024
126 changes: 121 additions & 5 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,

// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
if (resultBitwidth % 8 != 0)
unsigned bitwidth = preconditionType.getElementTypeBitWidth();
if (bitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");

return success();
Expand Down Expand Up @@ -768,6 +768,10 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");

if ((srcType.getShape().back() % 2) != 0)
return rewriter.notifyMatchFailure(
op, "Not an even number of i4 elements in trailing dim");

return success();
}

Expand Down Expand Up @@ -876,6 +880,58 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}

/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and bitwise ops
/// that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(8) &&
"Expected i8 type");

// 1. De-interleave low and high i8 elements.
int64_t vecDimSize = srcVecType.getShape().back();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could vecDimSize not be divisible by 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd number of i4 elements is not supported. I added that to the preconditions and an assert here.

SmallVector<int64_t> deinterleaveLowMaskValues;
SmallVector<int64_t> deinterleaveHighMaskValues;
assert((vecDimSize % 2) == 0 && "Odd number of i4 elements");
deinterleaveLowMaskValues.reserve(vecDimSize / 2);
deinterleaveHighMaskValues.reserve(vecDimSize / 2);
for (int i = 0, end = vecDimSize; i < end; i += 2) {
deinterleaveLowMaskValues.push_back(i);
deinterleaveHighMaskValues.push_back(i + 1);
}

auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
loc, srcValue, srcValue,
rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
loc, srcValue, srcValue,
rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));

// 2. Zero out the upper side of each low i8 element.
constexpr int8_t i8LowBitMask = 0x0F;
Value zeroOutMask = rewriter.create<arith::ConstantOp>(
loc,
DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
Value zeroOutLow =
rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);

// 3. Move high i4 values to upper side of the byte.
constexpr int8_t bitsToShift = 4;
VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
Value shlHigh =
rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);

// 4. Merge high and low i4 values.
auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);

// 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
}

namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
Expand Down Expand Up @@ -1019,7 +1075,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {

LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
// Set up the BitCastRewriter and verify the preconditions.
// Verify the preconditions.
Value srcValue = conversionOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
Expand All @@ -1043,6 +1099,65 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};

/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
///
/// For example:
/// arith.trunci %in : vector<8xi32> to vector<8xi4>
/// is rewriten as
///
/// %cst = arith.constant dense<15> : vector<4xi8>
/// %cst_0 = arith.constant dense<4> : vector<4xi8>
/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
/// %3 = arith.andi %1, %cst : vector<4xi8>
/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
/// %5 = arith.ori %3, %4 : vector<4xi8>
/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
///
struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;

LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
PatternRewriter &rewriter) const override {
// Verify the preconditions.
Value srcValue = truncOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
if (!srcVecType || !dstVecType)
return failure();

// Only single dim vectors are supported until we have
// `vector.deinterleave`.
if (srcVecType.getRank() != 1)
return failure();

if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();

// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
truncOp)))
return failure();

// Create a new iX -> i8 truncation op.
Location loc = truncOp.getLoc();
auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
Value i8TruncVal =
rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);

// Rewrite the i8 -> i4 truncation part.
Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);

// Finalize the rewrite.
rewriter.replaceOp(truncOp, subByteTrunc);
return success();
}
};

/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
Expand Down Expand Up @@ -1115,8 +1230,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
patterns.getContext(), benefit.getBenefit() + 1);
RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
benefit.getBenefit() + 1);
}

void vector::populateVectorTransposeNarrowTypeRewritePatterns(
Expand Down
42 changes: 42 additions & 0 deletions mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
return %0 : vector<8x32xf32>
}

// CHECK-LABEL: func.func @aligned_trunci(
func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
%0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
return %0 : vector<8xi4>
}

// CHECK-LABEL: func.func @aligned_trunci_base_case(
func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
%0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
return %0 : vector<8xi4>
}

// CHECK-LABEL: func.func @aligned_trunci_2d(
func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
// CHECK-NOT: vector.shuffle
// CHECK-NOT: vector.andi
// CHECK-NOT: vector.shli
// CHECK-NOT: vector.ori
// CHECK: arith.trunci
%0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
return %0 : vector<8x32xi4>
}

// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
Expand Down