Skip to content

Commit 496318a

Browse files
authored
[mlir][ArmSME] Lower vector.extract/insert on SME tiles to MOVA intrinsics (llvm#67786)
This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsics. This enables the following operations for ArmSME: ``` // Extract slice from tile: %slice = vector.extract %tile[%row] : vector<[4]xi32> from vector<[4]x[4]xi32> ``` ``` // Extract element from tile: %el = vector.extract %tile[%row, %col] : i32 from vector<[4]x[4]xi32> ``` ``` // Insert slice into tile: %new_tile = vector.insert %slice, %tile[%row] : vector<[4]xi32> into vector<[4]x[4]xi32> ``` ``` // Insert element into tile; %new_tile = vector.insert %el, %tile[%row, %col] : i32 into vector<[4]x[4]xi32> ```
1 parent 503bc5f commit 496318a

File tree

2 files changed

+495
-1
lines changed

2 files changed

+495
-1
lines changed

mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 114 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,6 +550,118 @@ struct VectorOuterProductToArmSMELowering
550550
}
551551
};
552552

553+
/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
554+
///
555+
/// Example:
556+
/// ```
557+
/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
558+
/// ```
559+
/// Becomes:
560+
/// ```
561+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
562+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
563+
/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
564+
/// ```
565+
struct VectorExtractToArmSMELowering
566+
: public ConvertOpToLLVMPattern<vector::ExtractOp> {
567+
using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
568+
569+
LogicalResult
570+
matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
571+
ConversionPatternRewriter &rewriter) const override {
572+
VectorType sourceType = extractOp.getSourceVectorType();
573+
if (!isValidSMETileVectorType(sourceType))
574+
return failure();
575+
576+
auto loc = extractOp.getLoc();
577+
auto position = extractOp.getMixedPosition();
578+
579+
Value sourceVector = extractOp.getVector();
580+
581+
// Extract entire vector. Should be handled by folder, but just to be safe.
582+
if (position.empty()) {
583+
rewriter.replaceOp(extractOp, sourceVector);
584+
return success();
585+
}
586+
587+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
588+
auto moveTileSliceToVector =
589+
rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
590+
sliceIndex);
591+
592+
if (position.size() == 1) {
593+
// Single index case: Extracts a 1D slice.
594+
rewriter.replaceOp(extractOp, moveTileSliceToVector);
595+
return success();
596+
}
597+
598+
// Two indices case: Extracts a single element.
599+
assert(position.size() == 2);
600+
rewriter.replaceOpWithNewOp<vector::ExtractOp>(
601+
extractOp, moveTileSliceToVector, position[1]);
602+
603+
return success();
604+
}
605+
};
606+
607+
/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
608+
/// `arm_sme.move_tile_slice_to_vector`.
609+
///
610+
/// Example:
611+
/// ```
612+
/// %new_tile = vector.insert %el, %tile[%row, %col]
613+
/// : i32 into vector<[4]x[4]xi32>
614+
/// ```
615+
/// Becomes:
616+
/// ```
617+
/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
618+
/// : vector<[4]xi32> from vector<[4]x[4]xi32>
619+
/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
620+
/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
621+
/// : vector<[4]xi32> into vector<[4]x[4]xi32>
622+
/// ```
623+
struct VectorInsertToArmSMELowering
624+
: public ConvertOpToLLVMPattern<vector::InsertOp> {
625+
using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
626+
627+
LogicalResult
628+
matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
629+
ConversionPatternRewriter &rewriter) const override {
630+
VectorType resultType = insertOp.getResult().getType();
631+
632+
if (!isValidSMETileVectorType(resultType))
633+
return failure();
634+
635+
auto loc = insertOp.getLoc();
636+
auto position = insertOp.getMixedPosition();
637+
638+
Value source = adaptor.getSource();
639+
640+
// Overwrite entire vector with value. Should be handled by folder, but
641+
// just to be safe.
642+
if (position.empty()) {
643+
rewriter.replaceOp(insertOp, source);
644+
return success();
645+
}
646+
647+
Value tileSlice = source;
648+
Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
649+
if (position.size() == 2) {
650+
// Two indices case: Insert single element into tile.
651+
// We need to first extract the existing slice and update the element.
652+
tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
653+
loc, adaptor.getDest(), sliceIndex);
654+
tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
655+
position[1]);
656+
}
657+
658+
// Insert the slice into the destination tile.
659+
rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
660+
insertOp, tileSlice, adaptor.getDest(), sliceIndex);
661+
return success();
662+
}
663+
};
664+
553665
} // namespace
554666

555667
void mlir::configureArmSMELegalizeForExportTarget(
@@ -604,5 +716,6 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
604716
patterns.add<
605717
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
606718
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
607-
VectorOuterProductToArmSMELowering, ZeroOpConversion>(converter);
719+
VectorOuterProductToArmSMELowering, ZeroOpConversion,
720+
VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
608721
}

0 commit comments

Comments
 (0)