Skip to content

Commit c457e2a

Browse files
committed
[mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead)
This patch removes the ArmSMETypeConverter, and instead updates `configureArmSMEToLLVMConversionLegality()` to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing `-to-llvm` lowering pass.
1 parent 14e9917 commit c457e2a

File tree

5 files changed

+16
-42
lines changed

5 files changed

+16
-42
lines changed

mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,17 +20,16 @@ class RewritePatternSet;
2020
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
2121
#include "mlir/Conversion/Passes.h.inc"
2222

23-
using arm_sme::ArmSMETypeConverter;
24-
2523
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
2624
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
2725

2826
/// Configure target to convert from the ArmSME dialect to LLVM intrinsics.
29-
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
27+
void configureArmSMEToLLVMConversionLegality(ConversionTarget &target,
28+
LLVMTypeConverter &typeConverter);
3029

3130
/// Populate the given list with patterns that convert from the ArmSME dialect
3231
/// to LLVM intrinsics.
33-
void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
32+
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
3433
RewritePatternSet &patterns);
3534

3635
} // namespace mlir

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();
3434

35-
//===----------------------------------------------------------------------===//
36-
// Type ArmSMETypeConverter pass.
37-
//===----------------------------------------------------------------------===//
38-
class ArmSMETypeConverter : public LLVMTypeConverter {
39-
public:
40-
ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
41-
};
42-
4335
//===----------------------------------------------------------------------===//
4436
// Registration
4537
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -538,10 +538,8 @@ struct ConvertArmSMEToLLVMPass
538538
void runOnOperation() override {
539539
LLVMConversionTarget target(getContext());
540540
RewritePatternSet patterns(&getContext());
541-
ArmSMETypeConverter converter(&getContext(),
542-
LowerToLLVMOptions(&getContext()));
543-
544-
configureArmSMEToLLVMConversionLegality(target);
541+
LLVMTypeConverter converter(&getContext());
542+
configureArmSMEToLLVMConversionLegality(target, converter);
545543
populateArmSMEToLLVMConversionPatterns(converter, patterns);
546544

547545
if (failed(applyPartialConversion(getOperation(), target,
@@ -552,7 +550,8 @@ struct ConvertArmSMEToLLVMPass
552550

553551
} // namespace
554552

555-
void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
553+
void mlir::configureArmSMEToLLVMConversionLegality(
554+
ConversionTarget &target, LLVMTypeConverter &typeConverter) {
556555
target.addIllegalDialect<arm_sme::ArmSMEDialect>();
557556
target.addLegalOp<
558557
arm_sme::MaterializeSSATileOp, arm_sme::aarch64_sme_zero,
@@ -571,10 +570,17 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
571570
arm_sme::aarch64_sme_write_vert, arm_sme::aarch64_sme_mopa>();
572571
target.addLegalDialect<arith::ArithDialect>();
573572
target.addLegalOp<UnrealizedConversionCastOp>();
573+
typeConverter.addConversion([&](VectorType type) -> std::optional<Type> {
574+
// There's no LLVM type for SME tiles, but after lowering to intrinsics all
575+
// SME vector types should be eliminated.
576+
if (arm_sme::isValidSMETileVectorType(type))
577+
return type;
578+
return std::nullopt;
579+
});
574580
}
575581

576-
void mlir::populateArmSMEToLLVMConversionPatterns(
577-
ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
582+
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
583+
RewritePatternSet &patterns) {
578584
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
579585
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
580586
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(

mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
add_mlir_dialect_library(MLIRArmSMETransforms
2-
ArmSMETypeConverter.cpp
32
EnableArmStreaming.cpp
43
TileAllocation.cpp
54

0 commit comments

Comments
 (0)