25
25
//
26
26
//===----------------------------------------------------------------------===//
27
27
28
- #ifndef AMX
29
- #define AMX
28
+ #ifndef AMX_OPS
29
+ #define AMX_OPS
30
30
31
31
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
32
+ include "mlir/Dialect/AMX/AMXInterfaces.td"
32
33
include "mlir/Interfaces/SideEffectInterfaces.td"
33
34
include "mlir/IR/AttrTypeBase.td"
34
35
include "mlir/IR/BuiltinTypes.td"
@@ -47,8 +48,6 @@ def AMX_Dialect : Dialect {
47
48
48
49
This `AMX` dialect provides a bridge between MLIR concepts such as
49
50
vectors and memrefs and the lower level LLVM IR support of AMX.
50
- The dialect is split into user-facing AMX ops (AMX_Op) and
51
- backend-facing intrinsic ops (AMX_IntrOp).
52
51
53
52
Note that since configuration changes (implicit at dialect level) are
54
53
costly, it is highly recommended to use the AMX dialect on same-shaped
@@ -135,21 +134,17 @@ def AMXTileI8 : AMXTileOf<[I8]>;
135
134
class AMX_Op<string mnemonic, list<Trait> traits = []> :
136
135
Op<AMX_Dialect, mnemonic, traits> {}
137
136
138
- // The "internal" intrinsics are meant for compiler usage.
139
- class AMX_IntrOp<string mnemonic, int numResults, list<Trait> traits = []> :
140
- LLVM_IntrOpBase<AMX_Dialect, mnemonic,
141
- "x86_" # !subst(".", "_", mnemonic) # "_internal",
142
- [], [], traits, numResults>;
143
-
144
137
//===----------------------------------------------------------------------===//
145
- // AMX Op definitions (user facing).
138
+ // AMX Op definitions
146
139
//===----------------------------------------------------------------------===//
147
140
148
141
//
149
142
// Tile reset.
150
143
//
151
144
152
- def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
145
+ def TileZeroOp : AMX_Op<"tile_zero", [Pure,
146
+ AMXIntrinsicOpInterface
147
+ ]> {
153
148
let summary = "tile zero operation";
154
149
let description = [{
155
150
Zeroes the destination tile, with the shape defined by the 2-dim
@@ -167,6 +162,14 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
167
162
TileType getTileType() {
168
163
return ::llvm::cast<TileType>(getRes().getType());
169
164
}
165
+
166
+ std::string getIntrinsicName() {
167
+ return "llvm.x86.tilezero.internal";
168
+ }
169
+ SmallVector<Value> getIntrinsicOperands(
170
+ ::mlir::ArrayRef<Value> operands,
171
+ const ::mlir::LLVMTypeConverter &typeConverter,
172
+ ::mlir::RewriterBase &rewriter);
170
173
}];
171
174
let assemblyFormat = "attr-dict `:` qualified(type($res))";
172
175
let hasVerifier = 1;
@@ -176,7 +179,9 @@ def TileZeroOp : AMX_Op<"tile_zero", [Pure]> {
176
179
// Tile memory operations.
177
180
//
178
181
179
- def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
182
+ def TileLoadOp : AMX_Op<"tile_load", [Pure,
183
+ AMXIntrinsicOpInterface
184
+ ]> {
180
185
let summary = "tile load operation";
181
186
let description = [{
182
187
Loads a tile from memory defined by a base and indices, with the
@@ -200,13 +205,23 @@ def TileLoadOp : AMX_Op<"tile_load", [Pure]> {
200
205
TileType getTileType() {
201
206
return ::llvm::cast<TileType>(getRes().getType());
202
207
}
208
+
209
+ std::string getIntrinsicName() {
210
+ return "llvm.x86.tileloadd64.internal";
211
+ }
212
+ SmallVector<Value> getIntrinsicOperands(
213
+ ::mlir::ArrayRef<Value> operands,
214
+ const ::mlir::LLVMTypeConverter &typeConverter,
215
+ ::mlir::RewriterBase &rewriter);
203
216
}];
204
217
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
205
218
"type($base) `into` qualified(type($res))";
206
219
let hasVerifier = 1;
207
220
}
208
221
209
- def TileStoreOp : AMX_Op<"tile_store"> {
222
+ def TileStoreOp : AMX_Op<"tile_store", [
223
+ AMXIntrinsicOpInterface
224
+ ]> {
210
225
let summary = "tile store operation";
211
226
let description = [{
212
227
Stores a tile to memory defined by a base and indices, with the
@@ -230,6 +245,14 @@ def TileStoreOp : AMX_Op<"tile_store"> {
230
245
TileType getTileType() {
231
246
return ::llvm::cast<TileType>(getVal().getType());
232
247
}
248
+
249
+ std::string getIntrinsicName() {
250
+ return "llvm.x86.tilestored64.internal";
251
+ }
252
+ SmallVector<Value> getIntrinsicOperands(
253
+ ::mlir::ArrayRef<Value> operands,
254
+ const ::mlir::LLVMTypeConverter &typeConverter,
255
+ ::mlir::RewriterBase &rewriter);
233
256
}];
234
257
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
235
258
"type($base) `,` qualified(type($val))";
@@ -240,8 +263,10 @@ def TileStoreOp : AMX_Op<"tile_store"> {
240
263
// Tile arithmetic operations.
241
264
//
242
265
243
- def TileMulFOp : AMX_Op<"tile_mulf", [
244
- Pure, AllTypesMatch<["acc", "res"]>]> {
266
+ def TileMulFOp : AMX_Op<"tile_mulf", [Pure,
267
+ AMXIntrinsicOpInterface,
268
+ AllTypesMatch<["acc", "res"]>
269
+ ]> {
245
270
let summary = "tile multiplication operation (floating-point)";
246
271
let description = [{
247
272
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -270,15 +295,30 @@ def TileMulFOp : AMX_Op<"tile_mulf", [
270
295
TileType getTileType() {
271
296
return ::llvm::cast<TileType>(getRes().getType());
272
297
}
298
+
299
+ std::string getIntrinsicName() {
300
+ std::string intr = "llvm.x86.tdp";
301
+ auto elementType =
302
+ getLhsTileType().getElementType();
303
+ intr += elementType.isF16() ? "fp16" : "bf16";
304
+ intr += "ps.internal";
305
+ return intr;
306
+ }
307
+ SmallVector<Value> getIntrinsicOperands(
308
+ ::mlir::ArrayRef<Value> operands,
309
+ const ::mlir::LLVMTypeConverter &typeConverter,
310
+ ::mlir::RewriterBase &rewriter);
273
311
}];
274
312
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
275
313
"qualified(type($lhs)) `,` qualified(type($rhs))"
276
314
" `,` qualified(type($acc)) ";
277
315
let hasVerifier = 1;
278
316
}
279
317
280
- def TileMulIOp : AMX_Op<"tile_muli", [
281
- Pure, AllTypesMatch<["acc", "res"]>]> {
318
+ def TileMulIOp : AMX_Op<"tile_muli", [Pure,
319
+ AMXIntrinsicOpInterface,
320
+ AllTypesMatch<["acc", "res"]>
321
+ ]> {
282
322
let summary = "tile multiplication operation (integer)";
283
323
let description = [{
284
324
Multiplies a "m x k" tile with a "k x n" tile and accumulates the results
@@ -313,77 +353,22 @@ def TileMulIOp : AMX_Op<"tile_muli", [
313
353
TileType getTileType() {
314
354
return ::llvm::cast<TileType>(getRes().getType());
315
355
}
356
+
357
+ std::string getIntrinsicName() {
358
+ std::string intr = "llvm.x86.tdpb";
359
+ intr += getIsZextLhs() ? "u" : "s";
360
+ intr += getIsZextRhs() ? "u" : "s";
361
+ intr += "d.internal";
362
+ return intr;
363
+ }
364
+ SmallVector<Value> getIntrinsicOperands(
365
+ ::mlir::ArrayRef<Value> operands,
366
+ const ::mlir::LLVMTypeConverter &typeConverter,
367
+ ::mlir::RewriterBase &rewriter);
316
368
}];
317
369
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
318
370
"qualified(type($lhs)) `,` qualified(type($rhs)) `,` qualified(type($acc)) ";
319
371
let hasVerifier = 1;
320
372
}
321
373
322
- //===----------------------------------------------------------------------===//
323
- // AMX IntrOp definitions (LLVM compiler facing).
324
- //===----------------------------------------------------------------------===//
325
-
326
- //
327
- // Tile reset. Parameters define the tile size.
328
- //
329
-
330
- def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
331
- Arguments<(ins AnyInteger, AnyInteger)>;
332
-
333
- //
334
- // Tile memory operations. Parameters define the tile size,
335
- // base address, and stride between consecutive rows for the
336
- // memory operation.
337
- //
338
-
339
- def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
340
- Arguments<(ins AnyInteger,
341
- AnyInteger, LLVM_AnyPointer, AnyInteger)>;
342
-
343
- def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
344
- Arguments<(ins AnyInteger,
345
- AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
346
-
347
- //
348
- // Tile multiplication operations (series of dot products). Parameters
349
- // define the tile sizes and source and destination tiles for the
350
- // operation. Note that the prefix "tdp" stands for tile dot product.
351
- //
352
-
353
- // Dot product of bf16 tiles into f32 tile.
354
- def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
355
- Arguments<(ins AnyInteger,
356
- AnyInteger,
357
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
358
-
359
- // Dot product of f16 tiles into f32 tile.
360
- def LLVM_x86_amx_tdpfp16ps : AMX_IntrOp<"tdpfp16ps", 1>,
361
- Arguments<(ins AnyInteger,
362
- AnyInteger,
363
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
364
-
365
- // Dot product of i8 tiles into i32 tile (with sign/sign extension).
366
- def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
367
- Arguments<(ins AnyInteger,
368
- AnyInteger,
369
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
370
-
371
- // Dot product of i8 tiles into i32 tile (with sign/zero extension).
372
- def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
373
- Arguments<(ins AnyInteger,
374
- AnyInteger,
375
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
376
-
377
- // Dot product of i8 tiles into i32 tile (with zero/sign extension).
378
- def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
379
- Arguments<(ins AnyInteger,
380
- AnyInteger,
381
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
382
-
383
- // Dot product of i8 tiles into i32 tile (with zero/zero extension).
384
- def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
385
- Arguments<(ins AnyInteger,
386
- AnyInteger,
387
- AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
388
-
389
- #endif // AMX
374
+ #endif // AMX_OPS
0 commit comments