@@ -398,6 +398,95 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
398
398
llvm.return %r0 : vector <32 x f32 >
399
399
}
400
400
401
+ llvm.func @rocdl.smfmac (%arg0 : i32 ,
402
+ %arg1 : vector <4 x f16 >,
403
+ %arg2 : vector <8 x f16 >,
404
+ %arg3 : vector <4 x f32 >,
405
+ %arg4 : vector <16 x f32 >,
406
+ %arg5 : vector <4 x i16 >,
407
+ %arg6 : vector <8 x i16 >,
408
+ %arg7 : vector <2 xi32 >,
409
+ %arg8 : vector <4 xi32 >,
410
+ %arg9 : vector <16 xi32 >) -> vector <4 x f32 > {
411
+ %csti32 = llvm.mlir.constant (42 : i32 ) : i32
412
+
413
+ // CHECK-LABEL: rocdl.smfmac
414
+
415
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
416
+ %r0 = rocdl.smfmac.f32.16x16x32.f16 %arg1 , %arg2 , %arg3 , %csti32 , %csti32 , %csti32 :
417
+ (vector <4 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
418
+ i32 , i32 , i32 ) -> vector <4 xf32 >
419
+
420
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.f16(<4 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
421
+ %r1 = rocdl.smfmac.f32.32x32x16.f16 %arg1 , %arg2 , %arg4 , %csti32 , %csti32 , %csti32 :
422
+ (vector <4 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
423
+ i32 , i32 , i32 ) -> vector <16 xf32 >
424
+
425
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x32.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
426
+ %r2 = rocdl.smfmac.f32.16x16x32.bf16 %arg5 , %arg6 , %arg3 , %csti32 , %csti32 , %csti32 :
427
+ (vector <4 xi16 >, vector <8 xi16 >, vector <4 xf32 >,
428
+ i32 , i32 , i32 ) -> vector <4 xf32 >
429
+
430
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x16.bf16(<4 x i16> %{{.*}}, <8 x i16> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
431
+ %r3 = rocdl.smfmac.f32.32x32x16.bf16 %arg5 , %arg6 , %arg4 , %csti32 , %csti32 , %csti32 :
432
+ (vector <4 xi16 >, vector <8 xi16 >, vector <16 xf32 >,
433
+ i32 , i32 , i32 ) -> vector <16 xf32 >
434
+
435
+ // CHECK: call <4 x i32> @llvm.amdgcn.smfmac.i32.16x16x64.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 42, i32 42, i32 42)
436
+ %r4 = rocdl.smfmac.i32.16x16x64.i8 %arg7 , %arg8 , %arg8 , %csti32 , %csti32 , %csti32 :
437
+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
438
+ i32 , i32 , i32 ) -> vector <4 xi32 >
439
+
440
+ // CHECK: call <16 x i32> @llvm.amdgcn.smfmac.i32.32x32x32.i8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 42, i32 42, i32 42)
441
+ %r5 = rocdl.smfmac.i32.32x32x32.i8 %arg7 , %arg8 , %arg9 , %csti32 , %csti32 , %csti32 :
442
+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
443
+ i32 , i32 , i32 ) -> vector <16 xi32 >
444
+
445
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
446
+ %r6 = rocdl.smfmac.f32.16x16x64.bf8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
447
+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
448
+ i32 , i32 , i32 ) -> vector <4 xf32 >
449
+
450
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
451
+ %r7 = rocdl.smfmac.f32.16x16x64.bf8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
452
+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
453
+ i32 , i32 , i32 ) -> vector <4 xf32 >
454
+
455
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
456
+ %r8 = rocdl.smfmac.f32.16x16x64.fp8.bf8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
457
+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
458
+ i32 , i32 , i32 ) -> vector <4 xf32 >
459
+
460
+ // CHECK: call <4 x float> @llvm.amdgcn.smfmac.f32.16x16x64.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x float> %{{.*}}, i32 42, i32 42, i32 42)
461
+ %r9 = rocdl.smfmac.f32.16x16x64.fp8.fp8 %arg7 , %arg8 , %arg3 , %csti32 , %csti32 , %csti32 :
462
+ (vector <2 xi32 >, vector <4 xi32 >, vector <4 xf32 >,
463
+ i32 , i32 , i32 ) -> vector <4 xf32 >
464
+
465
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
466
+ %r10 = rocdl.smfmac.f32.32x32x32.bf8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
467
+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
468
+ i32 , i32 , i32 ) -> vector <16 xf32 >
469
+
470
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.bf8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
471
+ %r11 = rocdl.smfmac.f32.32x32x32.bf8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
472
+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
473
+ i32 , i32 , i32 ) -> vector <16 xf32 >
474
+
475
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.bf8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
476
+ %r12 = rocdl.smfmac.f32.32x32x32.fp8.bf8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
477
+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
478
+ i32 , i32 , i32 ) -> vector <16 xf32 >
479
+
480
+
481
+ // CHECK: call <16 x float> @llvm.amdgcn.smfmac.f32.32x32x32.fp8.fp8(<2 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x float> %{{.*}}, i32 42, i32 42, i32 42)
482
+ %r13 = rocdl.smfmac.f32.32x32x32.fp8.fp8 %arg7 , %arg8 , %arg4 , %csti32 , %csti32 , %csti32 :
483
+ (vector <2 xi32 >, vector <4 xi32 >, vector <16 xf32 >,
484
+ i32 , i32 , i32 ) -> vector <16 xf32 >
485
+
486
+ llvm.return %r0 : vector <4 x f32 >
487
+ }
488
+
489
+
401
490
llvm.func @rocdl.mfma.scale.f32.32x32x64.f8f6f4 (%arg0 : i32 ,
402
491
%arg1 : vector <16 x f32 >, %arg2 : vector <8 xi32 >,
403
492
%arg3 : vector <6 xi32 >, %arg4 : vector <4 xi32 >) -> vector <16 x f32 > {
0 commit comments