@@ -18279,65 +18279,211 @@ Value *CodeGenFunction::EmitAMDGPUBuiltinExpr(unsigned BuiltinID,
18279
18279
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18280
18280
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18281
18281
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18282
- case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64: {
18282
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18283
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18284
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18285
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18286
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18287
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18288
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18289
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18290
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18291
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18292
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18293
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18294
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18295
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18296
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18297
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18298
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18299
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18300
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18301
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18302
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18303
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18304
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18305
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18306
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18307
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18308
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18309
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18310
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18311
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18312
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18313
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18314
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18315
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18316
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18317
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18318
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18319
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18320
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18321
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18322
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18323
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18324
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18325
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18326
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64: {
18283
18327
18284
18328
// These operations perform a matrix multiplication and accumulation of
18285
18329
// the form:
18286
18330
// D = A * B + C
18287
- // The return type always matches the type of matrix C.
18288
- unsigned ArgForMatchingRetType;
18331
+ // We need to specify one type for matrices AB and one for matrices CD.
18332
+ SmallVector<unsigned, 2> ArgsForMatchingMatrixTypes;
18333
+ // Some intrinsics expect "false" as an extra bool argument.
18334
+ bool AppendExtraBoolArg = false;
18289
18335
unsigned BuiltinWMMAOp;
18290
18336
18291
18337
switch (BuiltinID) {
18292
18338
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32:
18293
18339
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64:
18294
- ArgForMatchingRetType = 2;
18340
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12:
18341
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_f16_w64_gfx12:
18342
+ ArgsForMatchingMatrixTypes = {0, 2};
18295
18343
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_f16;
18296
18344
break;
18297
18345
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32:
18298
18346
case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64:
18299
- ArgForMatchingRetType = 2;
18347
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w32_gfx12:
18348
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf16_w64_gfx12:
18349
+ ArgsForMatchingMatrixTypes = {0, 2};
18300
18350
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf16;
18301
18351
break;
18352
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32_gfx12:
18353
+ case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64_gfx12:
18354
+ AppendExtraBoolArg = true;
18355
+ LLVM_FALLTHROUGH;
18302
18356
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w32:
18303
18357
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_w64:
18304
- ArgForMatchingRetType = 2 ;
18358
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18305
18359
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16;
18306
18360
break;
18361
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32_gfx12:
18362
+ case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64_gfx12:
18363
+ AppendExtraBoolArg = true;
18364
+ LLVM_FALLTHROUGH;
18307
18365
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w32:
18308
18366
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_w64:
18309
- ArgForMatchingRetType = 2 ;
18367
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18310
18368
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16;
18311
18369
break;
18312
18370
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w32:
18313
18371
case AMDGPU::BI__builtin_amdgcn_wmma_f16_16x16x16_f16_tied_w64:
18314
- ArgForMatchingRetType = 2 ;
18372
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18315
18373
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f16_16x16x16_f16_tied;
18316
18374
break;
18317
18375
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w32:
18318
18376
case AMDGPU::BI__builtin_amdgcn_wmma_bf16_16x16x16_bf16_tied_w64:
18319
- ArgForMatchingRetType = 2 ;
18377
+ ArgsForMatchingMatrixTypes = {0, 2} ;
18320
18378
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_bf16_16x16x16_bf16_tied;
18321
18379
break;
18322
18380
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32:
18323
18381
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64:
18324
- ArgForMatchingRetType = 4;
18382
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w32_gfx12:
18383
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu8_w64_gfx12:
18384
+ ArgsForMatchingMatrixTypes = {1, 4};
18325
18385
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu8;
18326
18386
break;
18327
18387
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32:
18328
18388
case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64:
18329
- ArgForMatchingRetType = 4;
18389
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w32_gfx12:
18390
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x16_iu4_w64_gfx12:
18391
+ ArgsForMatchingMatrixTypes = {1, 4};
18330
18392
BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x16_iu4;
18331
18393
break;
18394
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w32_gfx12:
18395
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_fp8_w64_gfx12:
18396
+ ArgsForMatchingMatrixTypes = {0, 2};
18397
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_fp8;
18398
+ break;
18399
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w32_gfx12:
18400
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_fp8_bf8_w64_gfx12:
18401
+ ArgsForMatchingMatrixTypes = {0, 2};
18402
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_fp8_bf8;
18403
+ break;
18404
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w32_gfx12:
18405
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_fp8_w64_gfx12:
18406
+ ArgsForMatchingMatrixTypes = {0, 2};
18407
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_fp8;
18408
+ break;
18409
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w32_gfx12:
18410
+ case AMDGPU::BI__builtin_amdgcn_wmma_f32_16x16x16_bf8_bf8_w64_gfx12:
18411
+ ArgsForMatchingMatrixTypes = {0, 2};
18412
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_f32_16x16x16_bf8_bf8;
18413
+ break;
18414
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w32_gfx12:
18415
+ case AMDGPU::BI__builtin_amdgcn_wmma_i32_16x16x32_iu4_w64_gfx12:
18416
+ ArgsForMatchingMatrixTypes = {1, 4};
18417
+ BuiltinWMMAOp = Intrinsic::amdgcn_wmma_i32_16x16x32_iu4;
18418
+ break;
18419
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w32:
18420
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_f16_w64:
18421
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18422
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_f16;
18423
+ break;
18424
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w32:
18425
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf16_w64:
18426
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18427
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf16;
18428
+ break;
18429
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w32:
18430
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f16_16x16x32_f16_w64:
18431
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18432
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f16_16x16x32_f16;
18433
+ break;
18434
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w32:
18435
+ case AMDGPU::BI__builtin_amdgcn_swmmac_bf16_16x16x32_bf16_w64:
18436
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18437
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_bf16_16x16x32_bf16;
18438
+ break;
18439
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w32:
18440
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu8_w64:
18441
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18442
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu8;
18443
+ break;
18444
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w32:
18445
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x32_iu4_w64:
18446
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18447
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x32_iu4;
18448
+ break;
18449
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w32:
18450
+ case AMDGPU::BI__builtin_amdgcn_swmmac_i32_16x16x64_iu4_w64:
18451
+ ArgsForMatchingMatrixTypes = {1, 3, 4, 5};
18452
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_i32_16x16x64_iu4;
18453
+ break;
18454
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w32:
18455
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_fp8_w64:
18456
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18457
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_fp8;
18458
+ break;
18459
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w32:
18460
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_fp8_bf8_w64:
18461
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18462
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_fp8_bf8;
18463
+ break;
18464
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w32:
18465
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_fp8_w64:
18466
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18467
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_fp8;
18468
+ break;
18469
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w32:
18470
+ case AMDGPU::BI__builtin_amdgcn_swmmac_f32_16x16x32_bf8_bf8_w64:
18471
+ ArgsForMatchingMatrixTypes = {0, 1, 2, 3};
18472
+ BuiltinWMMAOp = Intrinsic::amdgcn_swmmac_f32_16x16x32_bf8_bf8;
18473
+ break;
18332
18474
}
18333
18475
18334
18476
SmallVector<Value *, 6> Args;
18335
18477
for (int i = 0, e = E->getNumArgs(); i != e; ++i)
18336
18478
Args.push_back(EmitScalarExpr(E->getArg(i)));
18479
+ if (AppendExtraBoolArg)
18480
+ Args.push_back(Builder.getFalse());
18337
18481
18338
- Function *F = CGM.getIntrinsic(BuiltinWMMAOp,
18339
- {Args[ArgForMatchingRetType]->getType()});
18482
+ SmallVector<llvm::Type *, 6> ArgTypes;
18483
+ for (auto ArgIdx : ArgsForMatchingMatrixTypes)
18484
+ ArgTypes.push_back(Args[ArgIdx]->getType());
18340
18485
18486
+ Function *F = CGM.getIntrinsic(BuiltinWMMAOp, ArgTypes);
18341
18487
return Builder.CreateCall(F, Args);
18342
18488
}
18343
18489
0 commit comments