Skip to content

Commit 04ba475

Browse files
[mlir][Vector] Add a rewrite pattern for better low-precision ext(bit… (#66648)
…cast) expansion This revision adds a rewrite for sequences of vector `ext(bitcast)` to use a more efficient sequence of vector operations comprising `shuffle` and `bitwise` ops. Such patterns appear naturally when writing quantization / dequantization functionality with the vector dialect. The rewrite performs a simple enumeration of each of the bits in the result vector and determines its provenance in the source vector. The enumeration is used to generate the proper sequence of `shuffle`, `andi`, `ori` with shifts`. The rewrite currently only applies to 1-D non-scalable vectors and bails out if the final vector element type is not a multiple of 8. This is a failsafe heuristic determined empirically: if the resulting type is not an even number of bytes, further complexities arise that are not improved by this pattern: the heavy lifting still needs to be done by LLVM.
1 parent 2a38d83 commit 04ba475

File tree

5 files changed

+384
-149
lines changed

5 files changed

+384
-149
lines changed

mlir/include/mlir/Dialect/Vector/TransformOps/VectorTransformOps.td

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,8 @@ def ApplyRewriteNarrowTypePatternsOp : Op<Transform_Dialect,
300300

301301
This is usually a late step that is run after bufferization as part of the
302302
process of lowering to e.g. LLVM or NVVM.
303+
304+
Warning: these patterns currently only work for little endian targets.
303305
}];
304306

305307
let assemblyFormat = "attr-dict";

mlir/include/mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ namespace mlir {
2323
class RewritePatternSet;
2424

2525
namespace arith {
26+
class AndIOp;
2627
class NarrowTypeEmulationConverter;
2728
class TruncIOp;
2829
} // namespace arith
@@ -304,13 +305,22 @@ void populateVectorNarrowTypeEmulationPatterns(
304305

305306
/// Rewrite a vector `bitcast(trunci)` to use a more efficient sequence of
306307
/// vector operations comprising `shuffle` and `bitwise` ops.
308+
/// Warning: these patterns currently only work for little endian targets.
307309
FailureOr<Value> rewriteBitCastOfTruncI(RewriterBase &rewriter,
308310
vector::BitCastOp bitCastOp,
309311
arith::TruncIOp truncOp,
310312
vector::BroadcastOp maybeBroadcastOp);
311313

314+
/// Rewrite a vector `ext(bitcast)` to use a more efficient sequence of
315+
/// vector operations comprising `shuffle` and `bitwise` ops.
316+
/// Warning: these patterns currently only work for little endian targets.
317+
FailureOr<Value> rewriteExtOfBitCast(RewriterBase &rewriter, Operation *extOp,
318+
vector::BitCastOp bitCastOp,
319+
vector::BroadcastOp maybeBroadcastOp);
320+
312321
/// Appends patterns for rewriting vector operations over narrow types with
313322
/// ops over wider types.
323+
/// Warning: these patterns currently only work for little endian targets.
314324
void populateVectorNarrowTypeRewritePatterns(RewritePatternSet &patterns,
315325
PatternBenefit benefit = 1);
316326

0 commit comments

Comments
 (0)