-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC) #72852
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThese were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics. Full diff: https://github.com/llvm/llvm-project/pull/72852.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..420d2b6b1c08786 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
}
};
+/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+ : public OpRewritePattern<vector::ExtractOp> {
+ using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+ PatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ if (!arm_sme::isValidSMETileVectorType(sourceType))
+ return failure();
+
+ auto loc = extractOp.getLoc();
+ auto position = extractOp.getMixedPosition();
+
+ Value sourceVector = extractOp.getVector();
+
+ // Extract entire vector. Should be handled by folder, but just to be safe.
+ if (position.empty()) {
+ rewriter.replaceOp(extractOp, sourceVector);
+ return success();
+ }
+
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+ auto moveTileSliceToVector =
+ rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+ sliceIndex);
+
+ if (position.size() == 1) {
+ // Single index case: Extracts a 1D slice.
+ rewriter.replaceOp(extractOp, moveTileSliceToVector);
+ return success();
+ }
+
+ // Two indices case: Extracts a single element.
+ assert(position.size() == 2);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, moveTileSliceToVector, position[1]);
+
+ return success();
+ }
+};
+
+/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
+/// `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%row, %col]
+/// : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+ : public OpRewritePattern<vector::InsertOp> {
+ using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+ PatternRewriter &rewriter) const override {
+ VectorType resultType = insertOp.getResult().getType();
+
+ if (!arm_sme::isValidSMETileVectorType(resultType))
+ return failure();
+
+ auto loc = insertOp.getLoc();
+ auto position = insertOp.getMixedPosition();
+
+ Value source = insertOp.getSource();
+
+ // Overwrite entire vector with value. Should be handled by folder, but
+ // just to be safe.
+ if (position.empty()) {
+ rewriter.replaceOp(insertOp, source);
+ return success();
+ }
+
+ Value tileSlice = source;
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+ if (position.size() == 2) {
+ // Two indices case: Insert single element into tile.
+ // We need to first extract the existing slice and update the element.
+ tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, insertOp.getDest(), sliceIndex);
+ tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+ position[1]);
+ }
+
+ // Insert the slice into the destination tile.
+ rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+ insertOp, tileSlice, insertOp.getDest(), sliceIndex);
+ return success();
+ }
+};
+
} // namespace
void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
SplatOpToArmSMELowering, TransferReadToArmSMELowering,
TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
- VectorOuterProductToArmSMELowering>(&ctx);
+ VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
+ &ctx);
}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6078b3f2c5e4708..041a4897a836503 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -456,7 +456,8 @@ struct OuterProductOpConversion
// * half-precision - +sme2p1,+b16b16
//
// It should be possible to control lowering based on target features.
- // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+ // [1]
+ // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
return false;
@@ -520,118 +521,6 @@ struct OuterProductOpConversion
}
};
-/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
-/// ```
-struct VectorExtractToArmSMELowering
- : public ConvertOpToLLVMPattern<vector::ExtractOp> {
- using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- VectorType sourceType = extractOp.getSourceVectorType();
- if (!isValidSMETileVectorType(sourceType))
- return failure();
-
- auto loc = extractOp.getLoc();
- auto position = extractOp.getMixedPosition();
-
- Value sourceVector = extractOp.getVector();
-
- // Extract entire vector. Should be handled by folder, but just to be safe.
- if (position.empty()) {
- rewriter.replaceOp(extractOp, sourceVector);
- return success();
- }
-
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- auto moveTileSliceToVector =
- rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
- sliceIndex);
-
- if (position.size() == 1) {
- // Single index case: Extracts a 1D slice.
- rewriter.replaceOp(extractOp, moveTileSliceToVector);
- return success();
- }
-
- // Two indices case: Extracts a single element.
- assert(position.size() == 2);
- rewriter.replaceOpWithNewOp<vector::ExtractOp>(
- extractOp, moveTileSliceToVector, position[1]);
-
- return success();
- }
-};
-
-/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
-/// `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %new_tile = vector.insert %el, %tile[%row, %col]
-/// : i32 into vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-/// : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
-/// : vector<[4]xi32> into vector<[4]x[4]xi32>
-/// ```
-struct VectorInsertToArmSMELowering
- : public ConvertOpToLLVMPattern<vector::InsertOp> {
- using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
-
- LogicalResult
- matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
- ConversionPatternRewriter &rewriter) const override {
- VectorType resultType = insertOp.getResult().getType();
-
- if (!isValidSMETileVectorType(resultType))
- return failure();
-
- auto loc = insertOp.getLoc();
- auto position = insertOp.getMixedPosition();
-
- Value source = adaptor.getSource();
-
- // Overwrite entire vector with value. Should be handled by folder, but
- // just to be safe.
- if (position.empty()) {
- rewriter.replaceOp(insertOp, source);
- return success();
- }
-
- Value tileSlice = source;
- Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
- if (position.size() == 2) {
- // Two indices case: Insert single element into tile.
- // We need to first extract the existing slice and update the element.
- tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
- loc, adaptor.getDest(), sliceIndex);
- tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
- position[1]);
- }
-
- // Insert the slice into the destination tile.
- rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
- insertOp, tileSlice, adaptor.getDest(), sliceIndex);
- return success();
- }
-};
-
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +550,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
patterns.add<
LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
- OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
- VectorInsertToArmSMELowering>(converter);
+ OuterProductOpConversion, ZeroOpConversion>(converter);
}
|
…sme (NFC) These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
…sme (NFC) (llvm#72852) These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.
These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.