@@ -1090,15 +1090,20 @@ static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1090
1090
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth ();
1091
1091
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth ();
1092
1092
1093
- // Only {s}i4 -> (size_of({{s}i/f}) >= 8) are supported for now.
1094
- if (srcElemBitwidth != 4 || dstElemBitwidth < 8 ||
1095
- (dstElemBitwidth % srcElemBitwidth) != 0 )
1096
- return rewriter.notifyMatchFailure (op, " Not a supported aligned case" );
1093
+ if (dstElemBitwidth < 8 )
1094
+ return rewriter.notifyMatchFailure (
1095
+ op, " the bitwidth of dstType must be greater than or equal to 8" );
1096
+ if (dstElemBitwidth % srcElemBitwidth != 0 )
1097
+ return rewriter.notifyMatchFailure (op, " unaligned cases are not supported" );
1098
+ if (srcElemBitwidth != 2 && srcElemBitwidth != 4 )
1099
+ return rewriter.notifyMatchFailure (
1100
+ op, " only src bitwidth of 2 or 4 is supported at this moment" );
1097
1101
1098
- const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth;
1099
- if ((subByteVecType.getShape ().back () % numSrcElemsPerDestElem ) != 0 )
1102
+ const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1103
+ if ((subByteVecType.getShape ().back () % numSrcElemsPerByte ) != 0 )
1100
1104
return rewriter.notifyMatchFailure (
1101
- op, " Not an even number of i4 elements in trailing dim" );
1105
+ op, " the trailing dimension of the input vector of sub-bytes must be a "
1106
+ " multiple of 8 / <sub-byte-width>" );
1102
1107
1103
1108
return success ();
1104
1109
}
@@ -1179,70 +1184,166 @@ Value BitCastRewriter::genericRewriteStep(
1179
1184
return runningResult;
1180
1185
}
1181
1186
1182
- // / Rewrite the i4 -> i8 signed extension into a sequence of shuffles and
1183
- // / bitwise ops that take advantage of high-level information to avoid leaving
1184
- // / LLVM to scramble with peephole optimizations.
1185
- static Value rewriteI4ToI8SignedExt (PatternRewriter &rewriter, Location loc,
1186
- Value srcValue) {
1187
- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1188
- assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1189
- " Expected i4 type" );
1187
+ // / Bitcasts the aligned `subByteVec` vector to a vector of i8.
1188
+ // / Where aligned means it satisfies the alignedConversionPreconditions.
1189
+ // /
1190
+ // / Example:
1191
+ // / vector<16x16xi2> -> vector<16x4xi8>
1192
+ // / vector<16x16xi4> -> vector<16x8xi8>
1193
+ static Value bitcastSubByteVectorToI8 (PatternRewriter &rewriter, Location loc,
1194
+ Value subByteVec) {
1195
+ auto srcVecType = cast<VectorType>(subByteVec.getType ());
1196
+ int64_t srcBitwidth = srcVecType.getElementType ().getIntOrFloatBitWidth ();
1197
+ assert (8 % srcBitwidth == 0 &&
1198
+ " Unsupported sub-byte type (not a divisor of i8)" );
1199
+ int64_t numSrcElemsPerByte = 8 / srcBitwidth;
1200
+ SmallVector<int64_t > vecShape (srcVecType.getShape ());
1201
+ // Adjust last dimension of the vector, so the total size remains the same.
1202
+ vecShape.back () = vecShape.back () / numSrcElemsPerByte;
1203
+ auto i8VecType = VectorType::get (vecShape, rewriter.getI8Type ());
1204
+ return rewriter.create <vector::BitCastOp>(loc, i8VecType, subByteVec);
1205
+ }
1190
1206
1191
- // 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1192
- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1193
- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1194
- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1195
- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1196
- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1207
+ // / Extracts a signed N-bit sequence from each element of a vector of bytes,
1208
+ // / starting at the specified bit index.
1209
+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
1210
+ // /
1211
+ // / Example for a single element:
1212
+ // / Extract numBits=2 starting at bitIdx=2
1213
+ // / src = [0 | 1 | 0 | 1 | 1 | 1 | 1 | 0]
1214
+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1215
+ // / target = [. . . . ^ ^ . .]
1216
+ // /
1217
+ // / The target sequence is [11](decimal=-1) as signed 2-bit integer.
1218
+ // / So the result should be [11 11 11 11](decimal=-1) as signed 8-bit integer.
1219
+ // /
1220
+ // / src = [01 01 11 10]
1221
+ // / shl = arith.shl(src, 4) -> [11 10 00 00]
1222
+ // / result = arith.shrsi(shl, 6) -> [11 11 11 11]
1223
+ static Value extractNBitsPerByteAndSignExtendToI8 (PatternRewriter &rewriter,
1224
+ Location loc, Value src,
1225
+ int bitIdx, int numBits) {
1226
+ auto srcType = cast<VectorType>(src.getType ());
1227
+ Value shl = src;
1228
+ int8_t bitsToShiftLeft = 8 - numBits - bitIdx;
1229
+ assert (bitIdx >= 0 && bitsToShiftLeft >= 0 && numBits > 0 && numBits <= 8 &&
1230
+ " Invalid bitIdx range" );
1231
+ if (bitsToShiftLeft != 0 ) {
1232
+ Value shiftLeftValues = rewriter.create <arith::ConstantOp>(
1233
+ loc, DenseElementsAttr::get (srcType, bitsToShiftLeft));
1234
+ shl = rewriter.create <arith::ShLIOp>(loc, src, shiftLeftValues);
1235
+ }
1197
1236
1198
- // 2. Extend i4 elements to i8 elements using shifts. Low i4 elemens of each
1199
- // byte are place in one vector and the high i4 elements in another vector.
1200
- constexpr int8_t bitsToShift = 4 ;
1201
- auto shiftValues = rewriter.create <arith::ConstantOp>(
1202
- loc, DenseElementsAttr::get (i8VecType, bitsToShift));
1203
- Value shl = rewriter.create <arith::ShLIOp>(loc, i8Vector, shiftValues);
1204
- Value low = rewriter.create <arith::ShRSIOp>(loc, shl, shiftValues);
1205
- Value high = rewriter.create <arith::ShRSIOp>(loc, i8Vector, shiftValues);
1237
+ int8_t bitsToShiftRight = 8 - numBits;
1238
+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1239
+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1240
+ Value shr = rewriter.create <arith::ShRSIOp>(loc, shl, shiftRightValues);
1241
+ return shr;
1242
+ }
1206
1243
1207
- // 3. Interleave low and high i8 elements.
1208
- return rewriter.create <vector::InterleaveOp>(loc, low, high);
1244
+ // / Extracts an unsigned N-bit sequence from each element of a vector of bytes,
1245
+ // / starting at the specified bit index.
1246
+ // / The `bitIdx` starts at 0 from the LSB and moves to the left.
1247
+ // /
1248
+ // / Example for a single element:
1249
+ // / Extract numBits=2 starting at bitIdx=2
1250
+ // / src = [0 | 1 | 0 | 1 | 1 | 0 | 1 | 0]
1251
+ // / indices = [7 | 6 | 5 | 4 | 3 | 2 | 1 | 0]
1252
+ // / target = [. . . . ^ ^ . .]
1253
+ // /
1254
+ // / The target sequence is [10](decimal=2) as unsigned 2-bit integer.
1255
+ // / So the result should be [00 00 00 10](decimal=2) as unsigned 8-bit integer.
1256
+ // /
1257
+ // / src = [01 01 10 10]
1258
+ // / mask = [00 00 00 11]
1259
+ // / shr = arith.shrui(src, 2) = [00 01 01 10]
1260
+ // / result = arith.andi(shr, mask) = [00 00 00 10]
1261
+ // / NOTE: Similarly to extractNBitsPerByteAndSignExtendToI8, this could be
1262
+ // / achieved by using arith::ShLIOp + arith::ShRUIOp instead of the masking.
1263
+ // / However, by using arith::ShRUIOp + arith::AndIOp, we are eliminating shift
1264
+ // / left when the index is 0.
1265
+ static Value extractNBitsPerByteAndExtendToI8 (PatternRewriter &rewriter,
1266
+ Location loc, Value src,
1267
+ int bitIdx, int numBits) {
1268
+ assert (bitIdx >= 0 && bitIdx <= 8 - numBits && numBits > 0 && numBits <= 8 &&
1269
+ " Invalid bitIdx range" );
1270
+ auto srcType = cast<VectorType>(src.getType ());
1271
+ int8_t bitsToShiftRight = bitIdx;
1272
+ Value shr = src;
1273
+ if (bitsToShiftRight != 0 ) {
1274
+ Value shiftRightValues = rewriter.create <arith::ConstantOp>(
1275
+ loc, DenseElementsAttr::get (srcType, bitsToShiftRight));
1276
+ shr = rewriter.create <arith::ShRUIOp>(loc, src, shiftRightValues);
1277
+ }
1278
+ if (bitIdx + numBits == 8 ) {
1279
+ return shr;
1280
+ }
1281
+ uint8_t lowBitsMask = (1 << numBits) - 1 ;
1282
+ Value lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1283
+ loc, DenseElementsAttr::get (srcType, lowBitsMask));
1284
+ return rewriter.create <arith::AndIOp>(loc, shr, lowBitsMaskValues);
1209
1285
}
1210
1286
1211
- // / Rewrite the i4 -> i8 unsigned extension into a sequence of shuffles and
1212
- // / bitwise ops that take advantage of high-level information to avoid leaving
1213
- // / LLVM to scramble with peephole optimizations.
1214
- static Value rewriteI4ToI8UnsignedExt (PatternRewriter &rewriter, Location loc,
1215
- Value srcValue) {
1216
- VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1287
+ using ExtractNBitsFn =
1288
+ std::function<Value(PatternRewriter &, Location, Value, int , int )>;
1289
+
1290
+ // / Rewrite the i4 -> i8 extension into a sequence of shuffles and
1291
+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1292
+ static Value rewriteI4ToI8Ext (PatternRewriter &rewriter, Location loc,
1293
+ Value srcValue, const ExtractNBitsFn &extFn) {
1294
+ auto srcVecType = cast<VectorType>(srcValue.getType ());
1217
1295
assert (srcVecType.getElementType ().isSignlessInteger (4 ) &&
1218
1296
" Expected i4 type" );
1219
1297
1220
1298
// 1. Generate a bitcast vector<Xxi4> -> vector<X/2xi8>.
1221
- SmallVector<int64_t > i8VecShape = llvm::to_vector (srcVecType.getShape ());
1222
- constexpr int64_t i4Toi8BitwidthFactor = 2 ;
1223
- i8VecShape.back () = i8VecShape.back () / i4Toi8BitwidthFactor;
1224
- auto i8VecType = VectorType::get (i8VecShape, rewriter.getI8Type ());
1225
- Value i8Vector = rewriter.create <vector::BitCastOp>(loc, i8VecType, srcValue);
1226
-
1227
- // 2 Extend the i4 elements using shifts & masking. Low i4 elements of each
1228
- // byte are placed in one vector and the high i4 elements in another vector.
1229
- constexpr uint8_t lowBitsMask = 15 ; // Equivalent to [00001111] bit mask
1230
- auto lowBitsMaskValues = rewriter.create <arith::ConstantOp>(
1231
- loc, DenseElementsAttr::get (i8VecType, lowBitsMask));
1232
- Value low = rewriter.create <arith::AndIOp>(loc, i8VecType, i8Vector,
1233
- lowBitsMaskValues);
1234
- constexpr int8_t highBitsToShift = 4 ;
1235
- auto highShiftValues = rewriter.create <arith::ConstantOp>(
1236
- loc, DenseElementsAttr::get (i8VecType, highBitsToShift));
1237
- Value high = rewriter.create <arith::ShRUIOp>(loc, i8Vector, highShiftValues);
1299
+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1300
+
1301
+ // 2. Extend i4 elements to i8 elements. Low i4 elemens of each
1302
+ // byte are place in one vector and the high i4 elements in another vector.
1303
+ Value low = extFn (rewriter, loc, i8Vector, 0 , 4 );
1304
+ Value high = extFn (rewriter, loc, i8Vector, 4 , 4 );
1238
1305
1239
1306
// 3. Interleave low and high i8 elements.
1240
1307
return rewriter.create <vector::InterleaveOp>(loc, low, high);
1241
1308
}
1242
1309
1310
+ // / Rewrite the i2 -> i8 extension into a sequence of shuffles and
1311
+ // / bitwise ops to avoid leaving LLVM to scramble with peephole optimizations.
1312
+ static Value rewriteI2ToI8Ext (PatternRewriter &rewriter, Location loc,
1313
+ Value srcValue, const ExtractNBitsFn &extFn) {
1314
+ VectorType srcVecType = cast<VectorType>(srcValue.getType ());
1315
+ assert (srcVecType.getElementType ().isSignlessInteger (2 ) &&
1316
+ " Expected i2 type" );
1317
+
1318
+ // 1. Generate a bitcast vector<Xxi2> -> vector<X/2xi8>.
1319
+ Value i8Vector = bitcastSubByteVectorToI8 (rewriter, loc, srcValue);
1320
+
1321
+ // 2. Extract each i2 element
1322
+ // Positon 0 (bits 0-1)
1323
+ Value vec0 = extFn (rewriter, loc, i8Vector, 0 , 2 );
1324
+ // Position 1 (bits 2-3)
1325
+ Value vec1 = extFn (rewriter, loc, i8Vector, 2 , 2 );
1326
+ // Position 2 (bits 4-5)
1327
+ Value vec2 = extFn (rewriter, loc, i8Vector, 4 , 2 );
1328
+ // Position 3 (bits 6-7)
1329
+ Value vec3 = extFn (rewriter, loc, i8Vector, 6 , 2 );
1330
+
1331
+ // 3. Interleave all 4 elements by first interleaving
1332
+ // even elements and then odd
1333
+ // vec0 = [0,0,0,0],...
1334
+ // vec1 = [1,1,1,1],...
1335
+ // vec2 = [2,2,2,2],...
1336
+ // vec3 = [3,3,3,3],...
1337
+ // 02 = [0,2,0,2,0,2,0,2],...
1338
+ // 13 = [1,3,1,3,1,3,1,3],...
1339
+ // 0213 = [0,1,2,3,...],...
1340
+ Value interleave02 = rewriter.create <vector::InterleaveOp>(loc, vec0, vec2);
1341
+ Value interleave13 = rewriter.create <vector::InterleaveOp>(loc, vec1, vec3);
1342
+ return rewriter.create <vector::InterleaveOp>(loc, interleave02, interleave13);
1343
+ }
1344
+
1243
1345
// / Rewrite the i8 -> i4 truncation into a deinterleave and series of bitwise
1244
- // / ops that take advantage of high-level information to avoid leaving LLVM to
1245
- // / scramble with peephole optimizations.
1346
+ // / ops to avoid leaving LLVM to scramble with peephole optimizations.
1246
1347
static Value rewriteI8ToI4Trunc (PatternRewriter &rewriter, Location loc,
1247
1348
Value srcValue) {
1248
1349
VectorType srcVecType = cast<VectorType>(srcValue.getType ());
@@ -1443,13 +1544,19 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
1443
1544
return failure ();
1444
1545
1445
1546
// Perform the rewrite.
1547
+ Location loc = conversionOp.getLoc ();
1548
+ const auto &extFn = isSigned ? extractNBitsPerByteAndSignExtendToI8
1549
+ : extractNBitsPerByteAndExtendToI8;
1446
1550
Value subByteExt;
1447
- if (isSigned) {
1448
- subByteExt =
1449
- rewriteI4ToI8SignedExt (rewriter, conversionOp.getLoc (), srcValue);
1450
- } else {
1451
- subByteExt =
1452
- rewriteI4ToI8UnsignedExt (rewriter, conversionOp.getLoc (), srcValue);
1551
+ switch (srcVecType.getElementType ().getIntOrFloatBitWidth ()) {
1552
+ case 2 :
1553
+ subByteExt = rewriteI2ToI8Ext (rewriter, loc, srcValue, extFn);
1554
+ break ;
1555
+ case 4 :
1556
+ subByteExt = rewriteI4ToI8Ext (rewriter, loc, srcValue, extFn);
1557
+ break ;
1558
+ default :
1559
+ return failure ();
1453
1560
}
1454
1561
1455
1562
// Finalize the rewrite.
@@ -1490,6 +1597,10 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
1490
1597
if (failed (commonConversionPrecondition (rewriter, srcVecType, truncOp)))
1491
1598
return failure ();
1492
1599
1600
+ // TODO: Add support for truncating to i2.
1601
+ if (dstVecType.getElementType ().getIntOrFloatBitWidth () == 2 )
1602
+ return failure ();
1603
+
1493
1604
// Check general alignment preconditions. We invert the src/dst type order
1494
1605
// to reuse the existing precondition logic.
1495
1606
if (failed (alignedConversionPrecondition (rewriter, dstVecType, srcVecType,
0 commit comments