Skip to content

Commit d698ede

Browse files
authored
[mlir][amx] Restore conversion interface for AMX (#143871)
Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow. Fix after #140559
1 parent 013034c commit d698ede

File tree

4 files changed

+44
-0
lines changed

4 files changed

+44
-0
lines changed

mlir/include/mlir/Dialect/AMX/Transforms.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
2525
/// intrinsics.
2626
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
2727

28+
/// Register LLVM conversion interface for AMX dialect.
29+
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
30+
2831
} // namespace mlir
2932

3033
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
3333
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
3434
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
35+
#include "mlir/Dialect/AMX/Transforms.h"
3536
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
3637
#include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
3738
#include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
8586
registerConvertOpenMPToLLVMInterface(registry);
8687
registerConvertSCFToEmitCInterface(registry);
8788
ub::registerConvertUBToLLVMInterface(registry);
89+
registerConvertAMXToLLVMInterface(registry);
8890
gpu::registerConvertGpuToLLVMInterface(registry);
8991
NVVM::registerConvertGpuToNVVMInterface(registry);
9092
vector::registerConvertVectorToLLVMInterface(registry);

mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
6060
void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
6161
target.addIllegalDialect<AMXDialect>();
6262
}
63+
64+
namespace {
65+
/// Implement the interface to convert AMX to LLVM.
66+
struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
67+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
68+
69+
void populateConvertToLLVMConversionPatterns(
70+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
71+
RewritePatternSet &patterns) const final {
72+
populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
73+
}
74+
};
75+
} // namespace
76+
77+
void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
78+
registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
79+
dialect->addInterfaces<AMXToLLVMDialectInterface>();
80+
});
81+
}

mlir/test/Target/LLVMIR/amx.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
8888
amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
8989
return
9090
}
91+
92+
// CHECK-LABEL: define void @amx_tile_type_through_cf
93+
func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
94+
%idx: index, %cond: i1) {
95+
cf.cond_br %cond, ^bb1, ^bb2
96+
^bb1: // pred: ^bb0
97+
// CHECK: call x86_amx @llvm.x86.tileloadd64.internal
98+
%0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
99+
cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
100+
^bb2: // pred: ^bb0
101+
// CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
102+
%1 = amx.tile_zero : !amx.tile<16x64xi8>
103+
cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
104+
^bb3(%2: !amx.tile<16x64xi8>): // 2 preds: ^bb1, ^bb2
105+
cf.br ^bb4
106+
^bb4: // pred: ^bb3
107+
// CHECK: call void @llvm.x86.tilestored64.internal
108+
amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
109+
return
110+
}

0 commit comments

Comments
 (0)