Skip to content

Commit 37a98e6

Browse files
committed
Remove amx::TileType conversion from LLVMTypeConverter.
Signed-off-by: Ilya Enkovich <[email protected]>
1 parent b6802a3 commit 37a98e6

File tree

6 files changed

+32
-15
lines changed

6 files changed

+32
-15
lines changed

mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
1616

1717
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
18-
#include "mlir/Dialect/AMX/AMXDialect.h"
1918
#include "mlir/IR/BuiltinTypes.h"
2019
#include "mlir/Transforms/DialectConversion.h"
2120

@@ -259,9 +258,6 @@ class LLVMTypeConverter : public TypeConverter {
259258
/// Convert a 1D vector type into an LLVM vector type.
260259
FailureOr<Type> convertVectorType(VectorType type) const;
261260

262-
/// Convert an AMX tile type to the x86_amx type.
263-
Type convertAMXTileType(amx::TileType type) const;
264-
265261
/// Options for customizing the llvm lowering.
266262
LowerToLLVMOptions options;
267263

mlir/include/mlir/Dialect/AMX/Transforms.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,16 +14,20 @@ namespace mlir {
1414
class LLVMConversionTarget;
1515
class LLVMTypeConverter;
1616
class RewritePatternSet;
17+
class DialectRegistry;
1718

1819
/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
1920
/// intrinsics.
20-
void populateAMXLegalizeForLLVMExportPatterns(
21-
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
21+
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
22+
RewritePatternSet &patterns);
2223

2324
/// Configure the target to support lowering AMX ops to ops that map to LLVM
2425
/// intrinsics.
2526
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
2627

28+
/// Register LLVM conversion interface for AMX dialect.
29+
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
30+
2731
} // namespace mlir
2832

2933
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
2525
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
2626
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
27+
#include "mlir/Dialect/AMX/Transforms.h"
2728
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
2829
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
2930
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
7071
registerConvertNVVMToLLVMInterface(registry);
7172
registerConvertOpenMPToLLVMInterface(registry);
7273
ub::registerConvertUBToLLVMInterface(registry);
74+
registerConvertAMXToLLVMInterface(registry);
7375

7476
// Register all transform dialect extensions.
7577
affine::registerTransformDialectExtension(registry);

mlir/lib/Conversion/LLVMCommon/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
1212
Core
1313

1414
LINK_LIBS PUBLIC
15-
MLIRAMXDialect
1615
MLIRIR
1716
MLIRLLVMDialect
1817
MLIRSupport

mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
6767
return std::nullopt;
6868
return llvmType;
6969
});
70-
addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
7170

7271
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
7372
// before the conversions below since conversions are attempted in reverse
@@ -595,12 +594,6 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
595594
return vectorType;
596595
}
597596

598-
/// Convert an AMX tile type to LLVM x86_amx type.
599-
/// Shape and element type of the tile are ignored.
600-
Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
601-
return LLVM::LLVMX86AMXType::get(&getContext());
602-
}
603-
604597
/// Convert a type in the context of the default or bare pointer calling
605598
/// convention. Calling convention sensitive types, such as MemRefType and
606599
/// UnrankedMemRefType, are converted following the specific rules for the

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "mlir/Dialect/AMX/Transforms.h"
1010

11+
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1112
#include "mlir/Conversion/LLVMCommon/ConversionTarget.h"
1213
#include "mlir/Conversion/LLVMCommon/Pattern.h"
1314
#include "mlir/Dialect/AMX/AMXDialect.h"
@@ -208,9 +209,12 @@ struct TileMulIConversion : public ConvertOpToLLVMPattern<TileMulIOp> {
208209
} // namespace
209210

210211
void mlir::populateAMXLegalizeForLLVMExportPatterns(
211-
const LLVMTypeConverter &converter, RewritePatternSet &patterns) {
212+
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
212213
patterns.add<TileZeroConversion, TileLoadConversion, TileStoreConversion,
213214
TileMulFConversion, TileMulIConversion>(converter);
215+
converter.addConversion([&](amx::TileType type) {
216+
return LLVM::LLVMX86AMXType::get(&converter.getContext());
217+
});
214218
}
215219

216220
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
@@ -220,3 +224,22 @@ void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
220224
target.addIllegalOp<TileZeroOp, TileLoadOp, TileStoreOp, TileMulIOp,
221225
TileMulFOp>();
222226
}
227+
228+
namespace {
229+
/// Implement the interface to convert AMX to LLVM.
230+
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
231+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
232+
233+
void populateConvertToLLVMConversionPatterns(
234+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
235+
RewritePatternSet &patterns) const final {
236+
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
237+
}
238+
};
239+
} // namespace
240+
241+
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
242+
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
243+
dialect->addInterfaces<AMXToLLVMDialectInterface>();
244+
});
245+
}

0 commit comments

Comments
 (0)