Skip to content

[mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) #73639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 4, 2023

Conversation

MacDue
Copy link
Member

@MacDue MacDue commented Nov 28, 2023

This patch removes the ArmSMETypeConverter, and instead updates populateArmSMEToLLVMConversionPatterns() to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing -to-llvm lowering pass.

@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2023

@llvm/pr-subscribers-mlir-sme

Author: Benjamin Maxwell (MacDue)

Changes

This patch removes the ArmSMETypeConverter, and instead updates configureArmSMEToLLVMConversionLegality() to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing -to-llvm lowering pass.


Full diff: https://github.com/llvm/llvm-project/pull/73639.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h (+3-4)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (-8)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+13-7)
  • (removed) mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp (-22)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (-1)
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTARMSMETOLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-using arm_sme::ArmSMETypeConverter;
-
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+                                             LLVMTypeConverter &typeConverter);
 
 /// Populate the given list with patterns that convert from the ArmSME dialect
 /// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
 
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
-  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
   void runOnOperation() override {
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
-    ArmSMETypeConverter converter(&getContext(),
-                                  LowerToLLVMOptions(&getContext()));
-
-    configureArmSMEToLLVMConversionLegality(target);
+    LLVMTypeConverter converter(&getContext());
+    configureArmSMEToLLVMConversionLegality(target, converter);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
 
 } // namespace
 
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+    ConversionTarget &target, LLVMTypeConverter &typeConverter) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
       arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
+  typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+    // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+    // SME vector types should be eliminated.
+    if (arm_sme::isValidSMETileVectorType(type))
+      return type;
+    return std::nullopt;
+  });
 }
 
-void mlir::populateArmSMEToLLVMConversionPatterns(
-    ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
                OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
-    MLIRContext *ctx, const LowerToLLVMOptions &options)
-    : LLVMTypeConverter(ctx, options) {
-  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
-  // vectors (common in the context of ArmSME), e.g.
-  //    `vector<[16]x[16]xi8>`,
-  // entering the LLVM Type converter. LLVM does not support arrays of scalable
-  // vectors, but in the case of SME such types are effectively eliminated when
-  // emitting ArmSME LLVM IR intrinsics.
-  addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
-  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   TileAllocation.cpp
 

@llvmbot
Copy link
Member

llvmbot commented Nov 28, 2023

@llvm/pr-subscribers-mlir

Author: Benjamin Maxwell (MacDue)

Changes

This patch removes the ArmSMETypeConverter, and instead updates configureArmSMEToLLVMConversionLegality() to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing -to-llvm lowering pass.


Full diff: https://github.com/llvm/llvm-project/pull/73639.diff

5 Files Affected:

  • (modified) mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h (+3-4)
  • (modified) mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h (-8)
  • (modified) mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp (+13-7)
  • (removed) mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp (-22)
  • (modified) mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt (-1)
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
 #define GEN_PASS_DECL_CONVERTARMSMETOLLVM
 #include "mlir/Conversion/Passes.h.inc"
 
-using arm_sme::ArmSMETypeConverter;
-
 /// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
 std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
 
 /// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+                                             LLVMTypeConverter &typeConverter);
 
 /// Populate the given list with patterns that convert from the ArmSME dialect
 /// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
                                             RewritePatternSet &patterns);
 
 } // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
 /// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
 std::unique_ptr<Pass> createTileAllocationPass();
 
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
-  ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
 //===----------------------------------------------------------------------===//
 // Registration
 //===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
   void runOnOperation() override {
     LLVMConversionTarget target(getContext());
     RewritePatternSet patterns(&getContext());
-    ArmSMETypeConverter converter(&getContext(),
-                                  LowerToLLVMOptions(&getContext()));
-
-    configureArmSMEToLLVMConversionLegality(target);
+    LLVMTypeConverter converter(&getContext());
+    configureArmSMEToLLVMConversionLegality(target, converter);
     populateArmSMEToLLVMConversionPatterns(converter, patterns);
 
     if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
 
 } // namespace
 
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+    ConversionTarget &target, LLVMTypeConverter &typeConverter) {
   target.addIllegalDialect<arm_sme::ArmSMEDialect>();
   target.addLegalOp<
       arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
       arm_sme::aarch64_sme_mopa>();
   target.addLegalDialect<arith::ArithDialect>();
   target.addLegalOp<UnrealizedConversionCastOp>();
+  typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+    // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+    // SME vector types should be eliminated.
+    if (arm_sme::isValidSMETileVectorType(type))
+      return type;
+    return std::nullopt;
+  });
 }
 
-void mlir::populateArmSMEToLLVMConversionPatterns(
-    ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+                                                  RewritePatternSet &patterns) {
   patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
                MoveVectorToTileSliceConversion, StoreTileSliceConversion,
                OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
-    MLIRContext *ctx, const LowerToLLVMOptions &options)
-    : LLVMTypeConverter(ctx, options) {
-  // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
-  // vectors (common in the context of ArmSME), e.g.
-  //    `vector<[16]x[16]xi8>`,
-  // entering the LLVM Type converter. LLVM does not support arrays of scalable
-  // vectors, but in the case of SME such types are effectively eliminated when
-  // emitting ArmSME LLVM IR intrinsics.
-  addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
 add_mlir_dialect_library(MLIRArmSMETransforms
-  ArmSMETypeConverter.cpp
   EnableArmStreaming.cpp
   TileAllocation.cpp
 

@c-rhodes
Copy link
Collaborator

c-rhodes commented Nov 29, 2023

... This makes it easier to add these patterns to an existing -to-llvm lowering pass.

Context: iree-org/iree@main...MacDue:iree:armsme_tiling_and_lowering#diff-2c6c26aee4e11748464ce8d4fa0f83a0d5e52953a599c4c5dc2ef42cb55662a1R1072

Copy link
Collaborator

@c-rhodes c-rhodes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

grepping on typeConverter.addConversion this seems consistent with other conversions. One minor suggestion, but otherwise LGTM, cheers.

@MacDue MacDue force-pushed the arm_sme_type_conv_fixup branch from 828d6cc to 3de9bba Compare November 29, 2023 16:06
@banach-space
Copy link
Contributor

This will update the global LLVMTypeConverter used during the compilation, right? And so it will update it for all its consumers (even non-SME)? But that should be OK, because it will only happen when lowering to SME?

I just want to make sure that we retain the original protection that we gained when splitting this into a dedicated type converter.

@MacDue
Copy link
Member Author

MacDue commented Dec 1, 2023

This will update the global LLVMTypeConverter used during the compilation, right? And so it will update it for all its consumers (even non-SME)? But that should be OK, because it will only happen when lowering to SME?

I just want to make sure that we retain the original protection that we gained when splitting this into a dedicated type converter.

It's not global state, it's updating the single instance of the LLVMTypeConveter that's passed to it. Within MLIR core this is a NFC (since we construct a LLVMTypeConveter in our -convert-arm-sme-to-llvm pass).

Note that the ArmSMETypeConverter did not override any methods, it just called addConversion within its constructor. Constructing a new LLVMTypeConverter, and then calling addConversion() on it is equivalent to that (without an extra class).

Also, note that adding type conversions needed for patterns is a common thing done within other dialects (such as OpenMP and
SPIR-V).

@banach-space
Copy link
Contributor

This will update the global LLVMTypeConverter used during the compilation, right? And so it will update it for all its consumers (even non-SME)? But that should be OK, because it will only happen when lowering to SME?
I just want to make sure that we retain the original protection that we gained when splitting this into a dedicated type converter.

It's not global state, it's updating the single instance of the LLVMTypeConveter that's passed to it. Within MLIR core this is a NFC (since we construct a LLVMTypeConveter in our -convert-arm-sme-to-llvm pass).

Note that the ArmSMETypeConverter did not override any methods, it just called addConversion within its constructor. Constructing a new LLVMTypeConverter, and then calling addConversion() on it is equivalent to that (without an extra class).

Also, note that adding type conversions needed for patterns is a common thing done within other dialects (such as OpenMP and SPIR-V).

A slightly different question then - is there a test that shows that type conversion fails for vectors scalable in 2 dims? Unless lowering for SME? I think that that would be quite helpful.

…tead)

This patch removes the ArmSMETypeConverter, and instead updates
`configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type
conversion to the existing LLVMTypeConverter. This makes it easier to
add these patterns to an existing `-to-llvm` lowering pass.
@MacDue
Copy link
Member Author

MacDue commented Dec 1, 2023

This will update the global LLVMTypeConverter used during the compilation, right? And so it will update it for all its consumers (even non-SME)? But that should be OK, because it will only happen when lowering to SME?
I just want to make sure that we retain the original protection that we gained when splitting this into a dedicated type converter.

It's not global state, it's updating the single instance of the LLVMTypeConveter that's passed to it. Within MLIR core this is a NFC (since we construct a LLVMTypeConveter in our -convert-arm-sme-to-llvm pass).
Note that the ArmSMETypeConverter did not override any methods, it just called addConversion within its constructor. Constructing a new LLVMTypeConverter, and then calling addConversion() on it is equivalent to that (without an extra class).
Also, note that adding type conversions needed for patterns is a common thing done within other dialects (such as OpenMP and SPIR-V).

A slightly different question then - is there a test that shows that type conversion fails for vectors scalable in 2 dims? Unless lowering for SME? I think that that would be quite helpful.

It's not clear to me how you could write such a test. Type conversions apply to patterns, there's no patterns (other than our own) that apply to 2D scalable vectors when lowering to LLVM. So if you run -convert-vector-to-llvm on:

func.func @function_using_sme_size_2d_scalable_vec(%vec: vector<[4]x[4]xi32>) -> vector<[4]x[4]xi32> {
  %c7 = arith.constant 7 : i32
  %newVec = vector.insert %c7, %vec[0, 0] : i32 into vector<[4]x[4]xi32>
  return %newVec : vector<[4]x[4]xi32>
}

Nothing happens, no patterns match, so no type conversions run, and because it's a partial conversion it also does not error.

See: https://godbolt.org/z/TTG9cGrrj

@banach-space
Copy link
Contributor

Nothing happens, no patterns match, so no type conversions run, and because it's a partial conversion it also does not error.

See: https://godbolt.org/z/TTG9cGrrj

Argh, I missed that! Yeah, that makes sense, thanks for the discussion 🙏🏻

Copy link
Contributor

@banach-space banach-space left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@MacDue MacDue force-pushed the arm_sme_type_conv_fixup branch from 3de9bba to 5e83afd Compare December 4, 2023 10:50
@MacDue
Copy link
Member Author

MacDue commented Dec 4, 2023

Nothing happens, no patterns match, so no type conversions run, and because it's a partial conversion it also does not error.
See: https://godbolt.org/z/TTG9cGrrj

Argh, I missed that! Yeah, that makes sense, thanks for the discussion 🙏🏻

I've come up with a test now :)
I just had to write in in C++ rather than as an IR test, see:
https://github.com/llvm/llvm-project/blob/5e83afd0171c212a4d5c15db8e263650ac25e0f0/mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp

@MacDue MacDue merged commit 01e40a8 into llvm:main Dec 4, 2023
@MacDue MacDue deleted the arm_sme_type_conv_fixup branch December 4, 2023 17:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants