Skip to content

[mlir][amx] Restore conversion interface for AMX #143871

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged

Conversation

adam-smnk
Copy link
Contributor

Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.

Fix after #140559

Restores mistakenly removed AMX interface which ensures that
the custom tile type is converted to its LLVM equivalent within other
operations such as control flow.

Fix after llvm#140559
@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlir-llvm

@llvm/pr-subscribers-mlir

Author: Adam Siemieniuk (adam-smnk)

Changes

Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.

Fix after #140559


Full diff: https://github.com/llvm/llvm-project/pull/143871.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+3)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+19)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+20)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 4a751d99ceeee..7391ec2ff6b14 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7dcbabe8aafa3..f356b91b1b6c0 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7471dc797e0fc..37aebc9fab3eb 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addIllegalDialect<AMXDialect>();
 }
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+    dialect->addInterfaces<AMXToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 094475040436d..abdf2fe3bd534 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
   amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
   return
 }
+
+// CHECK-LABEL: define void @amx_tile_type_through_cf
+func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
+    %idx: index, %cond: i1) {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+^bb2:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
+^bb3(%2: !amx.tile<16x64xi8>):  // 2 preds: ^bb1, ^bb2
+  cf.br ^bb4
+^bb4:  // pred: ^bb3
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+  return
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 12, 2025

@llvm/pr-subscribers-mlir-amx

Author: Adam Siemieniuk (adam-smnk)

Changes

Restores mistakenly removed AMX interface which ensures that the custom tile type is converted to its LLVM equivalent within other operations such as control flow.

Fix after #140559


Full diff: https://github.com/llvm/llvm-project/pull/143871.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/AMX/Transforms.h (+3)
  • (modified) mlir/include/mlir/InitAllExtensions.h (+2)
  • (modified) mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp (+19)
  • (modified) mlir/test/Target/LLVMIR/amx.mlir (+20)
diff --git a/mlir/include/mlir/Dialect/AMX/Transforms.h b/mlir/include/mlir/Dialect/AMX/Transforms.h
index 4a751d99ceeee..7391ec2ff6b14 100644
--- a/mlir/include/mlir/Dialect/AMX/Transforms.h
+++ b/mlir/include/mlir/Dialect/AMX/Transforms.h
@@ -25,6 +25,9 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
 /// intrinsics.
 void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
 
+/// Register LLVM conversion interface for AMX dialect.
+void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
+
 } // namespace mlir
 
 #endif // MLIR_DIALECT_AMX_TRANSFORMS_H
diff --git a/mlir/include/mlir/InitAllExtensions.h b/mlir/include/mlir/InitAllExtensions.h
index 7dcbabe8aafa3..f356b91b1b6c0 100644
--- a/mlir/include/mlir/InitAllExtensions.h
+++ b/mlir/include/mlir/InitAllExtensions.h
@@ -32,6 +32,7 @@
 #include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
 #include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
 #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
+#include "mlir/Dialect/AMX/Transforms.h"
 #include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
 #include "mlir/Dialect/ArmNeon/TransformOps/ArmNeonVectorTransformOps.h"
 #include "mlir/Dialect/ArmSVE/TransformOps/ArmSVEVectorTransformOps.h"
@@ -85,6 +86,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
   registerConvertOpenMPToLLVMInterface(registry);
   registerConvertSCFToEmitCInterface(registry);
   ub::registerConvertUBToLLVMInterface(registry);
+  registerConvertAMXToLLVMInterface(registry);
   gpu::registerConvertGpuToLLVMInterface(registry);
   NVVM::registerConvertGpuToNVVMInterface(registry);
   vector::registerConvertVectorToLLVMInterface(registry);
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index 7471dc797e0fc..37aebc9fab3eb 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -60,3 +60,22 @@ void mlir::populateAMXLegalizeForLLVMExportPatterns(
 void mlir::configureAMXLegalizeForExportTarget(LLVMConversionTarget &target) {
   target.addIllegalDialect<AMXDialect>();
 }
+
+namespace {
+/// Implement the interface to convert AMX to LLVM.
+struct AMXToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
+  using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
+
+  void populateConvertToLLVMConversionPatterns(
+      ConversionTarget &target, LLVMTypeConverter &typeConverter,
+      RewritePatternSet &patterns) const final {
+    populateAMXLegalizeForLLVMExportPatterns(typeConverter, patterns);
+  }
+};
+} // namespace
+
+void mlir::registerConvertAMXToLLVMInterface(DialectRegistry &registry) {
+  registry.addExtension(+[](MLIRContext *ctx, amx::AMXDialect *dialect) {
+    dialect->addInterfaces<AMXToLLVMDialectInterface>();
+  });
+}
diff --git a/mlir/test/Target/LLVMIR/amx.mlir b/mlir/test/Target/LLVMIR/amx.mlir
index 094475040436d..abdf2fe3bd534 100644
--- a/mlir/test/Target/LLVMIR/amx.mlir
+++ b/mlir/test/Target/LLVMIR/amx.mlir
@@ -88,3 +88,23 @@ func.func @amx_tile_muli(%matA: memref<?x?xi8>, %matB: memref<?x?xi8>,
   amx.tile_store %out[%c16, %c16], %res3 : memref<?x?xi8>, !amx.tile<16x16xi32>
   return
 }
+
+// CHECK-LABEL: define void @amx_tile_type_through_cf
+func.func @amx_tile_type_through_cf(%src: memref<?x?xi8>, %out: memref<?x?xi8>,
+    %idx: index, %cond: i1) {
+  cf.cond_br %cond, ^bb1, ^bb2
+^bb1:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tileloadd64.internal
+  %0 = amx.tile_load %src[%idx, %idx] : memref<?x?xi8> into !amx.tile<16x64xi8>
+  cf.br ^bb3(%0 : !amx.tile<16x64xi8>)
+^bb2:  // pred: ^bb0
+  // CHECK: call x86_amx @llvm.x86.tilezero.internal(i16 16, i16 64)
+  %1 = amx.tile_zero : !amx.tile<16x64xi8>
+  cf.br ^bb3(%1 : !amx.tile<16x64xi8>)
+^bb3(%2: !amx.tile<16x64xi8>):  // 2 preds: ^bb1, ^bb2
+  cf.br ^bb4
+^bb4:  // pred: ^bb3
+  // CHECK: call void @llvm.x86.tilestored64.internal
+  amx.tile_store %out[%idx, %idx], %2 : memref<?x?xi8>, !amx.tile<16x64xi8>
+  return
+}

@adam-smnk adam-smnk merged commit d698ede into llvm:main Jun 12, 2025
11 checks passed
rolfmorel added a commit to libxsmm/tpp-mlir that referenced this pull request Jun 12, 2025
* llvm/llvm-project#139340
```
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h
sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp
```

* llvm/llvm-project#141466 &
llvm/llvm-project#141019
  * Add `BufferizationState &state` to `bufferize` and `getBuffer` 

* llvm/llvm-project#143159 &
llvm/llvm-project#142683 &
llvm/llvm-project#143779
  * Updates to `transform.apply_registered_pass` and its Python-bindings

* llvm/llvm-project#143217
* `tilingResult->mergeResult.replacements` ->
`tilingResult->replacements`

* llvm/llvm-project#140559 &
llvm/llvm-project#143871
* Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s
& fix which enables conversion again.
tomtor pushed a commit to tomtor/llvm-project that referenced this pull request Jun 14, 2025
Restores mistakenly removed AMX interface which ensures that the custom
tile type is converted to its LLVM equivalent within other operations
such as control flow.

Fix after llvm#140559
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants