Skip to content

Commit a9bc8ce

Browse files
mbelickiigcbot
authored andcommitted
Support JointMatrix Use parameter in custom SPIR-V translator.
This patch adds support to new OpTypeJointMatrix op with additional 'use' parameter.
1 parent 5805c38 commit a9bc8ce

File tree

2 files changed

+62
-27
lines changed

2 files changed

+62
-27
lines changed

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 39 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -354,34 +354,58 @@ void SPIRVTypeForwardPointer::decode(std::istream& I) {
354354
}
355355

356356
unsigned SPIRVTypeJointMatrixINTEL::getLayout() const {
357-
return (unsigned)get<SPIRVConstant>(Layout)->getZExtIntValue();
357+
return (unsigned)get<SPIRVConstant>(Args[2])->getZExtIntValue();
358+
}
359+
360+
unsigned SPIRVTypeJointMatrixINTEL::getUse() const {
361+
if (isUseParameterPresent())
362+
return (unsigned)get<SPIRVConstant>(Args[4])->getZExtIntValue();
363+
return 0;
358364
}
359365

360366
unsigned SPIRVTypeJointMatrixINTEL::getRows() const {
361-
return (unsigned)get<SPIRVConstant>(Rows)->getZExtIntValue();
367+
return (unsigned)get<SPIRVConstant>(Args[0])->getZExtIntValue();
362368
}
363369

364370
unsigned SPIRVTypeJointMatrixINTEL::getColumns() const {
365-
return (unsigned)get<SPIRVConstant>(Columns)->getZExtIntValue();
371+
return (unsigned)get<SPIRVConstant>(Args[1])->getZExtIntValue();
366372
}
367373

368374
unsigned SPIRVTypeJointMatrixINTEL::getScope() const {
369-
return (unsigned)get<SPIRVConstant>(Scope)->getZExtIntValue();
375+
return (unsigned)get<SPIRVConstant>(Args[3])->getZExtIntValue();
376+
}
377+
378+
bool SPIRVTypeJointMatrixINTEL::isUseParameterPresent() const {
379+
return Args.size() > 4;
370380
}
371381

372382
std::string SPIRVTypeJointMatrixINTEL::getMangledName() const {
373383
std::string name;
374-
switch (getLayout()) {
375-
case SPIRVTypeJointMatrixINTEL::LayoutPackedA:
376-
name += "packedA_";
377-
break;
378-
case SPIRVTypeJointMatrixINTEL::LayoutPackedB:
379-
name += "packedB_";
380-
break;
381-
case SPIRVTypeJointMatrixINTEL::LayoutRowMajor:
382-
case SPIRVTypeJointMatrixINTEL::LayoutColumnMajor:
383-
name += "acc_";
384-
break;
384+
if (isUseParameterPresent()) {
385+
switch (getUse()) {
386+
case SPIRVTypeJointMatrixINTEL::UseMatrixA:
387+
name += "packedA_";
388+
break;
389+
case SPIRVTypeJointMatrixINTEL::UseMatrixB:
390+
name += "packedB_";
391+
break;
392+
case SPIRVTypeJointMatrixINTEL::UseAccumulator:
393+
name += "acc_";
394+
break;
395+
}
396+
} else {
397+
switch (getLayout()) {
398+
case SPIRVTypeJointMatrixINTEL::LayoutPackedA:
399+
name += "packedA_";
400+
break;
401+
case SPIRVTypeJointMatrixINTEL::LayoutPackedB:
402+
name += "packedB_";
403+
break;
404+
case SPIRVTypeJointMatrixINTEL::LayoutRowMajor:
405+
case SPIRVTypeJointMatrixINTEL::LayoutColumnMajor:
406+
name += "acc_";
407+
break;
408+
}
385409
}
386410
name += std::to_string(getRows());
387411
name += "x";

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -680,58 +680,69 @@ class SPIRVTypeNamedBarrier :public SPIRVType {
680680
class SPIRVTypeJointMatrixINTEL : public SPIRVType {
681681
public:
682682
const static Op OC = OpTypeJointMatrixINTEL;
683-
const static SPIRVWord FixedWC = 7;
683+
const static SPIRVWord FixedWC = 3;
684684
// Complete constructor
685685
SPIRVTypeJointMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *ElemType,
686-
SPIRVId Rows, SPIRVId Columns,
687-
SPIRVId Layout, SPIRVId Scope)
688-
: SPIRVType(M, FixedWC, OC, TheId), ElemType(ElemType),
689-
Rows(Rows), Columns(Columns), Layout(Layout), Scope(Scope) {
686+
std::vector<SPIRVId> Args)
687+
: SPIRVType(M, FixedWC, OC, TheId), ElemType(ElemType), Args(Args) {
690688
validate();
691689
}
692690

693691
// Incomplete constructor
694692
SPIRVTypeJointMatrixINTEL()
695-
: SPIRVType(OC), ElemType(0), Rows(0), Columns(0),
696-
Layout(0), Scope(0) {
693+
: SPIRVType(OC), ElemType(0), Args({0, 0, 0, 0}) {
697694
}
698695

699696
CapVec getRequiredCapability() const override {
700697
return getVec(SPIRVCapabilityKind::CapabilityJointMatrixINTEL);
701698
}
702699

700+
void setWordCount(SPIRVWord WordCount) override {
701+
SPIRVType::setWordCount(WordCount);
702+
Args.resize(WordCount - FixedWC);
703+
}
704+
703705
SPIRVType *getElemType() const { return ElemType; }
704706

705707
unsigned getLayout() const;
708+
unsigned getUse() const;
706709
unsigned getRows() const;
707710
unsigned getColumns() const;
708711
unsigned getScope() const;
712+
709713
std::string getMangledName() const;
714+
bool isUseParameterPresent() const;
710715

711716
enum {
712717
LayoutColumnMajor = 0,
713718
LayoutRowMajor = 1,
714719
LayoutPackedA = 2,
715720
LayoutPackedB = 3,
721+
LayoutUnused = 4,
716722
LayoutMAX
717723
};
718724

725+
enum {
726+
UseMatrixA = 0,
727+
UseMatrixB = 1,
728+
UseAccumulator = 2,
729+
UseMAX
730+
};
731+
719732
protected:
720-
_SPIRV_DEF_DEC6(Id, ElemType, Rows, Columns, Layout, Scope)
733+
_SPIRV_DEF_DEC3_OVERRIDE(Id, ElemType, Args)
721734
void validate() const override {
722735
SPIRVEntry::validate();
723736
ElemType->validate();
724737
IGC_ASSERT_EXIT_MESSAGE(getRows() <= 64, "Unsupported rows size.");
725738
IGC_ASSERT_EXIT_MESSAGE(getColumns() <= 64, "Unsupported columns size.");
726739
IGC_ASSERT_EXIT_MESSAGE(getLayout() < LayoutMAX, "Unsupported layout.");
740+
IGC_ASSERT_EXIT_MESSAGE(getUse() < UseMAX, "Unsupported use parameter.");
727741
}
728742

729743
private:
730744
SPIRVType *ElemType;
731-
SPIRVId Rows;
732-
SPIRVId Columns;
733-
SPIRVId Layout;
734-
SPIRVId Scope;
745+
std::vector<SPIRVId> Args;
735746
};
736747

737748

0 commit comments

Comments
 (0)