Skip to content

Commit e5199eb

Browse files
mbelickiigcbot
authored andcommitted
Initial implementation of INTEL Joint Matrix SPIR-V extension.
This patch adds support for lowering INTEL Joint Matrix SPIR-V instructions to DPAS equivalents. The implementation lowers incoming SPV to opaque pointer types and placeholder functions. Then new pass resolved those placeholders based on the current traget platform.
1 parent 4e729d3 commit e5199eb

File tree

14 files changed

+905
-0
lines changed

14 files changed

+905
-0
lines changed

IGC/AdaptorOCL/SPIRV/SPIRVInternal.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ _SPIRV_OP(OpSUDotKHR)
454454
_SPIRV_OP(OpUDotAccSatKHR)
455455
_SPIRV_OP(OpSDotAccSatKHR)
456456
_SPIRV_OP(OpSUDotAccSatKHR)
457+
_SPIRV_OP(OpMatrixLoadINTEL)
458+
_SPIRV_OP(OpMatrixStoreINTEL)
459+
_SPIRV_OP(OpMatrixMadINTEL)
457460
#undef _SPIRV_OP
458461

459462
#define _SPIRV_OP(x, y) add(Op##y, #x);

IGC/AdaptorOCL/SPIRV/SPIRVReader.cpp

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2128,6 +2128,25 @@ SPIRVToLLVM::transType(SPIRVType *T) {
21282128
getOrCreateOpaquePtrType(M, "intel.buffer_rw_t",
21292129
SPIRAddressSpace::SPIRAS_Global));
21302130
}
2131+
case OpTypeMatrixINTEL:
2132+
{
2133+
SPIRVTypeMatrixINTEL *MT = static_cast<SPIRVTypeMatrixINTEL *>(T);
2134+
const char *typeName = nullptr;
2135+
switch (MT->getLayout()) {
2136+
case SPIRVTypeMatrixINTEL::LayoutPackedA:
2137+
typeName = "intel.joint_matrix_packedA_t";
2138+
break;
2139+
case SPIRVTypeMatrixINTEL::LayoutPackedB:
2140+
typeName = "intel.joint_matrix_packedB_t";
2141+
break;
2142+
case SPIRVTypeMatrixINTEL::LayoutRowMajor:
2143+
case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
2144+
typeName = "intel.joint_matrix_acc_t";
2145+
break;
2146+
}
2147+
IGC_ASSERT_EXIT_MESSAGE(typeName, "Unsupported layout of INTEL Joint Matrix.");
2148+
return mapType(T, getOrCreateOpaquePtrType(M, typeName, SPIRAddressSpace::SPIRAS_Global));
2149+
}
21312150
default: {
21322151
auto OC = T->getOpCode();
21332152
if (isOpaqueGenericTypeOpCode(OC) ||
@@ -3651,6 +3670,199 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
36513670
auto* BC = static_cast<SPIRVUnary*>(BV);
36523671
return mapValue(BV, transValue(BC->getOperand(0), F, BB));
36533672
}
3673+
case OpMatrixLoadINTEL: {
3674+
SPIRVMatrixLoadINTEL *ML = static_cast<SPIRVMatrixLoadINTEL *>(BV);
3675+
std::vector<SPIRVValue *> BArgs = ML->getOperands();
3676+
enum SPVIdx { Pointer, Stride, Layout, Scope, MemOp };
3677+
3678+
SPIRVTypeMatrixINTEL *MatTy = static_cast<SPIRVTypeMatrixINTEL *>(ML->getType());
3679+
const unsigned loadLayout = (unsigned)BM->get<SPIRVConstant>(BArgs[Layout]->getId())->getZExtIntValue();
3680+
3681+
IGC_ASSERT_MESSAGE(BB, "Invalid BB");
3682+
3683+
/* Get arugment values for the intrinsic call */
3684+
Value *PtrVal = transValue(BArgs[Pointer], F, BB);
3685+
Value *StrideVal = transValue(BArgs[Stride], F, BB);
3686+
3687+
unsigned AS = static_cast<PointerType *>(PtrVal->getType())->getAddressSpace();
3688+
/* Prepare types for the call: */
3689+
Type *RetTy = transType(MatTy);
3690+
Type *PtrTy = PointerType::get(Type::getInt8Ty(*Context), AS);
3691+
Type *StrideTy = Type::getInt32Ty(*Context);
3692+
Type *ElemTypeTy = Type::getInt32Ty(*Context);
3693+
Type *LayoutTy = Type::getInt32Ty(*Context);
3694+
Type *SizeTy = Type::getInt32Ty(*Context);
3695+
3696+
std::vector<Type *> ArgTys = {
3697+
PtrTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3698+
};
3699+
FunctionType *builtinTy = FunctionType::get(RetTy, ArgTys, false);
3700+
3701+
/* Cast if necessary and prepare rest of the arguments: */
3702+
CastInst *Ptr = CastInst::CreatePointerCast(PtrVal, PtrTy, "", BB);
3703+
if (StrideVal->getType() != StrideTy) {
3704+
IGC_ASSERT_MESSAGE(StrideVal->getType()->isIntegerTy(),
3705+
"Unspupported matrix stide type in load instruction.");
3706+
StrideVal = CastInst::CreateIntegerCast(StrideVal, StrideTy, false, "stride", Ptr);
3707+
}
3708+
3709+
Value *LoadLayoutVal = ConstantInt::get(LayoutTy, loadLayout);
3710+
Value *ElementTypeVal = ConstantInt::get(ElemTypeTy, MatTy->getElementTypeFlags());
3711+
Value *RowsVal = ConstantInt::get(SizeTy, MatTy->getRows());
3712+
Value *ColumnsVal = ConstantInt::get(SizeTy, MatTy->getColumns());
3713+
3714+
/* Get function to call */
3715+
const char *suffix = nullptr;
3716+
switch (MatTy->getLayout()) {
3717+
case SPIRVTypeMatrixINTEL::LayoutPackedA:
3718+
suffix = "_PackedA";
3719+
break;
3720+
case SPIRVTypeMatrixINTEL::LayoutPackedB:
3721+
suffix = "_PackedB";
3722+
break;
3723+
case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3724+
case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3725+
suffix = "_Accumulator";
3726+
break;
3727+
}
3728+
IGC_ASSERT_MESSAGE(suffix, "Unsupported layout type for INTEL Joint Matrix.");
3729+
auto BI = static_cast<SPIRVInstruction *>(BV);
3730+
std::string builtinName(getSPIRVBuiltinName(BV->getOpCode(), BI, ArgTys, suffix));
3731+
Function *Func = cast<Function>(M->getOrInsertFunction(builtinName, builtinTy));
3732+
3733+
std::vector<Value *> Args = {
3734+
Ptr, StrideVal, LoadLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3735+
};
3736+
CallInst *CI = CallInst::Create(Func, Args, "matrix", BB);
3737+
return mapValue(BV, CI);
3738+
}
3739+
case OpMatrixStoreINTEL: {
3740+
SPIRVMatrixStoreINTEL *MS = static_cast<SPIRVMatrixStoreINTEL *>(BV);
3741+
std::vector<SPIRVValue *> BArgs = MS->getOperands();
3742+
enum SPVIdx { Pointer, Object, Stride, Layout, Scope, MemOp };
3743+
3744+
SPIRVTypeMatrixINTEL *MatTy = static_cast<SPIRVTypeMatrixINTEL *>(BArgs[Object]->getType());
3745+
const unsigned storeLayout = (unsigned)BM->get<SPIRVConstant>(BArgs[Layout]->getId())->getZExtIntValue();
3746+
3747+
IGC_ASSERT_MESSAGE(BB, "Invalid BB");
3748+
3749+
/* Get arugment values for the intrinsic call */
3750+
Value *MatrixVal = transValue(BArgs[Object], F, BB);
3751+
Value *PtrVal = transValue(BArgs[Pointer], F, BB);
3752+
Value *StrideVal = transValue(BArgs[Stride], F, BB);
3753+
3754+
unsigned AS = static_cast<PointerType *>(PtrVal->getType())->getAddressSpace();
3755+
/* Prepare types for the call: */
3756+
Type *MatrixTy = transType(MatTy);
3757+
Type *PtrTy = PointerType::get(Type::getInt8Ty(*Context), AS);
3758+
Type *StrideTy = Type::getInt32Ty(*Context);
3759+
Type *ElemTypeTy = Type::getInt32Ty(*Context);
3760+
Type *LayoutTy = Type::getInt32Ty(*Context);
3761+
Type *SizeTy = Type::getInt32Ty(*Context);
3762+
3763+
std::vector<Type *> ArgTys = {
3764+
PtrTy, MatrixTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3765+
};
3766+
FunctionType *builtinTy = FunctionType::get(Type::getVoidTy(*Context), ArgTys, false);
3767+
3768+
/* Cast if necessary and prepare rest of the arguments: */
3769+
CastInst *Ptr = CastInst::CreatePointerCast(PtrVal, PtrTy, "", BB);
3770+
if (StrideVal->getType() != StrideTy) {
3771+
IGC_ASSERT_MESSAGE(StrideVal->getType()->isIntegerTy(),
3772+
"Unspupported matrix stide type in store instruction.");
3773+
StrideVal = CastInst::CreateIntegerCast(StrideVal, StrideTy, false, "stride", Ptr);
3774+
}
3775+
3776+
Value *StoreLayoutVal = ConstantInt::get(LayoutTy, storeLayout);
3777+
Value *ElementTypeVal = ConstantInt::get(ElemTypeTy, MatTy->getElementTypeFlags());
3778+
Value *RowsVal = ConstantInt::get(SizeTy, MatTy->getRows());
3779+
Value *ColumnsVal = ConstantInt::get(SizeTy, MatTy->getColumns());
3780+
3781+
/* Get function to call */
3782+
const char *suffix = nullptr;
3783+
switch (MatTy->getLayout()) {
3784+
case SPIRVTypeMatrixINTEL::LayoutPackedA:
3785+
suffix = "_PackedA";
3786+
break;
3787+
case SPIRVTypeMatrixINTEL::LayoutPackedB:
3788+
suffix = "_PackedB";
3789+
break;
3790+
case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3791+
case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3792+
suffix = "_Accumulator";
3793+
break;
3794+
}
3795+
IGC_ASSERT_MESSAGE(suffix, "Unsupported layout type for INTEL Joint Matrix.");
3796+
auto BI = static_cast<SPIRVInstruction *>(BV);
3797+
std::string builtinName(getSPIRVBuiltinName(BV->getOpCode(), BI, ArgTys, suffix));
3798+
Function *Func = cast<Function>(M->getOrInsertFunction(builtinName, builtinTy));
3799+
3800+
std::vector<Value *> Args = {
3801+
Ptr, MatrixVal, StrideVal, StoreLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3802+
};
3803+
CallInst *CI = CallInst::Create(Func, Args, "", BB);
3804+
return mapValue(BV, CI);
3805+
}
3806+
case OpMatrixMadINTEL: {
3807+
SPIRVMatrixMadINTEL *MM = static_cast<SPIRVMatrixMadINTEL *>(BV);
3808+
std::vector<SPIRVValue *> BArgs = MM->getOperands();
3809+
enum SPVIdx { A, B, C, Scope };
3810+
3811+
auto *MatATy = static_cast<SPIRVTypeMatrixINTEL *>(BArgs[A]->getType());
3812+
auto *MatBTy = static_cast<SPIRVTypeMatrixINTEL *>(BArgs[B]->getType());
3813+
auto *MatCTy = static_cast<SPIRVTypeMatrixINTEL *>(BArgs[C]->getType());
3814+
3815+
auto *ResMatTy = static_cast<SPIRVTypeMatrixINTEL *>(MM->getType());
3816+
3817+
const unsigned sizeM = MatATy->getRows();
3818+
const unsigned sizeK = MatATy->getColumns();
3819+
const unsigned sizeN = MatBTy->getColumns();
3820+
3821+
IGC_ASSERT(sizeM == MatCTy->getRows());
3822+
IGC_ASSERT(sizeN == MatCTy->getColumns());
3823+
IGC_ASSERT(sizeK == MatBTy->getRows());
3824+
3825+
IGC_ASSERT(ResMatTy->getRows() == MatCTy->getRows());
3826+
IGC_ASSERT(ResMatTy->getColumns() == MatCTy->getColumns());
3827+
3828+
Type *RetTy = transType(ResMatTy);
3829+
Type *ATy = transType(MatATy);
3830+
Type *BTy = transType(MatBTy);
3831+
Type *CTy = transType(MatCTy);
3832+
Type *ElemTypeTy = Type::getInt32Ty(*Context);
3833+
Type *SizeTy = Type::getInt32Ty(*Context);
3834+
3835+
std::vector<Type *> ArgTys = {
3836+
ATy, ElemTypeTy, SizeTy, SizeTy,
3837+
BTy, ElemTypeTy, SizeTy, SizeTy,
3838+
CTy, ElemTypeTy, SizeTy, SizeTy
3839+
};
3840+
FunctionType *builtinTy = FunctionType::get(RetTy, ArgTys, false);
3841+
3842+
auto BI = static_cast<SPIRVInstruction *>(BV);
3843+
std::string builtinName(getSPIRVBuiltinName(BV->getOpCode(), BI, ArgTys, ""));
3844+
Function *Func = cast<Function>(M->getOrInsertFunction(builtinName, builtinTy));
3845+
3846+
std::vector<Value *> Args = {
3847+
/* Matrix A */
3848+
transValue(BArgs[A], F, BB),
3849+
ConstantInt::get(ElemTypeTy, MatATy->getElementTypeFlags()),
3850+
ConstantInt::get(SizeTy, MatATy->getRows()),
3851+
ConstantInt::get(SizeTy, MatATy->getColumns()),
3852+
/* Matrix B */
3853+
transValue(BArgs[B], F, BB),
3854+
ConstantInt::get(ElemTypeTy, MatBTy->getElementTypeFlags()),
3855+
ConstantInt::get(SizeTy, MatBTy->getRows()),
3856+
ConstantInt::get(SizeTy, MatBTy->getColumns()),
3857+
/* Matrix C */
3858+
transValue(BArgs[C], F, BB),
3859+
ConstantInt::get(ElemTypeTy, MatCTy->getElementTypeFlags()),
3860+
ConstantInt::get(SizeTy, MatCTy->getRows()),
3861+
ConstantInt::get(SizeTy, MatCTy->getColumns()),
3862+
};
3863+
CallInst *CI = CallInst::Create(Func, Args, "matrix", BB);
3864+
return mapValue(BV, CI);
3865+
}
36543866
default: {
36553867
auto OC = BV->getOpCode();
36563868
if (isSPIRVCmpInstTransToLLVMInst(static_cast<SPIRVInstruction*>(BV))) {

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2306,5 +2306,20 @@ _SPIRV_OP(VariableLengthArray, true, 4)
23062306
_SPIRV_OP(SaveMemory, true, 3)
23072307
_SPIRV_OP(RestoreMemory, false, 2)
23082308
#undef _SPIRV_OP
2309+
class SPIRVMatrixINTELInst: public SPIRVInstTemplateBase {
2310+
CapVec getRequiredCapability() const override {
2311+
return getVec(CapabilityMatrixINTEL);
2312+
}
2313+
};
2314+
2315+
#define _SPIRV_OP(x, ...) \
2316+
typedef SPIRVInstTemplate<SPIRVMatrixINTELInst, \
2317+
Op##x##INTEL, __VA_ARGS__> \
2318+
SPIRV##x##INTEL;
2319+
2320+
_SPIRV_OP(MatrixLoad, true, 6, true)
2321+
_SPIRV_OP(MatrixStore, false, 5, true)
2322+
_SPIRV_OP(MatrixMad, true, 7)
2323+
#undef _SPIRV_OP
23092324
}
23102325
#endif // SPIRVINSTRUCTION_HPP_

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVOpCode.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ inline bool isTypeOpCode(Op OpCode) {
194194
unsigned OC = OpCode;
195195
return (OpTypeVoid <= OC && OC <= OpTypePipe) ||
196196
isSubgroupAvcINTELTypeOpCode(OpCode) || OC == OpTypeVmeImageINTEL ||
197+
OC == OpTypeMatrixINTEL ||
197198
isVCOpCode(OpCode) || OC == OpTypeTokenINTEL;
198199
}
199200

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVOpCodeEnum.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -515,5 +515,10 @@ _SPIRV_OP(TypeTokenINTEL, 6113)
515515
//_SPIRV_OP(DebugInfoModuleINTEL, 6114)
516516
_SPIRV_OP(ConvertFToBF16INTEL, 6116)
517517
_SPIRV_OP(ConvertBF16ToFINTEL, 6117)
518+
// SPV_INTEL_matrix
519+
_SPIRV_OP(TypeMatrixINTEL, 6119)
520+
_SPIRV_OP(MatrixLoadINTEL, 6120)
521+
_SPIRV_OP(MatrixStoreINTEL, 6121)
522+
_SPIRV_OP(MatrixMadINTEL, 6122)
518523
// SPV_INTEL_arithmetic_fence
519524
_SPIRV_OP(ArithmeticFenceINTEL, 6145)

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -353,5 +353,26 @@ void SPIRVTypeForwardPointer::decode(std::istream& I) {
353353
Decoder >> PointerId >> SC;
354354
}
355355

356+
unsigned SPIRVTypeMatrixINTEL::getLayout() const {
357+
return (unsigned)get<SPIRVConstant>(Layout)->getZExtIntValue();
358+
}
359+
360+
unsigned SPIRVTypeMatrixINTEL::getRows() const {
361+
return (unsigned)get<SPIRVConstant>(Rows)->getZExtIntValue();
362+
}
363+
364+
unsigned SPIRVTypeMatrixINTEL::getColumns() const {
365+
return (unsigned)get<SPIRVConstant>(Columns)->getZExtIntValue();
366+
}
367+
368+
unsigned SPIRVTypeMatrixINTEL::getScope() const {
369+
return (unsigned)get<SPIRVConstant>(Scope)->getZExtIntValue();
370+
}
371+
372+
uint32_t SPIRVTypeMatrixINTEL::getElementTypeFlags() const {
373+
const uint32_t bitWidth = ElemType->getBitWidth();
374+
const bool isFloating = ElemType->isTypeFloat();
375+
return bitWidth | (isFloating << 31);
376+
}
356377
}
357378

IGC/AdaptorOCL/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -677,6 +677,63 @@ class SPIRVTypeNamedBarrier :public SPIRVType {
677677
}
678678
};
679679

680+
class SPIRVTypeMatrixINTEL : public SPIRVType {
681+
public:
682+
const static Op OC = OpTypeMatrixINTEL;
683+
const static SPIRVWord FixedWC = 7;
684+
// Complete constructor
685+
SPIRVTypeMatrixINTEL(SPIRVModule *M, SPIRVId TheId, SPIRVType *ElemType,
686+
SPIRVId Rows, SPIRVId Columns,
687+
SPIRVId Layout, SPIRVId Scope)
688+
: SPIRVType(M, FixedWC, OC, TheId), ElemType(ElemType),
689+
Rows(Rows), Columns(Columns), Layout(Layout), Scope(Scope) {
690+
validate();
691+
}
692+
693+
// Incomplete constructor
694+
SPIRVTypeMatrixINTEL()
695+
: SPIRVType(OC), ElemType(0), Rows(0), Columns(0),
696+
Layout(0), Scope(0) {
697+
}
698+
699+
CapVec getRequiredCapability() const override {
700+
return getVec(SPIRVCapabilityKind::CapabilityMatrixINTEL);
701+
}
702+
703+
SPIRVType *getElemType() const { return ElemType; }
704+
705+
unsigned getLayout() const;
706+
unsigned getRows() const;
707+
unsigned getColumns() const;
708+
unsigned getScope() const;
709+
710+
uint32_t getElementTypeFlags() const;
711+
712+
enum {
713+
LayoutColumnMajor = 0,
714+
LayoutRowMajor = 1,
715+
LayoutPackedA = 2,
716+
LayoutPackedB = 3,
717+
LayoutMAX
718+
};
719+
720+
protected:
721+
_SPIRV_DEF_DEC6(Id, ElemType, Rows, Columns, Layout, Scope)
722+
void validate() const override {
723+
SPIRVEntry::validate();
724+
ElemType->validate();
725+
IGC_ASSERT_EXIT_MESSAGE(getRows() <= 64, "Unsupported rows size.");
726+
IGC_ASSERT_EXIT_MESSAGE(getColumns() <= 64, "Unsupported columns size.");
727+
IGC_ASSERT_EXIT_MESSAGE(getLayout() < LayoutMAX, "Unsupported layout.");
728+
}
729+
730+
private:
731+
SPIRVType *ElemType;
732+
SPIRVId Rows;
733+
SPIRVId Columns;
734+
SPIRVId Layout;
735+
SPIRVId Scope;
736+
};
680737

681738

682739
template<typename T2, typename T1>

IGC/AdaptorOCL/SPIRV/libSPIRV/spirv.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,7 @@ enum Capability {
619619
CapabilityOptNoneINTEL = 6094,
620620
CapabilityTokenTypeINTEL = 6112,
621621
CapabilityDebugInfoModuleINTEL = 6114,
622+
CapabilityMatrixINTEL = 6118,
622623
};
623624

624625
enum PackedVectorFormat {

IGC/AdaptorOCL/UnifyIROCL.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ SPDX-License-Identifier: MIT
9191
#include "Compiler/MetaDataUtilsWrapper.h"
9292
#include "Compiler/SPIRMetaDataTranslation.h"
9393
#include "Compiler/Optimizer/OpenCLPasses/ErrorCheckPass.h"
94+
#include "Compiler/Optimizer/OpenCLPasses/JointMatrixFuncsResolutionPass.h"
9495
#include "Compiler/MetaDataApi/IGCMetaDataHelper.h"
9596
#include "Compiler/CodeGenContextWrapper.hpp"
9697
#include "Compiler/FixResourcePtr.hpp"
@@ -331,6 +332,8 @@ static void CommonOCLBasedPasses(
331332
mpm.add(new HandleFRemInstructions());
332333
}
333334

335+
mpm.add(new JointMatrixFuncsResolutionPass(pContext));
336+
334337
mpm.add(new PreBIImportAnalysis());
335338
mpm.add(createTimeStatsCounterPass(pContext, TIME_Unify_BuiltinImport, STATS_COUNTER_START));
336339
mpm.add(createBuiltInImportPass(std::move(BuiltinGenericModule), std::move(BuiltinSizeModule)));

0 commit comments

Comments
 (0)