Skip to content

Commit 703c9e9

Browse files
[mlir][Vector] Add a rewrite pattern for better low-precision ext(bitcast) expansion
This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The implementation is 90% a refactoring of the existing `trunci(bitcast)` pattern into a common BitCastRewriter. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` with shifts`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
1 parent d4d8f21 commit 703c9e9

File tree

4 files changed

+379
-149
lines changed

4 files changed

+379
-149
lines changed

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace mlir {
2323
class RewritePatternSet;
2424

2525
namespace arith {
26+
class AndIOp;
2627
class NarrowTypeEmulationConverter;
2728
class TruncIOp;
2829
} // namespace arith
@@ -309,6 +310,12 @@ FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
309310
arith::TruncIOp truncOp,
310311
vector::BroadcastOp maybeBroadcastOp);
311312

313+
/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
314+
/// vector operations comprising `shuffle` and `bitwise` ops.
315+
FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
316+
vector::BitCastOp bitCastOp,
317+
vector::BroadcastOp maybeBroadcastOp);
318+
312319
/// Appends patterns for rewriting vector operations over narrow types with
313320
/// ops over wider types.
314321
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,

0 commit comments

Comments
 (0)