@@ -550,6 +550,118 @@ struct VectorOuterProductToArmSMELowering
550
550
}
551
551
};
552
552
553
+ // / Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
554
+ // /
555
+ // / Example:
556
+ // / ```
557
+ // / %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
558
+ // / ```
559
+ // / Becomes:
560
+ // / ```
561
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
562
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
563
+ // / %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
564
+ // / ```
565
+ struct VectorExtractToArmSMELowering
566
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
567
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
568
+
569
+ LogicalResult
570
+ matchAndRewrite (vector::ExtractOp extractOp, OpAdaptor adaptor,
571
+ ConversionPatternRewriter &rewriter) const override {
572
+ VectorType sourceType = extractOp.getSourceVectorType ();
573
+ if (!isValidSMETileVectorType (sourceType))
574
+ return failure ();
575
+
576
+ auto loc = extractOp.getLoc ();
577
+ auto position = extractOp.getMixedPosition ();
578
+
579
+ Value sourceVector = extractOp.getVector ();
580
+
581
+ // Extract entire vector. Should be handled by folder, but just to be safe.
582
+ if (position.empty ()) {
583
+ rewriter.replaceOp (extractOp, sourceVector);
584
+ return success ();
585
+ }
586
+
587
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
588
+ auto moveTileSliceToVector =
589
+ rewriter.create <arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
590
+ sliceIndex);
591
+
592
+ if (position.size () == 1 ) {
593
+ // Single index case: Extracts a 1D slice.
594
+ rewriter.replaceOp (extractOp, moveTileSliceToVector);
595
+ return success ();
596
+ }
597
+
598
+ // Two indices case: Extracts a single element.
599
+ assert (position.size () == 2 );
600
+ rewriter.replaceOpWithNewOp <vector::ExtractOp>(
601
+ extractOp, moveTileSliceToVector, position[1 ]);
602
+
603
+ return success ();
604
+ }
605
+ };
606
+
607
+ // / Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
608
+ // / `arm_sme.move_tile_slice_to_vector`.
609
+ // /
610
+ // / Example:
611
+ // / ```
612
+ // / %new_tile = vector.insert %el, %tile[%row, %col]
613
+ // / : i32 into vector<[4]x[4]xi32>
614
+ // / ```
615
+ // / Becomes:
616
+ // / ```
617
+ // / %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
618
+ // / : vector<[4]xi32> from vector<[4]x[4]xi32>
619
+ // / %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
620
+ // / %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
621
+ // / : vector<[4]xi32> into vector<[4]x[4]xi32>
622
+ // / ```
623
+ struct VectorInsertToArmSMELowering
624
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
625
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
626
+
627
+ LogicalResult
628
+ matchAndRewrite (vector::InsertOp insertOp, OpAdaptor adaptor,
629
+ ConversionPatternRewriter &rewriter) const override {
630
+ VectorType resultType = insertOp.getResult ().getType ();
631
+
632
+ if (!isValidSMETileVectorType (resultType))
633
+ return failure ();
634
+
635
+ auto loc = insertOp.getLoc ();
636
+ auto position = insertOp.getMixedPosition ();
637
+
638
+ Value source = adaptor.getSource ();
639
+
640
+ // Overwrite entire vector with value. Should be handled by folder, but
641
+ // just to be safe.
642
+ if (position.empty ()) {
643
+ rewriter.replaceOp (insertOp, source);
644
+ return success ();
645
+ }
646
+
647
+ Value tileSlice = source;
648
+ Value sliceIndex = vector::getAsValues (rewriter, loc, position[0 ]).front ();
649
+ if (position.size () == 2 ) {
650
+ // Two indices case: Insert single element into tile.
651
+ // We need to first extract the existing slice and update the element.
652
+ tileSlice = rewriter.create <arm_sme::MoveTileSliceToVectorOp>(
653
+ loc, adaptor.getDest (), sliceIndex);
654
+ tileSlice = rewriter.create <vector::InsertOp>(loc, source, tileSlice,
655
+ position[1 ]);
656
+ }
657
+
658
+ // Insert the slice into the destination tile.
659
+ rewriter.replaceOpWithNewOp <arm_sme::MoveVectorToTileSliceOp>(
660
+ insertOp, tileSlice, adaptor.getDest (), sliceIndex);
661
+ return success ();
662
+ }
663
+ };
664
+
553
665
} // namespace
554
666
555
667
void mlir::configureArmSMELegalizeForExportTarget (
@@ -604,5 +716,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
604
716
patterns.add <
605
717
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
606
718
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
607
- VectorOuterProductToArmSMELowering, ZeroOpConversion>(converter);
719
+ VectorOuterProductToArmSMELowering, ZeroOpConversion,
720
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
608
721
}
0 commit comments