-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Lower vector.extract/insert on SME tiles to MOVA intrinsics #67786
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-sme ChangesThis patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsic. This enables the following operations for ArmSME:
Full diff: https://github.com/llvm/llvm-project/pull/67786.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 0322c2f3fcd14d4..edf9e333b0e4784 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -550,6 +550,113 @@ struct VectorOuterProductToArmSMELowering
}
};
+/// Lower `vector.extract` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%y,%x]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%x] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+ using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType sourceType = extractOp.getSourceVectorType();
+ if (!isValidSMETileVectorType(sourceType))
+ return failure();
+
+ auto loc = extractOp.getLoc();
+ auto position = extractOp.getMixedPosition();
+
+ Value sourceVector = extractOp.getVector();
+
+ if (position.empty()) {
+ rewriter.replaceOp(extractOp, sourceVector);
+ return success();
+ }
+
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ auto moveTileSliceToVector =
+ rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+ sliceIndex);
+
+ if (position.size() == 1) {
+ // Single index case: Extracts a 1D slice.
+ rewriter.replaceOp(extractOp, moveTileSliceToVector);
+ return success();
+ }
+
+ // Two indices case: Extracts a single element.
+ assert(position.size() == 2);
+ rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+ extractOp, moveTileSliceToVector, position[1]);
+
+ return success();
+ }
+};
+
+/// Lower `vector.insert` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+/// : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%x] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %y
+/// : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+ : public ConvertOpToLLVMPattern<vector::InsertOp> {
+ using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+ LogicalResult
+ matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const override {
+ VectorType resultType = insertOp.getResult().getType();
+
+ if (!isValidSMETileVectorType(resultType))
+ return failure();
+
+ auto loc = insertOp.getLoc();
+ auto position = insertOp.getMixedPosition();
+
+ Value source = adaptor.getSource();
+
+ if (position.empty()) {
+ rewriter.replaceOp(insertOp, source);
+ return success();
+ }
+
+ Value tileSlice = source;
+ Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+ if (position.size() == 2) {
+ // Two indices case: Insert signle element into tile.
+ // We need to first extract the existing slice and update the element.
+ tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+ loc, adaptor.getDest(), sliceIndex);
+ tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+ position[1]);
+ }
+
+ // Insert the slice into the destination tile.
+ rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+ insertOp, tileSlice, adaptor.getDest(), sliceIndex);
+ return success();
+ }
+};
+
} // namespace
void mlir::configureArmSMELegalizeForExportTarget(
@@ -601,9 +708,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
void mlir::populateArmSMELegalizeForLLVMExportPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
- patterns
- .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
- LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
- MoveVectorToTileSliceToArmSMELowering,
- VectorOuterProductToArmSMELowering>(converter);
+ patterns.add<
+ ZeroOpConversion, StoreTileSliceToArmSMELowering,
+ LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
+ MoveVectorToTileSliceToArmSMELowering, VectorOuterProductToArmSMELowering,
+ VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
}
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 687ef79385334cf..9678a0c91c38a32 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -496,3 +496,81 @@ func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vec
%0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
"prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
}
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[SLICE:.*]]: vector<[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @vector_insert_slice(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %y: index) -> vector<[4]x[4]xi32>{
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @vector_insert_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[EL:.*]]: i32,
+// CHECK-SAME: %[[Y:.*]]: index,
+// CHECK-SAME: %[[X:.*]]: index)
+func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index, %x: index) -> vector<[4]x[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+ // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[Y]] : index to i32
+ // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+ %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+ return %new_tile : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// vector.extract
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @extract_insert_slice(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[INDEX:.*]]: index)
+func.func @extract_insert_slice(%tile: vector<[4]x[4]xi32>, %y: index) -> vector<[4]xi32> {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ %slice = vector.extract %tile[%y] : vector<[4]xi32> from vector<[4]x[4]xi32>
+ return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element(
+// CHECK-SAME: %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME: %[[Y:.*]]: index,
+// CHECK-SAME: %[[X:.*]]: index)
+func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: index) -> i32 {
+ // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+ // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+ // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+ // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+ // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+ // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+ // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+ %el = vector.extract %tile[%y,%x] : i32 from vector<[4]x[4]xi32>
+ return %el : i32
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, this is great!
I think that we are still missing some tests :) For example in https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir.
ded2e2a
to
5108d38
Compare
// ----- | ||
|
||
// CHECK-LABEL: @vector_insert_slice_i128 | ||
func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not to be a pain but would appreciate if these followed the order of the other tests for consistency, i.e.
- i8
- i16
- i32
- i64
- i128
- f16
- bf16
- f32
- f64
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've kept i32
as the first in each group as it's the only full test. The rest are just for completeness.
…nsics This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsic. For example: ``` // Extract slice from tile: %slice = vector.extract %tile[%y]: vector<[4]x[4]xi32> // Extract element from tile: %el = vector.extract %tile[%y,%x]: vector<[4]x[4]xi32> // Insert slice into tile: %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32> // Insert element into tile; %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32> ```
5108d38
to
4d9501a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Really happy to see that we can reuse the Vector dialect directly for this and how things are composing!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM cheers
This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsics.
This enables the following operations for ArmSME: