Skip to content

Commit f9ca93b

Browse files
committed
[mlir][amx] Simplify intrinsic generation
Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions. The existing amx tests are updated and expanded. The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure. Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581
1 parent a21986b commit f9ca93b

File tree

18 files changed

+432
-446
lines changed

18 files changed

+432
-446
lines changed

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 71 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,11 @@
2525
//
2626
//===----------------------------------------------------------------------===//
2727

28-
#ifndef AMX
29-
#define AMX
28+
#ifndef AMX_OPS
29+
#define AMX_OPS
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32+
include "mlir/Dialect/AMX/AMXInterfaces.td"
3233
include "mlir/Interfaces/SideEffectInterfaces.td"
3334
include "mlir/IR/AttrTypeBase.td"
3435
include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
4748

4849
This `AMX` dialect provides a bridge between MLIR concepts such as
4950
vectors and memrefs and the lower level LLVM IR support of AMX.
50-
The dialect is split into user-facing AMX ops (AMX_Op) and
51-
backend-facing intrinsic ops (AMX_IntrOp).
5251

5352
Note that since configuration changes (implicit at dialect level) are
5453
costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
135134
class AMX_Op<string mnemonic, list<Trait> traits = []> :
136135
Op<AMX_Dialect, mnemonic, traits> {}
137136

138-
// The "internal" intrinsics are meant for compiler usage.
139-
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
140-
LLVM_IntrOpBase<AMX_Dialect, mnemonic,
141-
"x86_" # !subst(".", "_", mnemonic) # "_internal",
142-
[], [], traits, numResults>;
143-
144137
//===----------------------------------------------------------------------===//
145-
// AMX Op definitions (user facing).
138+
// AMX Op definitions
146139
//===----------------------------------------------------------------------===//
147140

148141
//
149142
// Tile reset.
150143
//
151144

152-
def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
145+
def TileZeroOp : AMX_Op<"tile_zero", [Pure,
146+
AMXIntrinsicOpInterface
147+
]> {
153148
let summary = "tile zero operation";
154149
let description = [{
155150
Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
167162
TileType getTileType() {
168163
return ::llvm::cast<TileType>(getRes().getType());
169164
}
165+
166+
std::string getIntrinsicName() {
167+
return "llvm.x86.tilezero.internal";
168+
}
169+
SmallVector<Value> getIntrinsicOperands(
170+
::mlir::ArrayRef<Value> operands,
171+
const ::mlir::LLVMTypeConverter &typeConverter,
172+
::mlir::RewriterBase &rewriter);
170173
}];
171174
let assemblyFormat = "attr-dict `:` qualified(type($res))";
172175
let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
176179
// Tile memory operations.
177180
//
178181

179-
def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
182+
def TileLoadOp : AMX_Op<"tile_load", [Pure,
183+
AMXIntrinsicOpInterface
184+
]> {
180185
let summary = "tile load operation";
181186
let description = [{
182187
Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
200205
TileType getTileType() {
201206
return ::llvm::cast<TileType>(getRes().getType());
202207
}
208+
209+
std::string getIntrinsicName() {
210+
return "llvm.x86.tileloadd64.internal";
211+
}
212+
SmallVector<Value> getIntrinsicOperands(
213+
::mlir::ArrayRef<Value> operands,
214+
const ::mlir::LLVMTypeConverter &typeConverter,
215+
::mlir::RewriterBase &rewriter);
203216
}];
204217
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
205218
"type($base) `into` qualified(type($res))";
206219
let hasVerifier = 1;
207220
}
208221

209-
def TileStoreOp : AMX_Op<"tile_store"> {
222+
def TileStoreOp : AMX_Op<"tile_store", [
223+
AMXIntrinsicOpInterface
224+
]> {
210225
let summary = "tile store operation";
211226
let description = [{
212227
Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
230245
TileType getTileType() {
231246
return ::llvm::cast<TileType>(getVal().getType());
232247
}
248+
249+
std::string getIntrinsicName() {
250+
return "llvm.x86.tilestored64.internal";
251+
}
252+
SmallVector<Value> getIntrinsicOperands(
253+
::mlir::ArrayRef<Value> operands,
254+
const ::mlir::LLVMTypeConverter &typeConverter,
255+
::mlir::RewriterBase &rewriter);
233256
}];
234257
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
235258
"type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
240263
// Tile arithmetic operations.
241264
//
242265

243-
def TileMulFOp : AMX_Op<"tile_mulf", [
244-
Pure, AllTypesMatch<["acc", "res"]>]> {
266+
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
267+
AMXIntrinsicOpInterface,
268+
AllTypesMatch<["acc", "res"]>
269+
]> {
245270
let summary = "tile multiplication operation (floating-point)";
246271
let description = [{
247272
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,15 +295,30 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
270295
TileType getTileType() {
271296
return ::llvm::cast<TileType>(getRes().getType());
272297
}
298+
299+
std::string getIntrinsicName() {
300+
std::string intr = "llvm.x86.tdp";
301+
auto elementType =
302+
getLhsTileType().getElementType();
303+
intr += elementType.isF16() ? "fp16" : "bf16";
304+
intr += "ps.internal";
305+
return intr;
306+
}
307+
SmallVector<Value> getIntrinsicOperands(
308+
::mlir::ArrayRef<Value> operands,
309+
const ::mlir::LLVMTypeConverter &typeConverter,
310+
::mlir::RewriterBase &rewriter);
273311
}];
274312
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
275313
"qualified(type($lhs)) `,` qualified(type($rhs))"
276314
" `,` qualified(type($acc)) ";
277315
let hasVerifier = 1;
278316
}
279317

280-
def TileMulIOp : AMX_Op<"tile_muli", [
281-
Pure, AllTypesMatch<["acc", "res"]>]> {
318+
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
319+
AMXIntrinsicOpInterface,
320+
AllTypesMatch<["acc", "res"]>
321+
]> {
282322
let summary = "tile multiplication operation (integer)";
283323
let description = [{
284324
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
313353
TileType getTileType() {
314354
return ::llvm::cast<TileType>(getRes().getType());
315355
}
356+
357+
std::string getIntrinsicName() {
358+
std::string intr = "llvm.x86.tdpb";
359+
intr += getIsZextLhs() ? "u" : "s";
360+
intr += getIsZextRhs() ? "u" : "s";
361+
intr += "d.internal";
362+
return intr;
363+
}
364+
SmallVector<Value> getIntrinsicOperands(
365+
::mlir::ArrayRef<Value> operands,
366+
const ::mlir::LLVMTypeConverter &typeConverter,
367+
::mlir::RewriterBase &rewriter);
316368
}];
317369
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
318370
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
319371
let hasVerifier = 1;
320372
}
321373

322-
//===----------------------------------------------------------------------===//
323-
// AMX IntrOp definitions (LLVM compiler facing).
324-
//===----------------------------------------------------------------------===//
325-
326-
//
327-
// Tile reset. Parameters define the tile size.
328-
//
329-
330-
def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
331-
Arguments<(ins AnyInteger, AnyInteger)>;
332-
333-
//
334-
// Tile memory operations. Parameters define the tile size,
335-
// base address, and stride between consecutive rows for the
336-
// memory operation.
337-
//
338-
339-
def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
340-
Arguments<(ins AnyInteger,
341-
AnyInteger, LLVM_AnyPointer, AnyInteger)>;
342-
343-
def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
344-
Arguments<(ins AnyInteger,
345-
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
346-
347-
//
348-
// Tile multiplication operations (series of dot products). Parameters
349-
// define the tile sizes and source and destination tiles for the
350-
// operation. Note that the prefix "tdp" stands for tile dot product.
351-
//
352-
353-
// Dot product of bf16 tiles into f32 tile.
354-
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
355-
Arguments<(ins AnyInteger,
356-
AnyInteger,
357-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
358-
359-
// Dot product of f16 tiles into f32 tile.
360-
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361-
Arguments<(ins AnyInteger,
362-
AnyInteger,
363-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364-
365-
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
366-
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
367-
Arguments<(ins AnyInteger,
368-
AnyInteger,
369-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
370-
371-
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
372-
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
373-
Arguments<(ins AnyInteger,
374-
AnyInteger,
375-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
376-
377-
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
378-
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
379-
Arguments<(ins AnyInteger,
380-
AnyInteger,
381-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
382-
383-
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
384-
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
385-
Arguments<(ins AnyInteger,
386-
AnyInteger,
387-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
388-
389-
#endif // AMX
374+
#endif // AMX_OPS

mlir/include/mlir/Dialect/AMX/AMXDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
#define MLIR_DIALECT_AMX_AMXDIALECT_H_
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Dialect.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

23+
/// Include the generated interface declarations.
24+
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
25+
2226
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
2327

2428
#define GET_TYPEDEF_CLASSES
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines interfaces for the AMX dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef AMX_INTERFACES
14+
#define AMX_INTERFACES
15+
16+
include "mlir/IR/Interfaces.td"
17+
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
18+
19+
//===----------------------------------------------------------------------===//
20+
// AMX Intrinsic Interface
21+
//===----------------------------------------------------------------------===//
22+
23+
def AMXIntrinsicOpInterface
24+
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
25+
let description = [{
26+
A wrapper interface for operations representing AMX LLVM intrinsics.
27+
}];
28+
let cppNamespace = "::mlir::amx";
29+
}
30+
31+
#endif // AMX_INTERFACES
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect(AMX amx)
22
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
33

4-
set(LLVM_TARGET_DEFINITIONS AMX.td)
5-
mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
6-
add_public_tablegen_target(MLIRAMXConversionsIncGen)
4+
add_mlir_interface(AMXInterfaces)
5+
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)

mlir/include/mlir/Dialect/AMX/Transforms.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
2525
/// intrinsics.
2626
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
2727

28-
/// Register LLVM conversion interface for AMX dialect.
29-
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
30-
3128
} // namespace mlir
3229

3330
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
3333
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
3434
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
35-
#include "mlir/Dialect/AMX/Transforms.h"
3635
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
3736
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
3837
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
8483
registerConvertOpenMPToLLVMInterface(registry);
8584
registerConvertSCFToEmitCInterface(registry);
8685
ub::registerConvertUBToLLVMInterface(registry);
87-
registerConvertAMXToLLVMInterface(registry);
8886
gpu::registerConvertGpuToLLVMInterface(registry);
8987
NVVM::registerConvertGpuToNVVMInterface(registry);
9088
vector::registerConvertVectorToLLVMInterface(registry);

mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h

Lines changed: 0 additions & 31 deletions
This file was deleted.

mlir/include/mlir/Target/LLVMIR/Dialect/All.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#ifndef MLIR_TARGET_LLVMIR_DIALECT_ALL_H
1515
#define MLIR_TARGET_LLVMIR_DIALECT_ALL_H
1616

17-
#include "mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h"
1817
#include "mlir/Target/LLVMIR/Dialect/ArmNeon/ArmNeonToLLVMIRTranslation.h"
1918
#include "mlir/Target/LLVMIR/Dialect/ArmSME/ArmSMEToLLVMIRTranslation.h"
2019
#include "mlir/Target/LLVMIR/Dialect/ArmSVE/ArmSVEToLLVMIRTranslation.h"
@@ -37,7 +36,6 @@ class DialectRegistry;
3736
/// corresponding translation interfaces.
3837
static inline void registerAllToLLVMIRTranslations(DialectRegistry &registry) {
3938
registerArmNeonDialectTranslation(registry);
40-
registerAMXDialectTranslation(registry);
4139
registerArmSMEDialectTranslation(registry);
4240
registerArmSVEDialectTranslation(registry);
4341
registerBuiltinDialectTranslation(registry);

0 commit comments

Comments
 (0)