Skip to content

Commit a1b1f49

Browse files
add initial support for CooperativeMatrixConstructCheckedINTEL (#2331)
Add support for checked matrix construct instruction. Specification draft: https://github.com/intel/llvm/blob/2fa153ee852ea3d7d64df097f1f494cddacee90e/sycl/doc/design/spirv-extensions/SPV_INTEL_joint_matrix.asciidoc
1 parent 239fbd4 commit a1b1f49

File tree

4 files changed

+10
-5
lines changed

4 files changed

+10
-5
lines changed

lib/SPIRV/libSPIRV/SPIRVInstruction.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3457,6 +3457,7 @@ class SPIRVCooperativeMatrixCheckedInstructionsINTELInstBase
34573457
SPIRV##x##INTEL;
34583458
_SPIRV_OP(CooperativeMatrixLoadChecked, true, 9, true, 7)
34593459
_SPIRV_OP(CooperativeMatrixStoreChecked, false, 8, true, 8)
3460+
_SPIRV_OP(CooperativeMatrixConstructChecked, true, 8)
34603461
#undef _SPIRV_OP
34613462

34623463
class SPIRVCooperativeMatrixInvocationInstructionsINTELInstBase

lib/SPIRV/libSPIRV/SPIRVOpCodeEnumInternal.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ _SPIRV_OP_INTERNAL(CooperativeMatrixLoadCheckedINTEL,
2222
internal::OpCooperativeMatrixLoadCheckedINTEL)
2323
_SPIRV_OP_INTERNAL(CooperativeMatrixStoreCheckedINTEL,
2424
internal::OpCooperativeMatrixStoreCheckedINTEL)
25+
_SPIRV_OP_INTERNAL(CooperativeMatrixConstructCheckedINTEL,
26+
internal::OpCooperativeMatrixConstructCheckedINTEL)
2527
_SPIRV_OP_INTERNAL(CooperativeMatrixApplyFunctionINTEL,
2628
internal::OpCooperativeMatrixApplyFunctionINTEL)
2729
_SPIRV_OP_INTERNAL(ComplexFMulINTEL, internal::ComplexFMulINTEL)

lib/SPIRV/libSPIRV/spirv_internal.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ enum InternalOp {
7272
IOpTypeJointMatrixINTELv2 = 6184,
7373
IOpCooperativeMatrixLoadCheckedINTEL = 6193,
7474
IOpCooperativeMatrixStoreCheckedINTEL = 6194,
75+
IOpCooperativeMatrixConstructCheckedINTEL = 6195,
7576
IOpJointMatrixWorkItemLengthINTEL = 6410,
7677
IOpComplexFMulINTEL = 6415,
7778
IOpComplexFDivINTEL = 6416,
@@ -189,6 +190,7 @@ _SPIRV_OP(Op, CooperativeMatrixPrefetchINTEL)
189190
_SPIRV_OP(Capability, CooperativeMatrixCheckedInstructionsINTEL)
190191
_SPIRV_OP(Op, CooperativeMatrixLoadCheckedINTEL)
191192
_SPIRV_OP(Op, CooperativeMatrixStoreCheckedINTEL)
193+
_SPIRV_OP(Op, CooperativeMatrixConstructCheckedINTEL)
192194

193195
_SPIRV_OP(Capability, CooperativeMatrixInvocationInstructionsINTEL)
194196
_SPIRV_OP(Op, CooperativeMatrixApplyFunctionINTEL)

test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@
2222
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy1:]] [[#Int32Ty]] [[#Const3]] [[#Const12]] [[#Const12]] [[#Const2]]
2323
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy2:]] [[#Int8Ty]] [[#Const3]] [[#Const12]] [[#Const48]] [[#Const0]]
2424
; CHECK-SPIRV-DAG: TypeCooperativeMatrixKHR [[#MatTy3:]] [[#Int8Ty]] [[#Const2]] [[#Const48]] [[#Const12]] [[#Const1]]
25-
; CHECK-SPIRV: CompositeConstruct [[#MatTy1]]
25+
; CHECK-SPIRV: CooperativeMatrixConstructCheckedINTEL [[#MatTy1]]
2626
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy2]] [[#Load1:]]
2727
; TODO: Pass Matrix Type Id instead of Matrix Id to CooperativeMatrixLengthKHR.
2828
; CHECK-SPIRV: CooperativeMatrixLengthKHR [[#Int32Ty]] [[#]] [[#Load1]]
2929
; CHECK-SPIRV: CooperativeMatrixLoadCheckedINTEL [[#MatTy3]]
3030
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
3131
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL
3232

33-
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstructi(i32 0)
33+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTELiilli(i32 4, i32 4, i64 12, i64 12, i32 %_arg_Initvalue)
3434
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiillli(ptr addrspace(4) %[[MatrixPtr:[%0-9a-z.]+]], i32 0, i32 0, i32 0, i64 12, i64 48, i64 %_arg_K, i32 1)
3535
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_3_12_48_0(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0)
3636
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiilll
@@ -52,7 +52,7 @@ $_ZTSZZ15matrix_multiply = comdat any
5252
@__spirv_BuiltInLocalInvocationId = external dso_local local_unnamed_addr addrspace(1) constant <3 x i64>, align 32
5353

5454
; Function Attrs: convergent norecurse
55-
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K) local_unnamed_addr #0 comdat {
55+
define weak_odr dso_local spir_kernel void @_ZTSZZ15matrix_multiply(ptr addrspace(1) noundef align 1 %_arg_accA, ptr addrspace(1) noundef align 1 %_arg_accB, ptr noundef byval(%"class.sycl::_V1::range") align 8 %_arg_accB5, ptr noundef byval(%"class.sycl::_V1::id") align 8 %_arg_accB6, ptr addrspace(1) noundef align 4 %_arg_accC, i64 noundef %_arg_N, i64 noundef %_arg_K, i32 noundef %_arg_Initvalue) local_unnamed_addr #0 comdat {
5656
entry:
5757
%sub_c.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
5858
%ref.tmp29.sroa.0.i = alloca target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), align 8
@@ -77,7 +77,7 @@ entry:
7777
%cmp.i58.i = icmp ult i64 %5, 2147483648
7878
%sub5.i = sub nsw i64 %2, %5
7979
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
80-
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 noundef 0) #4
80+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i64 noundef 12, i64 noundef 12, i32 noundef %_arg_Initvalue) #4
8181
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
8282
%mul.i = mul nsw i64 %sub.i, 12
8383
%div2452.i = lshr i64 %sub5.i, 4
@@ -133,7 +133,7 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6
133133
}
134134

135135
; Function Attrs: convergent
136-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z26__spirv_CompositeConstruct(i32 noundef) local_unnamed_addr #2
136+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
137137

138138
declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef)
139139

0 commit comments

Comments
 (0)