-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][Vector] Add support for trunci to narrow type emulation #82565
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ector narrow type emulation This PR replaces the generation of `vector.shuffle` with `vector.interleave` in the i4 conversions in vector narrow type emulation. The multi dimensional semantics of `vector.interleave` allow us to enable these conversion emulations also for multi dimensional vectors.
✅ With the latest revision this PR passed the C/C++ code formatter. |
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Diego Caballero (dcaballe) ChangesWIP Full diff: https://github.com/llvm/llvm-project/pull/82565.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..82c08cc5a54936 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
// TODO: consider relaxing this restriction in the future if we find ways
// to really work with subbyte elements across the MLIR/LLVM boundary.
- unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
- if (resultBitwidth % 8 != 0)
+ unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+ if (bitwidth % 8 != 0)
return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
return success();
@@ -876,6 +876,57 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}
+/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+ Value srcValue) {
+ VectorType srcVecType = cast<VectorType>(srcValue.getType());
+ assert(srcVecType.getElementType().isSignlessInteger(8) &&
+ "Expected i8 type");
+
+ // 1. De-interleave low and high i8 elements.
+ int64_t vecDimSize = srcVecType.getShape().back();
+ SmallVector<int64_t> deinterleaveLowMaskValues;
+ SmallVector<int64_t> deinterleaveHighMaskValues;
+ deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+ deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+ for (int i = 0, end = vecDimSize; i < end; i += 2) {
+ deinterleaveLowMaskValues.push_back(i);
+ deinterleaveHighMaskValues.push_back(i + 1);
+ }
+
+ auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, srcValue, srcValue,
+ rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+ auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+ loc, srcValue, srcValue,
+ rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+ // 2. Zero out the upper side of each low i8 element.
+ constexpr int8_t i8LowBitMask = 0x0F;
+ Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+ loc,
+ DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
+ Value zeroOutLow =
+ rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+
+ // 3. Move high i4 values to upper side of the byte.
+ constexpr int8_t bitsToShift = 4;
+ VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+ auto shiftValues = rewriter.create<arith::ConstantOp>(
+ loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+ Value shlHigh =
+ rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+ // 4. Merge high and low i4 values.
+ auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
+
+ // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+ auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+ return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
namespace {
/// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
/// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1070,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
LogicalResult matchAndRewrite(ConversionOpType conversionOp,
PatternRewriter &rewriter) const override {
- // Set up the BitCastRewriter and verify the preconditions.
+ // Verify the preconditions.
Value srcValue = conversionOp.getIn();
auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1043,6 +1094,63 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
}
};
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+/// arith.trunci %in : vector<8xi32> to vector<8xi4>
+/// is rewriten as
+///
+/// %cst = arith.constant dense<15> : vector<4xi8>
+/// %cst_0 = arith.constant dense<4> : vector<4xi8>
+/// %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
+/// %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+/// %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+/// %3 = arith.andi %1, %cst : vector<4xi8>
+/// %4 = arith.shli %2, %cst_0 : vector<4xi8>
+/// %5 = arith.ori %3, %4 : vector<4xi8>
+/// %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+ using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+ PatternRewriter &rewriter) const override {
+ // Verify the preconditions.
+ Value srcValue = truncOp.getIn();
+ auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+ auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+
+ // Only single dim vectors are supported until we have
+ // `vector.deinterleave`.
+ if (srcVecType.getRank() != 1)
+ return failure();
+
+ if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+ return failure();
+
+ // Check general alignment preconditions. We invert the src/dst type order
+ // to reuse the existing precondition logic.
+ if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+ truncOp)))
+ return failure();
+
+ // Create a new iX -> i8 truncation op.
+ Location loc = truncOp.getLoc();
+ auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+ Value i8TruncVal =
+ rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+ // Rewrite the i8 -> i4 truncation part.
+ Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+ // Finalize the rewrite.
+ rewriter.replaceOp(truncOp, subByteTrunc);
+ return success();
+ }
+};
+
/// Rewrite a sub-byte vector transpose into a sequence of instructions that
/// perform the transpose on wider (byte) element types.
/// For example:
@@ -1115,8 +1223,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
// Patterns for aligned cases. We set higher priority as they are expected to
// generate better performance for aligned cases.
patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
- RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
- patterns.getContext(), benefit.getBenefit() + 1);
+ RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+ RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
+ benefit.getBenefit() + 1);
}
void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..8f0148119806c9 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
return %0 : vector<8x32xf32>
}
+// CHECK-LABEL: func.func @aligned_trunci(
+func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME: %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG: %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG: %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK: %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK: %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK: %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK: %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK: %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+ %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+ return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT: vector.shuffle
+// CHECK-NOT: vector.andi
+// CHECK-NOT: vector.shli
+// CHECK-NOT: vector.ori
+// CHECK: arith.trunci
+ %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+ return %0 : vector<8x32xi4>
+}
+
// CHECK-LABEL: func.func @i4_transpose(
func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
// CHECK-SAME: %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {
|
Kind ping |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, just a some minor comments.
@@ -876,6 +876,57 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc, | |||
return rewriter.create<vector::InterleaveOp>(loc, low, high); | |||
} | |||
|
|||
/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and | |
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and |
?
"Expected i8 type"); | ||
|
||
// 1. De-interleave low and high i8 elements. | ||
int64_t vecDimSize = srcVecType.getShape().back(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could vecDimSize
not be divisible by 2?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Odd number of i4 elements is not supported. I added that to the preconditions and an assert here.
This PR add support for
arith.trunci
to vector narrow type emulation for iX -> i4 truncations, for X >= 8. For now, the pattern only works for 1D vectors and is based onvector.shuffle
ops. We would needvector.deinterleave
to add n-D vector support.