@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
576
576
}
577
577
};
578
578
579
+ // / Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
580
+ // /
581
+ // / Example:
582
+ // / ```
583
+ // / %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
584
+ // / ```
585
+ // / Becomes:
586
+ // / ```
587
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
588
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
589
+ // / %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
590
+ // / ```
591
+ struct VectorExtractToArmSMELowering
592
+ : public OpRewritePattern<vector::ExtractOp> {
593
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
594
+
595
+ LogicalResult matchAndRewrite (vector::ExtractOp extractOp,
596
+ PatternRewriter &rewriter) const override {
597
+ VectorType sourceType = extractOp.getSourceVectorType ();
598
+ if (!arm_sme::isValidSMETileVectorType (sourceType))
599
+ return failure ();
600
+
601
+ auto loc = extractOp.getLoc ();
602
+ auto position = extractOp.getMixedPosition ();
603
+
604
+ Value sourceVector = extractOp.getVector ();
605
+
606
+ // Extract entire vector. Should be handled by folder, but just to be safe.
607
+ if (position.empty ()) {
608
+ rewriter.replaceOp (extractOp, sourceVector);
609
+ return success ();
610
+ }
611
+
612
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
613
+ auto moveTileSliceToVector =
614
+ rewriter.create <arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
615
+ sliceIndex);
616
+
617
+ if (position.size () == 1 ) {
618
+ // Single index case: Extracts a 1D slice.
619
+ rewriter.replaceOp (extractOp, moveTileSliceToVector);
620
+ return success ();
621
+ }
622
+
623
+ // Two indices case: Extracts a single element.
624
+ assert (position.size () == 2 );
625
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
626
+ extractOp, moveTileSliceToVector, position[1 ]);
627
+
628
+ return success ();
629
+ }
630
+ };
631
+
632
+ // / Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
633
+ // / `arm_sme.move_tile_slice_to_vector`.
634
+ // /
635
+ // / Example:
636
+ // / ```
637
+ // / %new_tile = vector.insert %el, %tile[%row, %col]
638
+ // / : i32 into vector<[4]x[4]xi32>
639
+ // / ```
640
+ // / Becomes:
641
+ // / ```
642
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
643
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
644
+ // / %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
645
+ // / %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
646
+ // / : vector<[4]xi32> into vector<[4]x[4]xi32>
647
+ // / ```
648
+ struct VectorInsertToArmSMELowering
649
+ : public OpRewritePattern<vector::InsertOp> {
650
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
651
+
652
+ LogicalResult matchAndRewrite (vector::InsertOp insertOp,
653
+ PatternRewriter &rewriter) const override {
654
+ VectorType resultType = insertOp.getResult ().getType ();
655
+
656
+ if (!arm_sme::isValidSMETileVectorType (resultType))
657
+ return failure ();
658
+
659
+ auto loc = insertOp.getLoc ();
660
+ auto position = insertOp.getMixedPosition ();
661
+
662
+ Value source = insertOp.getSource ();
663
+
664
+ // Overwrite entire vector with value. Should be handled by folder, but
665
+ // just to be safe.
666
+ if (position.empty ()) {
667
+ rewriter.replaceOp (insertOp, source);
668
+ return success ();
669
+ }
670
+
671
+ Value tileSlice = source;
672
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
673
+ if (position.size () == 2 ) {
674
+ // Two indices case: Insert single element into tile.
675
+ // We need to first extract the existing slice and update the element.
676
+ tileSlice = rewriter.create <arm_sme::MoveTileSliceToVectorOp>(
677
+ loc, insertOp.getDest (), sliceIndex);
678
+ tileSlice = rewriter.create <vector::InsertOp>(loc, source, tileSlice,
679
+ position[1 ]);
680
+ }
681
+
682
+ // Insert the slice into the destination tile.
683
+ rewriter.replaceOpWithNewOp <arm_sme::MoveVectorToTileSliceOp>(
684
+ insertOp, tileSlice, insertOp.getDest (), sliceIndex);
685
+ return success ();
686
+ }
687
+ };
688
+
579
689
} // namespace
580
690
581
691
void mlir::populateVectorToArmSMEPatterns (RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
584
694
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
585
695
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
586
696
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
587
- VectorOuterProductToArmSMELowering>(&ctx);
697
+ VectorOuterProductToArmSMELowering,
698
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
699
+ &ctx);
588
700
}
0 commit comments