Skip to content

[mlir][Vector] Add support for trunci to narrow type emulation #82565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Feb 27, 2024

Conversation

dcaballe
Copy link
Contributor

@dcaballe dcaballe commented Feb 22, 2024

This PR add support for arith.trunci to vector narrow type emulation for iX -> i4 truncations, for X >= 8. For now, the pattern only works for 1D vectors and is based on vector.shuffle ops. We would need vector.deinterleave to add n-D vector support.

…ector narrow type emulation

This PR replaces the generation of `vector.shuffle` with
`vector.interleave` in the i4 conversions in vector narrow type
emulation. The multi dimensional semantics of `vector.interleave` allow
us to enable these conversion emulations also for multi dimensional
vectors.
@dcaballe dcaballe changed the title [mlir][Vector] Replace vector.shuffle with vector.interleave in vector narrow type emulation [mlir][Vector] Add support for trunci to narrow type emulation Feb 22, 2024
Copy link

github-actions bot commented Feb 22, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

@llvmbot
Copy link
Member

llvmbot commented Feb 22, 2024

@llvm/pr-subscribers-mlir-vector

@llvm/pr-subscribers-mlir

Author: Diego Caballero (dcaballe)

Changes

WIP


Full diff: https://github.com/llvm/llvm-project/pull/82565.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+114-5)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+42)
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index fc11ae63e718a5..82c08cc5a54936 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -729,8 +729,8 @@ static LogicalResult commonConversionPrecondition(PatternRewriter &rewriter,
 
   // TODO: consider relaxing this restriction in the future if we find ways
   // to really work with subbyte elements across the MLIR/LLVM boundary.
-  unsigned resultBitwidth = preconditionType.getElementTypeBitWidth();
-  if (resultBitwidth % 8 != 0)
+  unsigned bitwidth = preconditionType.getElementTypeBitWidth();
+  if (bitwidth % 8 != 0)
     return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
 
   return success();
@@ -876,6 +876,57 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
   return rewriter.create<vector::InterleaveOp>(loc, low, high);
 }
 
+/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+static Value rewriteI8ToI4Trunc(PatternRewriter &rewriter, Location loc,
+                                Value srcValue) {
+  VectorType srcVecType = cast<VectorType>(srcValue.getType());
+  assert(srcVecType.getElementType().isSignlessInteger(8) &&
+         "Expected i8 type");
+
+  // 1. De-interleave low and high i8 elements.
+  int64_t vecDimSize = srcVecType.getShape().back();
+  SmallVector<int64_t> deinterleaveLowMaskValues;
+  SmallVector<int64_t> deinterleaveHighMaskValues;
+  deinterleaveLowMaskValues.reserve(vecDimSize / 2);
+  deinterleaveHighMaskValues.reserve(vecDimSize / 2);
+  for (int i = 0, end = vecDimSize; i < end; i += 2) {
+    deinterleaveLowMaskValues.push_back(i);
+    deinterleaveHighMaskValues.push_back(i + 1);
+  }
+
+  auto lowShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, srcValue, srcValue,
+      rewriter.getI64ArrayAttr(deinterleaveLowMaskValues));
+  auto highShuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, srcValue, srcValue,
+      rewriter.getI64ArrayAttr(deinterleaveHighMaskValues));
+
+  // 2. Zero out the upper side of each low i8 element.
+  constexpr int8_t i8LowBitMask = 0x0F;
+  Value zeroOutMask = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(lowShuffleOp.getResultVectorType(), i8LowBitMask));
+  Value zeroOutLow =
+      rewriter.create<arith::AndIOp>(loc, lowShuffleOp, zeroOutMask);
+
+  // 3. Move high i4 values to upper side of the byte.
+  constexpr int8_t bitsToShift = 4;
+  VectorType deinterI8VecType = highShuffleOp.getResultVectorType();
+  auto shiftValues = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(deinterI8VecType, bitsToShift));
+  Value shlHigh =
+      rewriter.create<arith::ShLIOp>(loc, highShuffleOp, shiftValues);
+
+  // 4. Merge high and low i4 values.
+  auto mergedHiLowOp = rewriter.create<arith::OrIOp>(loc, zeroOutLow, shlHigh);
+
+  // 5. Generate a bitcast vector<Xxi8> -> vector<2Xxi4>.
+  auto i4VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI4Type());
+  return rewriter.create<vector::BitCastOp>(loc, i4VecType, mergedHiLowOp);
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
@@ -1019,7 +1070,7 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
 
   LogicalResult matchAndRewrite(ConversionOpType conversionOp,
                                 PatternRewriter &rewriter) const override {
-    // Set up the BitCastRewriter and verify the preconditions.
+    // Verify the preconditions.
     Value srcValue = conversionOp.getIn();
     auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
     auto dstVecType = dyn_cast<VectorType>(conversionOp.getType());
@@ -1043,6 +1094,63 @@ struct RewriteAlignedSubByteIntSignedExt : OpRewritePattern<ConversionOpType> {
   }
 };
 
+/// Rewrite the i8 -> i4 part of any truncation into a sequence of shuffles and
+/// bitwise ops that take advantage of high-level information to avoid leaving
+/// LLVM to scramble with peephole optimizations.
+///
+/// For example:
+///    arith.trunci %in : vector<8xi32> to vector<8xi4>
+///      is rewriten as
+///
+///        %cst = arith.constant dense<15> : vector<4xi8>
+///        %cst_0 = arith.constant dense<4> : vector<4xi8>
+///        %0 = arith.trunci %in : vector<8xi32> to vector<8xi8>
+///        %1 = vector.shuffle %0, %0 [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+///        %2 = vector.shuffle %0, %0 [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+///        %3 = arith.andi %1, %cst : vector<4xi8>
+///        %4 = arith.shli %2, %cst_0 : vector<4xi8>
+///        %5 = arith.ori %3, %4 : vector<4xi8>
+///        %6 = vector.bitcast %5 : vector<4xi8> to vector<8xi4>
+///
+struct RewriteAlignedSubByteIntTrunc : OpRewritePattern<arith::TruncIOp> {
+  using OpRewritePattern<arith::TruncIOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(arith::TruncIOp truncOp,
+                                PatternRewriter &rewriter) const override {
+    // Verify the preconditions.
+    Value srcValue = truncOp.getIn();
+    auto srcVecType = dyn_cast<VectorType>(srcValue.getType());
+    auto dstVecType = dyn_cast<VectorType>(truncOp.getType());
+
+    // Only single dim vectors are supported until we have
+    // `vector.deinterleave`.
+    if (srcVecType.getRank() != 1)
+      return failure();
+
+    if (failed(commonConversionPrecondition(rewriter, srcVecType, truncOp)))
+      return failure();
+
+    // Check general alignment preconditions. We invert the src/dst type order
+    // to reuse the existing precondition logic.
+    if (failed(alignedConversionPrecondition(rewriter, dstVecType, srcVecType,
+                                             truncOp)))
+      return failure();
+
+    // Create a new iX -> i8 truncation op.
+    Location loc = truncOp.getLoc();
+    auto i8VecType = srcVecType.cloneWith(std::nullopt, rewriter.getI8Type());
+    Value i8TruncVal =
+        rewriter.create<arith::TruncIOp>(loc, i8VecType, srcValue);
+
+    // Rewrite the i8 -> i4 truncation part.
+    Value subByteTrunc = rewriteI8ToI4Trunc(rewriter, loc, i8TruncVal);
+
+    // Finalize the rewrite.
+    rewriter.replaceOp(truncOp, subByteTrunc);
+    return success();
+  }
+};
+
 /// Rewrite a sub-byte vector transpose into a sequence of instructions that
 /// perform the transpose on wider (byte) element types.
 /// For example:
@@ -1115,8 +1223,9 @@ void vector::populateVectorNarrowTypeRewritePatterns(
   // Patterns for aligned cases. We set higher priority as they are expected to
   // generate better performance for aligned cases.
   patterns.add<RewriteAlignedSubByteIntSignedExt<arith::ExtSIOp>,
-               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>>(
-      patterns.getContext(), benefit.getBenefit() + 1);
+               RewriteAlignedSubByteIntSignedExt<arith::SIToFPOp>,
+               RewriteAlignedSubByteIntTrunc>(patterns.getContext(),
+                                              benefit.getBenefit() + 1);
 }
 
 void vector::populateVectorTransposeNarrowTypeRewritePatterns(
diff --git a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
index 94e78ce40a3c19..8f0148119806c9 100644
--- a/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
+++ b/mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir
@@ -262,6 +262,48 @@ func.func @aligned_sitofp_2d(%a: vector<8x32xi4>) -> vector<8x32xf32> {
   return %0 : vector<8x32xf32>
 }
 
+// CHECK-LABEL: func.func @aligned_trunci(
+func.func @aligned_trunci(%a: vector<8xi32>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi32>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[I8:.*]] = arith.trunci %[[IN]] : vector<8xi32> to vector<8xi8>
+// CHECK:           %[[LOW:.*]] = vector.shuffle %[[I8]], %[[I8]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[I8]], %[[I8]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi32> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_base_case(
+func.func @aligned_trunci_base_case(%a: vector<8xi8>) -> vector<8xi4> {
+// CHECK-SAME:    %[[IN:.*]]: vector<8xi8>) -> vector<8xi4> {
+// CHECK-DAG:       %[[LOW_MASK:.*]] = arith.constant dense<15> : vector<4xi8>
+// CHECK-DAG:       %[[I4_BITS:.*]] = arith.constant dense<4> : vector<4xi8>
+// CHECK:           %[[LOW:.*]] = vector.shuffle %[[IN]], %[[IN]] [0, 2, 4, 6] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[HIGH:.*]] = vector.shuffle %[[IN]], %[[IN]] [1, 3, 5, 7] : vector<8xi8>, vector<8xi8>
+// CHECK:           %[[ZEROED_LOW:.*]] = arith.andi %[[LOW]], %[[LOW_MASK]] : vector<4xi8>
+// CHECK:           %[[SHL_HIGH:.*]] = arith.shli %[[HIGH]], %[[I4_BITS]] : vector<4xi8>
+// CHECK:           %[[MERGED:.*]] = arith.ori %[[ZEROED_LOW]], %[[SHL_HIGH]] : vector<4xi8>
+// CHECK:           %[[I4:.*]] = vector.bitcast %[[MERGED]] : vector<4xi8> to vector<8xi4>
+  %0 = arith.trunci %a : vector<8xi8> to vector<8xi4>
+  return %0 : vector<8xi4>
+}
+
+// CHECK-LABEL: func.func @aligned_trunci_2d(
+func.func @aligned_trunci_2d(%a: vector<8x32xi32>) -> vector<8x32xi4> {
+// CHECK-NOT:       vector.shuffle
+// CHECK-NOT:       vector.andi
+// CHECK-NOT:       vector.shli
+// CHECK-NOT:       vector.ori
+// CHECK:           arith.trunci
+  %0 = arith.trunci %a : vector<8x32xi32> to vector<8x32xi4>
+  return %0 : vector<8x32xi4>
+}
+
 // CHECK-LABEL: func.func @i4_transpose(
 func.func @i4_transpose(%a: vector<8x16xi4>) -> vector<16x8xi4> {
 // CHECK-SAME:    %[[IN:.*]]: vector<8x16xi4>) -> vector<16x8xi4> {

@dcaballe
Copy link
Contributor Author

Kind ping

Copy link
Member

@MacDue MacDue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, just a some minor comments.

@@ -876,6 +876,57 @@ static Value rewriteI4ToI8SignedExt(PatternRewriter &rewriter, Location loc,
return rewriter.create<vector::InterleaveOp>(loc, low, high);
}

/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// Rewrite the i8 -> i4 signed extension into a sequence of shuffles and
/// Rewrite the i8 -> i4 truncation into a sequence of shuffles and

?

"Expected i8 type");

// 1. De-interleave low and high i8 elements.
int64_t vecDimSize = srcVecType.getShape().back();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could vecDimSize not be divisible by 2?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Odd number of i4 elements is not supported. I added that to the preconditions and an assert here.

@dcaballe dcaballe merged commit 9d0acb8 into llvm:main Feb 27, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants