@@ -249,6 +249,26 @@ using UIToFPPattern = IToFPPattern<arith::UIToFPOp, ExtensionKind::Zero>;
249
249
// Patterns to Commute Extension Ops
250
250
// ===----------------------------------------------------------------------===//
251
251
252
+ struct ExtensionOverBroadcast final : NarrowingPattern<vector::BroadcastOp> {
253
+ using NarrowingPattern::NarrowingPattern;
254
+
255
+ LogicalResult matchAndRewrite (vector::BroadcastOp op,
256
+ PatternRewriter &rewriter) const override {
257
+ FailureOr<ExtensionOp> ext =
258
+ ExtensionOp::from (op.getSource ().getDefiningOp ());
259
+ if (failed (ext))
260
+ return failure ();
261
+
262
+ VectorType origTy = op.getResultVectorType ();
263
+ VectorType newTy =
264
+ origTy.cloneWith (origTy.getShape (), ext->getInElementType ());
265
+ Value newBroadcast =
266
+ rewriter.create <vector::BroadcastOp>(op.getLoc (), newTy, ext->getIn ());
267
+ ext->recreateAndReplace (rewriter, op, newBroadcast);
268
+ return success ();
269
+ }
270
+ };
271
+
252
272
struct ExtensionOverExtract final : NarrowingPattern<vector::ExtractOp> {
253
273
using NarrowingPattern::NarrowingPattern;
254
274
@@ -421,6 +441,68 @@ struct ExtensionOverInsertStridedSlice final
421
441
}
422
442
};
423
443
444
+ struct ExtensionOverShapeCast final : NarrowingPattern<vector::ShapeCastOp> {
445
+ using NarrowingPattern::NarrowingPattern;
446
+
447
+ LogicalResult matchAndRewrite (vector::ShapeCastOp op,
448
+ PatternRewriter &rewriter) const override {
449
+ FailureOr<ExtensionOp> ext =
450
+ ExtensionOp::from (op.getSource ().getDefiningOp ());
451
+ if (failed (ext))
452
+ return failure ();
453
+
454
+ VectorType origTy = op.getResultVectorType ();
455
+ VectorType newTy =
456
+ origTy.cloneWith (origTy.getShape (), ext->getInElementType ());
457
+ Value newCast =
458
+ rewriter.create <vector::ShapeCastOp>(op.getLoc (), newTy, ext->getIn ());
459
+ ext->recreateAndReplace (rewriter, op, newCast);
460
+ return success ();
461
+ }
462
+ };
463
+
464
+ struct ExtensionOverTranspose final : NarrowingPattern<vector::TransposeOp> {
465
+ using NarrowingPattern::NarrowingPattern;
466
+
467
+ LogicalResult matchAndRewrite (vector::TransposeOp op,
468
+ PatternRewriter &rewriter) const override {
469
+ FailureOr<ExtensionOp> ext =
470
+ ExtensionOp::from (op.getVector ().getDefiningOp ());
471
+ if (failed (ext))
472
+ return failure ();
473
+
474
+ VectorType origTy = op.getResultVectorType ();
475
+ VectorType newTy =
476
+ origTy.cloneWith (origTy.getShape (), ext->getInElementType ());
477
+ Value newTranspose = rewriter.create <vector::TransposeOp>(
478
+ op.getLoc (), newTy, ext->getIn (), op.getTransp ());
479
+ ext->recreateAndReplace (rewriter, op, newTranspose);
480
+ return success ();
481
+ }
482
+ };
483
+
484
+ struct ExtensionOverFlatTranspose final
485
+ : NarrowingPattern<vector::FlatTransposeOp> {
486
+ using NarrowingPattern::NarrowingPattern;
487
+
488
+ LogicalResult matchAndRewrite (vector::FlatTransposeOp op,
489
+ PatternRewriter &rewriter) const override {
490
+ FailureOr<ExtensionOp> ext =
491
+ ExtensionOp::from (op.getMatrix ().getDefiningOp ());
492
+ if (failed (ext))
493
+ return failure ();
494
+
495
+ VectorType origTy = op.getType ();
496
+ VectorType newTy =
497
+ origTy.cloneWith (origTy.getShape (), ext->getInElementType ());
498
+ Value newTranspose = rewriter.create <vector::FlatTransposeOp>(
499
+ op.getLoc (), newTy, ext->getIn (), op.getRowsAttr (),
500
+ op.getColumnsAttr ());
501
+ ext->recreateAndReplace (rewriter, op, newTranspose);
502
+ return success ();
503
+ }
504
+ };
505
+
424
506
// ===----------------------------------------------------------------------===//
425
507
// Pass Definitions
426
508
// ===----------------------------------------------------------------------===//
@@ -449,9 +531,11 @@ void populateArithIntNarrowingPatterns(
449
531
RewritePatternSet &patterns, const ArithIntNarrowingOptions &options) {
450
532
// Add commute patterns with a higher benefit. This is to expose more
451
533
// optimization opportunities to narrowing patterns.
452
- patterns.add <ExtensionOverExtract, ExtensionOverExtractElement,
453
- ExtensionOverExtractStridedSlice, ExtensionOverInsert,
454
- ExtensionOverInsertElement, ExtensionOverInsertStridedSlice>(
534
+ patterns.add <ExtensionOverBroadcast, ExtensionOverExtract,
535
+ ExtensionOverExtractElement, ExtensionOverExtractStridedSlice,
536
+ ExtensionOverInsert, ExtensionOverInsertElement,
537
+ ExtensionOverInsertStridedSlice, ExtensionOverShapeCast,
538
+ ExtensionOverTranspose, ExtensionOverFlatTranspose>(
455
539
patterns.getContext (), options, PatternBenefit (2 ));
456
540
457
541
patterns.add <SIToFPPattern, UIToFPPattern>(patterns.getContext (), options);
0 commit comments