-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[MLIR] [AMX] Utilize x86_amx type for AMX dialect in MLIR. #111197
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Thank you for submitting a Pull Request (PR) to the LLVM Project! This PR will be automatically labeled and the relevant teams will be notified. If you wish to, you can add reviewers by using the "Reviewers" section on this page. If this is not working for you, it is probably because you do not have write permissions for the repository. In which case you can instead tag reviewers by name in a comment by using If you have received no comments on your PR for a week, you can request a review by "ping"ing the PR by adding a comment “Ping”. The common courtesy "ping" rate is once a week. Please remember that you are asking for valuable time from other developers. If you have further questions, they may be answered by the LLVM GitHub User Guide. You can also ask questions in a comment on this PR, on the LLVM Discord or on the forums. |
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-amx Author: Ilya Enkovich (ienkovich) ChangesThis patch is intended to resolve #109481 and improve the usability of the AMX dialect. In LLVM IR, AMX intrinsics use In addition to translation problems, this inconsistency between types used in MLIR and LLVM IR makes MLIR verification and transformation quite problematic. Both To remove this inconsistency and make AMX operations more explicit in their limitations, I propose to add P.S. This patch also adds missing FP16 support. It's trivial but unrelated to type system changes, so let me know if I should submit it separately. Patch is 38.15 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/111197.diff 15 Files Affected:
diff --git a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
index d79b90f840ce83..bd4b3e73f07410 100644
--- a/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
+++ b/mlir/include/mlir/Conversion/LLVMCommon/TypeConverter.h
@@ -15,6 +15,7 @@
#define MLIR_CONVERSION_LLVMCOMMON_TYPECONVERTER_H
#include "mlir/Conversion/LLVMCommon/LoweringOptions.h"
+#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/DialectConversion.h"
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter {
/// Convert a 1D vector type into an LLVM vector type.
FailureOr<Type> convertVectorType(VectorType type) const;
+ /// Convert AMX tile type x86_amx type.
+ Type convertAMXTileType(amx::TileType type) const;
+
/// Options for customizing the llvm lowering.
LowerToLLVMOptions options;
diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td
index fcc8d169eab5ac..8ef5ac25fbbddf 100644
--- a/mlir/include/mlir/Dialect/AMX/AMX.td
+++ b/mlir/include/mlir/Dialect/AMX/AMX.td
@@ -30,6 +30,8 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
+include "mlir/IR/AttrTypeBase.td"
+include "mlir/IR/BuiltinTypes.td"
//===----------------------------------------------------------------------===//
// AMX dialect definition.
@@ -55,8 +57,69 @@ def AMX_Dialect : Dialect {
For details, see the Intel documentation:
https://software.intel.com/content/www/us/en/develop/articles/intel-sdm.html
}];
+ let useDefaultTypePrinterParser = 1;
}
+//===----------------------------------------------------------------------===//
+// AMX Tile definition.
+//===----------------------------------------------------------------------===//
+
+class AMX_Type<string typeName, string typeMnemonic, list<Trait> traits = []>
+ : TypeDef<AMX_Dialect, typeName, traits> {
+ let mnemonic = typeMnemonic;
+}
+
+def AMX_TileTypeElementType : AnyTypeOf<[F32, F16, BF16, I32, I8]> {
+ let cppFunctionName = "isValidTileTypeElementType";
+}
+
+def AMX_TileType : AMX_Type<"Tile", "amx.tile", [ShapedTypeInterface, ValueSemantics]> {
+ let summary = "AMX 2D tile to be used by AMX opertaions.";
+
+ let description = [{
+ This type is used to represent values in AMX tile registers. All AMX operations
+ work on AMX tiles and these tiles cannot be used in other operations directly.
+ LLVM IR type for AMX tile is a primitive type, but in MLIR we provide shape and
+ element type for IR verification and lowering to LLVMIR dialect.
+ }];
+
+ let parameters = (ins
+ ArrayRefParameter<"int64_t">:$shape,
+ AMX_TileTypeElementType:$elementType
+ );
+
+ let builders = [
+ TypeBuilderWithInferredContext<(ins
+ "ArrayRef<int64_t>":$shape, "Type":$elementType), [{
+ return $_get(elementType.getContext(), shape, elementType);
+ }]>
+ ];
+
+ let extraClassDeclaration = [{
+ /// Returns if this type is ranked (always true).
+ bool hasRank() const { return true; }
+
+ /// Clone this tile type with the given shape and element type. If the
+ /// provided shape is `std::nullopt`, the current shape of the type is used.
+ TileType cloneWith(std::optional<ArrayRef<int64_t>> shape,
+ Type elementType) const {
+ return get(shape.value_or(getShape()), elementType);
+ }
+ }];
+
+ let hasCustomAssemblyFormat = 1;
+ let skipDefaultBuilders = 1;
+}
+
+def IsAMXTilePred : CPred<"::llvm::isa<::mlir::amx::TileType>($_self)">;
+
+def IsAMX2DTilePred : And<[IsAMXTilePred,
+ CPred<[{::llvm::cast<::mlir::amx::TileType>($_self).getRank() == 2}]>]>;
+
+class AMX2DTileOf<list<Type> allowedTypes> :
+ ShapedContainerType<allowedTypes, IsAMX2DTilePred, "tile",
+ "::mlir::amx::TileType">;
+
//===----------------------------------------------------------------------===//
// AMX Op and IntrOp definitions.
//===----------------------------------------------------------------------===//
@@ -88,14 +151,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
Example:
```mlir
- %0 = amx.tile_zero : vector<16x16xbf16>
+ %0 = amx.tile_zero : <16x16xbf16>
```
}];
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
@@ -117,19 +180,19 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
Example:
```mlir
- %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
+ %0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
@@ -148,18 +211,18 @@ def TileStoreOp : AMX_Op<"tile_store"> {
Example:
```mlir
- amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
+ amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, <16x64xi8>
```
}];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
- VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
+ AMX2DTileOf<[F32, F16, BF16, I32, I8]>:$val);
let extraClassDeclaration = [{
MemRefType getMemRefType() {
return ::llvm::cast<MemRefType>(getBase().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getVal().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getVal().getType());
}
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
@@ -183,23 +246,22 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
Example:
```mlir
- %0 = amx.tile_mulf %a, %b, %c
- : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
+ %0 = amx.tile_mulf %a, %b, %c : <16x32xbf16>, <16x32xbf16>, <16x16xf32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
- VectorOfRankAndType<[2], [F32, BF16]>:$acc);
- let results = (outs VectorOfRankAndType<[2], [F32, BF16]>:$res);
+ let arguments = (ins AMX2DTileOf<[F16, BF16]>:$lhs,
+ AMX2DTileOf<[F16, BF16]>:$rhs,
+ AMX2DTileOf<[F32]>:$acc);
+ let results = (outs AMX2DTileOf<[F32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
@@ -222,26 +284,25 @@ def TileMulIOp : AMX_Op<"tile_muli", [
Example:
```mlir
- %0 = amx.tile_muli %a zext, %b zext, %c
- : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
+ %0 = amx.tile_muli %a zext, %b zext, %c : <16x64xi8>, <16x64xi8>, <16x16xi32>
```
}];
- let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
- VectorOfRankAndType<[2], [I32, I8]>:$rhs,
- VectorOfRankAndType<[2], [I32, I8]>:$acc,
+ let arguments = (ins AMX2DTileOf<[I8]>:$lhs,
+ AMX2DTileOf<[I8]>:$rhs,
+ AMX2DTileOf<[I32]>:$acc,
UnitAttr:$isZextLhs,
UnitAttr:$isZextRhs
);
- let results = (outs VectorOfRankAndType<[2], [I32, I8]>:$res);
+ let results = (outs AMX2DTileOf<[I32]>:$res);
let extraClassDeclaration = [{
- VectorType getLhsVectorType() {
- return ::llvm::cast<VectorType>(getLhs().getType());
+ TileType getLhsTileType() {
+ return ::llvm::cast<TileType>(getLhs().getType());
}
- VectorType getRhsVectorType() {
- return ::llvm::cast<VectorType>(getRhs().getType());
+ TileType getRhsTileType() {
+ return ::llvm::cast<TileType>(getRhs().getType());
}
- VectorType getVectorType() {
- return ::llvm::cast<VectorType>(getRes().getType());
+ TileType getTileType() {
+ return ::llvm::cast<TileType>(getRes().getType());
}
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
@@ -286,6 +347,12 @@ def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+// Dot product of f16 tiles into f32 tile.
+def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
+ Arguments<(ins AnyInteger,
+ AnyInteger,
+ AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
+
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins AnyInteger,
diff --git a/mlir/include/mlir/Dialect/AMX/AMXDialect.h b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
index 47c92479814dea..c0553ad8733fd4 100644
--- a/mlir/include/mlir/Dialect/AMX/AMXDialect.h
+++ b/mlir/include/mlir/Dialect/AMX/AMXDialect.h
@@ -21,6 +21,9 @@
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.h.inc"
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.h.inc"
diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
index 8f9c2f2f8a0b44..09dd0919c318fb 100644
--- a/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
+++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMTypes.td
@@ -250,4 +250,16 @@ def LLVMTargetExtType : LLVMType<"LLVMTargetExt", "target"> {
}];
}
+//===----------------------------------------------------------------------===//
+// LLVMX86AMXType
+//===----------------------------------------------------------------------===//
+
+def LLVMX86AMXType : LLVMType<"LLVMX86AMX", "x86_amx"> {
+ let summary = "LLVM x86_amx type.";
+ let description = [{
+ The x86_amx type represents a value held in an AMX tile register on an x86
+ machine. Can only be used in AMX intrinsics calls.
+ }];
+}
+
#endif // LLVMTYPES_TD
diff --git a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
index 568d9339aaabcb..39199e4affccfa 100644
--- a/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
+++ b/mlir/lib/Conversion/LLVMCommon/CMakeLists.txt
@@ -12,6 +12,7 @@ add_mlir_conversion_library(MLIRLLVMCommonConversion
Core
LINK_LIBS PUBLIC
+ MLIRAMXDialect
MLIRIR
MLIRLLVMDialect
MLIRSupport
diff --git a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
index fd6369b5bb4ee5..a585a4f6ab76f6 100644
--- a/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
+++ b/mlir/lib/Conversion/LLVMCommon/TypeConverter.cpp
@@ -67,6 +67,7 @@ LLVMTypeConverter::LLVMTypeConverter(MLIRContext *ctx,
return std::nullopt;
return llvmType;
});
+ addConversion([&](amx::TileType type) { return convertAMXTileType(type); });
// LLVM-compatible types are legal, so add a pass-through conversion. Do this
// before the conversions below since conversions are attempted in reverse
@@ -596,6 +597,12 @@ FailureOr<Type> LLVMTypeConverter::convertVectorType(VectorType type) const {
return vectorType;
}
+/// Convert an AMX tile type to LLVM x86_amx type.
+/// Shape and element type of the tile are ignored.
+Type LLVMTypeConverter::convertAMXTileType(amx::TileType type) const {
+ return LLVM::LLVMX86AMXType::get(&getContext());
+}
+
/// Convert a type in the context of the default or bare pointer calling
/// convention. Calling convention sensitive types, such as MemRefType and
/// UnrankedMemRefType, are converted following the specific rules for the
diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
index f0e434407c8a2d..829f48e223383e 100644
--- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
+++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp
@@ -13,14 +13,22 @@
#include "mlir/Dialect/AMX/AMXDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMTypes.h"
#include "mlir/IR/Builders.h"
+#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h"
+#include "llvm/ADT/TypeSwitch.h"
+
using namespace mlir;
#include "mlir/Dialect/AMX/AMXDialect.cpp.inc"
void amx::AMXDialect::initialize() {
+ addTypes<
+#define GET_TYPEDEF_LIST
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
+ >();
+
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/AMX/AMX.cpp.inc"
@@ -28,7 +36,7 @@ void amx::AMXDialect::initialize() {
}
/// Verify that AMX supports the implied tile shape.
-static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
+static LogicalResult verifyTileSize(Operation *op, amx::TileType tp) {
const unsigned kMaxRows = 16;
const unsigned kBitsPerRow = 64 * 8;
unsigned col = tp.getDimSize(1) * tp.getElementType().getIntOrFloatBitWidth();
@@ -40,8 +48,8 @@ static LogicalResult verifyTileSize(Operation *op, VectorType tp) {
}
/// Verify that AMX supports the multiplication.
-static LogicalResult verifyMultShape(Operation *op, VectorType atp,
- VectorType btp, VectorType ctp,
+static LogicalResult verifyMultShape(Operation *op, amx::TileType atp,
+ amx::TileType btp, amx::TileType ctp,
unsigned scale) {
unsigned am = atp.getDimSize(0), ak = atp.getDimSize(1) >> scale;
unsigned bk = btp.getDimSize(0), bn = btp.getDimSize(1) >> scale;
@@ -53,27 +61,27 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
}
LogicalResult amx::TileZeroOp::verify() {
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileLoadOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileStoreOp::verify() {
unsigned rank = getMemRefType().getRank();
if (getIndices().size() != rank)
return emitOpError("requires ") << rank << " indices";
- return verifyTileSize(*this, getVectorType());
+ return verifyTileSize(*this, getTileType());
}
LogicalResult amx::TileMulFOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -82,15 +90,15 @@ LogicalResult amx::TileMulFOp::verify() {
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
- if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
+ if ((!ta.isBF16() && !ta.isF16()) || (ta != tb) || !tc.isF32())
return emitOpError("unsupported type combination");
return success();
}
LogicalResult amx::TileMulIOp::verify() {
- VectorType aType = getLhsVectorType();
- VectorType bType = getRhsVectorType();
- VectorType cType = getVectorType();
+ amx::TileType aType = getLhsTileType();
+ amx::TileType bType = getRhsTileType();
+ amx::TileType cType = getTileType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
@@ -104,5 +112,34 @@ LogicalResult amx::TileMulIOp::verify() {
return success();
}
+Type amx::TileType::parse(AsmParser &parser) {
+ if (parser.parseLess())
+ return nullptr;
+
+ SmallVector<int64_t, 2> shape;
+ if (parser.parseDimensionList(shape, false, true))
+ return nullptr;
+
+ Type elementType;
+ if (parser.parseType(elementType))
+ return nullptr;
+
+ if (parser.parseGreater())
+ return nullptr;
+
+ return TileType::get(shape, elementType);
+}
+
+void amx::TileType::print(AsmPrinter &os) const {
+ os << "<";
+ os.printDimensionList(getShape());
+ os << 'x';
+ os.printType(getElementType());
+ os << '>';
+}
+
#define GET_OP_CLASSES
#include "mlir/Dialect/AMX/AMX.cpp.inc"
+
+#define GET_TYPEDEF_CLASSES
+#include "mlir/Dialect/AMX/AMXTypes.cpp.inc"
diff --git a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
index a8b10f63315d41..415a0998f684f9 100644
--- a/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
+++ b/mlir/lib/Dialect/AMX/Transforms/LegalizeForLLVMExport.cpp
@@ -25,13 +25,13 @@ namespace {
/// The second dimensions needs to be scaled by the number of bytes.
std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
const LLVMTypeConverter &typeConverter,
- VectorType vType, Location loc) {
+ amx::TileType tType, Location loc) {
Type llvmInt16Type = IntegerType::get(&typeConverter.getContext(), 16);
- unsigned width = vType.getElementType().getIntOrFloatBitWidth();
+ unsigned width = tType.getElementType().getIntOrFloatBitWidth();
assert(llvm::isPowerOf2_64(width) && width >= 8);
unsigned bytes = width >> 3;
- auto mattr = rewriter.getI16IntegerAttr(vType.getDimSize(0));
- auto nattr = rewriter.getI16IntegerAttr(vType.getDimSize(1) * bytes);
+ auto mattr = rewriter.getI16IntegerAttr(tType.getDimSize(0));
+ auto nattr = rewriter.getI16IntegerAttr(tType.getDimSize(1) * bytes);
return std::make_pair(
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, mattr),
rewriter.create<LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
@@ -78,12 +78,12 @@ struct TileZeroConversion : public ConvertOpToLLVMPattern<TileZeroOp> {
LogicalResult
matchAndRewrite(TileZeroOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Value> tsz =
- getTileSizes(rewriter, *getTypeConverter(), vType, op.getLoc());
+ getTileSizes(rewriter, *getTypeConverter(), tType, op.getLoc());
// Replace operation with intrinsic.
- Type resType = typeConverter->convertType(vType);
+ Type resType = typeConverter->convertType(tType);
rewriter.replaceOpWithNewOp<amx::x86_amx_tilezero>(op, resType, tsz.first,
tsz.second);
return success();
@@ -97,10 +97,10 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
matchAndRewrite(TileLoadOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MemRefType mType = op.getMemRefType();
- VectorType vType = op.getVectorType();
+ amx::TileType tType = op.getTileType();
// Determine m x n tile sizes.
std::pair<Value, Valu...
[truncated]
|
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter { | |||
/// Convert a 1D vector type into an LLVM vector type. | |||
FailureOr<Type> convertVectorType(VectorType type) const; | |||
|
|||
/// Convert AMX tile type x86_amx type. | |||
Type convertAMXTileType(amx::TileType type) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hm, I see the chicken and egg problem here. Technically, the LLVM header should not include an AMX header, but the conversion is on an AMX type.
Perhaps the AMX tile type needs to be "lowered" to a vector type before it gets converted to LLVM's X86_amx
type, but that's just odd.
Alternatively, since the X86_amx
type is actually in LLVM itself, the type could be declared there, then you don't need to cross-include a higher level file.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To be clear, you may need to declare the type twice: amx::TileType
that lowers directly to llvm::x86_amx
without need for conversion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't see how I can go directly from amx::TileType
to llvm::x86_type
. For proper conversion of MLIR module to LLVM module, all types should be LLVM compatible and this is defined by LLVMIR dialect. I cannot declare an external LLVM compatible type.
The only simple way to get rid of the new dependency here is to move amx::TileType
to bulitin types. It would better match what we have in LLVM, but doesn't feel good.
Probably, the best solution would be to introduce a type interface for converting custom types to LLVMIR dialect types. Then we can utilize this interface in LLVMTypeConverter
and keep amx::TileType
to LLVMX86AMXType
conversion code in AMX dialect removing dependency of LLVM converter on AMX dialect.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot declare an external LLVM compatible type.
I'm not sure what you mean by this. Should be possible to create a new type in the LLVM dialect that lowers directly to the LLVM IR type.
The only simple way to get rid of the new dependency here is to move
amx::TileType
to bulitin types.
That's not the right way around, indeed.
Probably, the best solution would be to introduce a type interface for converting custom types to LLVMIR dialect types.
Is this the only case of custom type conversion into LLVM? Are there no other examples in tree?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure what you mean by this. Should be possible to create a new type in the LLVM dialect that lowers directly to the LLVM IR type.
This patch already has such a type in the LLVM dialect. It's LLVM::LLVMX86AMXType
added to LLVMTypes.td
. Changes added to the LLVM type converter are required to convert amx::TileType
to this LLVMX86AMXType
. These changes allow existing passes (such as CF to LLVM lowering) to properly convert amx::TileType
values to LLVM::LLVMX86AMXType
values.
Is this the only case of custom type conversion into LLVM? Are there no other examples in tree?
I didn't see any similar cases. Users have to use their type converters to handle non-builtin types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks much better, thanks! One final comment on the conversion. We need to get that right without layer violations in headers.
@@ -258,6 +259,9 @@ class LLVMTypeConverter : public TypeConverter { | |||
/// Convert a 1D vector type into an LLVM vector type. | |||
FailureOr<Type> convertVectorType(VectorType type) const; | |||
|
|||
/// Convert AMX tile type x86_amx type. | |||
Type convertAMXTileType(amx::TileType type) const; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I cannot declare an external LLVM compatible type.
I'm not sure what you mean by this. Should be possible to create a new type in the LLVM dialect that lowers directly to the LLVM IR type.
The only simple way to get rid of the new dependency here is to move
amx::TileType
to bulitin types.
That's not the right way around, indeed.
Probably, the best solution would be to introduce a type interface for converting custom types to LLVMIR dialect types.
Is this the only case of custom type conversion into LLVM? Are there no other examples in tree?
Signed-off-by: Ilya Enkovich <[email protected]>
Signed-off-by: Ilya Enkovich <[email protected]>
Signed-off-by: Ilya Enkovich <[email protected]>
bbaeadd
to
37a98e6
Compare
@rengolin Since the dependency of Another option I was looking at was to extend |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is much cleaner, thank! I agree it's not as practical, but it keeps things separate.
One small comment and LGTM, thanks!
@ienkovich Congratulations on having your first Pull Request (PR) merged into the LLVM Project! Your changes will be combined with recent changes from other authors, then tested by our build bots. If there is a problem with a build, you may receive a report in an email or a comment on this PR. Please check whether problems have been caused by your change specifically, as the builds can include changes from many authors. It is not uncommon for your change to be included in a build that fails due to someone else's changes, or infrastructure issues. How to do this, and the rest of the post-merge process, is covered in detail here. If your change does cause a problem, it may be reverted, or you can revert it yourself. This is a normal part of LLVM development. You can fix your changes and open a new PR to merge them again. If you don't get any reports, no action is required from you. Your changes are working as expected, well done! |
This patch is intended to resolve #109481 and improve the usability of the AMX dialect.
In LLVM IR, AMX intrinsics use
x86_amx
which is one of the primitive types. This type is supposed to be used for AMX intrinsic calls and no other operations. AMX dialect of MLIR uses regular 2D vector types, which are then lowered to arrays of vectors in the LLVMIR dialect. This creates an inconsistency in the types used in the LLVMIR dialect and LLVMIR. Translation of AMX intrinsic calls to LLVM IR doesn't require result types to match and that is where tile loads and mul operation results getx86_amx
type. This works in very simple cases when mul and tile store operations directly consume the result of another AMX intrinsic call, but it doesn't work when an argument is a block argument (phi node).In addition to translation problems, this inconsistency between types used in MLIR and LLVM IR makes MLIR verification and transformation quite problematic. Both
amx.tileload
andvector::transfer_read
can load values of the same type, but only one of them can be used in AMX operations. In general, by looking at a type of value, we cannot determine if it can only be used for AMX operations or contrary can be used in other operations but AMX ones.To remove this inconsistency and make AMX operations more explicit in their limitations, I propose to add
LLVMX86AMXType
type to the LLVMIR dialect to matchx86_amx
type in LLVM IR, and introduceamx::TileType
to be used by AMX operations in MLIR. This resolves translation problems for AMX usage with phi nodes and provides proper type verification in MLIR for AMX operations.P.S. This patch also adds missing FP16 support. It's trivial but unrelated to type system changes, so let me know if I should submit it separately.