Skip to content

Commit c4c52d4

Browse files
authored
[mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC) (#72852)
These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.
1 parent 28b5054 commit c4c52d4

File tree

2 files changed

+114
-115
lines changed

2 files changed

+114
-115
lines changed

mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp

Lines changed: 113 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,116 @@ struct VectorOuterProductToArmSMELowering
573573
}
574574
};
575575

576+
/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
577+
///
578+
/// Example:
579+
/// ```
580+
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
581+
/// ```
582+
/// Becomes:
583+
/// ```
584+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
585+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
586+
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
587+
/// ```
588+
struct VectorExtractToArmSMELowering
589+
: public OpRewritePattern<vector::ExtractOp> {
590+
using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
591+
592+
LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
593+
PatternRewriter &rewriter) const override {
594+
VectorType sourceType = extractOp.getSourceVectorType();
595+
if (!arm_sme::isValidSMETileVectorType(sourceType))
596+
return failure();
597+
598+
auto loc = extractOp.getLoc();
599+
auto position = extractOp.getMixedPosition();
600+
601+
Value sourceVector = extractOp.getVector();
602+
603+
// Extract entire vector. Should be handled by folder, but just to be safe.
604+
if (position.empty()) {
605+
rewriter.replaceOp(extractOp, sourceVector);
606+
return success();
607+
}
608+
609+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
610+
auto moveTileSliceToVector =
611+
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
612+
sliceIndex);
613+
614+
if (position.size() == 1) {
615+
// Single index case: Extracts a 1D slice.
616+
rewriter.replaceOp(extractOp, moveTileSliceToVector);
617+
return success();
618+
}
619+
620+
// Two indices case: Extracts a single element.
621+
assert(position.size() == 2);
622+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
623+
extractOp, moveTileSliceToVector, position[1]);
624+
625+
return success();
626+
}
627+
};
628+
629+
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
630+
/// `arm_sme.move_tile_slice_to_vector`.
631+
///
632+
/// Example:
633+
/// ```
634+
/// %new_tile = vector.insert %el, %tile[%row, %col]
635+
/// : i32 into vector<[4]x[4]xi32>
636+
/// ```
637+
/// Becomes:
638+
/// ```
639+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
640+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
641+
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
642+
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
643+
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
644+
/// ```
645+
struct VectorInsertToArmSMELowering
646+
: public OpRewritePattern<vector::InsertOp> {
647+
using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
648+
649+
LogicalResult matchAndRewrite(vector::InsertOp insertOp,
650+
PatternRewriter &rewriter) const override {
651+
VectorType resultType = insertOp.getResult().getType();
652+
653+
if (!arm_sme::isValidSMETileVectorType(resultType))
654+
return failure();
655+
656+
auto loc = insertOp.getLoc();
657+
auto position = insertOp.getMixedPosition();
658+
659+
Value source = insertOp.getSource();
660+
661+
// Overwrite entire vector with value. Should be handled by folder, but
662+
// just to be safe.
663+
if (position.empty()) {
664+
rewriter.replaceOp(insertOp, source);
665+
return success();
666+
}
667+
668+
Value tileSlice = source;
669+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
670+
if (position.size() == 2) {
671+
// Two indices case: Insert single element into tile.
672+
// We need to first extract the existing slice and update the element.
673+
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
674+
loc, insertOp.getDest(), sliceIndex);
675+
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
676+
position[1]);
677+
}
678+
679+
// Insert the slice into the destination tile.
680+
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
681+
insertOp, tileSlice, insertOp.getDest(), sliceIndex);
682+
return success();
683+
}
684+
};
685+
576686
} // namespace
577687

578688
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -581,5 +691,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
581691
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
582692
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
583693
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
584-
VectorOuterProductToArmSMELowering>(&ctx);
694+
VectorOuterProductToArmSMELowering,
695+
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
696+
&ctx);
585697
}

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 1 addition & 114 deletions
Original file line numberDiff line numberDiff line change
@@ -520,118 +520,6 @@ struct OuterProductOpConversion
520520
}
521521
};
522522

523-
/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
524-
///
525-
/// Example:
526-
/// ```
527-
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
528-
/// ```
529-
/// Becomes:
530-
/// ```
531-
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
532-
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
533-
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
534-
/// ```
535-
struct VectorExtractToArmSMELowering
536-
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
537-
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
538-
539-
LogicalResult
540-
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
541-
ConversionPatternRewriter &rewriter) const override {
542-
VectorType sourceType = extractOp.getSourceVectorType();
543-
if (!isValidSMETileVectorType(sourceType))
544-
return failure();
545-
546-
auto loc = extractOp.getLoc();
547-
auto position = extractOp.getMixedPosition();
548-
549-
Value sourceVector = extractOp.getVector();
550-
551-
// Extract entire vector. Should be handled by folder, but just to be safe.
552-
if (position.empty()) {
553-
rewriter.replaceOp(extractOp, sourceVector);
554-
return success();
555-
}
556-
557-
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
558-
auto moveTileSliceToVector =
559-
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
560-
sliceIndex);
561-
562-
if (position.size() == 1) {
563-
// Single index case: Extracts a 1D slice.
564-
rewriter.replaceOp(extractOp, moveTileSliceToVector);
565-
return success();
566-
}
567-
568-
// Two indices case: Extracts a single element.
569-
assert(position.size() == 2);
570-
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
571-
extractOp, moveTileSliceToVector, position[1]);
572-
573-
return success();
574-
}
575-
};
576-
577-
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
578-
/// `arm_sme.move_tile_slice_to_vector`.
579-
///
580-
/// Example:
581-
/// ```
582-
/// %new_tile = vector.insert %el, %tile[%row, %col]
583-
/// : i32 into vector<[4]x[4]xi32>
584-
/// ```
585-
/// Becomes:
586-
/// ```
587-
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
588-
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
589-
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
590-
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
591-
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
592-
/// ```
593-
struct VectorInsertToArmSMELowering
594-
: public ConvertOpToLLVMPattern<vector::InsertOp> {
595-
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
596-
597-
LogicalResult
598-
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
599-
ConversionPatternRewriter &rewriter) const override {
600-
VectorType resultType = insertOp.getResult().getType();
601-
602-
if (!isValidSMETileVectorType(resultType))
603-
return failure();
604-
605-
auto loc = insertOp.getLoc();
606-
auto position = insertOp.getMixedPosition();
607-
608-
Value source = adaptor.getSource();
609-
610-
// Overwrite entire vector with value. Should be handled by folder, but
611-
// just to be safe.
612-
if (position.empty()) {
613-
rewriter.replaceOp(insertOp, source);
614-
return success();
615-
}
616-
617-
Value tileSlice = source;
618-
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
619-
if (position.size() == 2) {
620-
// Two indices case: Insert single element into tile.
621-
// We need to first extract the existing slice and update the element.
622-
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
623-
loc, adaptor.getDest(), sliceIndex);
624-
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
625-
position[1]);
626-
}
627-
628-
// Insert the slice into the destination tile.
629-
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
630-
insertOp, tileSlice, adaptor.getDest(), sliceIndex);
631-
return success();
632-
}
633-
};
634-
635523
} // namespace
636524

637525
void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +549,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
661549
patterns.add<
662550
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
663551
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
664-
OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
665-
VectorInsertToArmSMELowering>(converter);
552+
OuterProductOpConversion, ZeroOpConversion>(converter);
666553
}

0 commit comments

Comments
 (0)