@@ -219,7 +219,9 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
219
219
%arg4 : vector <16 x f32 >, %arg5 : vector <4 xf32 >,
220
220
%arg6 : vector <4 xf16 >, %arg7 : vector <32 x i32 >,
221
221
%arg8 : vector <16 x i32 >, %arg9 : vector <4 xi32 >,
222
- %arg10 : vector <2 xi16 >, %arg11 : i64 ) -> vector <32 x f32 > {
222
+ %arg10 : vector <2 xi16 >, %arg11 : i64 ,
223
+ %arg12 : vector <8 xbf16 >, %arg13 : vector <4 xi32 >,
224
+ %arg14 : vector <8 xf16 >) -> vector <32 x f32 > {
223
225
%csti32 = llvm.mlir.constant (42 : i32 ) : i32
224
226
225
227
// CHECK-LABEL: rocdl.xdlops
@@ -362,6 +364,37 @@ llvm.func @rocdl.xdlops(%arg0 : f32, %arg1 : f32,
362
364
%r27 = rocdl.mfma.f32.32x32x16.bf8.bf8 %arg11 , %arg11 , %arg4 , %csti32 , %csti32 , %csti32 :
363
365
(i64 , i64 , vector <16 xf32 >,
364
366
i32 , i32 , i32 ) -> vector <16 xf32 >
367
+
368
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.bf16(<8 x bfloat> %{{.*}}, <8 x bfloat> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
369
+ %r28 = rocdl.mfma.f32.16x16x32.bf16 %arg12 , %arg12 , %arg5 , %csti32 , %csti32 , %csti32 :
370
+ (vector <8 xbf16 >, vector <8 xbf16 >, vector <4 xf32 >,
371
+ i32 , i32 , i32 ) -> vector <4 xf32 >
372
+
373
+ // CHECK: call <4 x i32> @llvm.amdgcn.mfma.i32.16x16x64.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
374
+ %r29 = rocdl.mfma.i32.16x16x64.i8 %arg9 , %arg9 , %arg9 , %csti32 , %csti32 , %csti32 :
375
+ (vector <4 xi32 >, vector <4 xi32 >, vector <4 xi32 >,
376
+ i32 , i32 , i32 ) -> vector <4 xi32 >
377
+
378
+ // CHECK: call <4 x float> @llvm.amdgcn.mfma.f32.16x16x32.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <4 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
379
+ %r30 = rocdl.mfma.f32.16x16x32.f16 %arg14 , %arg14 , %arg5 , %csti32 , %csti32 , %csti32 :
380
+ (vector <8 xf16 >, vector <8 xf16 >, vector <4 xf32 >,
381
+ i32 , i32 , i32 ) -> vector <4 xi32 >
382
+
383
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.bf16(<8 x bfloat> %1{{.*}}, <8 x bfloat> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
384
+ %r31 = rocdl.mfma.f32.32x32x16.bf16 %arg12 , %arg12 , %arg4 , %csti32 , %csti32 , %csti32 :
385
+ (vector <8 xbf16 >, vector <8 xbf16 >, vector <16 xf32 >,
386
+ i32 , i32 , i32 ) -> vector <16 xf32 >
387
+
388
+ // CHECK: call <16 x i32> @llvm.amdgcn.mfma.i32.32x32x32.i8(<4 x i32> %{{.*}}, <4 x i32> %{{.*}}, <16 x i32> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
389
+ %r32 = rocdl.mfma.i32.32x32x32.i8 %arg9 , %arg9 , %arg8 , %csti32 , %csti32 , %csti32 :
390
+ (vector <4 xi32 >, vector <4 xi32 >, vector <16 xi32 >,
391
+ i32 , i32 , i32 ) -> vector <16 xi32 >
392
+
393
+ // CHECK: call <16 x float> @llvm.amdgcn.mfma.f32.32x32x16.f16(<8 x half> %{{.*}}, <8 x half> %{{.*}}, <16 x float> %{{.*}}, i32 {{.*}}, i32 {{.*}}, i32 {{.*}})
394
+ %r33 = rocdl.mfma.f32.32x32x16.f16 %arg14 , %arg14 , %arg4 , %csti32 , %csti32 , %csti32 :
395
+ (vector <8 xf16 >, vector <8 xf16 >, vector <16 xf32 >,
396
+ i32 , i32 , i32 ) -> vector <16 xf32 >
397
+
365
398
llvm.return %r0 : vector <32 x f32 >
366
399
}
367
400
0 commit comments