@@ -415,6 +415,105 @@ struct FoldExtractFromVectorOfSMELikeCreateMasks
415
415
}
416
416
};
417
417
418
+ // Shuffles arith extend ops after vector.extract op.
419
+ //
420
+ // This transforms IR like:
421
+ // %0 = arith.extsi %src : vector<4x[8]xi8> to vector<4x[8]xi32>
422
+ // %1 = vector.extract %0[0] : vector<[8]xi32> from vector<4x[8]xi32>
423
+ // Into:
424
+ // %0 = vector.extract %src[0] : vector<[8]xi8> from vector<4x[8]xi8>
425
+ // %1 = arith.extsi %0 : vector<[8]xi8> to vector<[8]xi32>
426
+ //
427
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
428
+ // pass when the result is the input to an outer product.
429
+ struct SwapVectorExtractOfArithExtend
430
+ : public OpRewritePattern<vector::ExtractOp> {
431
+ using OpRewritePattern::OpRewritePattern;
432
+
433
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
434
+ PatternRewriter &rewriter) const override {
435
+ VectorType resultType = llvm::dyn_cast<VectorType>(extractOp.getType ());
436
+ if (!resultType)
437
+ return rewriter.notifyMatchFailure (extractOp,
438
+ " extracted type is not a vector type" );
439
+
440
+ auto numScalableDims = llvm::count (resultType.getScalableDims (), true );
441
+ if (numScalableDims != 1 )
442
+ return rewriter.notifyMatchFailure (
443
+ extractOp, " extracted type is not a 1-D scalable vector type" );
444
+
445
+ auto *extendOp = extractOp.getVector ().getDefiningOp ();
446
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
447
+ extendOp))
448
+ return rewriter.notifyMatchFailure (extractOp,
449
+ " extract not from extend op" );
450
+
451
+ auto loc = extractOp.getLoc ();
452
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
453
+ Value extendSource = extendOp->getOperand (0 );
454
+
455
+ // Create new extract from source of extend.
456
+ Value newExtract = rewriter.create <vector::ExtractOp>(
457
+ loc, extendSource, extractOp.getMixedPosition ());
458
+
459
+ // Extend new extract to original result type.
460
+ Operation *newExtend =
461
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
462
+
463
+ rewriter.replaceOp (extractOp, newExtend->getResult (0 ));
464
+
465
+ return success ();
466
+ }
467
+ };
468
+
469
+ // Shuffles arith extend ops after vector.scalable.extract op.
470
+ //
471
+ // This transforms IR like:
472
+ // %0 = arith.extsi %src : vector<[8]xi8> to vector<[8]xi32>
473
+ // %1 = vector.scalable.extract %0[0] : vector<[4]xi32> from vector<[8]xi32>
474
+ // Into:
475
+ // %0 = vector.scalable.extract %src[0] : vector<[4]xi8> from vector<[8]xi8>
476
+ // %1 = arith.extsi %0 : vector<[4]xi8> to vector<[4]xi32>
477
+ //
478
+ // This enables outer product fusion in the `-arm-sme-outer-product-fusion`
479
+ // pass when the result is the input to an outer product.
480
+ struct SwapVectorScalableExtractOfArithExtend
481
+ : public OpRewritePattern<vector::ScalableExtractOp> {
482
+ using OpRewritePattern::OpRewritePattern;
483
+
484
+ LogicalResult matchAndRewrite (vector::ScalableExtractOp extractOp,
485
+ PatternRewriter &rewriter) const override {
486
+ auto *extendOp = extractOp.getSource ().getDefiningOp ();
487
+ if (!isa_and_present<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp>(
488
+ extendOp))
489
+ return rewriter.notifyMatchFailure (extractOp,
490
+ " extract not from extend op" );
491
+
492
+ auto loc = extractOp.getLoc ();
493
+ VectorType resultType = extractOp.getResultVectorType ();
494
+
495
+ Value extendSource = extendOp->getOperand (0 );
496
+ StringAttr extendOpName = extendOp->getName ().getIdentifier ();
497
+ VectorType extendSourceVectorType =
498
+ cast<VectorType>(extendSource.getType ());
499
+
500
+ // Create new extract from source of extend.
501
+ VectorType extractResultVectorType =
502
+ VectorType::Builder (resultType)
503
+ .setElementType (extendSourceVectorType.getElementType ());
504
+ Value newExtract = rewriter.create <vector::ScalableExtractOp>(
505
+ loc, extractResultVectorType, extendSource, extractOp.getPos ());
506
+
507
+ // Extend new extract to original result type.
508
+ Operation *newExtend =
509
+ rewriter.create (loc, extendOpName, Value (newExtract), resultType);
510
+
511
+ rewriter.replaceOp (extractOp, newExtend->getResult (0 ));
512
+
513
+ return success ();
514
+ }
515
+ };
516
+
418
517
struct VectorLegalizationPass
419
518
: public arm_sme::impl::VectorLegalizationBase<VectorLegalizationPass> {
420
519
void runOnOperation () override {
@@ -434,7 +533,10 @@ struct VectorLegalizationPass
434
533
return success ();
435
534
});
436
535
437
- patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks>(context);
536
+ patterns.add <FoldExtractFromVectorOfSMELikeCreateMasks,
537
+ SwapVectorExtractOfArithExtend,
538
+ SwapVectorScalableExtractOfArithExtend>(context);
539
+
438
540
// Note: High benefit to ensure masked outer products are lowered first.
439
541
patterns.add <LegalizeMaskedVectorOuterProductOpsByDecomposition>(
440
542
converter, context, 1024 );
0 commit comments