Skip to content

Commit 40c42c6

Browse files
mbrkusaninpetar-avramovicpiotrAMD
authored andcommitted
[AMDGPU] Add GFX12 WMMA and SWMMAC instructions (llvm#77795)
Co-authored-by: Petar Avramovic <[email protected]> Co-authored-by: Piotr Sobczak <[email protected]> Change-Id: I6ab1132823033fb047665f3a527cff748ff69589
1 parent b5e8986 commit 40c42c6

File tree

66 files changed

+17806
-143
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

66 files changed

+17806
-143
lines changed

clang/include/clang/Basic/BuiltinsAMDGPU.def

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -436,5 +436,67 @@ TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_i32, "ii*1", "nc", "gfx12-insts,w
436436
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4i16, "V4sV4s*1", "nc", "gfx12-insts,wavefrontsize64")
437437
TARGET_BUILTIN(__builtin_amdgcn_global_load_tr_v4f16, "V4hV4h*1", "nc", "gfx12-insts,wavefrontsize64")
438438

439+
//===----------------------------------------------------------------------===//
440+
// WMMA builtins.
441+
// Postfix w32 indicates the builtin requires wavefront size of 32.
442+
// Postfix w64 indicates the builtin requires wavefront size of 64.
443+
//
444+
// Some of these are very similar to their GFX11 counterparts, but they don't
445+
// require replication of the A,B matrices, so they use fewer vector elements.
446+
// Therefore, we add an "_gfx12" suffix to distinguish them from the existing
447+
// builtins.
448+
//===----------------------------------------------------------------------===//
449+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12, "V8fV8hV8hV8f", "nc", "gfx12-insts,wavefrontsize32")
450+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12, "V8fV8sV8sV8f", "nc", "gfx12-insts,wavefrontsize32")
451+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12, "V8hV8hV8hV8h", "nc", "gfx12-insts,wavefrontsize32")
452+
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12, "V8sV8sV8sV8s", "nc", "gfx12-insts,wavefrontsize32")
453+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
454+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12, "V8iIbiIbiV8iIb", "nc", "gfx12-insts,wavefrontsize32")
455+
// These are gfx12-only, but for consistency with the other WMMA variants we're
456+
// keeping the "_gfx12" suffix.
457+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
458+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
459+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
460+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12, "V8fV2iV2iV8f", "nc", "gfx12-insts,wavefrontsize32")
461+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12, "V8iIbV2iIbV2iV8iIb", "nc", "gfx12-insts,wavefrontsize32")
462+
463+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12, "V4fV4hV4hV4f", "nc", "gfx12-insts,wavefrontsize64")
464+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12, "V4fV4sV4sV4f", "nc", "gfx12-insts,wavefrontsize64")
465+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12, "V4hV4hV4hV4h", "nc", "gfx12-insts,wavefrontsize64")
466+
TARGET_BUILTIN(__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12, "V4sV4sV4sV4s", "nc", "gfx12-insts,wavefrontsize64")
467+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
468+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
469+
// These are gfx12-only, but for consistency with the other WMMA variants we're
470+
// keeping the "_gfx12" suffix.
471+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
472+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
473+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
474+
TARGET_BUILTIN(__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12, "V4fiiV4f", "nc", "gfx12-insts,wavefrontsize64")
475+
TARGET_BUILTIN(__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12, "V4iIbiIbiV4iIb", "nc", "gfx12-insts,wavefrontsize64")
476+
477+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32, "V8fV8hV16hV8fs", "nc", "gfx12-insts,wavefrontsize32")
478+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32, "V8fV8sV16sV8fs", "nc", "gfx12-insts,wavefrontsize32")
479+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32, "V8hV8hV16hV8hs", "nc", "gfx12-insts,wavefrontsize32")
480+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32, "V8sV8sV16sV8ss", "nc", "gfx12-insts,wavefrontsize32")
481+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
482+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32, "V8iIbiIbV2iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
483+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32, "V8iIbV2iIbV4iV8isIb", "nc", "gfx12-insts,wavefrontsize32")
484+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
485+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
486+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
487+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32, "V8fV2iV4iV8fs", "nc", "gfx12-insts,wavefrontsize32")
488+
489+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64, "V4fV4hV8hV4fs", "nc", "gfx12-insts,wavefrontsize64")
490+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64, "V4fV4sV8sV4fs", "nc", "gfx12-insts,wavefrontsize64")
491+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64, "V4hV4hV8hV4hs", "nc", "gfx12-insts,wavefrontsize64")
492+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64, "V4sV4sV8sV4ss", "nc", "gfx12-insts,wavefrontsize64")
493+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
494+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64, "V4iIbiIbiV4isIb", "nc", "gfx12-insts,wavefrontsize64")
495+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64, "V4iIbiIbV2iV4isIb", "nc", "gfx12-insts,wavefrontsize64")
496+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
497+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
498+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
499+
TARGET_BUILTIN(__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64, "V4fiV2iV4fs", "nc", "gfx12-insts,wavefrontsize64")
500+
439501
#undef BUILTIN
440502
#undef TARGET_BUILTIN

clang/lib/CodeGen/CGBuiltin.cpp

Lines changed: 164 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18285,65 +18285,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
1828518285
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1828618286
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
1828718287
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18288-
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18288+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18289+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18290+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18291+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18292+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18293+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18294+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18295+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18296+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18297+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18298+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18299+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18300+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18301+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18302+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18303+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18304+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18305+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18306+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18307+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18308+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18309+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18310+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18311+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18312+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18313+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18314+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18315+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18316+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18317+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18318+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18319+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18320+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18321+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18322+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18323+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18324+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18325+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18326+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18327+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18328+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18329+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18330+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18331+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18332+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
1828918333

1829018334
// These operations perform a matrix multiplication and accumulation of
1829118335
// the form:
1829218336
// D = A * B + C
18293-
// The return type always matches the type of matrix C.
18294-
unsigned ArgForMatchingRetType;
18337+
// We need to specify one type for matrices AB and one for matrices CD.
18338+
// Sparse matrix operations can have different types for A and B as well as
18339+
// an additional type for sparsity index.
18340+
// Destination type should be put before types used for source operands.
18341+
SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18342+
// On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18343+
// There is no need for the variable opsel argument, so always set it to
18344+
// "false".
18345+
bool AppendFalseForOpselArg = false;
1829518346
unsigned BuiltinWMMAOp;
1829618347

1829718348
switch (BuiltinID) {
1829818349
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
1829918350
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18300-
ArgForMatchingRetType = 2;
18351+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18352+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18353+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830118354
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
1830218355
break;
1830318356
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
1830418357
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18305-
ArgForMatchingRetType = 2;
18358+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18359+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18360+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1830618361
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
1830718362
break;
18363+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18364+
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18365+
AppendFalseForOpselArg = true;
18366+
LLVM_FALLTHROUGH;
1830818367
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
1830918368
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18310-
ArgForMatchingRetType = 2;
18369+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831118370
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
1831218371
break;
18372+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18373+
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18374+
AppendFalseForOpselArg = true;
18375+
LLVM_FALLTHROUGH;
1831318376
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
1831418377
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18315-
ArgForMatchingRetType = 2;
18378+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1831618379
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
1831718380
break;
1831818381
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
1831918382
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18320-
ArgForMatchingRetType = 2;
18383+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832118384
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
1832218385
break;
1832318386
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
1832418387
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18325-
ArgForMatchingRetType = 2;
18388+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
1832618389
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
1832718390
break;
1832818391
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
1832918392
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18330-
ArgForMatchingRetType = 4;
18393+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18394+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18395+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833118396
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
1833218397
break;
1833318398
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
1833418399
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18335-
ArgForMatchingRetType = 4;
18400+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18401+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18402+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
1833618403
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
1833718404
break;
18405+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18406+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18407+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18408+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18409+
break;
18410+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18411+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18412+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18413+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18414+
break;
18415+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18416+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18417+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18418+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18419+
break;
18420+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18421+
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18422+
ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18423+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18424+
break;
18425+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18426+
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18427+
ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18428+
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18429+
break;
18430+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18431+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18432+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18433+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18434+
break;
18435+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18436+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18437+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18438+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18439+
break;
18440+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18441+
case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18442+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18443+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18444+
break;
18445+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18446+
case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18447+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18448+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18449+
break;
18450+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18451+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18452+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18453+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18454+
break;
18455+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18456+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18457+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18458+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18459+
break;
18460+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18461+
case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18462+
ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18463+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18464+
break;
18465+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18466+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18467+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18468+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18469+
break;
18470+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18471+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18472+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18473+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18474+
break;
18475+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18476+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18477+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18478+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18479+
break;
18480+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18481+
case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18482+
ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18483+
BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18484+
break;
1833818485
}
1833918486

1834018487
SmallVector<Value *, 6> Args;
1834118488
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
1834218489
Args.push_back(EmitScalarExpr(E->getArg(i)));
18490+
if (AppendFalseForOpselArg)
18491+
Args.push_back(Builder.getFalse());
1834318492

18344-
Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18345-
{Args[ArgForMatchingRetType]->getType()});
18493+
SmallVector<llvm::Type *, 6> ArgTypes;
18494+
for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18495+
ArgTypes.push_back(Args[ArgIdx]->getType());
1834618496

18497+
Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
1834718498
return Builder.CreateCall(F, Args);
1834818499
}
1834918500

0 commit comments

Comments
 (0)