Skip to content

[mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… #66648

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Sep 18, 2023

Conversation

nicolasvasilache
Copy link
Contributor

…cast) expansion

This revision adds a rewrite for sequences of vector ext(bitcast) to use a more efficient sequence of vector operations comprising shuffle and bitwise ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of shuffle, andi, ori with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.

@llvmbot
Copy link
Member

llvmbot commented Sep 18, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-vector

Changes

…cast) expansion

This revision adds a rewrite for sequences of vector ext(bitcast) to use a more efficient sequence of vector operations comprising shuffle and bitwise ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of shuffle, andi, ori with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.

Patch is 45.88 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/66648.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h (+7)
  • (modified) mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp (+270-148)
  • (modified) mlir/test/Dialect/Vector/vector-rewrite-narrow-types.mlir (+159-136)
  • (modified) mlir/test/Integration/Dialect/Vector/CPU/test-rewrite-narrow-types.mlir (+46)
diff --git a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
index 8652fc7f5e5c640..eb561ba3b23557a 100644
--- a/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
+++ b/mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h
@@ -23,6 +23,7 @@ namespace mlir {
 class RewritePatternSet;
 
 namespace arith {
+class AndIOp;
 class NarrowTypeEmulationConverter;
 class TruncIOp;
 } // namespace arith
@@ -309,6 +310,12 @@ FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
                                         arith::TruncIOp truncOp,
                                         vector::BroadcastOp maybeBroadcastOp);
 
+/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
+/// vector operations comprising `shuffle` and `bitwise` ops.
+FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
+                                     vector::BitCastOp bitCastOp,
+                                     vector::BroadcastOp maybeBroadcastOp);
+
 /// Appends patterns for rewriting vector operations over narrow types with
 /// ops over wider types.
 void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
index 9d659bf694a2445..aa04d804b3a57f2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorEmulateNarrowType.cpp
@@ -18,11 +18,15 @@
 #include "mlir/IR/BuiltinTypes.h"
 #include "mlir/IR/TypeUtilities.h"
 #include "mlir/IR/Value.h"
+#include "mlir/Support/LLVM.h"
 #include "mlir/Transforms/DialectConversion.h"
+#include "llvm/ADT/APInt.h"
 #include "llvm/ADT/SmallVector.h"
+#include "llvm/ADT/TypeSwitch.h"
 #include "llvm/Support/Debug.h"
 #include "llvm/Support/raw_ostream.h"
 #include <cstdint>
+#include <numeric>
 
 using namespace mlir;
 
@@ -224,6 +228,98 @@ struct BitCastBitsEnumerator {
   SmallVector<SourceElementRangeList> sourceElementRanges;
 };
 
+/// Rewrite vector.bitcast to a sequence of shuffles and bitwise ops that take
+/// advantage of high-level information to avoid leaving LLVM to scramble with
+/// peephole optimizations.
+/// BitCastBitsEnumerator encodes for each element of the target vector the
+/// provenance of the bits in the source vector. We can "transpose" this
+/// information to build a sequence of shuffles and bitwise ops that will
+/// produce the desired result.
+//
+/// Consider the following motivating example:
+/// ```
+///   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
+/// ```
+//
+/// BitCastBitsEnumerator contains the following information:
+/// ```
+///   { 0: b@[0..5) lshl: 0}{ 1: b@[0..3) lshl: 5}
+///   { 1: b@[3..5) lshl: 0}{ 2: b@[0..5) lshl: 2}{ 3: b@[0..1) lshl: 7}
+///   { 3: b@[1..5) lshl: 0}{ 4: b@[0..4) lshl: 4}
+///   { 4: b@[4..5) lshl: 0}{ 5: b@[0..5) lshl: 1}{ 6: b@[0..2) lshl: 6}
+///   { 6: b@[2..5) lshl: 0}{ 7: b@[0..5) lshl: 3}
+///   { 8: b@[0..5) lshl: 0}{ 9: b@[0..3) lshl: 5}
+///   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7}
+///   {11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4}
+///   {12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6}
+///   {14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
+///   {16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
+///   {17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
+///   {19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
+///   {20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1}{22: b@[0..2) lshl: 6}
+///   {22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3}
+///   {24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5}
+///   {25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7}
+///   {27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
+///   {28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
+///   {30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3}
+/// ```
+//
+/// In the above, each row represents one target vector element and each
+/// column represents one bit contribution from a source vector element.
+/// The algorithm creates vector.shuffle operations (in this case there are 3
+/// shuffles (i.e. the max number of columns in BitCastBitsEnumerator), as
+/// follows:
+///   1. for each vector.shuffle, collect the source vectors that participate in
+///     this shuffle. One source vector per target element of the resulting
+///     vector.shuffle. If there is no source element contributing bits for the
+///     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
+///     2 columns).
+///   2. represent the bitrange in the source vector as a mask. If there is no
+///     source element contributing bits for the current vector.shuffle, take 0.
+///   3. shift right by the proper amount to align the source bitrange at
+///     position 0. This is exactly the low end of the bitrange. For instance,
+///     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
+///     shift right by 3 to get the bits contributed by the source element #1
+///     into position 0.
+///   4. shift left by the proper amount to to align to the desired position in
+///     the result element vector.  For instance, the contribution of the second
+///     source element for the first row needs to be shifted by `5` to form the
+///     first i8 result element.
+///
+/// Eventually, we end up building  the sequence
+/// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update
+/// the result vector (i.e. the `shiftright -> shiftleft -> or` part) with the
+/// bits extracted from the source vector (i.e. the `shuffle -> and` part).
+struct BitCastRewriter {
+  /// Helper metadata struct to hold the static quantities for the rewrite.
+  struct Metadata {
+    SmallVector<int64_t> shuffles;
+    SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+  };
+
+  BitCastRewriter(VectorType sourceVectorType, VectorType targetVectorType);
+
+  /// Verify that the preconditions for the rewrite are met.
+  LogicalResult precondition(PatternRewriter &rewriter,
+                             VectorType targetVectorType, Operation *op);
+
+  /// Precompute the metadata for the rewrite.
+  SmallVector<BitCastRewriter::Metadata>
+  precomputeMetadata(IntegerType shuffledElementType);
+
+  /// Rewrite one step of the sequence:
+  ///   `(shuffle -> and -> shiftright -> shiftleft -> or)`.
+  Value rewriteStep(PatternRewriter &rewriter, Location loc, Value initialValue,
+                    Value runningResult,
+                    const BitCastRewriter::Metadata &metadata);
+
+private:
+  /// Underlying enumerator that encodes the provenance of the bits in the each
+  /// element of the result vector.
+  BitCastBitsEnumerator enumerator;
+};
+
 } // namespace
 
 static raw_ostream &operator<<(raw_ostream &os,
@@ -275,79 +371,104 @@ BitCastBitsEnumerator::BitCastBitsEnumerator(VectorType sourceVectorType,
   }
 }
 
+BitCastRewriter::BitCastRewriter(VectorType sourceVectorType,
+                                 VectorType targetVectorType)
+    : enumerator(BitCastBitsEnumerator(sourceVectorType, targetVectorType)) {
+  LDBG("\n" << enumerator.sourceElementRanges);
+}
+
+LogicalResult BitCastRewriter::precondition(PatternRewriter &rewriter,
+                                            VectorType targetVectorType,
+                                            Operation *op) {
+  if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
+    return rewriter.notifyMatchFailure(op, "scalable or >1-D vector");
+
+  // TODO: consider relaxing this restriction in the future if we find ways
+  // to really work with subbyte elements across the MLIR/LLVM boundary.
+  int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
+  if (resultBitwidth % 8 != 0)
+    return rewriter.notifyMatchFailure(op, "bitwidth is not k * 8");
+
+  return success();
+}
+
+SmallVector<BitCastRewriter::Metadata>
+BitCastRewriter::precomputeMetadata(IntegerType shuffledElementType) {
+  SmallVector<BitCastRewriter::Metadata> result;
+  for (int64_t shuffleIdx = 0, e = enumerator.getMaxNumberOfEntries();
+       shuffleIdx < e; ++shuffleIdx) {
+    SmallVector<int64_t> shuffles;
+    SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
+
+    // Create the attribute quantities for the shuffle / mask / shift ops.
+    for (auto &l : enumerator.sourceElementRanges) {
+      int64_t sourceElement =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceElementIdx : 0;
+      shuffles.push_back(sourceElement);
+
+      int64_t bitLo =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitBegin : 0;
+      int64_t bitHi =
+          (shuffleIdx < (int64_t)l.size()) ? l[shuffleIdx].sourceBitEnd : 0;
+      IntegerAttr mask = IntegerAttr::get(
+          shuffledElementType,
+          llvm::APInt::getBitsSet(shuffledElementType.getIntOrFloatBitWidth(),
+                                  bitLo, bitHi));
+      masks.push_back(mask);
+
+      int64_t shiftRight = bitLo;
+      shiftRightAmounts.push_back(
+          IntegerAttr::get(shuffledElementType, shiftRight));
+
+      int64_t shiftLeft = l.computeLeftShiftAmount(shuffleIdx);
+      shiftLeftAmounts.push_back(
+          IntegerAttr::get(shuffledElementType, shiftLeft));
+    }
+
+    result.push_back({shuffles, masks, shiftRightAmounts, shiftLeftAmounts});
+  }
+  return result;
+}
+
+Value BitCastRewriter::rewriteStep(PatternRewriter &rewriter, Location loc,
+                                   Value initialValue, Value runningResult,
+                                   const BitCastRewriter::Metadata &metadata) {
+  // Create vector.shuffle from the metadata.
+  auto shuffleOp = rewriter.create<vector::ShuffleOp>(
+      loc, initialValue, initialValue, metadata.shuffles);
+
+  // Intersect with the mask.
+  VectorType shuffledVectorType = shuffleOp.getResultVectorType();
+  auto constOp = rewriter.create<arith::ConstantOp>(
+      loc, DenseElementsAttr::get(shuffledVectorType, metadata.masks));
+  Value andValue = rewriter.create<arith::AndIOp>(loc, shuffleOp, constOp);
+
+  // Align right on 0.
+  auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(shuffledVectorType, metadata.shiftRightAmounts));
+  Value shiftedRight =
+      rewriter.create<arith::ShRUIOp>(loc, andValue, shiftRightConstantOp);
+
+  // Shift bits left into their final position.
+  auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
+      loc,
+      DenseElementsAttr::get(shuffledVectorType, metadata.shiftLeftAmounts));
+  Value shiftedLeft =
+      rewriter.create<arith::ShLIOp>(loc, shiftedRight, shiftLeftConstantOp);
+
+  runningResult =
+      runningResult
+          ? rewriter.create<arith::OrIOp>(loc, runningResult, shiftedLeft)
+          : shiftedLeft;
+
+  return runningResult;
+}
+
 namespace {
 /// Rewrite bitcast(trunci) to a sequence of shuffles and bitwise ops that take
 /// advantage of high-level information to avoid leaving LLVM to scramble with
 /// peephole optimizations.
-
-// BitCastBitsEnumerator encodes for each element of the target vector the
-// provenance of the bits in the source vector. We can "transpose" this
-// information to build a sequence of shuffles and bitwise ops that will
-// produce the desired result.
-//
-// Let's take the following motivating example to explain the algorithm:
-// ```
-//   %0 = arith.trunci %a : vector<32xi64> to vector<32xi5>
-//   %1 = vector.bitcast %0 : vector<32xi5> to vector<20xi8>
-// ```
-//
-// BitCastBitsEnumerator contains the following information:
-// ```
-//   { 0: b@[0..5) lshl: 0}{1: b@[0..3) lshl: 5 }
-//   { 1: b@[3..5) lshl: 0}{2: b@[0..5) lshl: 2}{3: b@[0..1) lshl: 7 }
-//   { 3: b@[1..5) lshl: 0}{4: b@[0..4) lshl: 4 }
-//   { 4: b@[4..5) lshl: 0}{5: b@[0..5) lshl: 1}{6: b@[0..2) lshl: 6 }
-//   { 6: b@[2..5) lshl: 0}{7: b@[0..5) lshl: 3 }
-//   { 8: b@[0..5) lshl: 0}{9: b@[0..3) lshl: 5 }
-//   { 9: b@[3..5) lshl: 0}{10: b@[0..5) lshl: 2}{11: b@[0..1) lshl: 7 }
-//   { 11: b@[1..5) lshl: 0}{12: b@[0..4) lshl: 4 }
-//   { 12: b@[4..5) lshl: 0}{13: b@[0..5) lshl: 1}{14: b@[0..2) lshl: 6 }
-//   { 14: b@[2..5) lshl: 0}{15: b@[0..5) lshl: 3}
-//   { 16: b@[0..5) lshl: 0}{17: b@[0..3) lshl: 5}
-//   { 17: b@[3..5) lshl: 0}{18: b@[0..5) lshl: 2}{19: b@[0..1) lshl: 7}
-//   { 19: b@[1..5) lshl: 0}{20: b@[0..4) lshl: 4}
-//   { 20: b@[4..5) lshl: 0}{21: b@[0..5) lshl: 1 }{22: b@[0..2) lshl: 6}
-//   { 22: b@[2..5) lshl: 0}{23: b@[0..5) lshl: 3 }
-//   { 24: b@[0..5) lshl: 0}{25: b@[0..3) lshl: 5 }
-//   { 25: b@[3..5) lshl: 0}{26: b@[0..5) lshl: 2}{27: b@[0..1) lshl: 7 }
-//   { 27: b@[1..5) lshl: 0}{28: b@[0..4) lshl: 4}
-//   { 28: b@[4..5) lshl: 0}{29: b@[0..5) lshl: 1}{30: b@[0..2) lshl: 6}
-//   { 30: b@[2..5) lshl: 0}{31: b@[0..5) lshl: 3 }
-// ```
-//
-// In the above, each row represents one target vector element and each
-// column represents one bit contribution from a source vector element.
-// The algorithm creates vector.shuffle operations (in this case there are 3
-// shuffles (i.e. the max number of columns in BitCastBitsEnumerator). The
-// algorithm populates the bits as follows:
-// ```
-//     src bits 0 ...
-// 1st shuffle |xxxxx   |xx      |...
-// 2nd shuffle |     xxx|  xxxxx |...
-// 3rd shuffle |        |       x|...
-// ```
-//
-// The algorithm proceeds as follows:
-//   1. for each vector.shuffle, collect the source vectors that participate in
-//     this shuffle. One source vector per target element of the resulting
-//     vector.shuffle. If there is no source element contributing bits for the
-//     current vector.shuffle, take 0 (i.e. row 0 in the above example has only
-//     2 columns).
-//   2. represent the bitrange in the source vector as a mask. If there is no
-//     source element contributing bits for the current vector.shuffle, take 0.
-//   3. shift right by the proper amount to align the source bitrange at
-//     position 0. This is exactly the low end of the bitrange. For instance,
-//     the first element of row 2 is `{ 1: b@[3..5) lshl: 0}` and one needs to
-//     shift right by 3 to get the bits contributed by the source element #1
-//     into position 0.
-//   4. shift left by the proper amount to to align to the desired position in
-//     the result element vector.  For instance, the contribution of the second
-//     source element for the first row needs to be shifted by `5` to form the
-//     first i8 result element.
-// Eventually, we end up building  the sequence
-// `(shuffle -> and -> shiftright -> shiftleft -> or)` to iteratively update the
-// result vector (i.e. the `shiftright -> shiftleft -> or` part) with the bits
-// extracted from the source vector (i.e. the `shuffle -> and` part).
 struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
   using OpRewritePattern::OpRewritePattern;
 
@@ -359,93 +480,92 @@ struct RewriteBitCastOfTruncI : OpRewritePattern<vector::BitCastOp> {
     if (!truncOp)
       return rewriter.notifyMatchFailure(bitCastOp, "not a trunci source");
 
+    // Set up the BitCastRewriter and verify the precondition.
+    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
     VectorType targetVectorType = bitCastOp.getResultVectorType();
-    if (targetVectorType.getRank() != 1 || targetVectorType.isScalable())
-      return rewriter.notifyMatchFailure(bitCastOp, "scalable or >1-D vector");
-    // TODO: consider relaxing this restriction in the future if we find ways
-    // to really work with subbyte elements across the MLIR/LLVM boundary.
-    int64_t resultBitwidth = targetVectorType.getElementTypeBitWidth();
-    if (resultBitwidth % 8 != 0)
-      return rewriter.notifyMatchFailure(bitCastOp, "bitwidth is not k * 8");
+    BitCastRewriter bcr(sourceVectorType, targetVectorType);
+    if (failed(bcr.precondition(rewriter, targetVectorType, bitCastOp)))
+      return failure();
 
-    VectorType sourceVectorType = bitCastOp.getSourceVectorType();
-    BitCastBitsEnumerator be(sourceVectorType, targetVectorType);
-    LDBG("\n" << be.sourceElementRanges);
-
-    Value initialValue = truncOp.getIn();
-    auto initalVectorType = initialValue.getType().cast<VectorType>();
-    auto initalElementType = initalVectorType.getElementType();
-    auto initalElementBitWidth = initalElementType.getIntOrFloatBitWidth();
-
-    Value res;
-    for (int64_t shuffleIdx = 0, e = be.getMaxNumberOfEntries(); shuffleIdx < e;
-         ++shuffleIdx) {
-      SmallVector<int64_t> shuffles;
-      SmallVector<Attribute> masks, shiftRightAmounts, shiftLeftAmounts;
-
-      // Create the attribute quantities for the shuffle / mask / shift ops.
-      for (auto &srcEltRangeList : be.sourceElementRanges) {
-        bool idxContributesBits =
-            (shuffleIdx < (int64_t)srcEltRangeList.size());
-        int64_t sourceElementIdx =
-            idxContributesBits ? srcEltRangeList[shuffleIdx].sourceElementIdx
-                               : 0;
-        shuffles.push_back(sourceElementIdx);
-
-        int64_t bitLo = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitBegin
-                            : 0;
-        int64_t bitHi = (shuffleIdx < (int64_t)srcEltRangeList.size())
-                            ? srcEltRangeList[shuffleIdx].sourceBitEnd
-                            : 0;
-        IntegerAttr mask = IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth),
-            llvm::APInt::getBitsSet(initalElementBitWidth, bitLo, bitHi));
-        masks.push_back(mask);
-
-        int64_t shiftRight = bitLo;
-        shiftRightAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftRight));
-
-        int64_t shiftLeft = srcEltRangeList.computeLeftShiftAmount(shuffleIdx);
-        shiftLeftAmounts.push_back(IntegerAttr::get(
-            rewriter.getIntegerType(initalElementBitWidth), shiftLeft));
-      }
-
-      // Create vector.shuffle #shuffleIdx.
-      auto shuffleOp = rewriter.create<vector::ShuffleOp>(
-          bitCastOp.getLoc(), initialValue, initialValue, shuffles);
-      // And with the mask.
-      VectorType vt = VectorType::Builder(initalVectorType)
-                          .setDim(initalVectorType.getRank() - 1, masks.size());
-      auto constOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, masks));
-      Value andValue = rewriter.create<arith::AndIOp>(bitCastOp.getLoc(),
-                                                      shuffleOp, constOp);
-      // Align right on 0.
-      auto shiftRightConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftRightAmounts));
-      Value shiftedRight = rewriter.create<arith::ShRUIOp>(
-          bitCastOp.getLoc(), andValue, shiftRightConstantOp);
-
-      auto shiftLeftConstantOp = rewriter.create<arith::ConstantOp>(
-          bitCastOp.getLoc(), DenseElementsAttr::get(vt, shiftLeftAmounts));
-      Value shiftedLeft = rewriter.create<arith::ShLIOp>(
-          bitCastOp.getLoc(), shiftedRight, shiftLeftConstantOp);
-
-      res = res ? rewriter.create<arith::OrIOp>(bitCastOp.getLoc(), res,
-                                                shiftedLeft)
-                : shiftedLeft;
+    // Perform the rewrite.
+    Value truncValue = truncOp.getIn();
+    auto shuffledElementType =
+        cast<IntegerType>(getElementTypeOrSelf(truncValue.getType()));
+    Value runningResult;
+    for (const BitCastRewriter ::Metadata &metadata :
+         bcr.precomputeMetadata(shuffledElementType)) {
+      runningResult = bcr.rewriteStep(rewriter, bitCastOp->getLoc(), truncValue,
+                                      runningResult, metadata);
     }
 
-    bool narrowing = resultBitwidth <= initalElementBitWidth;
+    // Finalize the rewrite.
+    bool narrowing = targetVectorType.getElementTypeBitWidth() <=
+                     shuffledElementType.getIntOrFloatBitWidth();
     if (narrowing) {
       rewriter.replaceOpWithNewOp<arith::TruncIOp>(
-          bitCastOp, bitCastOp.getResultVectorType(), res);
+          bitCastOp, bitCastOp.getResultVectorType(), run...
[truncated]

@nicolasvasilache nicolasvasilache force-pushed the ext branch 2 times, most recently from 8bbacf9 to ae7115e Compare September 18, 2023 14:12
…cast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of
vector operations comprising `shuffle` and `bitwise` ops.

Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect.

The implementation is 90% a refactoring of the existing `trunci(bitcast)` pattern into a common BitCastRewriter.

The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance
in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori`
with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is
not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type
is not an even number of bytes, further complexities arise that are not improved by this pattern:
the heavy lifting still needs to be done by LLVM.
Copy link
Collaborator

@qcolombet qcolombet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@nicolasvasilache nicolasvasilache merged commit 04ba475 into llvm:main Sep 18, 2023
ZijunZhaoCCK pushed a commit to ZijunZhaoCCK/llvm-project that referenced this pull request Sep 19, 2023
llvm#66648)

…cast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to
use a more efficient sequence of vector operations comprising `shuffle`
and `bitwise` ops.

Such patterns appear naturally when writing quantization /
dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the
result vector and determines its provenance in the source vector. The
enumeration is used to generate the proper sequence of `shuffle`,
`andi`, `ori` with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails
out if the final vector element type is not a multiple of 8. This is a
failsafe heuristic determined empirically: if the resulting type is not
an even number of bytes, further complexities arise that are not
improved by this pattern: the heavy lifting still needs to be done by
LLVM.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
llvm#66648)

…cast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to
use a more efficient sequence of vector operations comprising `shuffle`
and `bitwise` ops.

Such patterns appear naturally when writing quantization /
dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the
result vector and determines its provenance in the source vector. The
enumeration is used to generate the proper sequence of `shuffle`,
`andi`, `ori` with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails
out if the final vector element type is not a multiple of 8. This is a
failsafe heuristic determined empirically: if the resulting type is not
an even number of bytes, further complexities arise that are not
improved by this pattern: the heavy lifting still needs to be done by
LLVM.
zahiraam pushed a commit to tahonermann/llvm-project that referenced this pull request Oct 24, 2023
llvm#66648)

…cast) expansion

This revision adds a rewrite for sequences of vector `ext(bitcast)` to
use a more efficient sequence of vector operations comprising `shuffle`
and `bitwise` ops.

Such patterns appear naturally when writing quantization /
dequantization functionality with the vector dialect.

The rewrite performs a simple enumeration of each of the bits in the
result vector and determines its provenance in the source vector. The
enumeration is used to generate the proper sequence of `shuffle`,
`andi`, `ori` with shifts`.

The rewrite currently only applies to 1-D non-scalable vectors and bails
out if the final vector element type is not a multiple of 8. This is a
failsafe heuristic determined empirically: if the resulting type is not
an even number of bytes, further complexities arise that are not
improved by this pattern: the heavy lifting still needs to be done by
LLVM.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants