@@ -2128,6 +2128,25 @@ SPIRVToLLVM::transType(SPIRVType *T) {
2128
2128
getOrCreateOpaquePtrType (M, " intel.buffer_rw_t" ,
2129
2129
SPIRAddressSpace::SPIRAS_Global));
2130
2130
}
2131
+ case OpTypeMatrixINTEL:
2132
+ {
2133
+ SPIRVTypeMatrixINTEL *MT = static_cast <SPIRVTypeMatrixINTEL *>(T);
2134
+ const char *typeName = nullptr ;
2135
+ switch (MT->getLayout ()) {
2136
+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
2137
+ typeName = " intel.joint_matrix_packedA_t" ;
2138
+ break ;
2139
+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
2140
+ typeName = " intel.joint_matrix_packedB_t" ;
2141
+ break ;
2142
+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
2143
+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
2144
+ typeName = " intel.joint_matrix_acc_t" ;
2145
+ break ;
2146
+ }
2147
+ IGC_ASSERT_EXIT_MESSAGE (typeName, " Unsupported layout of INTEL Joint Matrix." );
2148
+ return mapType (T, getOrCreateOpaquePtrType (M, typeName, SPIRAddressSpace::SPIRAS_Global));
2149
+ }
2131
2150
default : {
2132
2151
auto OC = T->getOpCode ();
2133
2152
if (isOpaqueGenericTypeOpCode (OC) ||
@@ -3651,6 +3670,199 @@ SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
3651
3670
auto * BC = static_cast <SPIRVUnary*>(BV);
3652
3671
return mapValue (BV, transValue (BC->getOperand (0 ), F, BB));
3653
3672
}
3673
+ case OpMatrixLoadINTEL: {
3674
+ SPIRVMatrixLoadINTEL *ML = static_cast <SPIRVMatrixLoadINTEL *>(BV);
3675
+ std::vector<SPIRVValue *> BArgs = ML->getOperands ();
3676
+ enum SPVIdx { Pointer, Stride, Layout, Scope, MemOp };
3677
+
3678
+ SPIRVTypeMatrixINTEL *MatTy = static_cast <SPIRVTypeMatrixINTEL *>(ML->getType ());
3679
+ const unsigned loadLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
3680
+
3681
+ IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3682
+
3683
+ /* Get arugment values for the intrinsic call */
3684
+ Value *PtrVal = transValue (BArgs[Pointer], F, BB);
3685
+ Value *StrideVal = transValue (BArgs[Stride], F, BB);
3686
+
3687
+ unsigned AS = static_cast <PointerType *>(PtrVal->getType ())->getAddressSpace ();
3688
+ /* Prepare types for the call: */
3689
+ Type *RetTy = transType (MatTy);
3690
+ Type *PtrTy = PointerType::get (Type::getInt8Ty (*Context), AS);
3691
+ Type *StrideTy = Type::getInt32Ty (*Context);
3692
+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3693
+ Type *LayoutTy = Type::getInt32Ty (*Context);
3694
+ Type *SizeTy = Type::getInt32Ty (*Context);
3695
+
3696
+ std::vector<Type *> ArgTys = {
3697
+ PtrTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3698
+ };
3699
+ FunctionType *builtinTy = FunctionType::get (RetTy, ArgTys, false );
3700
+
3701
+ /* Cast if necessary and prepare rest of the arguments: */
3702
+ CastInst *Ptr = CastInst::CreatePointerCast (PtrVal, PtrTy, " " , BB);
3703
+ if (StrideVal->getType () != StrideTy) {
3704
+ IGC_ASSERT_MESSAGE (StrideVal->getType ()->isIntegerTy (),
3705
+ " Unspupported matrix stide type in load instruction." );
3706
+ StrideVal = CastInst::CreateIntegerCast (StrideVal, StrideTy, false , " stride" , Ptr);
3707
+ }
3708
+
3709
+ Value *LoadLayoutVal = ConstantInt::get (LayoutTy, loadLayout);
3710
+ Value *ElementTypeVal = ConstantInt::get (ElemTypeTy, MatTy->getElementTypeFlags ());
3711
+ Value *RowsVal = ConstantInt::get (SizeTy, MatTy->getRows ());
3712
+ Value *ColumnsVal = ConstantInt::get (SizeTy, MatTy->getColumns ());
3713
+
3714
+ /* Get function to call */
3715
+ const char *suffix = nullptr ;
3716
+ switch (MatTy->getLayout ()) {
3717
+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
3718
+ suffix = " _PackedA" ;
3719
+ break ;
3720
+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
3721
+ suffix = " _PackedB" ;
3722
+ break ;
3723
+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3724
+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3725
+ suffix = " _Accumulator" ;
3726
+ break ;
3727
+ }
3728
+ IGC_ASSERT_MESSAGE (suffix, " Unsupported layout type for INTEL Joint Matrix." );
3729
+ auto BI = static_cast <SPIRVInstruction *>(BV);
3730
+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, suffix));
3731
+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3732
+
3733
+ std::vector<Value *> Args = {
3734
+ Ptr, StrideVal, LoadLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3735
+ };
3736
+ CallInst *CI = CallInst::Create (Func, Args, " matrix" , BB);
3737
+ return mapValue (BV, CI);
3738
+ }
3739
+ case OpMatrixStoreINTEL: {
3740
+ SPIRVMatrixStoreINTEL *MS = static_cast <SPIRVMatrixStoreINTEL *>(BV);
3741
+ std::vector<SPIRVValue *> BArgs = MS->getOperands ();
3742
+ enum SPVIdx { Pointer, Object, Stride, Layout, Scope, MemOp };
3743
+
3744
+ SPIRVTypeMatrixINTEL *MatTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[Object]->getType ());
3745
+ const unsigned storeLayout = (unsigned )BM->get <SPIRVConstant>(BArgs[Layout]->getId ())->getZExtIntValue ();
3746
+
3747
+ IGC_ASSERT_MESSAGE (BB, " Invalid BB" );
3748
+
3749
+ /* Get arugment values for the intrinsic call */
3750
+ Value *MatrixVal = transValue (BArgs[Object], F, BB);
3751
+ Value *PtrVal = transValue (BArgs[Pointer], F, BB);
3752
+ Value *StrideVal = transValue (BArgs[Stride], F, BB);
3753
+
3754
+ unsigned AS = static_cast <PointerType *>(PtrVal->getType ())->getAddressSpace ();
3755
+ /* Prepare types for the call: */
3756
+ Type *MatrixTy = transType (MatTy);
3757
+ Type *PtrTy = PointerType::get (Type::getInt8Ty (*Context), AS);
3758
+ Type *StrideTy = Type::getInt32Ty (*Context);
3759
+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3760
+ Type *LayoutTy = Type::getInt32Ty (*Context);
3761
+ Type *SizeTy = Type::getInt32Ty (*Context);
3762
+
3763
+ std::vector<Type *> ArgTys = {
3764
+ PtrTy, MatrixTy, StrideTy, LayoutTy, ElemTypeTy, SizeTy, SizeTy
3765
+ };
3766
+ FunctionType *builtinTy = FunctionType::get (Type::getVoidTy (*Context), ArgTys, false );
3767
+
3768
+ /* Cast if necessary and prepare rest of the arguments: */
3769
+ CastInst *Ptr = CastInst::CreatePointerCast (PtrVal, PtrTy, " " , BB);
3770
+ if (StrideVal->getType () != StrideTy) {
3771
+ IGC_ASSERT_MESSAGE (StrideVal->getType ()->isIntegerTy (),
3772
+ " Unspupported matrix stide type in store instruction." );
3773
+ StrideVal = CastInst::CreateIntegerCast (StrideVal, StrideTy, false , " stride" , Ptr);
3774
+ }
3775
+
3776
+ Value *StoreLayoutVal = ConstantInt::get (LayoutTy, storeLayout);
3777
+ Value *ElementTypeVal = ConstantInt::get (ElemTypeTy, MatTy->getElementTypeFlags ());
3778
+ Value *RowsVal = ConstantInt::get (SizeTy, MatTy->getRows ());
3779
+ Value *ColumnsVal = ConstantInt::get (SizeTy, MatTy->getColumns ());
3780
+
3781
+ /* Get function to call */
3782
+ const char *suffix = nullptr ;
3783
+ switch (MatTy->getLayout ()) {
3784
+ case SPIRVTypeMatrixINTEL::LayoutPackedA:
3785
+ suffix = " _PackedA" ;
3786
+ break ;
3787
+ case SPIRVTypeMatrixINTEL::LayoutPackedB:
3788
+ suffix = " _PackedB" ;
3789
+ break ;
3790
+ case SPIRVTypeMatrixINTEL::LayoutRowMajor:
3791
+ case SPIRVTypeMatrixINTEL::LayoutColumnMajor:
3792
+ suffix = " _Accumulator" ;
3793
+ break ;
3794
+ }
3795
+ IGC_ASSERT_MESSAGE (suffix, " Unsupported layout type for INTEL Joint Matrix." );
3796
+ auto BI = static_cast <SPIRVInstruction *>(BV);
3797
+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, suffix));
3798
+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3799
+
3800
+ std::vector<Value *> Args = {
3801
+ Ptr, MatrixVal, StrideVal, StoreLayoutVal, ElementTypeVal, RowsVal, ColumnsVal
3802
+ };
3803
+ CallInst *CI = CallInst::Create (Func, Args, " " , BB);
3804
+ return mapValue (BV, CI);
3805
+ }
3806
+ case OpMatrixMadINTEL: {
3807
+ SPIRVMatrixMadINTEL *MM = static_cast <SPIRVMatrixMadINTEL *>(BV);
3808
+ std::vector<SPIRVValue *> BArgs = MM->getOperands ();
3809
+ enum SPVIdx { A, B, C, Scope };
3810
+
3811
+ auto *MatATy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[A]->getType ());
3812
+ auto *MatBTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[B]->getType ());
3813
+ auto *MatCTy = static_cast <SPIRVTypeMatrixINTEL *>(BArgs[C]->getType ());
3814
+
3815
+ auto *ResMatTy = static_cast <SPIRVTypeMatrixINTEL *>(MM->getType ());
3816
+
3817
+ const unsigned sizeM = MatATy->getRows ();
3818
+ const unsigned sizeK = MatATy->getColumns ();
3819
+ const unsigned sizeN = MatBTy->getColumns ();
3820
+
3821
+ IGC_ASSERT (sizeM == MatCTy->getRows ());
3822
+ IGC_ASSERT (sizeN == MatCTy->getColumns ());
3823
+ IGC_ASSERT (sizeK == MatBTy->getRows ());
3824
+
3825
+ IGC_ASSERT (ResMatTy->getRows () == MatCTy->getRows ());
3826
+ IGC_ASSERT (ResMatTy->getColumns () == MatCTy->getColumns ());
3827
+
3828
+ Type *RetTy = transType (ResMatTy);
3829
+ Type *ATy = transType (MatATy);
3830
+ Type *BTy = transType (MatBTy);
3831
+ Type *CTy = transType (MatCTy);
3832
+ Type *ElemTypeTy = Type::getInt32Ty (*Context);
3833
+ Type *SizeTy = Type::getInt32Ty (*Context);
3834
+
3835
+ std::vector<Type *> ArgTys = {
3836
+ ATy, ElemTypeTy, SizeTy, SizeTy,
3837
+ BTy, ElemTypeTy, SizeTy, SizeTy,
3838
+ CTy, ElemTypeTy, SizeTy, SizeTy
3839
+ };
3840
+ FunctionType *builtinTy = FunctionType::get (RetTy, ArgTys, false );
3841
+
3842
+ auto BI = static_cast <SPIRVInstruction *>(BV);
3843
+ std::string builtinName (getSPIRVBuiltinName (BV->getOpCode (), BI, ArgTys, " " ));
3844
+ Function *Func = cast<Function>(M->getOrInsertFunction (builtinName, builtinTy));
3845
+
3846
+ std::vector<Value *> Args = {
3847
+ /* Matrix A */
3848
+ transValue (BArgs[A], F, BB),
3849
+ ConstantInt::get (ElemTypeTy, MatATy->getElementTypeFlags ()),
3850
+ ConstantInt::get (SizeTy, MatATy->getRows ()),
3851
+ ConstantInt::get (SizeTy, MatATy->getColumns ()),
3852
+ /* Matrix B */
3853
+ transValue (BArgs[B], F, BB),
3854
+ ConstantInt::get (ElemTypeTy, MatBTy->getElementTypeFlags ()),
3855
+ ConstantInt::get (SizeTy, MatBTy->getRows ()),
3856
+ ConstantInt::get (SizeTy, MatBTy->getColumns ()),
3857
+ /* Matrix C */
3858
+ transValue (BArgs[C], F, BB),
3859
+ ConstantInt::get (ElemTypeTy, MatCTy->getElementTypeFlags ()),
3860
+ ConstantInt::get (SizeTy, MatCTy->getRows ()),
3861
+ ConstantInt::get (SizeTy, MatCTy->getColumns ()),
3862
+ };
3863
+ CallInst *CI = CallInst::Create (Func, Args, " matrix" , BB);
3864
+ return mapValue (BV, CI);
3865
+ }
3654
3866
default : {
3655
3867
auto OC = BV->getOpCode ();
3656
3868
if (isSPIRVCmpInstTransToLLVMInst (static_cast <SPIRVInstruction*>(BV))) {
0 commit comments