Skip to content

Commit 0fa3ba7

Browse files
authored
[mlir][amx] Simplify intrinsic generation (#140559)
Replaces separate amx named intrinsic operations with direct calls to LLVM intrinsic functions. The existing amx tests are updated and expanded. The separate conversion step translating amx intrinsics into LLVM IR is eliminated. Instead, this step is now performed by the existing llvm dialect infrastructure. Related RFC: https://discourse.llvm.org/t/rfc-simplify-x86-intrinsic-generation/85581/7
1 parent 0f1277d commit 0fa3ba7

File tree

20 files changed

+450
-488
lines changed

20 files changed

+450
-488
lines changed

mlir/include/mlir/Conversion/LLVMCommon/Pattern.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,15 @@ SmallVector<Value> decomposeValue(OpBuilder &builder, Location loc, Value src,
6969
/// function is used to combine multiple values into a single value.
7070
Value composeValue(OpBuilder &builder, Location loc, ValueRange src,
7171
Type dstType);
72+
73+
/// Performs the index computation to get to the element at `indices` of the
74+
/// memory pointed to by `memRefDesc`, using the layout map of `type`.
75+
/// The indices are linearized as:
76+
/// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
77+
Value getStridedElementPtr(
78+
OpBuilder &builder, Location loc, const LLVMTypeConverter &converter,
79+
MemRefType type, Value memRefDesc, ValueRange indices,
80+
LLVM::GEPNoWrapFlags noWrapFlags = LLVM::GEPNoWrapFlags::none);
7281
} // namespace LLVM
7382

7483
/// Base class for operation conversions targeting the LLVM IR dialect. It
@@ -107,8 +116,8 @@ class ConvertToLLVMPattern : public ConversionPattern {
107116
static Value createIndexAttrConstant(OpBuilder &builder, Location loc,
108117
Type resultType, int64_t value);
109118

110-
// This is a strided getElementPtr variant that linearizes subscripts as:
111-
// `base_offset + index_0 * stride_0 + ... + index_n * stride_n`.
119+
/// Convenience wrapper for the corresponding helper utility.
120+
/// This is a strided getElementPtr variant with linearized subscripts.
112121
Value getStridedElementPtr(
113122
ConversionPatternRewriter &rewriter, Location loc, MemRefType type,
114123
Value memRefDesc, ValueRange indices,

mlir/include/mlir/Dialect/AMX/AMX.td

Lines changed: 68 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#define AMX
3030

3131
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32+
include "mlir/Dialect/AMX/AMXInterfaces.td"
3233
include "mlir/Interfaces/SideEffectInterfaces.td"
3334
include "mlir/IR/AttrTypeBase.td"
3435
include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
4748

4849
This `AMX` dialect provides a bridge between MLIR concepts such as
4950
vectors and memrefs and the lower level LLVM IR support of AMX.
50-
The dialect is split into user-facing AMX ops (AMX_Op) and
51-
backend-facing intrinsic ops (AMX_IntrOp).
5251

5352
Note that since configuration changes (implicit at dialect level) are
5453
costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
135134
class AMX_Op<string mnemonic, list<Trait> traits = []> :
136135
Op<AMX_Dialect, mnemonic, traits> {}
137136

138-
// The "internal" intrinsics are meant for compiler usage.
139-
class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
140-
LLVM_IntrOpBase<AMX_Dialect, mnemonic,
141-
"x86_" # !subst(".", "_", mnemonic) # "_internal",
142-
[], [], traits, numResults>;
143-
144137
//===----------------------------------------------------------------------===//
145-
// AMX Op definitions (user facing).
138+
// AMX Op definitions
146139
//===----------------------------------------------------------------------===//
147140

148141
//
149142
// Tile reset.
150143
//
151144

152-
def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
145+
def TileZeroOp : AMX_Op<"tile_zero", [Pure,
146+
AMXIntrinsicOpInterface
147+
]> {
153148
let summary = "tile zero operation";
154149
let description = [{
155150
Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
167162
TileType getTileType() {
168163
return ::llvm::cast<TileType>(getRes().getType());
169164
}
165+
166+
std::string getIntrinsicName() {
167+
return "llvm.x86.tilezero.internal";
168+
}
169+
SmallVector<Value> getIntrinsicOperands(
170+
::mlir::ArrayRef<Value> operands,
171+
const ::mlir::LLVMTypeConverter &typeConverter,
172+
::mlir::RewriterBase &rewriter);
170173
}];
171174
let assemblyFormat = "attr-dict `:` qualified(type($res))";
172175
let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
176179
// Tile memory operations.
177180
//
178181

179-
def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
182+
def TileLoadOp : AMX_Op<"tile_load", [Pure,
183+
AMXIntrinsicOpInterface
184+
]> {
180185
let summary = "tile load operation";
181186
let description = [{
182187
Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
200205
TileType getTileType() {
201206
return ::llvm::cast<TileType>(getRes().getType());
202207
}
208+
209+
std::string getIntrinsicName() {
210+
return "llvm.x86.tileloadd64.internal";
211+
}
212+
SmallVector<Value> getIntrinsicOperands(
213+
::mlir::ArrayRef<Value> operands,
214+
const ::mlir::LLVMTypeConverter &typeConverter,
215+
::mlir::RewriterBase &rewriter);
203216
}];
204217
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
205218
"type($base) `into` qualified(type($res))";
206219
let hasVerifier = 1;
207220
}
208221

209-
def TileStoreOp : AMX_Op<"tile_store"> {
222+
def TileStoreOp : AMX_Op<"tile_store", [
223+
AMXIntrinsicOpInterface
224+
]> {
210225
let summary = "tile store operation";
211226
let description = [{
212227
Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
230245
TileType getTileType() {
231246
return ::llvm::cast<TileType>(getVal().getType());
232247
}
248+
249+
std::string getIntrinsicName() {
250+
return "llvm.x86.tilestored64.internal";
251+
}
252+
SmallVector<Value> getIntrinsicOperands(
253+
::mlir::ArrayRef<Value> operands,
254+
const ::mlir::LLVMTypeConverter &typeConverter,
255+
::mlir::RewriterBase &rewriter);
233256
}];
234257
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
235258
"type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
240263
// Tile arithmetic operations.
241264
//
242265

243-
def TileMulFOp : AMX_Op<"tile_mulf", [
244-
Pure, AllTypesMatch<["acc", "res"]>]> {
266+
def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
267+
AMXIntrinsicOpInterface,
268+
AllTypesMatch<["acc", "res"]>
269+
]> {
245270
let summary = "tile multiplication operation (floating-point)";
246271
let description = [{
247272
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,15 +295,30 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
270295
TileType getTileType() {
271296
return ::llvm::cast<TileType>(getRes().getType());
272297
}
298+
299+
std::string getIntrinsicName() {
300+
std::string intr = "llvm.x86.tdp";
301+
auto elementType =
302+
getLhsTileType().getElementType();
303+
intr += elementType.isF16() ? "fp16" : "bf16";
304+
intr += "ps.internal";
305+
return intr;
306+
}
307+
SmallVector<Value> getIntrinsicOperands(
308+
::mlir::ArrayRef<Value> operands,
309+
const ::mlir::LLVMTypeConverter &typeConverter,
310+
::mlir::RewriterBase &rewriter);
273311
}];
274312
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
275313
"qualified(type($lhs)) `,` qualified(type($rhs))"
276314
" `,` qualified(type($acc)) ";
277315
let hasVerifier = 1;
278316
}
279317

280-
def TileMulIOp : AMX_Op<"tile_muli", [
281-
Pure, AllTypesMatch<["acc", "res"]>]> {
318+
def TileMulIOp : AMX_Op<"tile_muli", [Pure,
319+
AMXIntrinsicOpInterface,
320+
AllTypesMatch<["acc", "res"]>
321+
]> {
282322
let summary = "tile multiplication operation (integer)";
283323
let description = [{
284324
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
313353
TileType getTileType() {
314354
return ::llvm::cast<TileType>(getRes().getType());
315355
}
356+
357+
std::string getIntrinsicName() {
358+
std::string intr = "llvm.x86.tdpb";
359+
intr += getIsZextLhs() ? "u" : "s";
360+
intr += getIsZextRhs() ? "u" : "s";
361+
intr += "d.internal";
362+
return intr;
363+
}
364+
SmallVector<Value> getIntrinsicOperands(
365+
::mlir::ArrayRef<Value> operands,
366+
const ::mlir::LLVMTypeConverter &typeConverter,
367+
::mlir::RewriterBase &rewriter);
316368
}];
317369
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
318370
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
319371
let hasVerifier = 1;
320372
}
321373

322-
//===----------------------------------------------------------------------===//
323-
// AMX IntrOp definitions (LLVM compiler facing).
324-
//===----------------------------------------------------------------------===//
325-
326-
//
327-
// Tile reset. Parameters define the tile size.
328-
//
329-
330-
def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
331-
Arguments<(ins AnyInteger, AnyInteger)>;
332-
333-
//
334-
// Tile memory operations. Parameters define the tile size,
335-
// base address, and stride between consecutive rows for the
336-
// memory operation.
337-
//
338-
339-
def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
340-
Arguments<(ins AnyInteger,
341-
AnyInteger, LLVM_AnyPointer, AnyInteger)>;
342-
343-
def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
344-
Arguments<(ins AnyInteger,
345-
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
346-
347-
//
348-
// Tile multiplication operations (series of dot products). Parameters
349-
// define the tile sizes and source and destination tiles for the
350-
// operation. Note that the prefix "tdp" stands for tile dot product.
351-
//
352-
353-
// Dot product of bf16 tiles into f32 tile.
354-
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
355-
Arguments<(ins AnyInteger,
356-
AnyInteger,
357-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
358-
359-
// Dot product of f16 tiles into f32 tile.
360-
def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361-
Arguments<(ins AnyInteger,
362-
AnyInteger,
363-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364-
365-
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
366-
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
367-
Arguments<(ins AnyInteger,
368-
AnyInteger,
369-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
370-
371-
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
372-
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
373-
Arguments<(ins AnyInteger,
374-
AnyInteger,
375-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
376-
377-
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
378-
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
379-
Arguments<(ins AnyInteger,
380-
AnyInteger,
381-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
382-
383-
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
384-
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
385-
Arguments<(ins AnyInteger,
386-
AnyInteger,
387-
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
388-
389374
#endif // AMX

mlir/include/mlir/Dialect/AMX/AMXDialect.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
#define MLIR_DIALECT_AMX_AMXDIALECT_H_
1515

1616
#include "mlir/Bytecode/BytecodeOpInterface.h"
17+
#include "mlir/Dialect/LLVMIR/LLVMInterfaces.h"
1718
#include "mlir/IR/BuiltinTypes.h"
1819
#include "mlir/IR/Dialect.h"
1920
#include "mlir/IR/OpDefinition.h"
2021
#include "mlir/Interfaces/SideEffectInterfaces.h"
2122

23+
/// Include the generated interface declarations.
24+
#include "mlir/Dialect/AMX/AMXInterfaces.h.inc"
25+
2226
#include "mlir/Dialect/AMX/AMXDialect.h.inc"
2327

2428
#define GET_TYPEDEF_CLASSES
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
//===- AMXInterfaces.td - AMX interfaces -------------------*- tablegen -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This file defines interfaces for the AMX dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef AMX_INTERFACES
14+
#define AMX_INTERFACES
15+
16+
include "mlir/IR/Interfaces.td"
17+
include "mlir/Dialect/LLVMIR/LLVMInterfaces.td"
18+
19+
//===----------------------------------------------------------------------===//
20+
// AMX Intrinsic Interface
21+
//===----------------------------------------------------------------------===//
22+
23+
def AMXIntrinsicOpInterface
24+
: OpInterface<"AMXIntrinsicOp", [OneToOneIntrinsicOpInterface]> {
25+
let description = [{
26+
A wrapper interface for operations representing AMX LLVM intrinsics.
27+
}];
28+
let cppNamespace = "::mlir::amx";
29+
}
30+
31+
#endif // AMX_INTERFACES
Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
add_mlir_dialect(AMX amx)
22
add_mlir_doc(AMX AMX Dialects/ -gen-dialect-doc -dialect=amx)
33

4-
set(LLVM_TARGET_DEFINITIONS AMX.td)
5-
mlir_tablegen(AMXConversions.inc -gen-llvmir-conversions)
6-
add_public_tablegen_target(MLIRAMXConversionsIncGen)
4+
add_mlir_interface(AMXInterfaces)
5+
add_dependencies(MLIRAMXIncGen MLIRAMXInterfacesIncGen)

mlir/include/mlir/Dialect/AMX/Transforms.h

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,6 @@ void populateAMXLegalizeForLLVMExportPatterns(LLVMTypeConverter &converter,
2525
/// intrinsics.
2626
void configureAMXLegalizeForExportTarget(LLVMConversionTarget &target);
2727

28-
/// Register LLVM conversion interface for AMX dialect.
29-
void registerConvertAMXToLLVMInterface(DialectRegistry &registry);
30-
3128
} // namespace mlir
3229

3330
#endif // MLIR_DIALECT_AMX_TRANSFORMS_H

mlir/include/mlir/InitAllExtensions.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
#include "mlir/Conversion/SCFToEmitC/SCFToEmitC.h"
3333
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
3434
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
35-
#include "mlir/Dialect/AMX/Transforms.h"
3635
#include "mlir/Dialect/Affine/TransformOps/AffineTransformOps.h"
3736
#include "mlir/Dialect/Bufferization/TransformOps/BufferizationTransformOps.h"
3837
#include "mlir/Dialect/DLTI/TransformOps/DLTITransformOps.h"
@@ -84,7 +83,6 @@ inline void registerAllExtensions(DialectRegistry &registry) {
8483
registerConvertOpenMPToLLVMInterface(registry);
8584
registerConvertSCFToEmitCInterface(registry);
8685
ub::registerConvertUBToLLVMInterface(registry);
87-
registerConvertAMXToLLVMInterface(registry);
8886
gpu::registerConvertGpuToLLVMInterface(registry);
8987
NVVM::registerConvertGpuToNVVMInterface(registry);
9088
vector::registerConvertVectorToLLVMInterface(registry);

mlir/include/mlir/Target/LLVMIR/Dialect/AMX/AMXToLLVMIRTranslation.h

Lines changed: 0 additions & 31 deletions
This file was deleted.

0 commit comments

Comments
 (0)