Skip to content

[mlir][amx] Simplify intrinsic generation #140559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions mlir/include/mlir/Conversion/LLVMCommon/Pattern.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,15 @@ SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
/// function is used to combine multiple values into a single value.
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
Type dstType);

/// Performs the index computation to get to the element at `indices` of the
/// memory pointed to by `memRefDesc`, using the layout map of `type`.
/// The indices are linearized as:
/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
Value getStridedElementPtr(
OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
MemRefType type, Value memRefDesc, ValueRange indices,
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
} // namespace LLVM

/// Base class for operation conversions targeting the LLVM IR dialect. It
Expand Down Expand Up @@ -107,8 +116,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
Type resultType, int64_t value);

// This is a strided getElementPtr variant that linearizes subscripts as:
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
/// Convenience wrapper for the corresponding helper utility.
/// This is a strided getElementPtr variant with linearized subscripts.
Value getStridedElementPtr(
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
Value memRefDesc, ValueRange indices,
Expand Down
151 changes: 68 additions & 83 deletions mlir/include/mlir/Dialect/AMX/AMX.td
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#define AMX

include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Dialect/AMX/AMXInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"
Expand All @@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {

This `AMX` dialect provides a bridge between MLIR concepts such as
vectors and memrefs and the lower level LLVM IR support of AMX.
The dialect is split into user-facing AMX ops (AMX_Op) and
backend-facing intrinsic ops (AMX_IntrOp).

Note that since configuration changes (implicit at dialect level) are
costly, it is highly recommended to use the AMX dialect on same-shaped
Expand Down Expand Up @@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
class AMX_Op<string mnemonic, list<Trait> traits = []> :
Op<AMX_Dialect, mnemonic, traits> {}

// The "internal" intrinsics are meant for compiler usage.
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
LLVM_IntrOpBase<AMX_Dialect, mnemonic,
"x86_" # !subst(".", "_", mnemonic) # "_internal",
[], [], traits, numResults>;

//===----------------------------------------------------------------------===//
// AMX Op definitions (user facing).
// AMX Op definitions
//===----------------------------------------------------------------------===//

//
// Tile reset.
//

def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
def TileZeroOp : AMX_Op<"tile_zero", [Pure,
AMXIntrinsicOpInterface
]> {
let summary = "tile zero operation";
let description = [{
Zeroes the destination tile, with the shape defined by the 2-dim
Expand All @@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}

std::string getIntrinsicName() {
return "llvm.x86.tilezero.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
Expand All @@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
// Tile memory operations.
//

def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
def TileLoadOp : AMX_Op<"tile_load", [Pure,
AMXIntrinsicOpInterface
]> {
let summary = "tile load operation";
let description = [{
Loads a tile from memory defined by a base and indices, with the
Expand All @@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}

std::string getIntrinsicName() {
return "llvm.x86.tileloadd64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}

def TileStoreOp : AMX_Op<"tile_store"> {
def TileStoreOp : AMX_Op<"tile_store", [
AMXIntrinsicOpInterface
]> {
let summary = "tile store operation";
let description = [{
Stores a tile to memory defined by a base and indices, with the
Expand All @@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
TileType getTileType() {
return ::llvm::cast<TileType>(getVal().getType());
}

std::string getIntrinsicName() {
return "llvm.x86.tilestored64.internal";
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` qualified(type($val))";
Expand All @@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
// Tile arithmetic operations.
//

def TileMulFOp : AMX_Op<"tile_mulf", [
Pure, AllTypesMatch<["acc", "res"]>]> {
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
AMXIntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (floating-point)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
Expand Down Expand Up @@ -270,15 +295,30 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}

std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdp";
auto elementType =
getLhsTileType().getElementType();
intr += elementType.isF16() ? "fp16" : "bf16";
intr += "ps.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs))"
" `,` qualified(type($acc)) ";
let hasVerifier = 1;
}

def TileMulIOp : AMX_Op<"tile_muli", [
Pure, AllTypesMatch<["acc", "res"]>]> {
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
AMXIntrinsicOpInterface,
AllTypesMatch<["acc", "res"]>
]> {
let summary = "tile multiplication operation (integer)";
let description = [{
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
Expand Down Expand Up @@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}

std::string getIntrinsicName() {
std::string intr = "llvm.x86.tdpb";
intr += getIsZextLhs() ? "u" : "s";
intr += getIsZextRhs() ? "u" : "s";
intr += "d.internal";
return intr;
}
SmallVector<Value> getIntrinsicOperands(
::mlir::ArrayRef<Value> operands,
const ::mlir::LLVMTypeConverter &typeConverter,
::mlir::RewriterBase &rewriter);
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// AMX IntrOp definitions (LLVM compiler facing).
//===----------------------------------------------------------------------===//

//
// Tile reset. Parameters define the tile size.
//

def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
Arguments<(ins AnyInteger, AnyInteger)>;

//
// Tile memory operations. Parameters define the tile size,
// base address, and stride between consecutive rows for the
// memory operation.
//

def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger)>;

def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;

//
// Tile multiplication operations (series of dot products). Parameters
// define the tile sizes and source and destination tiles for the
// operation. Note that the prefix "tdp" stands for tile dot product.
//

// Dot product of bf16 tiles into f32 tile.
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of f16 tiles into f32 tile.
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of i8 tiles into i32 tile (with sign/zero extension).
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of i8 tiles into i32 tile (with zero/sign extension).
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of i8 tiles into i32 tile (with zero/zero extension).
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

#endif // AMX
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/AMX/AMXDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,15 @@
#define MLIR_DIALECT_AMX_AMXDIALECT_H_

#include "mlir/Bytecode/BytecodeOpInterface.h"
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"

/// Include the generated interface declarations.
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"

#include "mlir/Dialect/AMX/AMXDialect.h.inc"

#define GET_TYPEDEF_CLASSES
Expand Down
31 changes: 31 additions & 0 deletions mlir/include/mlir/Dialect/AMX/AMXInterfaces.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines interfaces for the AMX dialect.
//
//===----------------------------------------------------------------------===//

#ifndef AMX_INTERFACES
#define AMX_INTERFACES

include "mlir/IR/Interfaces.td"
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"

//===----------------------------------------------------------------------===//
// AMX Intrinsic Interface
//===----------------------------------------------------------------------===//

def AMXIntrinsicOpInterface
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
let description = [{
A wrapper interface for operations representing AMX LLVM intrinsics.
}];
let cppNamespace = "::mlir::amx";
}

#endif // AMX_INTERFACES
5 changes: 2 additions & 3 deletions mlir/include/mlir/Dialect/AMX/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
add_mlir_dialect(AMX amx)
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)

set(LLVM_TARGET_DEFINITIONS AMX.td)
mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRAMXConversionsIncGen)
add_mlir_interface(AMXInterfaces)
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)
3 changes: 0 additions & 3 deletions mlir/include/mlir/Dialect/AMX/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);

/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);

} // namespace mlir

#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
2 changes: 0 additions & 2 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
Expand Down Expand Up @@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertOpenMPToLLVMInterface(registry);
registerConvertSCFToEmitCInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);
gpu::registerConvertGpuToLLVMInterface(registry);
NVVM::registerConvertGpuToNVVMInterface(registry);
vector::registerConvertVectorToLLVMInterface(registry);
Expand Down

This file was deleted.

Loading