Skip to content

Commit 411ee3c

Browse files
committed
[mlir][Vector] Update VectorEmulateNarrowType.cpp (4/N)
This is PR 4 in a series of N patches aimed at improving "VectorEmulateNarrowType.cpp". This is mainly minor refactoring, no major functional changes are made/added. 1. Update `alignedConversionPrecondition` (1): This method didn't require the vector type for the "destination" argument. The underlying element type is sufficient. The corresponding argument has been renamed as `multiByteScalarTy` - this is meant as the multi-byte emulated type (`i8`, `i16`, `i32`, etc). 2. Update `alignedConversionPrecondition` (2): In #121298, we replaced `dstElemBitwidt` in this calculation: ```cpp const int numSrcElemsPerDestElem = dstElemBitwidth / srcElemBitwidth; ``` with the hard-coded value of 8: ```cpp const int numSrcElemsPerDestElem = 8 / srcElemBitwidth; ``` That was correct as for the patterns for which this hook was/is used: * `RewriteAlignedSubByteIntExt`, * `RewriteAlignedSubByteIntTrunc`. The destination type (or, more precisely, the emulated type) was always `i8`. In this PR, I am switching back to a more generic approach - the calculation should take into account the bit-width of the emulated type. Note that at the call sites I am passing `i8` as the emulated type, so the end-result is effectively identical. However, the intent is clearer, i.e., the underlying value is 8 because the emulated type happens to be `i8` (as opposed using a magic number). 3. Update alignedConversionPrecondition (3): The final check has been replaced with a new helper method, `isSubByteVecFittable`. This new method is also re-used within the code and hopefully will allow us more code re-use moving forward (to avoid re-implementing the same condition). 4. Update alignedConversionPrecondition (4): NEXT STEPS: We need to clarify the meaning of "source" and "destination" types. Currently the usage is ambiguous. For example, for this `arith.extsi` Op, `vector<8xi2>` and `vector<8xi32>` are the "source" and "destination" types, respectively: ```mlir %0 = arith.extsi %arg0 : vector<8xi2> to vector<8xi32> } ``` However, patterns like `RewriteAlignedSubByteIntExt` introduce `vector.bitcast` Ops like this: ```mlir %bitcast = vector.bitcast %arg0 : vector<8xi2> to vector<2xi8> ``` I've noticed that we tend to mix `vector<2xi8>` and `vector<8xi32>` as the destination types and that should be clarified.
1 parent 180145d commit 411ee3c

File tree

1 file changed

+102
-32
lines changed

1 file changed

+102
-32
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp

Lines changed: 102 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,38 @@ struct ConvertVectorMaskedLoad final
10501050
}
10511051
};
10521052

1053+
/// Check whether `subByteVecTy` fits wthin a vector of `multiByteScalarTy`
1054+
///
1055+
/// "Fitting" means that `subByteVecTy` (a vector of sub-byte elements, e.g.
1056+
/// vector<4xi4>), can fit within N scalar elements of type `multiByteScalarTy`
1057+
/// (a multi-byte scalar, e.g. i16), where N is some integer.
1058+
///
1059+
/// Put differently, this method checks whether this would be valid:
1060+
///
1061+
/// vector.bitcast subByteVecTy into vector<N x multiByteScalarTy>
1062+
///
1063+
/// EXAMPLES:
1064+
/// * vector<4xi4> -> i16 - yes (N = 1)
1065+
/// * vector<4xi4> -> i8 - yes (N = 2)
1066+
/// * vector<3xi4> -> i8 - no (N would have to be 1.5)
1067+
/// * vector<3xi2> -> i16 - no (N would have to be 0.5)
1068+
static bool isSubByteVecFittable(VectorType subByteVecTy,
1069+
Type multiByteScalarTy) {
1070+
assert((isa<IntegerType, FloatType>(multiByteScalarTy)) && "Not scalar!");
1071+
1072+
int subByteBits = subByteVecTy.getElementType().getIntOrFloatBitWidth();
1073+
int multiByteBits = multiByteScalarTy.getIntOrFloatBitWidth();
1074+
1075+
assert(subByteBits < 8 && "Not a sub-byte scalar type!");
1076+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1077+
assert(multiByteBits % subByteBits == 0 && "Unalagined element types!");
1078+
1079+
int elemsPerMultiByte = multiByteBits / subByteBits;
1080+
1081+
// TODO: This is a bit too restrictive for vectors rank > 1.
1082+
return subByteVecTy.getShape().back() % elemsPerMultiByte == 0;
1083+
}
1084+
10531085
//===----------------------------------------------------------------------===//
10541086
// ConvertVectorTransferRead
10551087
//===----------------------------------------------------------------------===//
@@ -1086,7 +1118,8 @@ struct ConvertVectorTransferRead final
10861118
auto origElements = op.getVectorType().getNumElements();
10871119

10881120
// Note, per-element-alignment was already verified above.
1089-
bool isFullyAligned = origElements % emulatedPerContainerElem == 0;
1121+
bool isFullyAligned =
1122+
isSubByteVecFittable(op.getVectorType(), containerElemTy);
10901123

10911124
auto newPadding = rewriter.create<arith::ExtUIOp>(loc, containerElemTy,
10921125
adaptor.getPadding());
@@ -1387,41 +1420,76 @@ LogicalResult BitCastRewriter::commonPrecondition(PatternRewriter &rewriter,
13871420
return commonConversionPrecondition(rewriter, preconditionType, op);
13881421
}
13891422

1390-
/// Verify that `subByteVecType` and `dstType` are aligned. Alignment
1391-
/// means that:
1392-
/// 1. The `dstType` element type is a multiple of the
1393-
/// `srcVectorOfSubByteType` element type (e.g. i4 vs i8 is OK, but i3 vs i8
1394-
/// is not supported). Let this multiple be `N`.
1395-
/// 2. The number of the (trailing) elements in `srcVectorOfSubByteType` is a
1396-
/// multiple of `N` from 1. (e.g., when targetting i8, 2xi4 is OK, but 3xi4 is
1397-
/// not supported).
1423+
/// Verify that `subByteVecTy` (vector) and `containerTy` (scalar) are aligned.
1424+
///
1425+
/// Alignment means that `subByteVecTy` can be packed into a vector of
1426+
/// `containerTy` elements. More specifically:
1427+
/// 1. The bit-width of `containerTy` is a multiple of the
1428+
/// bit-width of `subByteVecTy` elements. For example, for `i4` and `i16`
1429+
/// this multiple is 4.
1430+
/// 2. The multiple from 1. above divides evenly the number of the (trailing)
1431+
/// elements in `subByteVecTy`.
1432+
///
1433+
/// EXAMPLE 1:
1434+
/// `subByteVecTy = vector<2xi4>`, and
1435+
/// `containerTy = i16`
1436+
///
1437+
/// 2 divides evenly 4 ( = 16 / 4), hence both conditions are _met_.
1438+
///
1439+
/// EXAMPLE 2:
1440+
/// `subByteVecTy = vector<3xi4>`, and
1441+
/// `containerTy = i16`
1442+
///
1443+
/// 3 _does not_ divide evenly 4 (= 16/4), hence the conditions are _not met_.
1444+
///
1445+
/// EXAMPLE 3:
1446+
/// `subByteVecTy = vector<3xi3>`, and
1447+
/// `containerTy = i16`
1448+
///
1449+
/// 16 _is not_ a multiple of 3, hence the conditions are _not met_.
13981450
///
13991451
/// NOTE: This method assumes that common conversion preconditions are met. In
1400-
/// particular, the element type of `dstType` is assumed to be a multi-byte
1401-
/// type (e.g. i8, i16, i32).
1452+
/// particular, `containerTy` is assumed to be a
1453+
/// multi-byte scalar type (e.g., i8, i16, i32).
14021454
static LogicalResult alignedConversionPrecondition(PatternRewriter &rewriter,
1403-
VectorType subByteVecType,
1404-
VectorType dstType,
1455+
VectorType subByteVecTy,
1456+
Type containerTy,
14051457
Operation *op) {
1406-
if (!subByteVecType || !dstType)
1407-
return rewriter.notifyMatchFailure(op, "Not a supported aligned case");
1408-
unsigned srcElemBitwidth = subByteVecType.getElementTypeBitWidth();
1409-
unsigned dstElemBitwidth = dstType.getElementTypeBitWidth();
1458+
// TODO: This is validating the inputs rather than checking the conditions
1459+
// documented above. Replace with an assert.
1460+
if (!subByteVecTy)
1461+
return rewriter.notifyMatchFailure(op, "not a vector!");
14101462

1411-
if (dstElemBitwidth < 8)
1412-
return rewriter.notifyMatchFailure(
1413-
op, "the bitwidth of dstType must be greater than or equal to 8");
1414-
if (dstElemBitwidth % srcElemBitwidth != 0)
1415-
return rewriter.notifyMatchFailure(op, "unaligned cases are not supported");
1416-
if (srcElemBitwidth != 2 && srcElemBitwidth != 4)
1463+
// TODO: This is validating the inputs rather than checking the conditions
1464+
// documented above. Replace with an assert.
1465+
if (!containerTy.isIntOrFloat())
1466+
return rewriter.notifyMatchFailure(op, "not a scalar!");
1467+
1468+
unsigned subByteBits = subByteVecTy.getElementTypeBitWidth();
1469+
unsigned multiByteBits = containerTy.getIntOrFloatBitWidth();
1470+
1471+
// Enforced by the common pre-conditions.
1472+
assert(multiByteBits % 8 == 0 && "Not a multi-byte scalar type!");
1473+
1474+
// TODO: Remove this condition - the assert above (and
1475+
// commonConversionPrecondtion) takes care of that.
1476+
if (multiByteBits < 8)
1477+
return rewriter.notifyMatchFailure(op, "not a multi-byte scalar type!");
1478+
1479+
// TODO: Add support other widths (when/if needed)
1480+
if (subByteBits != 2 && subByteBits != 4)
14171481
return rewriter.notifyMatchFailure(
1418-
op, "only src bitwidth of 2 or 4 is supported at this moment");
1482+
op, "only 2-bit and 4-bit sub-byte type is supported at this moment");
1483+
1484+
// Condition 1.
1485+
if (multiByteBits % subByteBits != 0)
1486+
return rewriter.notifyMatchFailure(op, "unalagined element types");
14191487

1420-
const int numSrcElemsPerByte = 8 / srcElemBitwidth;
1421-
if ((subByteVecType.getShape().back() % numSrcElemsPerByte) != 0)
1488+
// Condition 2.
1489+
if (!isSubByteVecFittable(subByteVecTy, containerTy))
14221490
return rewriter.notifyMatchFailure(
1423-
op, "the trailing dimension of the input vector of sub-bytes must be a "
1424-
"multiple of 8 / <sub-byte-width>");
1491+
op, "not possible to fit this sub-byte vector type into a vector of "
1492+
"the given multi-byte type");
14251493

14261494
return success();
14271495
}
@@ -1858,8 +1926,9 @@ struct RewriteAlignedSubByteIntExt : OpRewritePattern<ConversionOpType> {
18581926
return failure();
18591927

18601928
// Check general alignment preconditions.
1861-
if (failed(alignedConversionPrecondition(rewriter, srcVecType, dstVecType,
1862-
conversionOp)))
1929+
Type containerType = rewriter.getI8Type();
1930+
if (failed(alignedConversionPrecondition(rewriter, srcVecType,
1931+
containerType, conversionOp)))
18631932
return failure();
18641933

18651934
// Perform the rewrite.
@@ -1923,8 +1992,9 @@ struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
19231992

19241993
// Check general alignment preconditions. We invert the src/dst type order
19251994
// to reuse the existing precondition logic.
1926-
if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
1927-
truncOp)))
1995+
Type containerType = rewriter.getI8Type();
1996+
if (failed(alignedConversionPrecondition(rewriter, dstVecType,
1997+
containerType, truncOp)))
19281998
return failure();
19291999

19302000
// Create a new iX -> i8 truncation op.

0 commit comments

Comments
 (0)