@@ -320,6 +320,22 @@ llvm::Type *getJointMatrixINTELExtType(llvm::Type *CompTy,
320
320
" spirv.JointMatrixINTEL" , {CompTy}, Params);
321
321
}
322
322
323
+ llvm::Type *
324
+ getCooperativeMatrixKHRExtType (llvm::Type *CompTy,
325
+ ArrayRef<TemplateArgument> TemplateArgs) {
326
+ assert (TemplateArgs.size () == 5 &&
327
+ " Wrong CooperativeMatrixKHR template parameters number" );
328
+ std::vector<unsigned > Params;
329
+ for (size_t I = 1 ; I != TemplateArgs.size (); ++I) {
330
+ assert (TemplateArgs[I].getKind () == TemplateArgument::Integral &&
331
+ " Wrong CooperativeMatrixKHR template parameter" );
332
+ Params.push_back (TemplateArgs[I].getAsIntegral ().getExtValue ());
333
+ }
334
+
335
+ return llvm::TargetExtType::get (
336
+ CompTy->getContext (), " spirv.CooperativeMatrixKHR" , {CompTy}, Params);
337
+ }
338
+
323
339
// / ConvertSYCLJointMatrixINTELType - Convert SYCL joint_matrix type
324
340
// / which is represented as a pointer to a structure to LLVM extension type
325
341
// / with the parameters that follow SPIR-V JointMatrixINTEL type.
@@ -363,6 +379,39 @@ llvm::Type *CodeGenTypes::ConvertSYCLJointMatrixINTELType(RecordDecl *RD) {
363
379
return getJointMatrixINTELExtType (CompTy, TemplateArgs);
364
380
}
365
381
382
+ // / ConvertSPVCooperativeMatrixType - Convert SYCL joint_matrix type
383
+ // / which is represented as a pointer to a structure to LLVM extension type
384
+ // / with the parameters that follow SPIR-V CooperativeMatrixKHR type.
385
+ // / The expected representation is:
386
+ // / target("spirv.CooperativeMatrixKHR", %element_type, %scope%, %rows%, %cols%,
387
+ // / %use%)
388
+ llvm::Type *CodeGenTypes::ConvertSPVCooperativeMatrixType (RecordDecl *RD) {
389
+ auto *TemplateDecl = cast<ClassTemplateSpecializationDecl>(RD);
390
+ ArrayRef<TemplateArgument> TemplateArgs =
391
+ TemplateDecl->getTemplateArgs ().asArray ();
392
+ assert (TemplateArgs[0 ].getKind () == TemplateArgument::Type &&
393
+ " 1st CooperativeMatrixKHR template parameter must be type" );
394
+ llvm::Type *CompTy = ConvertType (TemplateArgs[0 ].getAsType ());
395
+
396
+ if (CompTy->isStructTy ()) {
397
+ StringRef LlvmTyName = CompTy->getStructName ();
398
+ // Emit half/int16/float for sycl[::*]::{half,bfloat16,tf32}
399
+ if (LlvmTyName.starts_with (" class.sycl::" ) ||
400
+ LlvmTyName.starts_with (" class.__sycl_internal::" ))
401
+ LlvmTyName = LlvmTyName.rsplit (" ::" ).second ;
402
+ if (LlvmTyName == " half" ) {
403
+ CompTy = llvm::Type::getHalfTy (getLLVMContext ());
404
+ } else if (LlvmTyName == " tf32" ) {
405
+ CompTy = llvm::Type::getFloatTy (getLLVMContext ());
406
+ } else if (LlvmTyName == " bfloat16" ) {
407
+ CompTy = llvm::Type::getInt16Ty (getLLVMContext ());
408
+ } else {
409
+ llvm_unreachable (" Wrong matrix base type!" );
410
+ }
411
+ }
412
+ return getCooperativeMatrixKHRExtType (CompTy, TemplateArgs);
413
+ }
414
+
366
415
// / ConvertType - Convert the specified type to its LLVM form.
367
416
llvm::Type *CodeGenTypes::ConvertType (QualType T) {
368
417
T = Context.getCanonicalType (T);
@@ -654,6 +703,10 @@ llvm::Type *CodeGenTypes::ConvertType(QualType T) {
654
703
" __spv::__spirv_JointMatrixINTEL" ) {
655
704
ResultType = ConvertSYCLJointMatrixINTELType (RD);
656
705
break ;
706
+ } else if (RD && RD->getQualifiedNameAsString () ==
707
+ " __spv::__spirv_CooperativeMatrixKHR" ) {
708
+ ResultType = ConvertSPVCooperativeMatrixType (RD);
709
+ break ;
657
710
} else if (RD && RD->getQualifiedNameAsString () ==
658
711
" __spv::__spirv_TaskSequenceINTEL" ) {
659
712
ResultType = llvm::TargetExtType::get (getLLVMContext (),
0 commit comments