Skip to content

Commit 45da762

Browse files
vmaksimojsji
authored andcommitted
Align translation of OpCooperativeMatrixLengthKHR to match the spec (#2964)
`SPV_KHR_cooperative_matrix` extension defines that the only argument accepted in this instruction is `Matrix Type <id>`, not the pointer to an actual matrix. This resolves #2963 Original commit: KhronosGroup/SPIRV-LLVM-Translator@197a800558787c9
1 parent 789f0e6 commit 45da762

File tree

7 files changed

+22
-8
lines changed

7 files changed

+22
-8
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3627,8 +3627,7 @@ Instruction *SPIRVToLLVM::transBuiltinFromInst(const std::string &FuncName,
36273627
Func->addFnAttr(Attribute::Convergent);
36283628
}
36293629
CallInst *Call;
3630-
if (OC == OpCooperativeMatrixLengthKHR &&
3631-
Ops[0]->getOpCode() == OpTypeCooperativeMatrixKHR) {
3630+
if (OC == OpCooperativeMatrixLengthKHR) {
36323631
// OpCooperativeMatrixLengthKHR needs special handling as its operand is
36333632
// a Type instead of a Value.
36343633
llvm::Type *MatTy = transType(reinterpret_cast<SPIRVType *>(Ops[0]));

llvm-spirv/lib/SPIRV/SPIRVWriter.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6787,6 +6787,10 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
67876787
transValue(CI->getArgOperand(2), BB), BB);
67886788
return BM->addStoreInst(transValue(CI->getArgOperand(0), BB), V, {}, BB);
67896789
}
6790+
case OpCooperativeMatrixLengthKHR: {
6791+
return BM->addCooperativeMatrixLengthKHRInst(
6792+
transScavengedType(CI), transType(CI->getArgOperand(0)->getType()), BB);
6793+
}
67906794
case OpGroupNonUniformShuffleDown: {
67916795
Function *F = CI->getCalledFunction();
67926796
if (F->arg_size() && F->getArg(0)->hasStructRetAttr()) {

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.cpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,9 @@ class SPIRVModuleImpl : public SPIRVModule {
279279
SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() override;
280280
SPIRVInstruction *addTaskSequenceGetINTELInst(SPIRVType *, SPIRVValue *,
281281
SPIRVBasicBlock *) override;
282+
SPIRVInstruction *
283+
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
284+
SPIRVBasicBlock *) override;
282285
SPIRVType *addOpaqueGenericType(Op) override;
283286
SPIRVTypeDeviceEvent *addDeviceEventType() override;
284287
SPIRVTypeQueue *addQueueType() override;
@@ -1094,6 +1097,14 @@ SPIRVInstruction *SPIRVModuleImpl::addTaskSequenceGetINTELInst(
10941097
BB);
10951098
}
10961099

1100+
SPIRVInstruction *SPIRVModuleImpl::addCooperativeMatrixLengthKHRInst(
1101+
SPIRVType *RetTy, SPIRVType *MatTy, SPIRVBasicBlock *BB) {
1102+
return addInstruction(
1103+
SPIRVInstTemplateBase::create(OpCooperativeMatrixLengthKHR, RetTy,
1104+
getId(), getVec(MatTy->getId()), BB, this),
1105+
BB);
1106+
}
1107+
10971108
SPIRVType *SPIRVModuleImpl::addOpaqueGenericType(Op TheOpCode) {
10981109
return addType(new SPIRVTypeOpaqueGeneric(TheOpCode, this, getId()));
10991110
}

llvm-spirv/lib/SPIRV/libSPIRV/SPIRVModule.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,9 @@ class SPIRVModule {
272272
virtual SPIRVTypeTaskSequenceINTEL *addTaskSequenceINTELType() = 0;
273273
virtual SPIRVInstruction *
274274
addTaskSequenceGetINTELInst(SPIRVType *, SPIRVValue *, SPIRVBasicBlock *) = 0;
275+
virtual SPIRVInstruction *
276+
addCooperativeMatrixLengthKHRInst(SPIRVType *, SPIRVType *,
277+
SPIRVBasicBlock *) = 0;
275278
virtual SPIRVTypeVoid *addVoidType() = 0;
276279
virtual SPIRVType *addOpaqueGenericType(Op) = 0;
277280
virtual SPIRVTypeDeviceEvent *addDeviceEventType() = 0;

llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
3333
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]]
3434
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]]
35-
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
36-
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
35+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
3736
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
3837
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
3938
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL

llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_prefetch.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,7 @@
3232
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
3333
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
3434
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
35-
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
36-
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
35+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
3736
; CHECK-SPIRV: CooperativeMatrixPrefetchINTEL
3837
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
3938
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]

llvm-spirv/test/extensions/KHR/SPV_KHR_cooperative_matrix/cooperative_matrix.ll

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,7 @@
3030
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
3131
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
3232
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy2]] [[#Load1:]]
33-
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
34-
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
33+
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#MatTy2]]
3534
; CHECK-SPIRV: CooperativeMatrixLoadKHR [[#MatTy3]]
3635
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
3736
; CHECK-SPIRV: CooperativeMatrixStoreKHR

0 commit comments

Comments
 (0)