Skip to content

Commit f6f7114

Browse files
authored
Represent init/fill joint_matrix instruction as OpCompositeConstruct (#1363)
Signed-off-by: Dmitry Sidorov <[email protected]>
1 parent 39cae09 commit f6f7114

File tree

6 files changed

+42
-9
lines changed

6 files changed

+42
-9
lines changed

lib/SPIRV/SPIRVReader.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2116,12 +2116,12 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
21162116

21172117
case OpCompositeConstruct: {
21182118
auto CC = static_cast<SPIRVCompositeConstruct *>(BV);
2119-
auto Constituents = transValue(CC->getConstituents(), F, BB);
2119+
auto Constituents = transValue(CC->getOperands(), F, BB);
21202120
std::vector<Constant *> CV;
21212121
for (const auto &I : Constituents) {
21222122
CV.push_back(dyn_cast<Constant>(I));
21232123
}
2124-
switch (BV->getType()->getOpCode()) {
2124+
switch (static_cast<size_t>(BV->getType()->getOpCode())) {
21252125
case OpTypeVector:
21262126
return mapValue(BV, ConstantVector::get(CV));
21272127
case OpTypeArray:
@@ -2132,6 +2132,8 @@ Value *SPIRVToLLVM::transValueWithoutDecoration(SPIRVValue *BV, Function *F,
21322132
return mapValue(BV,
21332133
ConstantStruct::get(
21342134
dyn_cast<StructType>(transType(CC->getType())), CV));
2135+
case internal::OpTypeJointMatrixINTEL:
2136+
return mapValue(BV, transSPIRVBuiltinFromInst(CC, BB));
21352137
default:
21362138
llvm_unreachable("Unhandled type!");
21372139
}

lib/SPIRV/SPIRVWriter.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4444,6 +4444,12 @@ LLVMToSPIRVBase::transBuiltinToInstWithoutDecoration(Op OC, CallInst *CI,
44444444
transValue(CI->getArgOperand(1), BB), MemoryAccess,
44454445
BB);
44464446
}
4447+
case OpCompositeConstruct: {
4448+
std::vector<SPIRVId> Operands = {
4449+
transValue(CI->getArgOperand(0), BB)->getId()};
4450+
return BM->addCompositeConstructInst(transType(CI->getType()), Operands,
4451+
BB);
4452+
}
44474453
default: {
44484454
if (isCvtOpCode(OC) && OC != OpGenericCastToPtrExplicit) {
44494455
return BM->addUnaryInst(OC, transType(CI->getType()),

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1863,7 +1863,7 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
18631863
// Incomplete constructor
18641864
SPIRVCompositeConstruct() : SPIRVInstruction(OC) {}
18651865

1866-
const std::vector<SPIRVValue *> getConstituents() const {
1866+
std::vector<SPIRVValue *> getOperands() override {
18671867
return getValues(Constituents);
18681868
}
18691869

@@ -1875,13 +1875,15 @@ class SPIRVCompositeConstruct : public SPIRVInstruction {
18751875
_SPIRV_DEF_ENCDEC3(Type, Id, Constituents)
18761876
void validate() const override {
18771877
SPIRVInstruction::validate();
1878-
switch (getValueType(this->getId())->getOpCode()) {
1878+
size_t TypeOpCode = this->getType()->getOpCode();
1879+
switch (TypeOpCode) {
18791880
case OpTypeVector:
1880-
assert(getConstituents().size() > 1 &&
1881+
assert(Constituents.size() > 1 &&
18811882
"There must be at least two Constituent operands in vector");
18821883
break;
18831884
case OpTypeArray:
18841885
case OpTypeStruct:
1886+
case internal::OpTypeJointMatrixINTEL:
18851887
break;
18861888
default:
18871889
assert(false && "Invalid type");

lib/SPIRV/libSPIRV/SPIRVType.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ bool SPIRVType::isTypeArray() const { return OpCode == OpTypeArray; }
151151
bool SPIRVType::isTypeBool() const { return OpCode == OpTypeBool; }
152152

153153
bool SPIRVType::isTypeComposite() const {
154-
return isTypeVector() || isTypeArray() || isTypeStruct();
154+
return isTypeVector() || isTypeArray() || isTypeStruct() ||
155+
isTypeJointMatrixINTEL();
155156
}
156157

157158
bool SPIRVType::isTypeFloat(unsigned Bits) const {
@@ -193,6 +194,10 @@ bool SPIRVType::isTypeStruct() const { return OpCode == OpTypeStruct; }
193194

194195
bool SPIRVType::isTypeVector() const { return OpCode == OpTypeVector; }
195196

197+
bool SPIRVType::isTypeJointMatrixINTEL() const {
198+
return OpCode == internal::OpTypeJointMatrixINTEL;
199+
}
200+
196201
bool SPIRVType::isTypeVectorBool() const {
197202
return isTypeVector() && getVectorComponentType()->isTypeBool();
198203
}

lib/SPIRV/libSPIRV/SPIRVType.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ class SPIRVType : public SPIRVEntry {
9595
bool isTypeSampler() const;
9696
bool isTypeStruct() const;
9797
bool isTypeVector() const;
98+
bool isTypeJointMatrixINTEL() const;
9899
bool isTypeVectorInt() const;
99100
bool isTypeVectorFloat() const;
100101
bool isTypeVectorBool() const;

test/transcoding/SPV_INTEL_joint_matrix/joint_matrix.ll

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#Two:]] 2
2323
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#Three:]] 3
2424
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#Sixteen:]] 16
25+
; CHECK-SPIRV-DAG: Constant [[#IntTy]] [[#FortyTwo:]] 42
2526
; CHECK-SPIRV: TypeJointMatrixINTEL [[#CTy:]] [[#ShortTy]] [[#Two]] [[#Two]] [[#Zero]] [[#Three]]
2627
; CHECK-SPIRV: TypeJointMatrixINTEL [[#ATy:]] [[#CharTy]] [[#Two]] [[#Sixteen]] [[#Zero]] [[#Three]]
2728
; CHECK-SPIRV: TypeJointMatrixINTEL [[#BTy:]] [[#CharTy]] [[#Sixteen]] [[#Two]] [[#Three]] [[#Three]]
@@ -39,8 +40,12 @@
3940
; CHECK-SPIRV: JointMatrixLoadINTEL [[#ATy]] [[#A:]] [[#Aptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
4041
; CHECK-SPIRV: JointMatrixLoadINTEL [[#BTy]] [[#B:]] [[#Bptr:]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
4142
; CHECK-SPIRV: JointMatrixMadINTEL [[#CTy]] [[#CMad]] [[#A]] [[#B]] [[#C]] [[#Three]]
42-
4343
; CHECK-SPIRV: JointMatrixStoreINTEL [[#Cptr:]] [[#C]] [[#Stride]] [[#Zero]] [[#Three]] [[#Zero]]
44+
; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#Cnew:]] [[#FortyTwo]]
45+
; CHECK-SPIRV: Store [[#PtrToZero:]] [[#Zero]]
46+
; CHECK-SPIRV: Load [[#]] [[#ZeroLoad:]] [[#PtrToZero]]
47+
; CHECK-SPIRV: CompositeConstruct [[#CTy]] [[#CnewLoad:]] [[#ZeroLoad]]
48+
4449

4550
; CHECK-LLVM: %spirv.JointMatrixINTEL._short_2_2_0_3
4651
; CHECK-LLVM: %spirv.JointMatrixINTEL._char_2_16_0_3
@@ -51,9 +56,11 @@
5156
; CHECK-LLVM: [[A:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS4cliii(i8 addrspace(4)* [[APtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
5257
; CHECK-LLVM: [[B:%.*]] = call spir_func %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* @_Z77__spirv_JointMatrixLoadINTEL_RPU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS4cliii(i8 addrspace(4)* [[BPtr:%.*]], i64 [[Stride]], i32 0, i32 3, i32 0)
5358
; CHECK-LLVM: [[CMad:%.*]] = call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z27__spirv_JointMatrixMadINTELPU3AS139__spirv_JointMatrixINTEL__char_2_16_0_3PU3AS139__spirv_JointMatrixINTEL__char_16_2_3_3PU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3i(%spirv.JointMatrixINTEL._char_2_16_0_3 addrspace(1)* [[A]], %spirv.JointMatrixINTEL._char_16_2_3_3 addrspace(1)* [[B]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i32 3)
54-
5559
; CHECK-LLVM: call spir_func void @_Z29__spirv_JointMatrixStoreINTELPU3AS4sPU3AS139__spirv_JointMatrixINTEL__short_2_2_0_3liii(i16 addrspace(4)* [[CPtr]], %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* [[C]], i64 [[Stride]], i32 0, i32 3, i32 0)
56-
60+
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 42)
61+
; CHECK-LLVM: store i32 0, i32 addrspace(4)* [[StoredZero:%.*]], align 4
62+
; CHECK-LLVM: [[LoadedZero:%.*]] = load i32, i32 addrspace(4)* [[StoredZero]], align 8
63+
; CHECK-LLVM: call spir_func %spirv.JointMatrixINTEL._short_2_2_0_3 addrspace(1)* @_Z26__spirv_CompositeConstructi(i32 [[LoadedZero]])
5764

5865
; ModuleID = 'joint_matrix_test-sycl-spir64-unknown-unknown.bc'
5966
source_filename = "./joint_matrix_test.cpp"
@@ -119,6 +126,13 @@ for.body.i: ; preds = %for.cond.i
119126

120127
_ZZ4mainENKUlN2cl4sycl7nd_itemILi2EEEE_clES2_.exit: ; preds = %for.cond.i
121128
tail call spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)* %add.ptr7.i, %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(4)* %C.0.i, i64 %_arg_1, i32 0, i32 3, i32 0) #3
129+
%C.0.i.new = call spir_func %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(4)* @_Z26__spirv_CompositeConstructi(i32 42) #1
130+
%ref.tmp = alloca i32, align 4
131+
%ref.tmp.ascast = addrspacecast i32* %ref.tmp to i32 addrspace(4)*
132+
store i32 0, i32 addrspace(4)* %ref.tmp.ascast, align 4
133+
%zero = load i32, i32 addrspace(4)* %ref.tmp.ascast, align 8
134+
%C.0.i.new.load = call spir_func %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(4)* @_Z26__spirv_CompositeConstructi(i32 %zero) #1
135+
122136
ret void
123137
}
124138

@@ -137,6 +151,9 @@ declare dso_local spir_func %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(
137151
; Function Attrs: convergent
138152
declare dso_local spir_func void @_Z29__spirv_JointMatrixStoreINTELIsLm2ELm2ELN5__spv12MatrixLayoutE0ELNS0_5Scope4FlagE3EEvPT_PNS0_24__spirv_JointMatrixINTELIS4_XT0_EXT1_EXT2_EXT3_EEEmS1_S3_i(i16 addrspace(4)*, %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(4)*, i64, i32, i32, i32) local_unnamed_addr #1
139153

154+
; Function Attrs: convergent
155+
declare dso_local spir_func %"struct.__spv::__spirv_JointMatrixINTEL" addrspace(4)* @_Z26__spirv_CompositeConstructi(i32) #1
156+
140157
; Function Attrs: inaccessiblememonly nofree nosync nounwind willreturn
141158
declare void @llvm.assume(i1 noundef) #2
142159

0 commit comments

Comments
 (0)