@@ -573,6 +573,116 @@ struct VectorOuterProductToArmSMELowering
573
573
}
574
574
};
575
575
576
+ // / Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
577
+ // /
578
+ // / Example:
579
+ // / ```
580
+ // / %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
581
+ // / ```
582
+ // / Becomes:
583
+ // / ```
584
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
585
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
586
+ // / %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
587
+ // / ```
588
+ struct VectorExtractToArmSMELowering
589
+ : public OpRewritePattern<vector::ExtractOp> {
590
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
591
+
592
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
593
+ PatternRewriter &rewriter) const override {
594
+ VectorType sourceType = extractOp.getSourceVectorType ();
595
+ if (!arm_sme::isValidSMETileVectorType (sourceType))
596
+ return failure ();
597
+
598
+ auto loc = extractOp.getLoc ();
599
+ auto position = extractOp.getMixedPosition ();
600
+
601
+ Value sourceVector = extractOp.getVector ();
602
+
603
+ // Extract entire vector. Should be handled by folder, but just to be safe.
604
+ if (position.empty ()) {
605
+ rewriter.replaceOp (extractOp, sourceVector);
606
+ return success ();
607
+ }
608
+
609
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
610
+ auto moveTileSliceToVector =
611
+ rewriter.create <arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
612
+ sliceIndex);
613
+
614
+ if (position.size () == 1 ) {
615
+ // Single index case: Extracts a 1D slice.
616
+ rewriter.replaceOp (extractOp, moveTileSliceToVector);
617
+ return success ();
618
+ }
619
+
620
+ // Two indices case: Extracts a single element.
621
+ assert (position.size () == 2 );
622
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
623
+ extractOp, moveTileSliceToVector, position[1 ]);
624
+
625
+ return success ();
626
+ }
627
+ };
628
+
629
+ // / Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
630
+ // / `arm_sme.move_tile_slice_to_vector`.
631
+ // /
632
+ // / Example:
633
+ // / ```
634
+ // / %new_tile = vector.insert %el, %tile[%row, %col]
635
+ // / : i32 into vector<[4]x[4]xi32>
636
+ // / ```
637
+ // / Becomes:
638
+ // / ```
639
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
640
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
641
+ // / %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
642
+ // / %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
643
+ // / : vector<[4]xi32> into vector<[4]x[4]xi32>
644
+ // / ```
645
+ struct VectorInsertToArmSMELowering
646
+ : public OpRewritePattern<vector::InsertOp> {
647
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
648
+
649
+ LogicalResult matchAndRewrite (vector::InsertOp insertOp,
650
+ PatternRewriter &rewriter) const override {
651
+ VectorType resultType = insertOp.getResult ().getType ();
652
+
653
+ if (!arm_sme::isValidSMETileVectorType (resultType))
654
+ return failure ();
655
+
656
+ auto loc = insertOp.getLoc ();
657
+ auto position = insertOp.getMixedPosition ();
658
+
659
+ Value source = insertOp.getSource ();
660
+
661
+ // Overwrite entire vector with value. Should be handled by folder, but
662
+ // just to be safe.
663
+ if (position.empty ()) {
664
+ rewriter.replaceOp (insertOp, source);
665
+ return success ();
666
+ }
667
+
668
+ Value tileSlice = source;
669
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
670
+ if (position.size () == 2 ) {
671
+ // Two indices case: Insert single element into tile.
672
+ // We need to first extract the existing slice and update the element.
673
+ tileSlice = rewriter.create <arm_sme::MoveTileSliceToVectorOp>(
674
+ loc, insertOp.getDest (), sliceIndex);
675
+ tileSlice = rewriter.create <vector::InsertOp>(loc, source, tileSlice,
676
+ position[1 ]);
677
+ }
678
+
679
+ // Insert the slice into the destination tile.
680
+ rewriter.replaceOpWithNewOp <arm_sme::MoveVectorToTileSliceOp>(
681
+ insertOp, tileSlice, insertOp.getDest (), sliceIndex);
682
+ return success ();
683
+ }
684
+ };
685
+
576
686
} // namespace
577
687
578
688
void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
@@ -581,5 +691,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
581
691
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
582
692
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
583
693
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
584
- VectorOuterProductToArmSMELowering>(&ctx);
694
+ VectorOuterProductToArmSMELowering,
695
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
696
+ &ctx);
585
697
}
0 commit comments