Skip to content

[mlir] Rewrites for I2 to I8 signed and unsigned extension #121298

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 18 commits into from
Jan 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
235 changes: 173 additions & 62 deletions mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();

// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
(dstElemBitwidth % srcElemBitwidth) != 0)
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
if (dstElemBitwidth < 8)
return rewriter.notifyMatchFailure(
op, "the bitwidth of dstType must be greater than or equal to 8");
if (dstElemBitwidth % srcElemBitwidth != 0)
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
return rewriter.notifyMatchFailure(
op, "only src bitwidth of 2 or 4 is supported at this moment");

const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
return rewriter.notifyMatchFailure(
op, "Not an even number of i4 elements in trailing dim");
op, "the trailing dimension of the input vector of sub-bytes must be a "
"multiple of 8 / <sub-byte-width>");

return success();
}
Expand Down Expand Up @@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
return runningResult;
}

/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");
/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
/// Where aligned means it satisfies the alignedConversionPreconditions.
///
/// Example:
/// vector<16x16xi2> -> vector<16x4xi8>
/// vector<16x16xi4> -> vector<16x8xi8>
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
Value subByteVec) {
auto srcVecType = cast<VectorType>(subByteVec.getType());
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
assert(8 % srcBitwidth == 0 &&
"Unsupported sub-byte type (not a divisor of i8)");
int64_t numSrcElemsPerByte = 8 / srcBitwidth;
SmallVector<int64_t> vecShape(srcVecType.getShape());
// Adjust last dimension of the vector, so the total size remains the same.
vecShape.back() = vecShape.back() / numSrcElemsPerByte;
auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
}

// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
/// Extracts a signed N-bit sequence from each element of a vector of bytes,
/// starting at the specified bit index.
/// The `bitIdx` starts at 0 from the LSB and moves to the left.
///
/// Example for a single element:
/// Extract numBits=2 starting at bitIdx=2
/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
/// target = [. . . . ^ ^ . .]
///
/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
///
/// src = [01 01 11 10]
/// shl = arith.shl(src, 4) -> [11 10 00 00]
/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
Location loc, Value src,
int bitIdx, int numBits) {
auto srcType = cast<VectorType>(src.getType());
Value shl = src;
int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
if (bitsToShiftLeft != 0) {
Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
}

// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
constexpr int8_t bitsToShift = 4;
auto shiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
int8_t bitsToShiftRight = 8 - numBits;
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
return shr;
}

// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
/// starting at the specified bit index.
/// The `bitIdx` starts at 0 from the LSB and moves to the left.
///
/// Example for a single element:
/// Extract numBits=2 starting at bitIdx=2
/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
/// target = [. . . . ^ ^ . .]
///
/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
///
/// src = [01 01 10 10]
/// mask = [00 00 00 11]
/// shr = arith.shrui(src, 2) = [00 01 01 10]
/// result = arith.andi(shr, mask) = [00 00 00 10]
/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
/// left when the index is 0.
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
Location loc, Value src,
int bitIdx, int numBits) {
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
"Invalid bitIdx range");
auto srcType = cast<VectorType>(src.getType());
int8_t bitsToShiftRight = bitIdx;
Value shr = src;
if (bitsToShiftRight != 0) {
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
}
if (bitIdx + numBits == 8) {
return shr;
}
uint8_t lowBitsMask = (1 << numBits) - 1;
Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(srcType, lowBitsMask));
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
}

/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
/// bitwise ops that take advantage of high-level information to avoid leaving
/// LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
using ExtractNBitsFn =
std::function<Value(PatternRewriter &, Location, Value, int, int)>;

/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
Value srcValue, const ExtractNBitsFn &extFn) {
auto srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(4) &&
"Expected i4 type");

// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
constexpr int64_t i4Toi8BitwidthFactor = 2;
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);

// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
// byte are placed in one vector and the high i4 elements in another vector.
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
lowBitsMaskValues);
constexpr int8_t highBitsToShift = 4;
auto highShiftValues = rewriter.create<arith::ConstantOp>(
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);

// 2. Extend i4 elements to i8 elements. Low i4 elemens of each
// byte are place in one vector and the high i4 elements in another vector.
Value low = extFn(rewriter, loc, i8Vector, 0, 4);
Value high = extFn(rewriter, loc, i8Vector, 4, 4);

// 3. Interleave low and high i8 elements.
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}

/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
Value srcValue, const ExtractNBitsFn &extFn) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
assert(srcVecType.getElementType().isSignlessInteger(2) &&
"Expected i2 type");

// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);

// 2. Extract each i2 element
// Positon 0 (bits 0-1)
Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
// Position 1 (bits 2-3)
Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
// Position 2 (bits 4-5)
Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
// Position 3 (bits 6-7)
Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);

// 3. Interleave all 4 elements by first interleaving
// even elements and then odd
// vec0 = [0,0,0,0],...
// vec1 = [1,1,1,1],...
// vec2 = [2,2,2,2],...
// vec3 = [3,3,3,3],...
// 02 = [0,2,0,2,0,2,0,2],...
// 13 = [1,3,1,3,1,3,1,3],...
// 0213 = [0,1,2,3,...],...
Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
}

/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
/// ops that take advantage of high-level information to avoid leaving LLVM to
/// scramble with peephole optimizations.
/// ops to avoid leaving LLVM to scramble with peephole optimizations.
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
Value srcValue) {
VectorType srcVecType = cast<VectorType>(srcValue.getType());
Expand Down Expand Up @@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
return failure();

// Perform the rewrite.
Location loc = conversionOp.getLoc();
const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
: extractNBitsPerByteAndExtendToI8;
Value subByteExt;
if (isSigned) {
subByteExt =
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
} else {
subByteExt =
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
case 2:
subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
break;
case 4:
subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
break;
default:
return failure();
}

// Finalize the rewrite.
Expand Down Expand Up @@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
return failure();

// TODO: Add support for truncating to i2.
if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
return failure();

// Check general alignment preconditions. We invert the src/dst type order
// to reuse the existing precondition logic.
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
Expand Down
Loading
Loading