9
9
; RUN: igc_opt -platformpvc -igc-joint-matrix-resolution -S 2>&1 < %s | FileCheck %s
10
10
; ------------------------------------------------
11
11
; JointMatrixFuncsResolutionPass
12
+ ;
13
+ ; Test verifies resolution of joint matrix extract and insert functions,
14
+ ; including adding of joint_matrix_apply metadata.
12
15
; ------------------------------------------------
13
16
14
17
%spirv.JointMatrixINTEL._float_16_16_3_3_2 = type opaque
18
21
; CHECK-SAME: float addrspace(1)* [[PTR1:%.*]], i64 [[IND1:%.*]], float addrspace(1)* [[PTR2:%.*]], i64 [[IND2:%.*]]) {
19
22
define spir_kernel void @test (float addrspace (1 )* %ptr1 , i64 %ind1 , float addrspace (1 )* %ptr2 , i64 %ind2 ) {
20
23
; CHECK-NEXT: [[TMP1:%.*]] = alloca [2 x <64 x float>]
21
- ; CHECK-NEXT: [[TMP2:%.*]] = alloca <16 x float>
24
+ ; CHECK-NEXT: [[TMP2:%.*]] = alloca [2 x <64 x float>]
25
+ ; CHECK-NEXT: [[TMP3:%.*]] = alloca <16 x float>
22
26
23
- ; CHECK-NEXT: [[TMP3 :%.*]] = bitcast <16 x float>* [[TMP2 ]] to i8*
24
- ; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_global_v8i8_pi32_i32(i8* [[TMP3 ]], float addrspace(1)* [[PTR1]], i64 32, i32 0)
25
- ; CHECK-NEXT: [[TMP4 :%.*]] = load <16 x float>, <16 x float>* [[TMP2 ]]
27
+ ; CHECK-NEXT: [[TMP4 :%.*]] = bitcast <16 x float>* [[TMP3 ]] to i8*
28
+ ; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_16x16_i32_16_global_v8i8_pi32_i32(i8* [[TMP4 ]], float addrspace(1)* [[PTR1]], i64 32, i32 0)
29
+ ; CHECK-NEXT: [[TMP5 :%.*]] = load <16 x float>, <16 x float>* [[TMP3 ]]
26
30
%C1 = call spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace (1 )* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fliii (float addrspace (1 )* %ptr1 , i64 32 , i32 0 , i32 3 , i32 0 )
27
31
28
- ; CHECK-NEXT: [[MATRIX_ELEMENT:%.*]] = extractelement <16 x float> [[TMP4 ]], i64 [[IND1]]
32
+ ; CHECK-NEXT: [[MATRIX_ELEMENT:%.*]] = extractelement <16 x float> [[TMP5 ]], i64 [[IND1]]
29
33
%1 = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2l (%spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace (1 )* %C1 , i64 %ind1 )
30
34
31
- ; CHECK-NEXT: [[TMP5 :%.*]] = fadd float [[MATRIX_ELEMENT]], 5.000000e+00
35
+ ; CHECK-NEXT: [[TMP6 :%.*]] = fadd float [[MATRIX_ELEMENT]], 5.000000e+00
32
36
%2 = fadd float %1 , 5 .0
33
37
34
- ; CHECK-NEXT: [[TMP6 :%.*]] = insertelement <16 x float> [[TMP4 ]], float [[TMP5 ]], i64 [[IND1]]
38
+ ; CHECK-NEXT: [[TMP7 :%.*]] = insertelement <16 x float> [[TMP5 ]], float [[TMP6 ]], i64 [[IND1]]
35
39
%3 = call spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace (1 )* @_Z27__spirv_VectorInsertDynamicPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2fl (%spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace (1 )* %C1 , float %2 , i64 %ind1 )
36
40
37
- ; CHECK-NEXT: [[TMP7 :%.*]] = bitcast [2 x <64 x float>]* [[TMP1 ]] to i8*
38
- ; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_global_v8i8_pi32_i32(i8* [[TMP7 ]], float addrspace(1)* [[PTR2]], i64 128, i32 0)
39
- ; CHECK-NEXT: [[TMP8 :%.*]] = bitcast [2 x <64 x float>]* [[TMP1 ]] to <64 x float>*
40
- ; CHECK-NEXT: [[TMP9 :%.*]] = load <64 x float>, <64 x float>* [[TMP8 ]]
41
- ; CHECK-NEXT: [[TMP10 :%.*]] = getelementptr <64 x float>, <64 x float>* [[TMP8 ]], i32 1
42
- ; CHECK-NEXT: [[TMP11 :%.*]] = load <64 x float>, <64 x float>* [[TMP10 ]]
43
- ; CHECK-NEXT: [[TMP12 :%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP9 ]], 0
44
- ; CHECK-NEXT: [[TMP13 :%.*]] = insertvalue [2 x <64 x float>] [[TMP12 ]], <64 x float> [[TMP11 ]], 1
41
+ ; CHECK-NEXT: [[TMP8 :%.*]] = bitcast [2 x <64 x float>]* [[TMP2 ]] to i8*
42
+ ; CHECK-NEXT: call void @__builtin_spriv_OpJointMatrixLoadINTEL_Accumulator_RowMajor_SG16_32x64_i32_128_global_v8i8_pi32_i32(i8* [[TMP8 ]], float addrspace(1)* [[PTR2]], i64 128, i32 0)
43
+ ; CHECK-NEXT: [[TMP9 :%.*]] = bitcast [2 x <64 x float>]* [[TMP2 ]] to <64 x float>*
44
+ ; CHECK-NEXT: [[TMP10 :%.*]] = load <64 x float>, <64 x float>* [[TMP9 ]]
45
+ ; CHECK-NEXT: [[TMP11 :%.*]] = getelementptr <64 x float>, <64 x float>* [[TMP9 ]], i32 1
46
+ ; CHECK-NEXT: [[TMP12 :%.*]] = load <64 x float>, <64 x float>* [[TMP11 ]]
47
+ ; CHECK-NEXT: [[TMP13 :%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP10 ]], 0
48
+ ; CHECK-NEXT: [[TMP14 :%.*]] = insertvalue [2 x <64 x float>] [[TMP13 ]], <64 x float> [[TMP12 ]], 1
45
49
%C2 = call spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace (1 )* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2PU3AS1fliii (float addrspace (1 )* %ptr2 , i64 128 , i32 0 , i32 3 , i32 0 )
46
50
47
- ; CHECK-NEXT: [[TMP14:%.*]] = icmp ugt i64 [[IND2]], 63
48
- ; CHECK-NEXT: [[MATRIX_SLICE_HALF0:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 0
49
- ; CHECK-NEXT: [[MATRIX_SLICE_HALF1:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 1
50
- ; CHECK-NEXT: [[MATRIX_SLICE_SELECTED_HALF:%.*]] = select i1 [[TMP14]], <64 x float> [[MATRIX_SLICE_HALF1]], <64 x float> [[MATRIX_SLICE_HALF0]]
51
- ; CHECK-NEXT: [[TMP15:%.*]] = urem i64 [[IND2]], 64
52
- ; CHECK-NEXT: [[MATRIX_ELEMENT5:%.*]] = extractelement <64 x float> [[MATRIX_SLICE_SELECTED_HALF]], i64 [[TMP15]]
51
+ ; CHECK-NEXT: store [2 x <64 x float>] [[TMP14]], [2 x <64 x float>]* [[TMP1]]
52
+ ; CHECK-NEXT: [[TMP15:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to float*
53
+ ; CHECK-NEXT: [[TMP16:%.*]] = getelementptr float, float* [[TMP15]], i64 [[IND2]]
54
+ ; CHECK-NEXT: [[TMP17:%.*]] = load float, float* [[TMP16]],{{.*}} !joint_matrix_apply [[MD:![0-9]+]]
53
55
%4 = call spir_func float @_Z28__spirv_VectorExtractDynamicPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2l (%spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace (1 )* %C2 , i64 %ind2 )
54
56
55
- ; CHECK-NEXT: [[TMP16 :%.*]] = fadd float [[MATRIX_ELEMENT5 ]], 5.000000e+00
57
+ ; CHECK-NEXT: [[TMP18 :%.*]] = fadd float [[TMP17 ]], 5.000000e+00
56
58
%5 = fadd float %4 , 5 .0
57
59
58
- ; CHECK-NEXT: [[TMP17:%.*]] = icmp ugt i64 [[IND2]], 63
59
- ; CHECK-NEXT: [[MATRIX_SLICE_HALF07:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 0
60
- ; CHECK-NEXT: [[MATRIX_SLICE_HALF18:%.*]] = extractvalue [2 x <64 x float>] [[TMP13]], 1
61
- ; CHECK-NEXT: [[MATRIX_SLICE_SELECTED_HALF9:%.*]] = select i1 [[TMP17]], <64 x float> [[MATRIX_SLICE_HALF18]], <64 x float> [[MATRIX_SLICE_HALF07]]
62
- ; CHECK-NEXT: [[TMP18:%.*]] = urem i64 [[IND2]], 64
63
- ; CHECK-NEXT: [[TMP19:%.*]] = insertelement <64 x float> [[MATRIX_SLICE_SELECTED_HALF9]], float [[TMP16]], i64 [[TMP18]]
64
- ; CHECK-NEXT: [[TMP20:%.*]] = select i1 [[TMP17]], <64 x float> [[MATRIX_SLICE_HALF07]], <64 x float> [[TMP19]]
65
- ; CHECK-NEXT: [[TMP21:%.*]] = select i1 [[TMP17]], <64 x float> [[TMP19]], <64 x float> [[MATRIX_SLICE_HALF18]]
66
- ; CHECK-NEXT: [[TMP22:%.*]] = insertvalue [2 x <64 x float>] undef, <64 x float> [[TMP20]], 0
67
- ; CHECK-NEXT: [[TMP23:%.*]] = insertvalue [2 x <64 x float>] [[TMP22]], <64 x float> [[TMP21]], 1
60
+ ; CHECK-NEXT: store [2 x <64 x float>] [[TMP14]], [2 x <64 x float>]* [[TMP1]]
61
+ ; CHECK-NEXT: [[TMP19:%.*]] = bitcast [2 x <64 x float>]* [[TMP1]] to float*
62
+ ; CHECK-NEXT: [[TMP20:%.*]] = getelementptr float, float* [[TMP19]], i64 [[IND2]]
63
+ ; CHECK-NEXT: store float [[TMP18]], float* [[TMP20]]
64
+ ; CHECK-NEXT: [[TMP21:%.*]] = load [2 x <64 x float>], [2 x <64 x float>]* [[TMP1]]
68
65
%6 = call spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace (1 )* @_Z27__spirv_VectorInsertDynamicPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2fl (%spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace (1 )* %C2 , float %5 , i64 %ind2 )
69
66
70
67
; CHECK-NEXT: ret void
@@ -79,6 +76,7 @@ declare spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace(1)* @_Z27
79
76
declare spir_func %spirv.JointMatrixINTEL._float_32_64_3_3_2 addrspace (1 )* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_32_64_3_3_2PU3AS1fliii (float addrspace (1 )*, i64 , i32 , i32 , i32 )
80
77
declare spir_func %spirv.JointMatrixINTEL._float_16_16_3_3_2 addrspace (1 )* @_Z81__spirv_JointMatrixLoadINTEL_RPU3AS143__spirv_JointMatrixINTEL__float_16_16_3_3_2PU3AS1fliii (float addrspace (1 )*, i64 , i32 , i32 , i32 )
81
78
79
+ ; CHECK: [[MD]] = !{i1 true}
82
80
!igc.functions = !{!0 }
83
81
!0 = !{void (float addrspace (1 )*, i64 , float addrspace (1 )*, i64 )* @test , !1 }
84
82
!1 = !{!2 , !3 }
0 commit comments