Skip to content

[mlir][ArmSME] Lower vector.extract/insert on SME tiles to MOVA intrinsics #67786

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Oct 4, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Sep 29, 2023

This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsics.

This enables the following operations for ArmSME:

// Extract slice from tile:
%slice = vector.extract %tile[%row] 
                 : vector<[4]xi32> from vector<[4]x[4]xi32>
// Extract element from tile:
%el = vector.extract %tile[%row, %col]
                 : i32 from vector<[4]x[4]xi32>
// Insert slice into tile:
%new_tile = vector.insert %slice, %tile[%row]
                    : vector<[4]xi32> into vector<[4]x[4]xi32>
// Insert element into tile;
%new_tile = vector.insert %el, %tile[%row, %col]
                    : i32 into vector<[4]x[4]xi32>

@llvmbot
Copy link
Member

llvmbot commented Sep 29, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Changes

This patch adds support for lowering vector.insert/extract of tile slices or elements to ArmSME MOVA intrinsic.

This enables the following operations for ArmSME:

// Extract slice from tile:
%slice = vector.extract %tile[%y]: vector&lt;[4]x[4]xi32&gt;
// Extract element from tile:
%el = vector.extract %tile[%y, %x]: vector&lt;[4]x[4]xi32&gt;
// Insert slice into tile:
%new_tile = vector.insert %slice, %tile[%y]
                    : vector&lt;[4]xi32&gt; into vector&lt;[4]x[4]xi32&gt;
// Insert element into tile;
%new_tile = vector.insert %el, %tile[%y, %x]
                    : i32 into vector&lt;[4]x[4]xi32&gt;

Full diff: https://github.com/llvm/llvm-project/pull/67786.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+112-5)
  • (modified) mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir (+78)
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 0322c2f3fcd14d4..edf9e333b0e4784 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -550,6 +550,113 @@ struct VectorOuterProductToArmSMELowering
   }
 };
 
+/// Lower `vector.extract` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%y,%x]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%x] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+    : public ConvertOpToLLVMPattern<vector::ExtractOp> {
+  using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    if (!isValidSMETileVectorType(sourceType))
+      return failure();
+
+    auto loc = extractOp.getLoc();
+    auto position = extractOp.getMixedPosition();
+
+    Value sourceVector = extractOp.getVector();
+
+    if (position.empty()) {
+      rewriter.replaceOp(extractOp, sourceVector);
+      return success();
+    }
+
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+    auto moveTileSliceToVector =
+        rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+                                                          sliceIndex);
+
+    if (position.size() == 1) {
+      // Single index case: Extracts a 1D slice.
+      rewriter.replaceOp(extractOp, moveTileSliceToVector);
+      return success();
+    }
+
+    // Two indices case: Extracts a single element.
+    assert(position.size() == 2);
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, moveTileSliceToVector, position[1]);
+
+    return success();
+  }
+};
+
+/// Lower `vector.insert` using SME MOVA intrinsics.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%y]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%x] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %y
+///               : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+    : public ConvertOpToLLVMPattern<vector::InsertOp> {
+  using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
+
+  LogicalResult
+  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
+                  ConversionPatternRewriter &rewriter) const override {
+    VectorType resultType = insertOp.getResult().getType();
+
+    if (!isValidSMETileVectorType(resultType))
+      return failure();
+
+    auto loc = insertOp.getLoc();
+    auto position = insertOp.getMixedPosition();
+
+    Value source = adaptor.getSource();
+
+    if (position.empty()) {
+      rewriter.replaceOp(insertOp, source);
+      return success();
+    }
+
+    Value tileSlice = source;
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0])[0];
+    if (position.size() == 2) {
+      // Two indices case: Insert signle element into tile.
+      // We need to first extract the existing slice and update the element.
+      tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, adaptor.getDest(), sliceIndex);
+      tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+                                                    position[1]);
+    }
+
+    // Insert the slice into the destination tile.
+    rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+        insertOp, tileSlice, adaptor.getDest(), sliceIndex);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::configureArmSMELegalizeForExportTarget(
@@ -601,9 +708,9 @@ void mlir::configureArmSMELegalizeForExportTarget(
 void mlir::populateArmSMELegalizeForLLVMExportPatterns(
     LLVMTypeConverter &converter, RewritePatternSet &patterns) {
   patterns.add<EnableZAPattern, DisableZAPattern>(patterns.getContext());
-  patterns
-      .add<ZeroOpConversion, StoreTileSliceToArmSMELowering,
-           LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
-           MoveVectorToTileSliceToArmSMELowering,
-           VectorOuterProductToArmSMELowering>(converter);
+  patterns.add<
+      ZeroOpConversion, StoreTileSliceToArmSMELowering,
+      LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
+      MoveVectorToTileSliceToArmSMELowering, VectorOuterProductToArmSMELowering,
+      VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(converter);
 }
diff --git a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
index 687ef79385334cf..9678a0c91c38a32 100644
--- a/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
+++ b/mlir/test/Dialect/ArmSME/vector-ops-to-llvm.mlir
@@ -496,3 +496,81 @@ func.func @vector_outerproduct_add_masked_f32(%lhs : vector<[4]xf32>, %rhs : vec
   %0 = vector.mask %mask { vector.outerproduct %lhs, %rhs, %acc {kind = #vector.kind<add>} : vector<[4]xf32>, vector<[4]xf32> } : vector<[4]x[4]xi1> -> vector<[4]x[4]xf32>
   "prevent.dce"(%0) : (vector<[4]x[4]xf32>) -> ()
 }
+
+//===----------------------------------------------------------------------===//
+// vector.insert
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @vector_insert_slice(
+// CHECK-SAME:                       %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                       %[[SLICE:.*]]: vector<[4]xi32>,
+// CHECK-SAME:                       %[[INDEX:.*]]: index)
+func.func @vector_insert_slice(%tile: vector<[4]x[4]xi32>, %slice: vector<[4]xi32>, %y: index) -> vector<[4]x[4]xi32>{
+  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_castui %[[INDEX]] : index to i32
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[TILE_SLICE_INDEX]], %[[PTRUE]], %[[SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %new_tile = vector.insert %slice, %tile[%y] : vector<[4]xi32> into vector<[4]x[4]xi32>
+  return %new_tile : vector<[4]x[4]xi32>
+}
+
+// -----
+
+
+// CHECK-LABEL: @vector_insert_element(
+// CHECK-SAME:                         %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                         %[[EL:.*]]: i32,
+// CHECK-SAME:                         %[[Y:.*]]: index,
+// CHECK-SAME:                         %[[X:.*]]: index)
+func.func @vector_insert_element(%tile: vector<[4]x[4]xi32>, %el: i32, %y: index, %x: index) -> vector<[4]x[4]xi32> {
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[NEW_SLICE:.*]] = llvm.insertelement %[[EL]], %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+  // CHECK-NEXT: %[[SLICE_INDEX:.*]] = arith.index_castui %[[Y]] : index to i32
+  // CHECK-NEXT: "arm_sme.intr.write.horiz"(%[[TILE_ID]], %[[SLICE_INDEX]], %[[PTRUE]], %[[NEW_SLICE]]) : (i32, i32, vector<[4]xi1>, vector<[4]xi32>) -> ()
+  %new_tile = vector.insert %el, %tile[%y,%x] : i32 into vector<[4]x[4]xi32>
+  return %new_tile : vector<[4]x[4]xi32>
+}
+
+//===----------------------------------------------------------------------===//
+// vector.extract
+//===----------------------------------------------------------------------===//
+
+// -----
+
+// CHECK-LABEL: @extract_insert_slice(
+// CHECK-SAME:                        %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                        %[[INDEX:.*]]: index)
+func.func @extract_insert_slice(%tile: vector<[4]x[4]xi32>, %y: index) -> vector<[4]xi32> {
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[TILE_SLICE_INDEX:.*]] = arith.index_cast %[[INDEX]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[TILE_SLICE_INDEX]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  %slice = vector.extract %tile[%y] : vector<[4]xi32> from vector<[4]x[4]xi32>
+  return %slice : vector<[4]xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @vector_extract_element(
+// CHECK-SAME:                          %[[TILE:.*]]: vector<[4]x[4]xi32>,
+// CHECK-SAME:                          %[[Y:.*]]: index,
+// CHECK-SAME:                          %[[X:.*]]: index)
+func.func @vector_extract_element(%tile: vector<[4]x[4]xi32>, %y: index, %x: index) -> i32 {
+  // CHECK-NEXT: %[[ZERO_VEC:.*]] = arith.constant dense<0> : vector<[4]xi32>
+  // CHECK-NEXT: %[[PTRUE:.*]] = arith.constant dense<true> : vector<[4]xi1>
+  // CHECK-NEXT: %[[X_I32:.*]] = builtin.unrealized_conversion_cast %[[X]] : index to i64
+  // CHECK-NEXT: %[[TILE_ID:.*]] = arm_sme.cast_vector_to_tile %[[TILE]] : vector<[4]x[4]xi32> to i32
+  // CHECK-NEXT: %[[Y_I32:.*]] = arith.index_cast %[[Y]] : index to i32
+  // CHECK-NEXT: %[[SLICE:.*]] = "arm_sme.intr.read.horiz"(%[[ZERO_VEC]], %[[PTRUE]], %[[TILE_ID]], %[[Y_I32]]) : (vector<[4]xi32>, vector<[4]xi1>, i32, i32) -> vector<[4]xi32>
+  // CHECK-NEXT: %[[EL:.*]] = llvm.extractelement %[[SLICE]]{{\[}}%[[X_I32]] : i64] : vector<[4]xi32>
+  %el = vector.extract %tile[%y,%x] : i32 from vector<[4]x[4]xi32>
+  return %el : i32
+}

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, this is great!

I think that we are still missing some tests :) For example in https://github.com/llvm/llvm-project/blob/main/mlir/test/Dialect/ArmSME/vector-ops-to-sme.mlir.

@MacDue MacDue force-pushed the vector_insert_and_extract_sme branch from ded2e2a to 5108d38 Compare October 2, 2023 11:50
// -----

// CHECK-LABEL: @vector_insert_slice_i128
func.func @vector_insert_slice_i128(%tile: vector<[1]x[1]xi128>, %slice: vector<[1]xi128>, %row: index) -> vector<[1]x[1]xi128> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not to be a pain but would appreciate if these followed the order of the other tests for consistency, i.e.

  • i8
  • i16
  • i32
  • i64
  • i128
  • f16
  • bf16
  • f32
  • f64

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've kept i32 as the first in each group as it's the only full test. The rest are just for completeness.

MacDue added 4 commits October 3, 2023 17:22
…nsics

This patch adds support for lowering vector.insert/extract of tile
slices or elements to ArmSME MOVA intrinsic.

For example:

```
// Extract slice from tile:
%slice = vector.extract %tile[%y]: vector<[4]x[4]xi32>
// Extract element from tile:
%el = vector.extract %tile[%y,%x]: vector<[4]x[4]xi32>

// Insert slice into tile:
%new_tile = vector.insert %slice, %tile[%y]
                    : vector<[4]xi32> into vector<[4]x[4]xi32>
// Insert element into tile;
%new_tile = vector.insert %el, %tile[%y,%x]
                    : i32 into vector<[4]x[4]xi32>
```
@MacDue MacDue force-pushed the vector_insert_and_extract_sme branch from 5108d38 to 4d9501a Compare October 3, 2023 17:29
Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Really happy to see that we can reuse the Vector dialect directly for this and how things are composing!

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit 496318a into llvm:main Oct 4, 2023
@MacDue MacDue deleted the vector_insert_and_extract_sme branch October 4, 2023 19:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants