-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) #73639
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
@llvm/pr-subscribers-mlir-sme Author: Benjamin Maxwell (MacDue) ChangesThis patch removes the ArmSMETypeConverter, and instead updates Full diff: https://github.com/llvm/llvm-project/pull/73639.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
-using arm_sme::ArmSMETypeConverter;
-
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+ LLVMTypeConverter &typeConverter);
/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
std::unique_ptr<Pass> createTileAllocationPass();
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
- ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
- ArmSMETypeConverter converter(&getContext(),
- LowerToLLVMOptions(&getContext()));
-
- configureArmSMEToLLVMConversionLegality(target);
+ LLVMTypeConverter converter(&getContext());
+ configureArmSMEToLLVMConversionLegality(target, converter);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
} // namespace
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
+ typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+ // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+ // SME vector types should be eliminated.
+ if (arm_sme::isValidSMETileVectorType(type))
+ return type;
+ return std::nullopt;
+ });
}
-void mlir::populateArmSMEToLLVMConversionPatterns(
- ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
- MLIRContext *ctx, const LowerToLLVMOptions &options)
- : LLVMTypeConverter(ctx, options) {
- // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
- // vectors (common in the context of ArmSME), e.g.
- // `vector<[16]x[16]xi8>`,
- // entering the LLVM Type converter. LLVM does not support arrays of scalable
- // vectors, but in the case of SME such types are effectively eliminated when
- // emitting ArmSME LLVM IR intrinsics.
- addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRArmSMETransforms
- ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
TileAllocation.cpp
|
@llvm/pr-subscribers-mlir Author: Benjamin Maxwell (MacDue) ChangesThis patch removes the ArmSMETypeConverter, and instead updates Full diff: https://github.com/llvm/llvm-project/pull/73639.diff 5 Files Affected:
diff --git a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
index fe851d17867dff5..b2130742e0f71cf 100644
--- a/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
+++ b/mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
@@ -20,17 +20,16 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"
-using arm_sme::ArmSMETypeConverter;
-
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
-void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
+void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
+ LLVMTypeConverter &typeConverter);
/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
-void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
+void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);
} // namespace mlir
diff --git a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
index 6f7617f5411c57f..4c9907be4086b1f 100644
--- a/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
+++ b/mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that replaces 'arm_sme.get_tile_id' ops with actual tiles.
std::unique_ptr<Pass> createTileAllocationPass();
-//===----------------------------------------------------------------------===//
-// Type ArmSMETypeConverter pass.
-//===----------------------------------------------------------------------===//
-class ArmSMETypeConverter : public LLVMTypeConverter {
-public:
- ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
-};
-
//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
diff --git a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
index e409dc57fb020e2..12f9df7ed4de2f4 100644
--- a/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
+++ b/mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
@@ -540,10 +540,8 @@ struct ConvertArmSMEToLLVMPass
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
- ArmSMETypeConverter converter(&getContext(),
- LowerToLLVMOptions(&getContext()));
-
- configureArmSMEToLLVMConversionLegality(target);
+ LLVMTypeConverter converter(&getContext());
+ configureArmSMEToLLVMConversionLegality(target, converter);
populateArmSMEToLLVMConversionPatterns(converter, patterns);
if (failed(applyPartialConversion(getOperation(), target,
@@ -554,7 +552,8 @@ struct ConvertArmSMEToLLVMPass
} // namespace
-void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
+void mlir::configureArmSMEToLLVMConversionLegality(
+ ConversionTarget &target, LLVMTypeConverter &typeConverter) {
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
target.addLegalOp<
arm_sme::GetTileID, arm_sme::CastTileToVector, arm_sme::CastVectorToTile,
@@ -574,10 +573,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
arm_sme::aarch64_sme_mopa>();
target.addLegalDialect<arith::ArithDialect>();
target.addLegalOp<UnrealizedConversionCastOp>();
+ typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
+ // There's no LLVM type for SME tiles, but after lowering to intrinsics all
+ // SME vector types should be eliminated.
+ if (arm_sme::isValidSMETileVectorType(type))
+ return type;
+ return std::nullopt;
+ });
}
-void mlir::populateArmSMEToLLVMConversionPatterns(
- ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
+void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
+ RewritePatternSet &patterns) {
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion>(converter);
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp b/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
deleted file mode 100644
index 1cefc220ecf1035..000000000000000
--- a/mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp
+++ /dev/null
@@ -1,22 +0,0 @@
-//===- ArmSMETypeConverter.cpp - Convert builtin to LLVM dialect types ----===//
-//
-// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
-// See https://llvm.org/LICENSE.txt for license information.
-// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
-//
-//===----------------------------------------------------------------------===//
-
-#include "mlir/Dialect/ArmSME/Transforms/Passes.h"
-
-using namespace mlir;
-arm_sme::ArmSMETypeConverter::ArmSMETypeConverter(
- MLIRContext *ctx, const LowerToLLVMOptions &options)
- : LLVMTypeConverter(ctx, options) {
- // Disable LLVM type conversion for vectors. This is to prevent 2-d scalable
- // vectors (common in the context of ArmSME), e.g.
- // `vector<[16]x[16]xi8>`,
- // entering the LLVM Type converter. LLVM does not support arrays of scalable
- // vectors, but in the case of SME such types are effectively eliminated when
- // emitting ArmSME LLVM IR intrinsics.
- addConversion([&](VectorType type) { return type; });
-}
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
index e2407d9f48f7061..f0854d0678e106d 100644
--- a/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRArmSMETransforms
- ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
TileAllocation.cpp
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
grepping on typeConverter.addConversion
this seems consistent with other conversions. One minor suggestion, but otherwise LGTM, cheers.
828d6cc
to
3de9bba
Compare
This will update the global I just want to make sure that we retain the original protection that we gained when splitting this into a dedicated type converter. |
It's not global state, it's updating the single instance of the Note that the Also, note that adding type conversions needed for patterns is a common thing done within other dialects (such as OpenMP and |
A slightly different question then - is there a test that shows that type conversion fails for vectors scalable in 2 dims? Unless lowering for SME? I think that that would be quite helpful. |
…tead) This patch removes the ArmSMETypeConverter, and instead updates `configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing `-to-llvm` lowering pass.
It's not clear to me how you could write such a test. Type conversions apply to patterns, there's no patterns (other than our own) that apply to 2D scalable vectors when lowering to LLVM. So if you run
Nothing happens, no patterns match, so no type conversions run, and because it's a partial conversion it also does not error. |
Argh, I missed that! Yeah, that makes sense, thanks for the discussion 🙏🏻 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks!
3de9bba
to
5e83afd
Compare
I've come up with a test now :) |
This patch removes the ArmSMETypeConverter, and instead updates
populateArmSMEToLLVMConversionPatterns()
to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing-to-llvm
lowering pass.