Skip to content

[SPIR-V]: Add SPIR-V extension: SPV_KHR_cooperative_matrix #96091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 24, 2024

Conversation

VyacheslavLevytskyy
Copy link
Contributor

@VyacheslavLevytskyy VyacheslavLevytskyy commented Jun 19, 2024

This PR adds SPIR-V extension SPV_KHR_cooperative_matrix that "adds a new set of types known as "cooperative matrix" types, where the storage for and computations performed on the matrix are spread across a set of invocations such as a subgroup" (see https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc).

This PR also fixes #96170, a new test cases is attached (llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll).

@VyacheslavLevytskyy VyacheslavLevytskyy marked this pull request as ready for review June 20, 2024 12:04
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2024

@llvm/pr-subscribers-backend-spir-v

Author: Vyacheslav Levytskyy (VyacheslavLevytskyy)

Changes

This PR adds SPIR-V extension SPV_KHR_cooperative_matrix that "adds a new set of types known as "cooperative matrix" types, where the storage for and computations performed on the matrix are spread across a set of invocations such as a subgroup" (see https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc).

This PR also fixes #96170, a new test cases is attached.


Patch is 27.09 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/96091.diff

12 Files Affected:

  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp (+72-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVBuiltins.td (+15-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp (+4)
  • (modified) llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h (+28)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp (+27-5)
  • (modified) llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h (+7-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstrInfo.td (+16)
  • (modified) llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp (+1-1)
  • (modified) llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp (+9)
  • (modified) llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td (+3)
  • (added) llvm/test/CodeGen/SPIRV/extensions/SPV_KHR_cooperative_matrix/cooperative_matrix.ll (+54)
  • (added) llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll (+30)
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
index c14e5098be711..f5f36075d4a31 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
@@ -558,16 +558,21 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,
 
 static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
                                const SPIRV::IncomingCall *Call,
-                               Register TypeReg = Register(0)) {
+                               Register TypeReg,
+                               ArrayRef<uint32_t> ImmArgs = {}) {
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   auto MIB = MIRBuilder.buildInstr(Opcode);
   if (TypeReg.isValid())
     MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
-  for (Register ArgReg : Call->Arguments) {
+  unsigned Sz = Call->Arguments.size() - ImmArgs.size();
+  for (unsigned i = 0; i < Sz; ++i) {
+    Register ArgReg = Call->Arguments[i];
     if (!MRI->getRegClassOrNull(ArgReg))
       MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
     MIB.addUse(ArgReg);
   }
+  for (uint32_t ImmArg : ImmArgs)
+    MIB.addImm(ImmArg);
   return true;
 }
 
@@ -575,7 +580,7 @@ static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
 static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
                                 MachineIRBuilder &MIRBuilder) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call);
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));
 
   assert(Call->Arguments.size() == 2 &&
          "Need 2 arguments for atomic init translation");
@@ -633,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
                                  MachineIRBuilder &MIRBuilder,
                                  SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call);
+    return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));
 
   Register ScopeRegister =
       buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
@@ -870,7 +875,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
                              MachineIRBuilder &MIRBuilder,
                              SPIRVGlobalRegistry *GR) {
   if (Call->isSpirvOp())
-    return buildOpFromWrapper(MIRBuilder, Opcode, Call);
+    return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));
 
   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
   unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
@@ -1824,6 +1829,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call,
   return true;
 }
 
+static bool generateConstructInst(const SPIRV::IncomingCall *Call,
+                                  MachineIRBuilder &MIRBuilder,
+                                  SPIRVGlobalRegistry *GR) {
+  return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
+                            GR->getSPIRVTypeID(Call->ReturnType));
+}
+
+static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
+                                 MachineIRBuilder &MIRBuilder,
+                                 SPIRVGlobalRegistry *GR) {
+  const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
+  unsigned Opcode =
+      SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
+  bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
+  unsigned ArgSz = Call->Arguments.size();
+  unsigned LiteralIdx = 0;
+  if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
+    LiteralIdx = 3;
+  else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
+    LiteralIdx = 4;
+  SmallVector<uint32_t, 1> ImmArgs;
+  MachineRegisterInfo *MRI = MIRBuilder.getMRI();
+  if (LiteralIdx > 0)
+    ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
+  Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
+  if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
+    SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
+    if (!CoopMatrType)
+      report_fatal_error("Can't find a register's type definition");
+    MIRBuilder.buildInstr(Opcode)
+        .addDef(Call->ReturnRegister)
+        .addUse(TypeReg)
+        .addUse(CoopMatrType->getOperand(0).getReg());
+    return true;
+  }
+  return buildOpFromWrapper(MIRBuilder, Opcode, Call,
+                            IsSet ? TypeReg : Register(0), ImmArgs);
+}
+
 static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
                                      MachineIRBuilder &MIRBuilder,
                                      SPIRVGlobalRegistry *GR) {
@@ -2382,6 +2426,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
   case SPIRV::Select:
     return generateSelectInst(Call.get(), MIRBuilder);
+  case SPIRV::Construct:
+    return generateConstructInst(Call.get(), MIRBuilder, GR);
   case SPIRV::SpecConstant:
     return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
   case SPIRV::Enqueue:
@@ -2400,6 +2446,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
     return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
   case SPIRV::KernelClock:
     return generateKernelClockInst(Call.get(), MIRBuilder, GR);
+  case SPIRV::CoopMatr:
+    return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
   }
   return false;
 }
@@ -2524,6 +2572,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
                                        ExtensionType->getIntParameter(0)));
 }
 
+static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
+                                  MachineIRBuilder &MIRBuilder,
+                                  SPIRVGlobalRegistry *GR) {
+  assert(ExtensionType->getNumIntParameters() == 4 &&
+         "Invalid number of parameters for SPIR-V coop matrices builtin!");
+  assert(ExtensionType->getNumTypeParameters() == 1 &&
+         "SPIR-V coop matrices builtin type must have a type parameter!");
+  const SPIRVType *ElemType =
+      GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
+  // Create or get an existing type from GlobalRegistry.
+  return GR->getOrCreateOpTypeCoopMatr(
+      MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
+      ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
+      ExtensionType->getIntParameter(3));
+}
+
 static SPIRVType *
 getImageType(const TargetExtType *ExtensionType,
              const SPIRV::AccessQualifier::AccessQualifier Qualifier,
@@ -2654,6 +2718,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
   case SPIRV::OpTypeSampledImage:
     TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
     break;
+  case SPIRV::OpTypeCooperativeMatrixKHR:
+    TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
+    break;
   default:
     TargetType =
         getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
diff --git a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
index 2b8e6d856686a..5595d4cde120c 100644
--- a/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
+++ b/llvm/lib/Target/SPIRV/SPIRVBuiltins.td
@@ -60,6 +60,8 @@ def AtomicFloating : BuiltinGroup;
 def GroupUniform : BuiltinGroup;
 def KernelClock : BuiltinGroup;
 def CastToPtr : BuiltinGroup;
+def Construct : BuiltinGroup;
+def CoopMatr : BuiltinGroup;
 
 //===----------------------------------------------------------------------===//
 // Class defining a demangled builtin record. The information in the record
@@ -114,6 +116,12 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage
 // Select builtin record:
 def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>;
 
+// Composite Construct builtin record:
+def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>;
+
+// Dot builtin record:
+def : DemangledBuiltin<"dot", OpenCL_std, Dot, 2, 2>;
+
 //===----------------------------------------------------------------------===//
 // Class defining an extended builtin record used for lowering into an
 // OpExtInst instruction.
@@ -608,6 +616,12 @@ defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToGlobal", Ope
 defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
 defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
 
+// Cooperative Matrix builtin records:
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>;
+defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;
+
 //===----------------------------------------------------------------------===//
 // Class defining a work/sub group builtin that should be translated into a
 // SPIR-V instruction using the defined properties.
@@ -1436,7 +1450,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
 def : BuiltinType<"spirv.Image", OpTypeImage>;
 def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
 def : BuiltinType<"spirv.Pipe", OpTypePipe>;
-
+def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;
 
 //===----------------------------------------------------------------------===//
 // Class matching an OpenCL builtin type name to an equivalent SPIR-V
diff --git a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
index 75aa1823b11f2..c2a5b234d6a2b 100644
--- a/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
@@ -66,6 +66,10 @@ static const std::map<std::string, SPIRV::Extension::Extension>
          SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
         {"SPV_KHR_shader_clock",
          SPIRV::Extension::Extension::SPV_KHR_shader_clock},
+        {"SPV_KHR_cooperative_matrix",
+         SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
+        {"SPV_INTEL_joint_matrix",
+         SPIRV::Extension::Extension::SPV_INTEL_joint_matrix},
 };
 
 bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
diff --git a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
index a37e65a47eda0..cb8576ddee719 100644
--- a/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
+++ b/llvm/lib/Target/SPIRV/SPIRVDuplicatesTracker.h
@@ -60,6 +60,8 @@ enum SpecialTypeKind {
   STK_Pipe,
   STK_DeviceEvent,
   STK_Pointer,
+  STK_CoopMatr,
+  STK_JointMatr,
   STK_Last = -1
 };
 
@@ -113,7 +115,33 @@ make_descr_sampled_image(const Type *SampledTy, const MachineInstr *ImageTy) {
           .Val,
       SpecialTypeKind::STK_SampledImage);
 }
+/*
+union MatrAttrs {
+  struct BitFlags {
+    unsigned Layout : 2;
 
+    unsigned Depth : 2;
+    unsigned Arrayed : 1;
+    unsigned MS : 1;
+    unsigned Sampled : 2;
+    unsigned ImageFormat : 6;
+    unsigned AQ : 2;
+  } Flags;
+  unsigned Val;
+
+  MatrAttrs(unsigned Dim, unsigned Depth, unsigned Arrayed, unsigned MS,
+             unsigned Sampled, unsigned ImageFormat, unsigned AQ = 0) {
+    Val = 0;
+    Flags.Dim = Dim;
+    Flags.Depth = Depth;
+    Flags.Arrayed = Arrayed;
+    Flags.MS = MS;
+    Flags.Sampled = Sampled;
+    Flags.ImageFormat = ImageFormat;
+    Flags.AQ = AQ;
+  }
+};
+*/
 inline SpecialTypeDescriptor make_descr_sampler() {
   return std::make_tuple(nullptr, 0U, SpecialTypeKind::STK_Sampler);
 }
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
index b22d2a04f75b1..b8710d24bff94 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
@@ -1080,12 +1080,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
   return IntType && IntType->getOperand(2).getImm() != 0;
 }
 
+SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
+  return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
+             ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
+             : nullptr;
+}
+
 unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
-  SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg);
-  SPIRVType *ElemType =
-      PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
-          ? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
-          : nullptr;
+  SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
   return ElemType ? ElemType->getOpcode() : 0;
 }
 
@@ -1189,6 +1191,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
       .addUse(getSPIRVTypeID(ImageType));
 }
 
+SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
+    MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
+    const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
+    uint32_t Use) {
+  Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
+  if (ResVReg.isValid())
+    return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
+  ResVReg = createTypeVReg(MIRBuilder);
+  SPIRVType *SpirvTy =
+      MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
+          .addDef(ResVReg)
+          .addUse(getSPIRVTypeID(ElemType))
+          .addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
+          .addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
+  DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
+  return SpirvTy;
+}
+
 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
diff --git a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
index db01f68f48de9..cc4e20b8247cc 100644
--- a/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
+++ b/llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
@@ -292,6 +292,8 @@ class SPIRVGlobalRegistry {
     return Res->second;
   }
 
+  // Return a pointee's type, or nullptr otherwise.
+  SPIRVType *getPointeeType(SPIRVType *PtrType);
   // Return a pointee's type op code, or 0 otherwise.
   unsigned getPointeeTypeOp(Register PtrReg);
 
@@ -514,7 +516,11 @@ class SPIRVGlobalRegistry {
 
   SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
                                            MachineIRBuilder &MIRBuilder);
-
+  SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder,
+                                       const TargetExtType *ExtensionType,
+                                       const SPIRVType *ElemType,
+                                       uint32_t Scope, uint32_t Rows,
+                                       uint32_t Columns, uint32_t Use);
   SPIRVType *
   getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
                         SPIRV::AccessQualifier::AccessQualifier AccQual);
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
index dedfd5e6e32db..63549b06e9670 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
+++ b/llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
@@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
 def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
                   (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
                   "$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
+def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
+                  (ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
+                  "$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;
 
 // 3.42.7 Constant-Creation Instructions
 
@@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta
                   "$res = OpAsmINTEL $type $asm_type $target $asm">;
 def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
                   "$res = OpAsmCallINTEL $type $asm">;
+
+// SPV_KHR_cooperative_matrix
+def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
+                  (ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops),
+                  "$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
+def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
+                  (ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
+                  "OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">;
+def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
+                  (ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
+                  "$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
+def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
+                  "$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
index b9e5569029cfd..3134a9108e5e2 100644
--- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
@@ -1105,7 +1105,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
   if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
     Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
     SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
-        SrcPtrTy, I, TII, SPIRV::StorageClass::Generic);
+        GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
     MachineBasicBlock &BB = *I.getParent();
     const DebugLoc &DL = I.getDebugLoc();
     bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
diff --git a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
index 30a6c474f467a..ac0aa682ea4be 100644
--- a/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
+++ b/llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
@@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI,
       Reqs.addCapability(SPIRV::Capability::AsmINTEL);
     }
     break;
+  case SPIRV::OpTypeCooperativeMatrixKHR:
+    if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
+      report_fatal_error(
+          "OpTypeCooperativeMatrixKHR type requires the "
+          "following SPIR-V extension: SPV_KHR_cooperative_matrix",
+          false);
+    Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
+    Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
+    break;
   default:
     break;
   }
diff --git a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
index 318c5cebb7a43..f7e482449e0ca 100644
--- a/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
+++ b/llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
@@ -302,6 +302,8 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>;
 defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
 defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
 defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
+defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;
+defm SPV_INTEL_joint_matrix : ExtensionOperand<112>;
 
 //===----------------------------------------------------------------------===//
 // Multiclass used to define Capabilities enum values and at the same time
@@ -478,6 +480,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl
 defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
 defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV...
[truncated]

@@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
{"SPV_KHR_shader_clock",
SPIRV::Extension::Extension::SPV_KHR_shader_clock},
{"SPV_KHR_cooperative_matrix",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPIRVUsage doc file will also need to be updated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do in the next PR

@VyacheslavLevytskyy VyacheslavLevytskyy merged commit 57f7937 into llvm:main Jun 24, 2024
8 checks passed
AlexisPerry pushed a commit to llvm-project-tlp/llvm-project that referenced this pull request Jul 9, 2024
This PR adds SPIR-V extension SPV_KHR_cooperative_matrix that "adds a
new set of types known as "cooperative matrix" types, where the storage
for and computations performed on the matrix are spread across a set of
invocations such as a subgroup" (see
https://github.com/KhronosGroup/SPIRV-Registry/blob/main/extensions/KHR/SPV_KHR_cooperative_matrix.asciidoc).

This PR also fixes llvm#96170, a
new test cases is attached
(llvm/test/CodeGen/SPIRV/transcoding/OpPtrCastToGeneric.ll).
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[SPIR-V] Address space casting via a Generic pointer as an intermediary creates a wrong Generic pointer type
3 participants