Skip to content

[MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. #111197

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Nov 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 117 additions & 41 deletions mlir/include/mlir/Dialect/AMX/AMX.td
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@

include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/AttrTypeBase.td"
include "mlir/IR/BuiltinTypes.td"

//===----------------------------------------------------------------------===//
// AMX dialect definition.
Expand All @@ -55,8 +57,77 @@ def AMX_Dialect : Dialect {
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
let useDefaultTypePrinterParser = 1;
}

//===----------------------------------------------------------------------===//
// AMX Tile definition.
//===----------------------------------------------------------------------===//

class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
: TypeDef<AMX_Dialect, typeName, traits> {
let mnemonic = typeMnemonic;
}

def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
let cppFunctionName = "isValidTileTypeElementType";
}

def AMX_TileType : AMX_Type<"Tile", "tile", [ShapedTypeInterface, ValueSemantics]> {
let summary = "AMX 2D tile to be used by AMX opertaions.";

let description = [{
This type is used to represent values in AMX tile registers. All AMX operations
work on AMX tiles and these tiles cannot be used in other operations directly.
LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
element type for IR verification and lowering to LLVMIR dialect.
}];

let parameters = (ins
ArrayRefParameter<"int64_t">:$shape,
AMX_TileTypeElementType:$elementType
);

let builders = [
TypeBuilderWithInferredContext<(ins
"ArrayRef<int64_t>":$shape, "Type":$elementType), [{
return $_get(elementType.getContext(), shape, elementType);
}]>
];

let extraClassDeclaration = [{
/// Returns if this type is ranked (always true).
bool hasRank() const { return true; }

/// Clone this tile type with the given shape and element type. If the
/// provided shape is `std::nullopt`, the current shape of the type is used.
TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
Type elementType) const {
return get(shape.value_or(getShape()), elementType);
}
}];

let hasCustomAssemblyFormat = 1;
let skipDefaultBuilders = 1;
}

def IsAMXTilePred : And<[CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">,
CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;

class AMXTileOf<list<Type> allowedTypes> :
ShapedContainerType<allowedTypes, IsAMXTilePred, "tile",
"::mlir::amx::TileType">;

def AnyAMXTile : AMXTileOf<[F32, F16, BF16, I32, I8]>;

def AMXTileF32 : AMXTileOf<[F32]>;

def AMXTileF16OrBF16 : AMXTileOf<[F16, BF16]>;

def AMXTileI32 : AMXTileOf<[I32]>;

def AMXTileI8 : AMXTileOf<[I8]>;

//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -88,17 +159,16 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:

```mlir
%0 = amx.tile_zero : vector<16x16xbf16>
%0 = amx.tile_zero : !amx.tile<16x16xbf16>
```
}];
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
let assemblyFormat = "attr-dict `:` qualified(type($res))";
let hasVerifier = 1;
}

Expand All @@ -117,23 +187,22 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:

```mlir
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let results = (outs AnyAMXTile:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` type($res)";
"type($base) `into` qualified(type($res))";
let hasVerifier = 1;
}

Expand All @@ -148,22 +217,22 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:

```mlir
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, !amx.tile<16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
AnyAMXTile:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getVal().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` type($val)";
"type($base) `,` qualified(type($val))";
let hasVerifier = 1;
}

Expand All @@ -184,26 +253,27 @@ def TileMulFOp : AMX_Op<"tile_mulf", [

```mlir
%0 = amx.tile_mulf %a, %b, %c
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
: !amx.tile<16x32xbf16>, !amx.tile<16x32xbf16>, !amx.tile<16x16xf32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
let arguments = (ins AMXTileF16OrBF16:$lhs,
AMXTileF16OrBF16:$rhs,
AMXTileF32:$acc);
let results = (outs AMXTileF32:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
"qualified(type($lhs)) `,` qualified(type($rhs))"
" `,` qualified(type($acc)) ";
let hasVerifier = 1;
}

Expand All @@ -223,29 +293,29 @@ def TileMulIOp : AMX_Op<"tile_muli", [

```mlir
%0 = amx.tile_muli %a zext, %b zext, %c
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
: !amx.tile<16x64xi8>, !amx.tile<16x64xi8>, !amx.tile<16x16xi32>
```
}];
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
let arguments = (ins AMXTileI8:$lhs,
AMXTileI8:$rhs,
AMXTileI32:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
let results = (outs AMXTileI32:$res);
let extraClassDeclaration = [{
VectorType getLhsVectorType() {
return ::llvm::cast<VectorType>(getLhs().getType());
TileType getLhsTileType() {
return ::llvm::cast<TileType>(getLhs().getType());
}
VectorType getRhsVectorType() {
return ::llvm::cast<VectorType>(getRhs().getType());
TileType getRhsTileType() {
return ::llvm::cast<TileType>(getRhs().getType());
}
VectorType getVectorType() {
return ::llvm::cast<VectorType>(getRes().getType());
TileType getTileType() {
return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
let hasVerifier = 1;
}

Expand Down Expand Up @@ -286,6 +356,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of f16 tiles into f32 tile.
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;

// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
Expand Down
3 changes: 3 additions & 0 deletions mlir/include/mlir/Dialect/AMX/AMXDialect.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@

#include "mlir/Dialect/AMX/AMXDialect.h.inc"

#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/AMX/AMXTypes.h.inc"

#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"

Expand Down
8 changes: 6 additions & 2 deletions mlir/include/mlir/Dialect/AMX/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,16 +14,20 @@ namespace mlir {
class LLVMConversionTarget;
class LLVMTypeConverter;
class RewritePatternSet;
class DialectRegistry;

/// Collect a set of patterns to lower AMX ops to ops that map to LLVM
/// intrinsics.
void populateAMXLegalizeForLLVMExportPatterns(
const LLVMTypeConverter &converter, RewritePatternSet &patterns);
void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns);

/// Configure the target to support lowering AMX ops to ops that map to LLVM
/// intrinsics.
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);

/// Register LLVM conversion interface for AMX dialect.
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);

} // namespace mlir

#endif // MLIR_DIALECT_AMX_TRANSFORMS_H
12 changes: 12 additions & 0 deletions mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
}];
}

//===----------------------------------------------------------------------===//
// LLVMX86AMXType
//===----------------------------------------------------------------------===//

def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
let summary = "LLVM x86_amx type.";
let description = [{
The x86_amx type represents a value held in an AMX tile register on an x86
machine. Can only be used in AMX intrinsics calls.
}];
}

#endif // LLVMTYPES_TD
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include "mlir/Conversion/NVVMToLLVM/NVVMToLLVM.h"
#include "mlir/Conversion/OpenMPToLLVM/ConvertOpenMPToLLVM.h"
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
#include "mlir/Dialect/AMX/Transforms.h"
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
Expand Down Expand Up @@ -70,6 +71,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
registerConvertNVVMToLLVMInterface(registry);
registerConvertOpenMPToLLVMInterface(registry);
ub::registerConvertUBToLLVMInterface(registry);
registerConvertAMXToLLVMInterface(registry);

// Register all transform dialect extensions.
affine::registerTransformDialectExtension(registry);
Expand Down
Loading
Loading