Skip to content

[SPIR-V]: Add SPIR-V extension: SPV_KHR_cooperative_matrix #96091

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 72 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVBuiltins.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -558,24 +558,29 @@ static Register buildMemSemanticsReg(Register SemanticsRegister,

static bool buildOpFromWrapper(MachineIRBuilder &MIRBuilder, unsigned Opcode,
const SPIRV::IncomingCall *Call,
Register TypeReg = Register(0)) {
Register TypeReg,
ArrayRef<uint32_t> ImmArgs = {}) {
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
auto MIB = MIRBuilder.buildInstr(Opcode);
if (TypeReg.isValid())
MIB.addDef(Call->ReturnRegister).addUse(TypeReg);
for (Register ArgReg : Call->Arguments) {
unsigned Sz = Call->Arguments.size() - ImmArgs.size();
for (unsigned i = 0; i < Sz; ++i) {
Register ArgReg = Call->Arguments[i];
if (!MRI->getRegClassOrNull(ArgReg))
MRI->setRegClass(ArgReg, &SPIRV::IDRegClass);
MIB.addUse(ArgReg);
}
for (uint32_t ImmArg : ImmArgs)
MIB.addImm(ImmArg);
return true;
}

/// Helper function for translating atomic init to OpStore.
static bool buildAtomicInitInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call);
return buildOpFromWrapper(MIRBuilder, SPIRV::OpStore, Call, Register(0));

assert(Call->Arguments.size() == 2 &&
"Need 2 arguments for atomic init translation");
Expand Down Expand Up @@ -633,7 +638,7 @@ static bool buildAtomicStoreInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call);
return buildOpFromWrapper(MIRBuilder, SPIRV::OpAtomicStore, Call, Register(0));

Register ScopeRegister =
buildConstantIntReg(SPIRV::Scope::Device, MIRBuilder, GR);
Expand Down Expand Up @@ -870,7 +875,7 @@ static bool buildBarrierInst(const SPIRV::IncomingCall *Call, unsigned Opcode,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
if (Call->isSpirvOp())
return buildOpFromWrapper(MIRBuilder, Opcode, Call);
return buildOpFromWrapper(MIRBuilder, Opcode, Call, Register(0));

MachineRegisterInfo *MRI = MIRBuilder.getMRI();
unsigned MemFlags = getIConstVal(Call->Arguments[0], MRI);
Expand Down Expand Up @@ -1824,6 +1829,45 @@ static bool generateSelectInst(const SPIRV::IncomingCall *Call,
return true;
}

static bool generateConstructInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
return buildOpFromWrapper(MIRBuilder, SPIRV::OpCompositeConstruct, Call,
GR->getSPIRVTypeID(Call->ReturnType));
}

static bool generateCoopMatrInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
const SPIRV::DemangledBuiltin *Builtin = Call->Builtin;
unsigned Opcode =
SPIRV::lookupNativeBuiltin(Builtin->Name, Builtin->Set)->Opcode;
bool IsSet = Opcode != SPIRV::OpCooperativeMatrixStoreKHR;
unsigned ArgSz = Call->Arguments.size();
unsigned LiteralIdx = 0;
if (Opcode == SPIRV::OpCooperativeMatrixLoadKHR && ArgSz > 3)
LiteralIdx = 3;
else if (Opcode == SPIRV::OpCooperativeMatrixStoreKHR && ArgSz > 4)
LiteralIdx = 4;
SmallVector<uint32_t, 1> ImmArgs;
MachineRegisterInfo *MRI = MIRBuilder.getMRI();
if (LiteralIdx > 0)
ImmArgs.push_back(getConstFromIntrinsic(Call->Arguments[LiteralIdx], MRI));
Register TypeReg = GR->getSPIRVTypeID(Call->ReturnType);
if (Opcode == SPIRV::OpCooperativeMatrixLengthKHR) {
SPIRVType *CoopMatrType = GR->getSPIRVTypeForVReg(Call->Arguments[0]);
if (!CoopMatrType)
report_fatal_error("Can't find a register's type definition");
MIRBuilder.buildInstr(Opcode)
.addDef(Call->ReturnRegister)
.addUse(TypeReg)
.addUse(CoopMatrType->getOperand(0).getReg());
return true;
}
return buildOpFromWrapper(MIRBuilder, Opcode, Call,
IsSet ? TypeReg : Register(0), ImmArgs);
}

static bool generateSpecConstantInst(const SPIRV::IncomingCall *Call,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
Expand Down Expand Up @@ -2382,6 +2426,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateSampleImageInst(DemangledCall, Call.get(), MIRBuilder, GR);
case SPIRV::Select:
return generateSelectInst(Call.get(), MIRBuilder);
case SPIRV::Construct:
return generateConstructInst(Call.get(), MIRBuilder, GR);
case SPIRV::SpecConstant:
return generateSpecConstantInst(Call.get(), MIRBuilder, GR);
case SPIRV::Enqueue:
Expand All @@ -2400,6 +2446,8 @@ std::optional<bool> lowerBuiltin(const StringRef DemangledCall,
return generateGroupUniformInst(Call.get(), MIRBuilder, GR);
case SPIRV::KernelClock:
return generateKernelClockInst(Call.get(), MIRBuilder, GR);
case SPIRV::CoopMatr:
return generateCoopMatrInst(Call.get(), MIRBuilder, GR);
}
return false;
}
Expand Down Expand Up @@ -2524,6 +2572,22 @@ static SPIRVType *getPipeType(const TargetExtType *ExtensionType,
ExtensionType->getIntParameter(0)));
}

static SPIRVType *getCoopMatrType(const TargetExtType *ExtensionType,
MachineIRBuilder &MIRBuilder,
SPIRVGlobalRegistry *GR) {
assert(ExtensionType->getNumIntParameters() == 4 &&
"Invalid number of parameters for SPIR-V coop matrices builtin!");
assert(ExtensionType->getNumTypeParameters() == 1 &&
"SPIR-V coop matrices builtin type must have a type parameter!");
const SPIRVType *ElemType =
GR->getOrCreateSPIRVType(ExtensionType->getTypeParameter(0), MIRBuilder);
// Create or get an existing type from GlobalRegistry.
return GR->getOrCreateOpTypeCoopMatr(
MIRBuilder, ExtensionType, ElemType, ExtensionType->getIntParameter(0),
ExtensionType->getIntParameter(1), ExtensionType->getIntParameter(2),
ExtensionType->getIntParameter(3));
}

static SPIRVType *
getImageType(const TargetExtType *ExtensionType,
const SPIRV::AccessQualifier::AccessQualifier Qualifier,
Expand Down Expand Up @@ -2654,6 +2718,9 @@ SPIRVType *lowerBuiltinType(const Type *OpaqueType,
case SPIRV::OpTypeSampledImage:
TargetType = getSampledImageType(BuiltinType, MIRBuilder, GR);
break;
case SPIRV::OpTypeCooperativeMatrixKHR:
TargetType = getCoopMatrType(BuiltinType, MIRBuilder, GR);
break;
default:
TargetType =
getNonParameterizedType(BuiltinType, TypeRecord, MIRBuilder, GR);
Expand Down
13 changes: 12 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVBuiltins.td
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ def AtomicFloating : BuiltinGroup;
def GroupUniform : BuiltinGroup;
def KernelClock : BuiltinGroup;
def CastToPtr : BuiltinGroup;
def Construct : BuiltinGroup;
def CoopMatr : BuiltinGroup;

//===----------------------------------------------------------------------===//
// Class defining a demangled builtin record. The information in the record
Expand Down Expand Up @@ -114,6 +116,9 @@ def : DemangledBuiltin<"__spirv_ImageSampleExplicitLod", OpenCL_std, SampleImage
// Select builtin record:
def : DemangledBuiltin<"__spirv_Select", OpenCL_std, Select, 3, 3>;

// Composite Construct builtin record:
def : DemangledBuiltin<"__spirv_CompositeConstruct", OpenCL_std, Construct, 1, 0>;

//===----------------------------------------------------------------------===//
// Class defining an extended builtin record used for lowering into an
// OpExtInst instruction.
Expand Down Expand Up @@ -608,6 +613,12 @@ defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToGlobal", Ope
defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToLocal", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;
defm : DemangledNativeBuiltin<"__spirv_OpGenericCastToPtrExplicit_ToPrivate", OpenCL_std, CastToPtr, 2, 2, OpGenericCastToPtr>;

// Cooperative Matrix builtin records:
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLoadKHR", OpenCL_std, CoopMatr, 2, 0, OpCooperativeMatrixLoadKHR>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixStoreKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixStoreKHR>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixMulAddKHR", OpenCL_std, CoopMatr, 3, 0, OpCooperativeMatrixMulAddKHR>;
defm : DemangledNativeBuiltin<"__spirv_CooperativeMatrixLengthKHR", OpenCL_std, CoopMatr, 1, 1, OpCooperativeMatrixLengthKHR>;

//===----------------------------------------------------------------------===//
// Class defining a work/sub group builtin that should be translated into a
// SPIR-V instruction using the defined properties.
Expand Down Expand Up @@ -1436,7 +1447,7 @@ def : BuiltinType<"spirv.DeviceEvent", OpTypeDeviceEvent>;
def : BuiltinType<"spirv.Image", OpTypeImage>;
def : BuiltinType<"spirv.SampledImage", OpTypeSampledImage>;
def : BuiltinType<"spirv.Pipe", OpTypePipe>;

def : BuiltinType<"spirv.CooperativeMatrixKHR", OpTypeCooperativeMatrixKHR>;

//===----------------------------------------------------------------------===//
// Class matching an OpenCL builtin type name to an equivalent SPIR-V
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCommandLine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@ static const std::map<std::string, SPIRV::Extension::Extension>
SPIRV::Extension::Extension::SPV_INTEL_function_pointers},
{"SPV_KHR_shader_clock",
SPIRV::Extension::Extension::SPV_KHR_shader_clock},
{"SPV_KHR_cooperative_matrix",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SPIRVUsage doc file will also need to be updated

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will do in the next PR

SPIRV::Extension::Extension::SPV_KHR_cooperative_matrix},
};

bool SPIRVExtensionsParser::parse(cl::Option &O, llvm::StringRef ArgName,
Expand Down
32 changes: 27 additions & 5 deletions llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1080,12 +1080,14 @@ bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
return IntType && IntType->getOperand(2).getImm() != 0;
}

SPIRVType *SPIRVGlobalRegistry::getPointeeType(SPIRVType *PtrType) {
return PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
: nullptr;
}

unsigned SPIRVGlobalRegistry::getPointeeTypeOp(Register PtrReg) {
SPIRVType *PtrType = getSPIRVTypeForVReg(PtrReg);
SPIRVType *ElemType =
PtrType && PtrType->getOpcode() == SPIRV::OpTypePointer
? getSPIRVTypeForVReg(PtrType->getOperand(2).getReg())
: nullptr;
SPIRVType *ElemType = getPointeeType(getSPIRVTypeForVReg(PtrReg));
return ElemType ? ElemType->getOpcode() : 0;
}

Expand Down Expand Up @@ -1189,6 +1191,26 @@ SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
.addUse(getSPIRVTypeID(ImageType));
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeCoopMatr(
MachineIRBuilder &MIRBuilder, const TargetExtType *ExtensionType,
const SPIRVType *ElemType, uint32_t Scope, uint32_t Rows, uint32_t Columns,
uint32_t Use) {
Register ResVReg = DT.find(ExtensionType, &MIRBuilder.getMF());
if (ResVReg.isValid())
return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
ResVReg = createTypeVReg(MIRBuilder);
SPIRVType *SpirvTy =
MIRBuilder.buildInstr(SPIRV::OpTypeCooperativeMatrixKHR)
.addDef(ResVReg)
.addUse(getSPIRVTypeID(ElemType))
.addUse(buildConstantInt(Scope, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Rows, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Columns, MIRBuilder, nullptr, true))
.addUse(buildConstantInt(Use, MIRBuilder, nullptr, true));
DT.add(ExtensionType, &MIRBuilder.getMF(), ResVReg);
return SpirvTy;
}

SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
Expand Down
8 changes: 7 additions & 1 deletion llvm/lib/Target/SPIRV/SPIRVGlobalRegistry.h
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,8 @@ class SPIRVGlobalRegistry {
return Res->second;
}

// Return a pointee's type, or nullptr otherwise.
SPIRVType *getPointeeType(SPIRVType *PtrType);
// Return a pointee's type op code, or 0 otherwise.
unsigned getPointeeTypeOp(Register PtrReg);

Expand Down Expand Up @@ -514,7 +516,11 @@ class SPIRVGlobalRegistry {

SPIRVType *getOrCreateOpTypeSampledImage(SPIRVType *ImageType,
MachineIRBuilder &MIRBuilder);

SPIRVType *getOrCreateOpTypeCoopMatr(MachineIRBuilder &MIRBuilder,
const TargetExtType *ExtensionType,
const SPIRVType *ElemType,
uint32_t Scope, uint32_t Rows,
uint32_t Columns, uint32_t Use);
SPIRVType *
getOrCreateOpTypePipe(MachineIRBuilder &MIRBuilder,
SPIRV::AccessQualifier::AccessQualifier AccQual);
Expand Down
16 changes: 16 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVInstrInfo.td
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,9 @@ def OpTypeAccelerationStructureNV: Op<5341, (outs TYPE:$res), (ins),
def OpTypeCooperativeMatrixNV: Op<5358, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols),
"$res = OpTypeCooperativeMatrixNV $compType $scope $rows $cols">;
def OpTypeCooperativeMatrixKHR: Op<4456, (outs TYPE:$res),
(ins TYPE:$compType, ID:$scope, ID:$rows, ID:$cols, ID:$use),
"$res = OpTypeCooperativeMatrixKHR $compType $scope $rows $cols $use">;

// 3.42.7 Constant-Creation Instructions

Expand Down Expand Up @@ -864,3 +867,16 @@ def OpAsmINTEL: Op<5610, (outs ID:$res), (ins TYPE:$type, TYPE:$asm_type, ID:$ta
"$res = OpAsmINTEL $type $asm_type $target $asm">;
def OpAsmCallINTEL: Op<5611, (outs ID:$res), (ins TYPE:$type, ID:$asm, variable_ops),
"$res = OpAsmCallINTEL $type $asm">;

// SPV_KHR_cooperative_matrix
def OpCooperativeMatrixLoadKHR: Op<4457, (outs ID:$res),
(ins TYPE:$resType, ID:$pointer, ID:$memory_layout, variable_ops),
"$res = OpCooperativeMatrixLoadKHR $resType $pointer $memory_layout">;
def OpCooperativeMatrixStoreKHR: Op<4458, (outs),
(ins ID:$pointer, ID:$objectToStore, ID:$memory_layout, variable_ops),
"OpCooperativeMatrixStoreKHR $pointer $objectToStore $memory_layout">;
def OpCooperativeMatrixMulAddKHR: Op<4459, (outs ID:$res),
(ins TYPE:$type, ID:$A, ID:$B, ID:$C, variable_ops),
"$res = OpCooperativeMatrixMulAddKHR $type $A $B $C">;
def OpCooperativeMatrixLengthKHR: Op<4460, (outs ID:$res), (ins TYPE:$type, ID:$coop_matr_type),
"$res = OpCooperativeMatrixLengthKHR $type $coop_matr_type">;
2 changes: 1 addition & 1 deletion llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1105,7 +1105,7 @@ bool SPIRVInstructionSelector::selectAddrSpaceCast(Register ResVReg,
if (isGenericCastablePtr(SrcSC) && isGenericCastablePtr(DstSC)) {
Register Tmp = MRI->createVirtualRegister(&SPIRV::IDRegClass);
SPIRVType *GenericPtrTy = GR.getOrCreateSPIRVPointerType(
SrcPtrTy, I, TII, SPIRV::StorageClass::Generic);
GR.getPointeeType(SrcPtrTy), I, TII, SPIRV::StorageClass::Generic);
MachineBasicBlock &BB = *I.getParent();
const DebugLoc &DL = I.getDebugLoc();
bool Success = BuildMI(BB, I, DL, TII.get(SPIRV::OpPtrCastToGeneric))
Expand Down
9 changes: 9 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVModuleAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1168,6 +1168,15 @@ void addInstrRequirements(const MachineInstr &MI,
Reqs.addCapability(SPIRV::Capability::AsmINTEL);
}
break;
case SPIRV::OpTypeCooperativeMatrixKHR:
if (!ST.canUseExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix))
report_fatal_error(
"OpTypeCooperativeMatrixKHR type requires the "
"following SPIR-V extension: SPV_KHR_cooperative_matrix",
false);
Reqs.addExtension(SPIRV::Extension::SPV_KHR_cooperative_matrix);
Reqs.addCapability(SPIRV::Capability::CooperativeMatrixKHR);
break;
default:
break;
}
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVSymbolicOperands.td
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,7 @@ defm SPV_INTEL_inline_assembly : ExtensionOperand<107>;
defm SPV_INTEL_cache_controls : ExtensionOperand<108>;
defm SPV_INTEL_global_variable_host_access : ExtensionOperand<109>;
defm SPV_INTEL_global_variable_fpga_decorations : ExtensionOperand<110>;
defm SPV_KHR_cooperative_matrix : ExtensionOperand<111>;

//===----------------------------------------------------------------------===//
// Multiclass used to define Capabilities enum values and at the same time
Expand Down Expand Up @@ -478,6 +479,7 @@ defm GlobalVariableHostAccessINTEL : CapabilityOperand<6187, 0, 0, [SPV_INTEL_gl
defm HostAccessINTEL : CapabilityOperand<6188, 0, 0, [SPV_INTEL_global_variable_host_access], []>;
defm GlobalVariableFPGADecorationsINTEL : CapabilityOperand<6189, 0, 0, [SPV_INTEL_global_variable_fpga_decorations], []>;
defm CacheControlsINTEL : CapabilityOperand<6441, 0, 0, [SPV_INTEL_cache_controls], []>;
defm CooperativeMatrixKHR : CapabilityOperand<6022, 0, 0, [SPV_KHR_cooperative_matrix], []>;

//===----------------------------------------------------------------------===//
// Multiclass used to define SourceLanguage enum values and at the same time
Expand Down
Loading
Loading