Skip to content

Commit 929eb50

Browse files
[mlir] Rewrites for I2 to I8 signed and unsigned extension (#121298)
Adds rewrites for i2 to i8 signed and unsigned extension, similar to the ones that already exist for i4 to i8 conversion. I use this for i6 quantized models, and this gives me roughly a 2x speedup for an i6 4096x4096 dequantization-matmul on an AMD 5950x. I didn't add the rewrite for i8 to i2 truncation because I currently don't use it, but if this is needed, I can add it as well. --------- Co-authored-by: Andrzej Warzyński <[email protected]>
1 parent d1d2564 commit 929eb50

File tree

2 files changed

+357
-66
lines changed

2 files changed

+357
-66
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 173 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
10901090
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
10911091
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
10921092

1093-
// Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1094-
if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1095-
(dstElemBitwidth % srcElemBitwidth) != 0)
1096-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1093+
if (dstElemBitwidth < 8)
1094+
return rewriter.notifyMatchFailure(
1095+
op, "the bitwidth of dstType must be greater than or equal to 8");
1096+
if (dstElemBitwidth % srcElemBitwidth != 0)
1097+
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1098+
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1099+
return rewriter.notifyMatchFailure(
1100+
op, "only src bitwidth of 2 or 4 is supported at this moment");
10971101

1098-
const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099-
if ((subByteVecType.getShape().back() % numSrcElemsPerDestElem) != 0)
1102+
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1103+
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
11001104
return rewriter.notifyMatchFailure(
1101-
op, "Not an even number of i4 elements in trailing dim");
1105+
op, "the trailing dimension of the input vector of sub-bytes must be a "
1106+
"multiple of 8 / <sub-byte-width>");
11021107

11031108
return success();
11041109
}
@@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
11791184
return runningResult;
11801185
}
11811186

1182-
/// Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1183-
/// bitwise ops that take advantage of high-level information to avoid leaving
1184-
/// LLVM to scramble with peephole optimizations.
1185-
static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
1186-
Value srcValue) {
1187-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1188-
assert(srcVecType.getElementType().isSignlessInteger(4) &&
1189-
"Expected i4 type");
1187+
/// Bitcasts the aligned `subByteVec` vector to a vector of i8.
1188+
/// Where aligned means it satisfies the alignedConversionPreconditions.
1189+
///
1190+
/// Example:
1191+
/// vector<16x16xi2> -> vector<16x4xi8>
1192+
/// vector<16x16xi4> -> vector<16x8xi8>
1193+
static Value bitcastSubByteVectorToI8(PatternRewriter &rewriter, Location loc,
1194+
Value subByteVec) {
1195+
auto srcVecType = cast<VectorType>(subByteVec.getType());
1196+
int64_t srcBitwidth = srcVecType.getElementType().getIntOrFloatBitWidth();
1197+
assert(8 % srcBitwidth == 0 &&
1198+
"Unsupported sub-byte type (not a divisor of i8)");
1199+
int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1200+
SmallVector<int64_t> vecShape(srcVecType.getShape());
1201+
// Adjust last dimension of the vector, so the total size remains the same.
1202+
vecShape.back() = vecShape.back() / numSrcElemsPerByte;
1203+
auto i8VecType = VectorType::get(vecShape, rewriter.getI8Type());
1204+
return rewriter.create<vector::BitCastOp>(loc, i8VecType, subByteVec);
1205+
}
11901206

1191-
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1192-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1193-
constexpr int64_t i4Toi8BitwidthFactor = 2;
1194-
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1195-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1196-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1207+
/// Extracts a signed N-bit sequence from each element of a vector of bytes,
1208+
/// starting at the specified bit index.
1209+
/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1210+
///
1211+
/// Example for a single element:
1212+
/// Extract numBits=2 starting at bitIdx=2
1213+
/// src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1214+
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1215+
/// target = [. . . . ^ ^ . .]
1216+
///
1217+
/// The target sequence is [11](decimal=-1) as signed 2-bit integer.
1218+
/// So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1219+
///
1220+
/// src = [01 01 11 10]
1221+
/// shl = arith.shl(src, 4) -> [11 10 00 00]
1222+
/// result = arith.shrsi(shl, 6) -> [11 11 11 11]
1223+
static Value extractNBitsPerByteAndSignExtendToI8(PatternRewriter &rewriter,
1224+
Location loc, Value src,
1225+
int bitIdx, int numBits) {
1226+
auto srcType = cast<VectorType>(src.getType());
1227+
Value shl = src;
1228+
int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1229+
assert(bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1230+
"Invalid bitIdx range");
1231+
if (bitsToShiftLeft != 0) {
1232+
Value shiftLeftValues = rewriter.create<arith::ConstantOp>(
1233+
loc, DenseElementsAttr::get(srcType, bitsToShiftLeft));
1234+
shl = rewriter.create<arith::ShLIOp>(loc, src, shiftLeftValues);
1235+
}
11971236

1198-
// 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1199-
// byte are place in one vector and the high i4 elements in another vector.
1200-
constexpr int8_t bitsToShift = 4;
1201-
auto shiftValues = rewriter.create<arith::ConstantOp>(
1202-
loc, DenseElementsAttr::get(i8VecType, bitsToShift));
1203-
Value shl = rewriter.create<arith::ShLIOp>(loc, i8Vector, shiftValues);
1204-
Value low = rewriter.create<arith::ShRSIOp>(loc, shl, shiftValues);
1205-
Value high = rewriter.create<arith::ShRSIOp>(loc, i8Vector, shiftValues);
1237+
int8_t bitsToShiftRight = 8 - numBits;
1238+
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1239+
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1240+
Value shr = rewriter.create<arith::ShRSIOp>(loc, shl, shiftRightValues);
1241+
return shr;
1242+
}
12061243

1207-
// 3. Interleave low and high i8 elements.
1208-
return rewriter.create<vector::InterleaveOp>(loc, low, high);
1244+
/// Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1245+
/// starting at the specified bit index.
1246+
/// The `bitIdx` starts at 0 from the LSB and moves to the left.
1247+
///
1248+
/// Example for a single element:
1249+
/// Extract numBits=2 starting at bitIdx=2
1250+
/// src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1251+
/// indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1252+
/// target = [. . . . ^ ^ . .]
1253+
///
1254+
/// The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1255+
/// So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1256+
///
1257+
/// src = [01 01 10 10]
1258+
/// mask = [00 00 00 11]
1259+
/// shr = arith.shrui(src, 2) = [00 01 01 10]
1260+
/// result = arith.andi(shr, mask) = [00 00 00 10]
1261+
/// NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1262+
/// achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1263+
/// However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1264+
/// left when the index is 0.
1265+
static Value extractNBitsPerByteAndExtendToI8(PatternRewriter &rewriter,
1266+
Location loc, Value src,
1267+
int bitIdx, int numBits) {
1268+
assert(bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1269+
"Invalid bitIdx range");
1270+
auto srcType = cast<VectorType>(src.getType());
1271+
int8_t bitsToShiftRight = bitIdx;
1272+
Value shr = src;
1273+
if (bitsToShiftRight != 0) {
1274+
Value shiftRightValues = rewriter.create<arith::ConstantOp>(
1275+
loc, DenseElementsAttr::get(srcType, bitsToShiftRight));
1276+
shr = rewriter.create<arith::ShRUIOp>(loc, src, shiftRightValues);
1277+
}
1278+
if (bitIdx + numBits == 8) {
1279+
return shr;
1280+
}
1281+
uint8_t lowBitsMask = (1 << numBits) - 1;
1282+
Value lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1283+
loc, DenseElementsAttr::get(srcType, lowBitsMask));
1284+
return rewriter.create<arith::AndIOp>(loc, shr, lowBitsMaskValues);
12091285
}
12101286

1211-
/// Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1212-
/// bitwise ops that take advantage of high-level information to avoid leaving
1213-
/// LLVM to scramble with peephole optimizations.
1214-
static Value rewriteI4ToI8UnsignedExt(PatternRewriter &rewriter, Location loc,
1215-
Value srcValue) {
1216-
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1287+
using ExtractNBitsFn =
1288+
std::function<Value(PatternRewriter &, Location, Value, int, int)>;
1289+
1290+
/// Rewrite the i4 -> i8 extension into a sequence of shuffles and
1291+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292+
static Value rewriteI4ToI8Ext(PatternRewriter &rewriter, Location loc,
1293+
Value srcValue, const ExtractNBitsFn &extFn) {
1294+
auto srcVecType = cast<VectorType>(srcValue.getType());
12171295
assert(srcVecType.getElementType().isSignlessInteger(4) &&
12181296
"Expected i4 type");
12191297

12201298
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1221-
SmallVector<int64_t> i8VecShape = llvm::to_vector(srcVecType.getShape());
1222-
constexpr int64_t i4Toi8BitwidthFactor = 2;
1223-
i8VecShape.back() = i8VecShape.back() / i4Toi8BitwidthFactor;
1224-
auto i8VecType = VectorType::get(i8VecShape, rewriter.getI8Type());
1225-
Value i8Vector = rewriter.create<vector::BitCastOp>(loc, i8VecType, srcValue);
1226-
1227-
// 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1228-
// byte are placed in one vector and the high i4 elements in another vector.
1229-
constexpr uint8_t lowBitsMask = 15; // Equivalent to [00001111] bit mask
1230-
auto lowBitsMaskValues = rewriter.create<arith::ConstantOp>(
1231-
loc, DenseElementsAttr::get(i8VecType, lowBitsMask));
1232-
Value low = rewriter.create<arith::AndIOp>(loc, i8VecType, i8Vector,
1233-
lowBitsMaskValues);
1234-
constexpr int8_t highBitsToShift = 4;
1235-
auto highShiftValues = rewriter.create<arith::ConstantOp>(
1236-
loc, DenseElementsAttr::get(i8VecType, highBitsToShift));
1237-
Value high = rewriter.create<arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1299+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1300+
1301+
// 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1302+
// byte are place in one vector and the high i4 elements in another vector.
1303+
Value low = extFn(rewriter, loc, i8Vector, 0, 4);
1304+
Value high = extFn(rewriter, loc, i8Vector, 4, 4);
12381305

12391306
// 3. Interleave low and high i8 elements.
12401307
return rewriter.create<vector::InterleaveOp>(loc, low, high);
12411308
}
12421309

1310+
/// Rewrite the i2 -> i8 extension into a sequence of shuffles and
1311+
/// bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1312+
static Value rewriteI2ToI8Ext(PatternRewriter &rewriter, Location loc,
1313+
Value srcValue, const ExtractNBitsFn &extFn) {
1314+
VectorType srcVecType = cast<VectorType>(srcValue.getType());
1315+
assert(srcVecType.getElementType().isSignlessInteger(2) &&
1316+
"Expected i2 type");
1317+
1318+
// 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1319+
Value i8Vector = bitcastSubByteVectorToI8(rewriter, loc, srcValue);
1320+
1321+
// 2. Extract each i2 element
1322+
// Positon 0 (bits 0-1)
1323+
Value vec0 = extFn(rewriter, loc, i8Vector, 0, 2);
1324+
// Position 1 (bits 2-3)
1325+
Value vec1 = extFn(rewriter, loc, i8Vector, 2, 2);
1326+
// Position 2 (bits 4-5)
1327+
Value vec2 = extFn(rewriter, loc, i8Vector, 4, 2);
1328+
// Position 3 (bits 6-7)
1329+
Value vec3 = extFn(rewriter, loc, i8Vector, 6, 2);
1330+
1331+
// 3. Interleave all 4 elements by first interleaving
1332+
// even elements and then odd
1333+
// vec0 = [0,0,0,0],...
1334+
// vec1 = [1,1,1,1],...
1335+
// vec2 = [2,2,2,2],...
1336+
// vec3 = [3,3,3,3],...
1337+
// 02 = [0,2,0,2,0,2,0,2],...
1338+
// 13 = [1,3,1,3,1,3,1,3],...
1339+
// 0213 = [0,1,2,3,...],...
1340+
Value interleave02 = rewriter.create<vector::InterleaveOp>(loc, vec0, vec2);
1341+
Value interleave13 = rewriter.create<vector::InterleaveOp>(loc, vec1, vec3);
1342+
return rewriter.create<vector::InterleaveOp>(loc, interleave02, interleave13);
1343+
}
1344+
12431345
/// Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1244-
/// ops that take advantage of high-level information to avoid leaving LLVM to
1245-
/// scramble with peephole optimizations.
1346+
/// ops to avoid leaving LLVM to scramble with peephole optimizations.
12461347
static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
12471348
Value srcValue) {
12481349
VectorType srcVecType = cast<VectorType>(srcValue.getType());
@@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
14431544
return failure();
14441545

14451546
// Perform the rewrite.
1547+
Location loc = conversionOp.getLoc();
1548+
const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
1549+
: extractNBitsPerByteAndExtendToI8;
14461550
Value subByteExt;
1447-
if (isSigned) {
1448-
subByteExt =
1449-
rewriteI4ToI8SignedExt(rewriter, conversionOp.getLoc(), srcValue);
1450-
} else {
1451-
subByteExt =
1452-
rewriteI4ToI8UnsignedExt(rewriter, conversionOp.getLoc(), srcValue);
1551+
switch (srcVecType.getElementType().getIntOrFloatBitWidth()) {
1552+
case 2:
1553+
subByteExt = rewriteI2ToI8Ext(rewriter, loc, srcValue, extFn);
1554+
break;
1555+
case 4:
1556+
subByteExt = rewriteI4ToI8Ext(rewriter, loc, srcValue, extFn);
1557+
break;
1558+
default:
1559+
return failure();
14531560
}
14541561

14551562
// Finalize the rewrite.
@@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
14901597
if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
14911598
return failure();
14921599

1600+
// TODO: Add support for truncating to i2.
1601+
if (dstVecType.getElementType().getIntOrFloatBitWidth() == 2)
1602+
return failure();
1603+
14931604
// Check general alignment preconditions. We invert the src/dst type order
14941605
// to reuse the existing precondition logic.
14951606
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,

0 commit comments

Comments
 (0)