Skip to content

Commit 7f3b0e5

Browse files
committed
[mlir][arith] Add narrowing patterns to commute more vector ops
This commutes the extension (`arith.extsi`, `arith.extui`) over the following vector ops: `vector.broadcast`, `vector.shape_cast`, `vector.transpose`, `vector.flat_transpose`. I focused on these as I saw them getting created by vector unroll patterns. Maybe except `vector.flat_transpose`. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D149534
1 parent 3ff8708 commit 7f3b0e5

File tree

2 files changed

+175
-3
lines changed

2 files changed

+175
-3
lines changed

mlir/lib/Dialect/Arith/Transforms/IntNarrowing.cpp

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
249249
// Patterns to Commute Extension Ops
250250
//===----------------------------------------------------------------------===//
251251

252+
struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
253+
using NarrowingPattern::NarrowingPattern;
254+
255+
LogicalResult matchAndRewrite(vector::BroadcastOp op,
256+
PatternRewriter &rewriter) const override {
257+
FailureOr<ExtensionOp> ext =
258+
ExtensionOp::from(op.getSource().getDefiningOp());
259+
if (failed(ext))
260+
return failure();
261+
262+
VectorType origTy = op.getResultVectorType();
263+
VectorType newTy =
264+
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
265+
Value newBroadcast =
266+
rewriter.create<vector::BroadcastOp>(op.getLoc(), newTy, ext->getIn());
267+
ext->recreateAndReplace(rewriter, op, newBroadcast);
268+
return success();
269+
}
270+
};
271+
252272
struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
253273
using NarrowingPattern::NarrowingPattern;
254274

@@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final
421441
}
422442
};
423443

444+
struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
445+
using NarrowingPattern::NarrowingPattern;
446+
447+
LogicalResult matchAndRewrite(vector::ShapeCastOp op,
448+
PatternRewriter &rewriter) const override {
449+
FailureOr<ExtensionOp> ext =
450+
ExtensionOp::from(op.getSource().getDefiningOp());
451+
if (failed(ext))
452+
return failure();
453+
454+
VectorType origTy = op.getResultVectorType();
455+
VectorType newTy =
456+
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
457+
Value newCast =
458+
rewriter.create<vector::ShapeCastOp>(op.getLoc(), newTy, ext->getIn());
459+
ext->recreateAndReplace(rewriter, op, newCast);
460+
return success();
461+
}
462+
};
463+
464+
struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
465+
using NarrowingPattern::NarrowingPattern;
466+
467+
LogicalResult matchAndRewrite(vector::TransposeOp op,
468+
PatternRewriter &rewriter) const override {
469+
FailureOr<ExtensionOp> ext =
470+
ExtensionOp::from(op.getVector().getDefiningOp());
471+
if (failed(ext))
472+
return failure();
473+
474+
VectorType origTy = op.getResultVectorType();
475+
VectorType newTy =
476+
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
477+
Value newTranspose = rewriter.create<vector::TransposeOp>(
478+
op.getLoc(), newTy, ext->getIn(), op.getTransp());
479+
ext->recreateAndReplace(rewriter, op, newTranspose);
480+
return success();
481+
}
482+
};
483+
484+
struct ExtensionOverFlatTranspose final
485+
: NarrowingPattern<vector::FlatTransposeOp> {
486+
using NarrowingPattern::NarrowingPattern;
487+
488+
LogicalResult matchAndRewrite(vector::FlatTransposeOp op,
489+
PatternRewriter &rewriter) const override {
490+
FailureOr<ExtensionOp> ext =
491+
ExtensionOp::from(op.getMatrix().getDefiningOp());
492+
if (failed(ext))
493+
return failure();
494+
495+
VectorType origTy = op.getType();
496+
VectorType newTy =
497+
origTy.cloneWith(origTy.getShape(), ext->getInElementType());
498+
Value newTranspose = rewriter.create<vector::FlatTransposeOp>(
499+
op.getLoc(), newTy, ext->getIn(), op.getRowsAttr(),
500+
op.getColumnsAttr());
501+
ext->recreateAndReplace(rewriter, op, newTranspose);
502+
return success();
503+
}
504+
};
505+
424506
//===----------------------------------------------------------------------===//
425507
// Pass Definitions
426508
//===----------------------------------------------------------------------===//
@@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns(
449531
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
450532
// Add commute patterns with a higher benefit. This is to expose more
451533
// optimization opportunities to narrowing patterns.
452-
patterns.add<ExtensionOverExtract, ExtensionOverExtractElement,
453-
ExtensionOverExtractStridedSlice, ExtensionOverInsert,
454-
ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
534+
patterns.add<ExtensionOverBroadcast, ExtensionOverExtract,
535+
ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
536+
ExtensionOverInsert, ExtensionOverInsertElement,
537+
ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
538+
ExtensionOverTranspose, ExtensionOverFlatTranspose>(
455539
patterns.getContext(), options, PatternBenefit(2));
456540

457541
patterns.add<SIToFPPattern, UIToFPPattern>(patterns.getContext(), options);

mlir/test/Dialect/Arith/int-narrowing.mlir

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -442,3 +442,91 @@ func.func @extui_over_insert_strided_slice_cst_2d(%a: vector<1x2xi8>) -> vector<
442442
%e = vector.insert_strided_slice %d, %cst {offsets = [0, 1], strides = [1, 1]} : vector<1x2xi32> into vector<2x3xi32>
443443
return %e : vector<2x3xi32>
444444
}
445+
446+
// CHECK-LABEL: func.func @extsi_over_broadcast_3xi16
447+
// CHECK-SAME: (%[[ARG:.+]]: i16)
448+
// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : i16 to vector<3xi16>
449+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[BCST]] : vector<3xi16> to vector<3xi32>
450+
// CHECK-NEXT: return %[[RET]] : vector<3xi32>
451+
func.func @extsi_over_broadcast_3xi16(%a: i16) -> vector<3xi32> {
452+
%b = arith.extsi %a : i16 to i32
453+
%r = vector.broadcast %b : i32 to vector<3xi32>
454+
return %r : vector<3xi32>
455+
}
456+
457+
// CHECK-LABEL: func.func @extui_over_broadcast_2x3xi16
458+
// CHECK-SAME: (%[[ARG:.+]]: vector<3xi16>)
459+
// CHECK-NEXT: %[[BCST:.+]] = vector.broadcast %[[ARG]] : vector<3xi16> to vector<2x3xi16>
460+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[BCST]] : vector<2x3xi16> to vector<2x3xi32>
461+
// CHECK-NEXT: return %[[RET]] : vector<2x3xi32>
462+
func.func @extui_over_broadcast_2x3xi16(%a: vector<3xi16>) -> vector<2x3xi32> {
463+
%b = arith.extui %a : vector<3xi16> to vector<3xi32>
464+
%r = vector.broadcast %b : vector<3xi32> to vector<2x3xi32>
465+
return %r : vector<2x3xi32>
466+
}
467+
468+
// CHECK-LABEL: func.func @extsi_over_shape_cast_2x3xi16
469+
// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
470+
// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<2x3xi16> to vector<3x2xi16>
471+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[CAST]] : vector<3x2xi16> to vector<3x2xi32>
472+
// CHECK-NEXT: return %[[RET]] : vector<3x2xi32>
473+
func.func @extsi_over_shape_cast_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
474+
%b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
475+
%r = vector.shape_cast %b : vector<2x3xi32> to vector<3x2xi32>
476+
return %r : vector<3x2xi32>
477+
}
478+
479+
// CHECK-LABEL: func.func @extui_over_shape_cast_5x2x3xi16
480+
// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>)
481+
// CHECK-NEXT: %[[CAST:.+]] = vector.shape_cast %[[ARG]] : vector<5x2x3xi16> to vector<2x3x5xi16>
482+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[CAST]] : vector<2x3x5xi16> to vector<2x3x5xi32>
483+
// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32>
484+
func.func @extui_over_shape_cast_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
485+
%b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
486+
%r = vector.shape_cast %b : vector<5x2x3xi32> to vector<2x3x5xi32>
487+
return %r : vector<2x3x5xi32>
488+
}
489+
490+
// CHECK-LABEL: func.func @extsi_over_transpose_2x3xi16
491+
// CHECK-SAME: (%[[ARG:.+]]: vector<2x3xi16>)
492+
// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 0] : vector<2x3xi16> to vector<3x2xi16>
493+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<3x2xi16> to vector<3x2xi32>
494+
// CHECK-NEXT: return %[[RET]] : vector<3x2xi32>
495+
func.func @extsi_over_transpose_2x3xi16(%a: vector<2x3xi16>) -> vector<3x2xi32> {
496+
%b = arith.extsi %a : vector<2x3xi16> to vector<2x3xi32>
497+
%r = vector.transpose %b, [1, 0] : vector<2x3xi32> to vector<3x2xi32>
498+
return %r : vector<3x2xi32>
499+
}
500+
501+
// CHECK-LABEL: func.func @extui_over_transpose_5x2x3xi16
502+
// CHECK-SAME: (%[[ARG:.+]]: vector<5x2x3xi16>)
503+
// CHECK-NEXT: %[[TRAN:.+]] = vector.transpose %[[ARG]], [1, 2, 0] : vector<5x2x3xi16> to vector<2x3x5xi16>
504+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<2x3x5xi16> to vector<2x3x5xi32>
505+
// CHECK-NEXT: return %[[RET]] : vector<2x3x5xi32>
506+
func.func @extui_over_transpose_5x2x3xi16(%a: vector<5x2x3xi16>) -> vector<2x3x5xi32> {
507+
%b = arith.extui %a : vector<5x2x3xi16> to vector<5x2x3xi32>
508+
%r = vector.transpose %b, [1, 2, 0] : vector<5x2x3xi32> to vector<2x3x5xi32>
509+
return %r : vector<2x3x5xi32>
510+
}
511+
512+
// CHECK-LABEL: func.func @extsi_over_flat_transpose_16xi16
513+
// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>)
514+
// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 4 : i32, rows = 4 : i32} : vector<16xi16> -> vector<16xi16>
515+
// CHECK-NEXT: %[[RET:.+]] = arith.extsi %[[TRAN]] : vector<16xi16> to vector<16xi32>
516+
// CHECK-NEXT: return %[[RET]] : vector<16xi32>
517+
func.func @extsi_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
518+
%b = arith.extsi %a : vector<16xi16> to vector<16xi32>
519+
%r = vector.flat_transpose %b {columns = 4 : i32, rows = 4 : i32} : vector<16xi32> -> vector<16xi32>
520+
return %r : vector<16xi32>
521+
}
522+
523+
// CHECK-LABEL: func.func @extui_over_flat_transpose_16xi16
524+
// CHECK-SAME: (%[[ARG:.+]]: vector<16xi16>)
525+
// CHECK-NEXT: %[[TRAN:.+]] = vector.flat_transpose %[[ARG]] {columns = 8 : i32, rows = 2 : i32} : vector<16xi16> -> vector<16xi16>
526+
// CHECK-NEXT: %[[RET:.+]] = arith.extui %[[TRAN]] : vector<16xi16> to vector<16xi32>
527+
// CHECK-NEXT: return %[[RET]] : vector<16xi32>
528+
func.func @extui_over_flat_transpose_16xi16(%a: vector<16xi16>) -> vector<16xi32> {
529+
%b = arith.extui %a : vector<16xi16> to vector<16xi32>
530+
%r = vector.flat_transpose %b {columns = 8 : i32, rows = 2 : i32} : vector<16xi32> -> vector<16xi32>
531+
return %r : vector<16xi32>
532+
}

0 commit comments

Comments
 (0)