Skip to content

[mlir][ArmSME] Move vector.extract/insert lowerings to vector-to-arm-sme (NFC) #72852

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 20, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 20, 2023

These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.

@llvmbot
Copy link
Member

llvmbot commented Nov 20, 2023

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage for these, as these lower to high-level ArmSME ops, not intrinsics.


Full diff: https://github.com/llvm/llvm-project/pull/72852.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp (+113-1)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp (+3-115)
diff --git a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
index 953a465c18de69f..420d2b6b1c08786 100644
--- a/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
+++ b/mlir/lib/Conversion/VectorToArmSME/VectorToArmSME.cpp
@@ -576,6 +576,116 @@ struct VectorOuterProductToArmSMELowering
   }
 };
 
+/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
+/// ```
+struct VectorExtractToArmSMELowering
+    : public OpRewritePattern<vector::ExtractOp> {
+  using OpRewritePattern<vector::ExtractOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::ExtractOp extractOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType sourceType = extractOp.getSourceVectorType();
+    if (!arm_sme::isValidSMETileVectorType(sourceType))
+      return failure();
+
+    auto loc = extractOp.getLoc();
+    auto position = extractOp.getMixedPosition();
+
+    Value sourceVector = extractOp.getVector();
+
+    // Extract entire vector. Should be handled by folder, but just to be safe.
+    if (position.empty()) {
+      rewriter.replaceOp(extractOp, sourceVector);
+      return success();
+    }
+
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+    auto moveTileSliceToVector =
+        rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
+                                                          sliceIndex);
+
+    if (position.size() == 1) {
+      // Single index case: Extracts a 1D slice.
+      rewriter.replaceOp(extractOp, moveTileSliceToVector);
+      return success();
+    }
+
+    // Two indices case: Extracts a single element.
+    assert(position.size() == 2);
+    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
+        extractOp, moveTileSliceToVector, position[1]);
+
+    return success();
+  }
+};
+
+/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
+/// `arm_sme.move_tile_slice_to_vector`.
+///
+/// Example:
+/// ```
+/// %new_tile = vector.insert %el, %tile[%row, %col]
+///                     : i32 into vector<[4]x[4]xi32>
+/// ```
+/// Becomes:
+/// ```
+/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
+///            : vector<[4]xi32> from vector<[4]x[4]xi32>
+/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
+/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
+///               : vector<[4]xi32> into vector<[4]x[4]xi32>
+/// ```
+struct VectorInsertToArmSMELowering
+    : public OpRewritePattern<vector::InsertOp> {
+  using OpRewritePattern<vector::InsertOp>::OpRewritePattern;
+
+  LogicalResult matchAndRewrite(vector::InsertOp insertOp,
+                                PatternRewriter &rewriter) const override {
+    VectorType resultType = insertOp.getResult().getType();
+
+    if (!arm_sme::isValidSMETileVectorType(resultType))
+      return failure();
+
+    auto loc = insertOp.getLoc();
+    auto position = insertOp.getMixedPosition();
+
+    Value source = insertOp.getSource();
+
+    // Overwrite entire vector with value. Should be handled by folder, but
+    // just to be safe.
+    if (position.empty()) {
+      rewriter.replaceOp(insertOp, source);
+      return success();
+    }
+
+    Value tileSlice = source;
+    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
+    if (position.size() == 2) {
+      // Two indices case: Insert single element into tile.
+      // We need to first extract the existing slice and update the element.
+      tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
+          loc, insertOp.getDest(), sliceIndex);
+      tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
+                                                    position[1]);
+    }
+
+    // Insert the slice into the destination tile.
+    rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
+        insertOp, tileSlice, insertOp.getDest(), sliceIndex);
+    return success();
+  }
+};
+
 } // namespace
 
 void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
@@ -584,5 +694,7 @@ void mlir::populateVectorToArmSMEPatterns(RewritePatternSet &patterns,
                SplatOpToArmSMELowering, TransferReadToArmSMELowering,
                TransferWriteToArmSMELowering, TransposeOpToArmSMELowering,
                VectorLoadToArmSMELowering, VectorStoreToArmSMELowering,
-               VectorOuterProductToArmSMELowering>(&ctx);
+               VectorOuterProductToArmSMELowering,
+               VectorExtractToArmSMELowering, VectorInsertToArmSMELowering>(
+      &ctx);
 }
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
index 6078b3f2c5e4708..041a4897a836503 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/ArmSME/Transforms/LegalizeForLLVMExport.cpp
@@ -456,7 +456,8 @@ struct OuterProductOpConversion
       //   * half-precision   - +sme2p1,+b16b16
       //
       // It should be possible to control lowering based on target features.
-      // [1] https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
+      // [1]
+      // https://developer.arm.com/downloads/-/exploration-tools/feature-names-for-a-profile
       if ((vectorType.getRank() != 2) || !vectorType.allDimsScalable())
         return false;
 
@@ -520,118 +521,6 @@ struct OuterProductOpConversion
   }
 };
 
-/// Lower `vector.extract` using `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %el = vector.extract %tile[%row, %col]: i32 from vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %el = vector.extract %slice[%col] : i32 from vector<[4]xi32>
-/// ```
-struct VectorExtractToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::ExtractOp> {
-  using ConvertOpToLLVMPattern<vector::ExtractOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::ExtractOp extractOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType sourceType = extractOp.getSourceVectorType();
-    if (!isValidSMETileVectorType(sourceType))
-      return failure();
-
-    auto loc = extractOp.getLoc();
-    auto position = extractOp.getMixedPosition();
-
-    Value sourceVector = extractOp.getVector();
-
-    // Extract entire vector. Should be handled by folder, but just to be safe.
-    if (position.empty()) {
-      rewriter.replaceOp(extractOp, sourceVector);
-      return success();
-    }
-
-    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
-    auto moveTileSliceToVector =
-        rewriter.create<arm_sme::MoveTileSliceToVectorOp>(loc, sourceVector,
-                                                          sliceIndex);
-
-    if (position.size() == 1) {
-      // Single index case: Extracts a 1D slice.
-      rewriter.replaceOp(extractOp, moveTileSliceToVector);
-      return success();
-    }
-
-    // Two indices case: Extracts a single element.
-    assert(position.size() == 2);
-    rewriter.replaceOpWithNewOp<vector::ExtractOp>(
-        extractOp, moveTileSliceToVector, position[1]);
-
-    return success();
-  }
-};
-
-/// Lower `vector.insert` using `arm_sme.move_vector_to_tile_slice` and
-/// `arm_sme.move_tile_slice_to_vector`.
-///
-/// Example:
-/// ```
-/// %new_tile = vector.insert %el, %tile[%row, %col]
-///                     : i32 into vector<[4]x[4]xi32>
-/// ```
-/// Becomes:
-/// ```
-/// %slice = arm_sme.move_tile_slice_to_vector %tile[%row]
-///            : vector<[4]xi32> from vector<[4]x[4]xi32>
-/// %new_slice = vector.insert %el, %slice[%col] : i32 into vector<[4]xi32>
-/// %new_tile = arm_sme.move_vector_to_tile_slice %new_slice, %tile, %row
-///               : vector<[4]xi32> into vector<[4]x[4]xi32>
-/// ```
-struct VectorInsertToArmSMELowering
-    : public ConvertOpToLLVMPattern<vector::InsertOp> {
-  using ConvertOpToLLVMPattern<vector::InsertOp>::ConvertOpToLLVMPattern;
-
-  LogicalResult
-  matchAndRewrite(vector::InsertOp insertOp, OpAdaptor adaptor,
-                  ConversionPatternRewriter &rewriter) const override {
-    VectorType resultType = insertOp.getResult().getType();
-
-    if (!isValidSMETileVectorType(resultType))
-      return failure();
-
-    auto loc = insertOp.getLoc();
-    auto position = insertOp.getMixedPosition();
-
-    Value source = adaptor.getSource();
-
-    // Overwrite entire vector with value. Should be handled by folder, but
-    // just to be safe.
-    if (position.empty()) {
-      rewriter.replaceOp(insertOp, source);
-      return success();
-    }
-
-    Value tileSlice = source;
-    Value sliceIndex = vector::getAsValues(rewriter, loc, position[0]).front();
-    if (position.size() == 2) {
-      // Two indices case: Insert single element into tile.
-      // We need to first extract the existing slice and update the element.
-      tileSlice = rewriter.create<arm_sme::MoveTileSliceToVectorOp>(
-          loc, adaptor.getDest(), sliceIndex);
-      tileSlice = rewriter.create<vector::InsertOp>(loc, source, tileSlice,
-                                                    position[1]);
-    }
-
-    // Insert the slice into the destination tile.
-    rewriter.replaceOpWithNewOp<arm_sme::MoveVectorToTileSliceOp>(
-        insertOp, tileSlice, adaptor.getDest(), sliceIndex);
-    return success();
-  }
-};
-
 } // namespace
 
 void mlir::configureArmSMELegalizeForExportTarget(
@@ -661,6 +550,5 @@ void mlir::populateArmSMELegalizeForLLVMExportPatterns(
   patterns.add<
       LoadTileSliceToArmSMELowering, MoveTileSliceToVectorArmSMELowering,
       MoveVectorToTileSliceToArmSMELowering, StoreTileSliceToArmSMELowering,
-      OuterProductOpConversion, ZeroOpConversion, VectorExtractToArmSMELowering,
-      VectorInsertToArmSMELowering>(converter);
+      OuterProductOpConversion, ZeroOpConversion>(converter);
 }

…sme (NFC)

These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage
for these, as these lower to high-level ArmSME ops, not intrinsics.
Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM cheers

@MacDue MacDue merged commit c4c52d4 into llvm:main Nov 20, 2023
@MacDue MacDue deleted the mv branch November 20, 2023 14:05
zahiraam pushed a commit to zahiraam/llvm-project that referenced this pull request Nov 20, 2023
…sme (NFC) (llvm#72852)

These were placed in LegalizeForLLVMExport.cpp, which is the wrong stage
for these, as these lower to high-level ArmSME ops, not intrinsics.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants