Skip to content

Commit d053f45

Browse files
VyacheslavLevytskyysys-ce-bb
authored andcommitted
Out of bounds cooperative matrix load/store/fill work with legacy TypeJointMatrixINTEL (#2318)
This PR is to ensure that CooperativeMatrixLoadCheckedINTEL, CooperativeMatrixStoreCheckedINTEL and CooperativeMatrixConstructCheckedINTEL instructions can correctly accept TypeJointMatrixINTEL as an input during both forward and reverse translation. Original commit: KhronosGroup/SPIRV-LLVM-Translator@632c3f626990f36
1 parent 4ddf13b commit d053f45

File tree

3 files changed

+172
-12
lines changed

3 files changed

+172
-12
lines changed

llvm-spirv/lib/SPIRV/SPIRVReader.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3297,6 +3297,7 @@ Instruction *SPIRVToLLVM::transSPIRVBuiltinFromInst(SPIRVInstruction *BI,
32973297
case OpSUDotAccSatKHR:
32983298
case internal::OpJointMatrixLoadINTEL:
32993299
case OpCooperativeMatrixLoadKHR:
3300+
case internal::OpCooperativeMatrixLoadCheckedINTEL:
33003301
case internal::OpTaskSequenceCreateINTEL:
33013302
AddRetTypePostfix = true;
33023303
break;

llvm-spirv/test/extensions/INTEL/SPV_INTEL_joint_matrix/cooperative_matrix_checked.ll

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
; RUN: llvm-dis < %t.rev.bc | FileCheck %s --check-prefix=CHECK-LLVM
1010

1111
; CHECK-SPIRV-DAG: Capability CooperativeMatrixKHR
12+
; CHECK-SPIRV-DAG: Capability CooperativeMatrixCheckedInstructionsINTEL
1213
; CHECK-SPIRV-DAG: Extension "SPV_KHR_cooperative_matrix"
1314
; CHECK-SPIRV-DAG: Extension "SPV_INTEL_joint_matrix"
1415
; CHECK-SPIRV-DAG: TypeInt [[#Int8Ty:]] 8 0
@@ -30,12 +31,12 @@
3031
; CHECK-SPIRV: CooperativeMatrixMulAddKHR [[#MatTy1]]
3132
; CHECK-SPIRV: CooperativeMatrixStoreCheckedINTEL
3233

33-
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTELiilli(i32 4, i32 4, i64 12, i64 12, i32 %_arg_Initvalue)
34-
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiillli(ptr addrspace(4) %[[MatrixPtr:[%0-9a-z.]+]], i32 0, i32 0, i32 0, i64 12, i64 48, i64 %_arg_K, i32 1)
34+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTELiiiii(i32 4, i32 4, i32 12, i32 12, i32 %_arg_Initvalue)
35+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z95__spirv_CooperativeMatrixLoadCheckedINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__char_3_12_48_0PU3AS4ciiiiili(ptr addrspace(4) %[[MatrixPtr:[%0-9a-z.]+]], i32 0, i32 0, i32 0, i32 12, i32 48, i64 %_arg_K, i32 1)
3536
; CHECK-LLVM: call spir_func i32 @_Z34__spirv_CooperativeMatrixLengthKHRPU3AS144__spirv_CooperativeMatrixKHR__char_3_12_48_0(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0)
36-
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTELPU3AS4ciiilll
37+
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z95__spirv_CooperativeMatrixLoadCheckedINTEL_RPU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_1PU3AS4ciiiiil
3738
; CHECK-LLVM: call spir_func target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHRPU3AS144__spirv_CooperativeMatrixKHR__char_3_12_48_0PU3AS144__spirv_CooperativeMatrixKHR__char_2_48_12_1PU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_2i(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) %{{.*}}, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) %{{.*}}, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2)
38-
; CHECK-LLVM: call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTELPU3AS4iiiPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_2illli(ptr addrspace(4) %{{.*}}, i32 0, i32 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2)
39+
; CHECK-LLVM: call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTELPU3AS4iiiPU3AS144__spirv_CooperativeMatrixKHR__uint_3_12_12_2iiili(ptr addrspace(4) %{{.*}}, i32 0, i32 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2)
3940

4041
; ModuleID = 'test-matrix-opaque.bc'
4142
source_filename = "matrix-int8-test.cpp"
@@ -77,7 +78,7 @@ entry:
7778
%cmp.i58.i = icmp ult i64 %5, 2147483648
7879
%sub5.i = sub nsw i64 %2, %5
7980
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
80-
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i64 noundef 12, i64 noundef 12, i32 noundef %_arg_Initvalue) #4
81+
%call.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef 4, i32 noundef 4, i32 noundef 12, i32 noundef 12, i32 noundef %_arg_Initvalue) #4
8182
store target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) %call.i.i, ptr %sub_c.sroa.0.i, align 8
8283
%mul.i = mul nsw i64 %sub.i, 12
8384
%div2452.i = lshr i64 %sub5.i, 4
@@ -102,14 +103,14 @@ for.body.i: ; preds = %for.cond.i
102103
%conv13.i = zext i32 %mul12.i to i64
103104
%add.ptr.i96.i = getelementptr inbounds i8, ptr addrspace(1) %add.ptr.i93.i, i64 %conv13.i
104105
%call.ascast.i66.i = addrspacecast ptr addrspace(1) %add.ptr.i96.i to ptr addrspace(4)
105-
%call1.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef %call.ascast.i66.i, i32 noundef 0, i32 noundef 0, i32 noundef 0, i64 noundef 12, i64 noundef 48, i64 noundef %_arg_K, i32 noundef 1) #4
106+
%call1.i.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef %call.ascast.i66.i, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 12, i32 noundef 48, i64 noundef %_arg_K, i32 noundef 1) #4
106107
%len = tail call spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) %call1.i.i)
107108
%div20.i = mul nsw i32 %k.0.i, 12
108109
%conv21.i = zext i32 %div20.i to i64
109110
%mul23.i = mul i64 %mul22.i, %conv21.i
110111
%add.ptr.i111.i = getelementptr i8, ptr addrspace(1) %add.ptr.i108140.i, i64 %mul23.i
111112
%call.ascast.i72.i = addrspacecast ptr addrspace(1) %add.ptr.i111.i to ptr addrspace(4)
112-
%call1.i73.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef %call.ascast.i72.i, i32 noundef 0, i32 noundef 0, i32 noundef 0, i64 noundef 48, i64 noundef 12, i64 noundef %mul22.i) #4
113+
%call1.i73.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef %call.ascast.i72.i, i32 noundef 0, i32 noundef 0, i32 noundef 0, i32 noundef 48, i32 noundef 12, i64 noundef %mul22.i) #4
113114
call void @llvm.lifetime.start.p0(i64 8, ptr nonnull %ref.tmp29.sroa.0.i)
114115
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
115116
%call.i77.i = tail call spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef %call1.i.i, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef %call1.i73.i, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0.125.i, i32 noundef 12) #4
@@ -127,27 +128,27 @@ _ZZZ15matrix_multiplyIiaLm24ELm96ELm24ELm96ELm24ELm24EEvR10big_matrixIT_XT5_EXT6
127128
%add.ptr.i81.i = getelementptr inbounds i32, ptr addrspace(1) %add.ptr.i.i, i64 %mul39.i
128129
%call.ascast.i.i = addrspacecast ptr addrspace(1) %add.ptr.i81.i to ptr addrspace(4)
129130
%sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i = load target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2), ptr %sub_c.sroa.0.i, align 8
130-
tail call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef %call.ascast.i.i, i32 noundef 0, i32 noundef 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 noundef 0, i64 noundef 12, i64 noundef 12, i64 noundef %_arg_N, i32 noundef 1) #4
131+
tail call spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef %call.ascast.i.i, i32 noundef 0, i32 noundef 0, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef %sub_c.sroa.0.i.0.sub_c.sroa.0.i.0.sub_c.sroa.0.0.sub_c.sroa.0.0.sub_c.sroa.0.0..i, i32 noundef 0, i32 noundef 12, i32 noundef 12, i64 noundef %_arg_N, i32 noundef 1) #4
131132
call void @llvm.lifetime.end.p0(i64 8, ptr nonnull %sub_c.sroa.0.i)
132133
ret void
133134
}
134135

135136
; Function Attrs: convergent
136-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
137+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z46__spirv_CooperativeMatrixConstructCheckedINTEL(i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef) local_unnamed_addr #2
137138

138139
declare dso_local spir_func noundef i32 @_Z34__spirv_CooperativeMatrixLengthKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef)
139140

140141
; Function Attrs: convergent
141-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
142+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_1(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
142143

143144
; Function Attrs: convergent
144-
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i64 noundef, i64 noundef) local_unnamed_addr #2
145+
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) @_Z41__spirv_CooperativeMatrixLoadCheckedINTEL_2(ptr addrspace(4) noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef) local_unnamed_addr #2
145146

146147
; Function Attrs: convergent
147148
declare dso_local spir_func noundef target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) @_Z34__spirv_CooperativeMatrixMulAddKHR(target("spirv.CooperativeMatrixKHR", i8, 3, 12, 48, 0) noundef, target("spirv.CooperativeMatrixKHR", i8, 2, 48, 12, 1) noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef) local_unnamed_addr #2
148149

149150
; Function Attrs: convergent
150-
declare dso_local spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef, i64 noundef, i64 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
151+
declare dso_local spir_func void @_Z42__spirv_CooperativeMatrixStoreCheckedINTEL(ptr addrspace(4) noundef, i32 noundef, i32 noundef, target("spirv.CooperativeMatrixKHR", i32, 3, 12, 12, 2) noundef, i32 noundef, i32 noundef, i32 noundef, i64 noundef, i32 noundef) local_unnamed_addr #2
151152

152153
; Function Attrs: nocallback nofree nosync nounwind willreturn memory(argmem: readwrite)
153154
declare void @llvm.lifetime.start.p0(i64 immarg, ptr nocapture) #3

0 commit comments

Comments
 (0)