Skip to content

Commit 01e40a8

Browse files
authored
[mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) (#73639)
This patch removes the ArmSMETypeConverter, and instead updates `populateArmSMEToLLVMConversionPatterns()` to add an ArmSME vector type conversion to the existing LLVMTypeConverter. This makes it easier to add these patterns to an existing `-to-llvm` lowering pass.
1 parent bbd2b08 commit 01e40a8

File tree

8 files changed

+69
-39
lines changed

8 files changed

+69
-39
lines changed

mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,6 @@ class RewritePatternSet;
2020
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
2121
#include "mlir/Conversion/Passes.h.inc"
2222

23-
using arm_sme::ArmSMETypeConverter;
24-
2523
/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
2624
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();
2725

@@ -30,7 +28,7 @@ void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);
3028

3129
/// Populate the given list with patterns that convert from the ArmSME dialect
3230
/// to LLVM intrinsics.
33-
void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
31+
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
3432
RewritePatternSet &patterns);
3533

3634
} // namespace mlir

mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
3232
/// Pass that allocates tile IDs to ArmSME operations.
3333
std::unique_ptr<Pass> createTileAllocationPass();
3434

35-
//===----------------------------------------------------------------------===//
36-
// Type ArmSMETypeConverter pass.
37-
//===----------------------------------------------------------------------===//
38-
class ArmSMETypeConverter : public LLVMTypeConverter {
39-
public:
40-
ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
41-
};
42-
4335
//===----------------------------------------------------------------------===//
4436
// Registration
4537
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,7 @@ struct ConvertArmSMEToLLVMPass
538538
void runOnOperation() override {
539539
LLVMConversionTarget target(getContext());
540540
RewritePatternSet patterns(&getContext());
541-
ArmSMETypeConverter converter(&getContext(),
542-
LowerToLLVMOptions(&getContext()));
543-
541+
LLVMTypeConverter converter(&getContext());
544542
configureArmSMEToLLVMConversionLegality(target);
545543
populateArmSMEToLLVMConversionPatterns(converter, patterns);
546544

@@ -573,8 +571,16 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
573571
target.addLegalOp<UnrealizedConversionCastOp>();
574572
}
575573

576-
void mlir::populateArmSMEToLLVMConversionPatterns(
577-
ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
574+
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
575+
RewritePatternSet &patterns) {
576+
converter.addConversion([&](VectorType type) -> std::optional<Type> {
577+
// There's no LLVM type for SME tiles, but after lowering to intrinsics all
578+
// SME vector types should be eliminated.
579+
if (arm_sme::isValidSMETileVectorType(type))
580+
return type;
581+
return std::nullopt;
582+
});
583+
578584
patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
579585
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
580586
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(

mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp

Lines changed: 0 additions & 22 deletions
This file was deleted.

mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
add_mlir_dialect_library(MLIRArmSMETransforms
2-
ArmSMETypeConverter.cpp
32
EnableArmStreaming.cpp
43
TileAllocation.cpp
54

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
add_mlir_unittest(MLIRArmSMETests
2+
TileTypeConversionTest.cpp)
3+
target_link_libraries(MLIRArmSMETests
4+
PRIVATE
5+
MLIRArmSMEToLLVM)
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
//===- TileTypeConversionTest.cpp - Tests ArmSME tile type conversion -----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
10+
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
11+
#include "mlir/Conversion/LLVMCommon/Pattern.h"
12+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
13+
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"
14+
15+
#include "gtest/gtest.h"
16+
17+
using namespace mlir;
18+
19+
class ArmSMETest : public ::testing::Test {
20+
protected:
21+
ArmSMETest() { context.getOrLoadDialect<mlir::arm_sme::ArmSMEDialect>(); }
22+
23+
mlir::MLIRContext context;
24+
};
25+
26+
TEST_F(ArmSMETest, TestTileTypeConversion) {
27+
LLVMTypeConverter llvmConverter(&context);
28+
LLVMTypeConverter llvmConverterWithArmSMEConversion(&context);
29+
30+
RewritePatternSet patterns(&context);
31+
populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
32+
patterns);
33+
34+
Type i32 = IntegerType::get(&context, 32);
35+
auto smeTileType = VectorType::get({4, 4}, i32, {true, true});
36+
37+
// An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
38+
{
39+
SmallVector<Type> convertedType;
40+
ASSERT_TRUE(failed(llvmConverter.convertType(smeTileType, convertedType)));
41+
}
42+
43+
// An updated LLVMTypeConverer should return the ArmSME tile vector type
44+
// unchanged.
45+
{
46+
SmallVector<Type> convertedType;
47+
ASSERT_TRUE(succeeded(llvmConverterWithArmSMEConversion.convertType(
48+
smeTileType, convertedType)));
49+
ASSERT_EQ(ArrayRef<Type>(convertedType), ArrayRef<Type>{smeTileType});
50+
}
51+
}

mlir/unittests/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
66
MLIRIR
77
MLIRDialect)
88

9+
add_subdirectory(ArmSME)
910
add_subdirectory(Index)
1011
add_subdirectory(LLVMIR)
1112
add_subdirectory(MemRef)

0 commit comments

Comments
 (0)