@@ -18285,65 +18285,216 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
18285
18285
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18286
18286
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18287
18287
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18288
- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18288
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18289
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18290
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18291
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18292
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18293
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18294
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18295
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18296
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18297
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18298
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18299
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18300
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18301
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18302
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18303
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18304
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18305
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18306
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18307
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18308
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18309
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18310
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18311
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18312
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18313
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18314
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18315
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18316
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18317
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18318
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18319
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18320
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18321
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18322
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18323
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18324
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18325
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18326
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18327
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18328
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18329
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18330
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18331
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18332
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
18289
18333
18290
18334
// These operations perform a matrix multiplication and accumulation of
18291
18335
// the form:
18292
18336
// D = A * B + C
18293
- // The return type always matches the type of matrix C.
18294
- unsigned ArgForMatchingRetType;
18337
+ // We need to specify one type for matrices AB and one for matrices CD.
18338
+ // Sparse matrix operations can have different types for A and B as well as
18339
+ // an additional type for sparsity index.
18340
+ // Destination type should be put before types used for source operands.
18341
+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18342
+ // On GFX12, the intrinsics with 16-bit accumulator use a packed layout.
18343
+ // There is no need for the variable opsel argument, so always set it to
18344
+ // "false".
18345
+ bool AppendFalseForOpselArg = false;
18295
18346
unsigned BuiltinWMMAOp;
18296
18347
18297
18348
switch (BuiltinID) {
18298
18349
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
18299
18350
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18300
- ArgForMatchingRetType = 2;
18351
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18352
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18353
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18301
18354
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
18302
18355
break;
18303
18356
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
18304
18357
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18305
- ArgForMatchingRetType = 2;
18358
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18359
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18360
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18306
18361
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
18307
18362
break;
18363
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18364
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18365
+ AppendFalseForOpselArg = true;
18366
+ LLVM_FALLTHROUGH;
18308
18367
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
18309
18368
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18310
- ArgForMatchingRetType = 2;
18369
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18311
18370
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
18312
18371
break;
18372
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18373
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18374
+ AppendFalseForOpselArg = true;
18375
+ LLVM_FALLTHROUGH;
18313
18376
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
18314
18377
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18315
- ArgForMatchingRetType = 2;
18378
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18316
18379
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
18317
18380
break;
18318
18381
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
18319
18382
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18320
- ArgForMatchingRetType = 2;
18383
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18321
18384
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
18322
18385
break;
18323
18386
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
18324
18387
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18325
- ArgForMatchingRetType = 2;
18388
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18326
18389
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
18327
18390
break;
18328
18391
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18329
18392
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18330
- ArgForMatchingRetType = 4;
18393
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18394
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18395
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18331
18396
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
18332
18397
break;
18333
18398
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18334
18399
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18335
- ArgForMatchingRetType = 4;
18400
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18401
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18402
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18336
18403
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
18337
18404
break;
18405
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18406
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18407
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18408
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18409
+ break;
18410
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18411
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18412
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18413
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18414
+ break;
18415
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18416
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18417
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18418
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18419
+ break;
18420
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18421
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18422
+ ArgsForMatchingMatrixTypes = {2, 0}; // CD, AB
18423
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18424
+ break;
18425
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18426
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18427
+ ArgsForMatchingMatrixTypes = {4, 1}; // CD, AB
18428
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18429
+ break;
18430
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18431
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18432
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18433
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18434
+ break;
18435
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18436
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18437
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18438
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18439
+ break;
18440
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18441
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18442
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18443
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18444
+ break;
18445
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18446
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18447
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18448
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18449
+ break;
18450
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18451
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18452
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18453
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18454
+ break;
18455
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18456
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18457
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18458
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18459
+ break;
18460
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18461
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18462
+ ArgsForMatchingMatrixTypes = {4, 1, 3, 5}; // CD, A, B, Index
18463
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18464
+ break;
18465
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18466
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18467
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18468
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18469
+ break;
18470
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18471
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18472
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18473
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18474
+ break;
18475
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18476
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18477
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18478
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18479
+ break;
18480
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18481
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18482
+ ArgsForMatchingMatrixTypes = {2, 0, 1, 3}; // CD, A, B, Index
18483
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18484
+ break;
18338
18485
}
18339
18486
18340
18487
SmallVector<Value *, 6> Args;
18341
18488
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
18342
18489
Args.push_back(EmitScalarExpr(E->getArg(i)));
18490
+ if (AppendFalseForOpselArg)
18491
+ Args.push_back(Builder.getFalse());
18343
18492
18344
- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18345
- {Args[ArgForMatchingRetType]->getType()});
18493
+ SmallVector<llvm::Type *, 6> ArgTypes;
18494
+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18495
+ ArgTypes.push_back(Args[ArgIdx]->getType());
18346
18496
18497
+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
18347
18498
return Builder.CreateCall(F, Args);
18348
18499
}
18349
18500
0 commit comments