Skip to content

[mlir][ArmSME] Remove ArmSMETypeConverter (and configure LLVM one instead) #73639

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Dec 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions mlir/include/mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@ class RewritePatternSet;
#define GEN_PASS_DECL_CONVERTARMSMETOLLVM
#include "mlir/Conversion/Passes.h.inc"

using arm_sme::ArmSMETypeConverter;

/// Create a pass to convert from the ArmSME dialect to LLVM intrinsics.
std::unique_ptr<Pass> createConvertArmSMEToLLVMPass();

Expand All @@ -30,7 +28,7 @@ void configureArmSMEToLLVMConversionLegality(ConversionTarget &target);

/// Populate the given list with patterns that convert from the ArmSME dialect
/// to LLVM intrinsics.
void populateArmSMEToLLVMConversionPatterns(ArmSMETypeConverter &converter,
void populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

} // namespace mlir
Expand Down
8 changes: 0 additions & 8 deletions mlir/include/mlir/Dialect/ArmSME/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ std::unique_ptr<Pass> createEnableArmStreamingPass(
/// Pass that allocates tile IDs to ArmSME operations.
std::unique_ptr<Pass> createTileAllocationPass();

//===----------------------------------------------------------------------===//
// Type ArmSMETypeConverter pass.
//===----------------------------------------------------------------------===//
class ArmSMETypeConverter : public LLVMTypeConverter {
public:
ArmSMETypeConverter(MLIRContext *ctx, const LowerToLLVMOptions &options);
};

//===----------------------------------------------------------------------===//
// Registration
//===----------------------------------------------------------------------===//
Expand Down
16 changes: 11 additions & 5 deletions mlir/lib/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,9 +538,7 @@ struct ConvertArmSMEToLLVMPass
void runOnOperation() override {
LLVMConversionTarget target(getContext());
RewritePatternSet patterns(&getContext());
ArmSMETypeConverter converter(&getContext(),
LowerToLLVMOptions(&getContext()));

LLVMTypeConverter converter(&getContext());
configureArmSMEToLLVMConversionLegality(target);
populateArmSMEToLLVMConversionPatterns(converter, patterns);

Expand Down Expand Up @@ -573,8 +571,16 @@ void mlir::configureArmSMEToLLVMConversionLegality(ConversionTarget &target) {
target.addLegalOp<UnrealizedConversionCastOp>();
}

void mlir::populateArmSMEToLLVMConversionPatterns(
ArmSMETypeConverter &converter, RewritePatternSet &patterns) {
void mlir::populateArmSMEToLLVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
converter.addConversion([&](VectorType type) -> std::optional<Type> {
// There's no LLVM type for SME tiles, but after lowering to intrinsics all
// SME vector types should be eliminated.
if (arm_sme::isValidSMETileVectorType(type))
return type;
return std::nullopt;
});

patterns.add<LoadTileSliceConversion, MoveTileSliceToVectorConversion,
MoveVectorToTileSliceConversion, StoreTileSliceConversion,
OuterProductOpConversion, ZeroOpConversion, GetTileConversion>(
Expand Down
22 changes: 0 additions & 22 deletions mlir/lib/Dialect/ArmSME/Transforms/ArmSMETypeConverter.cpp

This file was deleted.

1 change: 0 additions & 1 deletion mlir/lib/Dialect/ArmSME/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
add_mlir_dialect_library(MLIRArmSMETransforms
ArmSMETypeConverter.cpp
EnableArmStreaming.cpp
TileAllocation.cpp

Expand Down
5 changes: 5 additions & 0 deletions mlir/unittests/Dialect/ArmSME/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
add_mlir_unittest(MLIRArmSMETests
TileTypeConversionTest.cpp)
target_link_libraries(MLIRArmSMETests
PRIVATE
MLIRArmSMEToLLVM)
51 changes: 51 additions & 0 deletions mlir/unittests/Dialect/ArmSME/TileTypeConversionTest.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
//===- TileTypeConversionTest.cpp - Tests ArmSME tile type conversion -----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Conversion/ArmSMEToLLVM/ArmSMEToLLVM.h"
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
#include "mlir/Conversion/LLVMCommon/Pattern.h"
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
#include "mlir/Dialect/ArmSME/IR/ArmSME.h"

#include "gtest/gtest.h"

using namespace mlir;

class ArmSMETest : public ::testing::Test {
protected:
ArmSMETest() { context.getOrLoadDialect<mlir::arm_sme::ArmSMEDialect>(); }

mlir::MLIRContext context;
};

TEST_F(ArmSMETest, TestTileTypeConversion) {
LLVMTypeConverter llvmConverter(&context);
LLVMTypeConverter llvmConverterWithArmSMEConversion(&context);

RewritePatternSet patterns(&context);
populateArmSMEToLLVMConversionPatterns(llvmConverterWithArmSMEConversion,
patterns);

Type i32 = IntegerType::get(&context, 32);
auto smeTileType = VectorType::get({4, 4}, i32, {true, true});

// An unmodified LLVMTypeConverer should fail to convert an ArmSME tile type.
{
SmallVector<Type> convertedType;
ASSERT_TRUE(failed(llvmConverter.convertType(smeTileType, convertedType)));
}

// An updated LLVMTypeConverer should return the ArmSME tile vector type
// unchanged.
{
SmallVector<Type> convertedType;
ASSERT_TRUE(succeeded(llvmConverterWithArmSMEConversion.convertType(
smeTileType, convertedType)));
ASSERT_EQ(ArrayRef<Type>(convertedType), ArrayRef<Type>{smeTileType});
}
}
1 change: 1 addition & 0 deletions mlir/unittests/Dialect/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ target_link_libraries(MLIRDialectTests
MLIRIR
MLIRDialect)

add_subdirectory(ArmSME)
add_subdirectory(Index)
add_subdirectory(LLVMIR)
add_subdirectory(MemRef)
Expand Down