Skip to content

Commit bba9290

Browse files
committed
[mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC)
These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.
1 parent 43af73f commit bba9290

File tree

2 files changed

+114
-115
lines changed

2 files changed

+114
-115
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
576576
}
577577
};
578578

579+
/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
580+
///
581+
/// Example:
582+
/// ```
583+
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
584+
/// ```
585+
/// Becomes:
586+
/// ```
587+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
588+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
589+
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
590+
/// ```
591+
struct VectorExtractToArmSMELowering
592+
: public OpRewritePattern<vector::ExtractOp> {
593+
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
594+
595+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
596+
PatternRewriter &rewriter) const override {
597+
VectorType sourceType = extractOp.getSourceVectorType();
598+
if (!arm_sme::isValidSMETileVectorType(sourceType))
599+
return failure();
600+
601+
auto loc = extractOp.getLoc();
602+
auto position = extractOp.getMixedPosition();
603+
604+
Value sourceVector = extractOp.getVector();
605+
606+
// Extract entire vector. Should be handled by folder, but just to be safe.
607+
if (position.empty()) {
608+
rewriter.replaceOp(extractOp, sourceVector);
609+
return success();
610+
}
611+
612+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
613+
auto moveTileSliceToVector =
614+
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
615+
sliceIndex);
616+
617+
if (position.size() == 1) {
618+
// Single index case: Extracts a 1D slice.
619+
rewriter.replaceOp(extractOp, moveTileSliceToVector);
620+
return success();
621+
}
622+
623+
// Two indices case: Extracts a single element.
624+
assert(position.size() == 2);
625+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
626+
extractOp, moveTileSliceToVector, position[1]);
627+
628+
return success();
629+
}
630+
};
631+
632+
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
633+
/// `arm_sme.move_tile_slice_to_vector`.
634+
///
635+
/// Example:
636+
/// ```
637+
/// %new_tile = vector.insert %el, %tile[%row, %col]
638+
/// : i32 into vector<[4]x[4]xi32>
639+
/// ```
640+
/// Becomes:
641+
/// ```
642+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
643+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
644+
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
645+
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
646+
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
647+
/// ```
648+
struct VectorInsertToArmSMELowering
649+
: public OpRewritePattern<vector::InsertOp> {
650+
using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
651+
652+
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
653+
PatternRewriter &rewriter) const override {
654+
VectorType resultType = insertOp.getResult().getType();
655+
656+
if (!arm_sme::isValidSMETileVectorType(resultType))
657+
return failure();
658+
659+
auto loc = insertOp.getLoc();
660+
auto position = insertOp.getMixedPosition();
661+
662+
Value source = insertOp.getSource();
663+
664+
// Overwrite entire vector with value. Should be handled by folder, but
665+
// just to be safe.
666+
if (position.empty()) {
667+
rewriter.replaceOp(insertOp, source);
668+
return success();
669+
}
670+
671+
Value tileSlice = source;
672+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
673+
if (position.size() == 2) {
674+
// Two indices case: Insert single element into tile.
675+
// We need to first extract the existing slice and update the element.
676+
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
677+
loc, insertOp.getDest(), sliceIndex);
678+
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
679+
position[1]);
680+
}
681+
682+
// Insert the slice into the destination tile.
683+
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
684+
insertOp, tileSlice, insertOp.getDest(), sliceIndex);
685+
return success();
686+
}
687+
};
688+
579689
} // namespace
580690

581691
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
584694
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
585695
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
586696
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
587-
VectorOuterProductToArmSMELowering>(&ctx);
697+
VectorOuterProductToArmSMELowering,
698+
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
699+
&ctx);
588700
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -520,118 +520,6 @@ struct OuterProductOpConversion
520520
}
521521
};
522522

523-
/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
524-
///
525-
/// Example:
526-
/// ```
527-
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
528-
/// ```
529-
/// Becomes:
530-
/// ```
531-
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
532-
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
533-
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
534-
/// ```
535-
struct VectorExtractToArmSMELowering
536-
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
537-
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
538-
539-
LogicalResult
540-
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
541-
ConversionPatternRewriter &rewriter) const override {
542-
VectorType sourceType = extractOp.getSourceVectorType();
543-
if (!isValidSMETileVectorType(sourceType))
544-
return failure();
545-
546-
auto loc = extractOp.getLoc();
547-
auto position = extractOp.getMixedPosition();
548-
549-
Value sourceVector = extractOp.getVector();
550-
551-
// Extract entire vector. Should be handled by folder, but just to be safe.
552-
if (position.empty()) {
553-
rewriter.replaceOp(extractOp, sourceVector);
554-
return success();
555-
}
556-
557-
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
558-
auto moveTileSliceToVector =
559-
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
560-
sliceIndex);
561-
562-
if (position.size() == 1) {
563-
// Single index case: Extracts a 1D slice.
564-
rewriter.replaceOp(extractOp, moveTileSliceToVector);
565-
return success();
566-
}
567-
568-
// Two indices case: Extracts a single element.
569-
assert(position.size() == 2);
570-
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
571-
extractOp, moveTileSliceToVector, position[1]);
572-
573-
return success();
574-
}
575-
};
576-
577-
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
578-
/// `arm_sme.move_tile_slice_to_vector`.
579-
///
580-
/// Example:
581-
/// ```
582-
/// %new_tile = vector.insert %el, %tile[%row, %col]
583-
/// : i32 into vector<[4]x[4]xi32>
584-
/// ```
585-
/// Becomes:
586-
/// ```
587-
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
588-
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
589-
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
590-
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
591-
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
592-
/// ```
593-
struct VectorInsertToArmSMELowering
594-
: public ConvertOpToLLVMPattern<vector::InsertOp> {
595-
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
596-
597-
LogicalResult
598-
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
599-
ConversionPatternRewriter &rewriter) const override {
600-
VectorType resultType = insertOp.getResult().getType();
601-
602-
if (!isValidSMETileVectorType(resultType))
603-
return failure();
604-
605-
auto loc = insertOp.getLoc();
606-
auto position = insertOp.getMixedPosition();
607-
608-
Value source = adaptor.getSource();
609-
610-
// Overwrite entire vector with value. Should be handled by folder, but
611-
// just to be safe.
612-
if (position.empty()) {
613-
rewriter.replaceOp(insertOp, source);
614-
return success();
615-
}
616-
617-
Value tileSlice = source;
618-
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
619-
if (position.size() == 2) {
620-
// Two indices case: Insert single element into tile.
621-
// We need to first extract the existing slice and update the element.
622-
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
623-
loc, adaptor.getDest(), sliceIndex);
624-
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
625-
position[1]);
626-
}
627-
628-
// Insert the slice into the destination tile.
629-
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
630-
insertOp, tileSlice, adaptor.getDest(), sliceIndex);
631-
return success();
632-
}
633-
};
634-
635523
} // namespace
636524

637525
void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +549,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
661549
patterns.add<
662550
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
663551
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
664-
OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
665-
VectorInsertToArmSMELowering>(converter);
552+
OuterProductOpConversion, ZeroOpConversion>(converter);
666553
}

0 commit comments

Comments
 (0)