1
1
// RUN: mlir-opt -allow-unregistered-dialect -split-input-file -verify-diagnostics %s | FileCheck %s
2
2
3
3
spirv.module Logical GLSL450 requires #spirv.vce <v1.0 , [Shader ], []> {
4
- // CHECK-LABEL: @matrix_times_scalar
5
- spirv.func @matrix_times_scalar (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f32 ) -> !spirv.matrix <3 x vector <3 xf32 >> " None" {
6
- // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32 -> !spirv.matrix<3 x vector<3xf32>>
7
- %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f32 -> !spirv.matrix < 3 x vector < 3 x f32 >>
4
+ // CHECK-LABEL: @matrix_times_scalar_1
5
+ spirv.func @matrix_times_scalar_1 (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f32 ) -> !spirv.matrix <3 x vector <3 xf32 >> " None" {
6
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.matrix<3 x vector<3xf32>>, f32
7
+ %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f32
8
8
spirv.ReturnValue %result : !spirv.matrix <3 x vector <3 xf32 >>
9
9
}
10
10
11
+ // CHECK-LABEL: @matrix_times_scalar_2
12
+ spirv.func @matrix_times_scalar_2 (%arg0 : !spirv.coopmatrix <16 x16 xf16 , Subgroup >, %arg1 : f16 ) -> !spirv.coopmatrix <16 x16 xf16 , Subgroup > " None" {
13
+ // CHECK: {{%.*}} = spirv.MatrixTimesScalar {{%.*}}, {{%.*}} : !spirv.coopmatrix<16x16xf16, Subgroup>, f16
14
+ %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.coopmatrix <16 x16 xf16 , Subgroup >, f16
15
+ spirv.ReturnValue %result : !spirv.coopmatrix <16 x16 xf16 , Subgroup >
16
+ }
17
+
11
18
// CHECK-LABEL: @matrix_transpose_1
12
19
spirv.func @matrix_transpose_1 (%arg0 : !spirv.matrix <3 x vector <2 xf32 >>) -> !spirv.matrix <2 x vector <3 xf32 >> " None" {
13
20
// CHECK: {{%.*}} = spirv.Transpose {{%.*}} : !spirv.matrix<3 x vector<2xf32>> -> !spirv.matrix<2 x vector<3xf32>>
@@ -39,82 +46,74 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
39
46
40
47
// -----
41
48
42
- func.func @input_type_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f16 ) -> () {
49
+ func.func @input_type_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f16 ) {
43
50
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
44
- %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f16 -> !spirv.matrix <3 x vector <3 xf32 >>
51
+ %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f16
52
+ return
45
53
}
46
54
47
55
// -----
48
56
49
- func.func @input_type_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f64 ) -> () {
57
+ func.func @input_type_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f64 ) {
50
58
// expected-error @+1 {{input matrix components' type and scaling value must have the same type}}
51
- %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f64 -> !spirv.matrix <3 x vector <3 xf32 >>
52
- }
53
-
54
- // -----
55
-
56
- func.func @input_output_component_type_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f32 ) -> () {
57
- // expected-error @+1 {{input and result matrices' columns must have the same component type}}
58
- %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f32 -> !spirv.matrix <3 x vector <3 xf64 >>
59
- }
60
-
61
- // -----
62
-
63
- func.func @input_output_size_mismatch (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : f32 ) -> () {
64
- // expected-error @+1 {{input and result matrices must have the same number of columns}}
65
- %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f32 -> !spirv.matrix <4 x vector <3 xf32 >>
59
+ %result = spirv.MatrixTimesScalar %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, f64
60
+ return
66
61
}
67
62
68
63
// -----
69
64
70
- func.func @transpose_op_shape_mismatch_1 (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) -> () {
65
+ func.func @transpose_op_shape_mismatch_1 (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) {
71
66
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
72
67
%result = spirv.Transpose %arg0 : !spirv.matrix <3 x vector <4 xf32 >> -> !spirv.matrix <3 x vector <3 xf32 >>
73
- spirv.Return
68
+ return
74
69
}
75
70
76
71
// -----
77
72
78
- func.func @transpose_op_shape_mismatch_2 (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) -> () {
73
+ func.func @transpose_op_shape_mismatch_2 (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) {
79
74
// expected-error @+1 {{input matrix rows count must be equal to output matrix columns count}}
80
75
%result = spirv.Transpose %arg0 : !spirv.matrix <3 x vector <4 xf32 >> -> !spirv.matrix <2 x vector <4 xf32 >>
81
- spirv.Return
76
+ return
82
77
}
83
78
84
79
// -----
85
80
86
- func.func @transpose_op_type_mismatch (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) -> () {
81
+ func.func @transpose_op_type_mismatch (%arg0 : !spirv.matrix <3 x vector <4 xf32 >>) {
87
82
// expected-error @+1 {{input and output matrices must have the same component type}}
88
83
%result = spirv.Transpose %arg0 : !spirv.matrix <3 x vector <4 xf32 >> -> !spirv.matrix <4 x vector <3 xf16 >>
89
- spirv.Return
84
+ return
90
85
}
91
86
92
87
// -----
93
88
94
89
func.func @matrix_times_matrix_invalid_input_shape_1 (%arg0 : !spirv.matrix <3 x vector <2 xf32 >>, %arg1 : !spirv.matrix <2 x vector <3 xf32 >>){
95
90
// expected-error @+1 {{right and result matrices must have equal columns' count}}
96
91
%result = spirv.MatrixTimesMatrix %arg0 , %arg1 : !spirv.matrix <3 x vector <2 xf32 >>, !spirv.matrix <2 x vector <3 xf32 >> -> !spirv.matrix <3 x vector <2 xf32 >>
92
+ return
97
93
}
98
94
99
95
// -----
100
96
101
97
func.func @matrix_times_matrix_invalid_input_shape_2 (%arg0 : !spirv.matrix <3 x vector <2 xf32 >>, %arg1 : !spirv.matrix <2 x vector <3 xf32 >>){
102
98
// expected-error @+1 {{left and result matrices must have equal rows' count}}
103
99
%result = spirv.MatrixTimesMatrix %arg0 , %arg1 : !spirv.matrix <3 x vector <2 xf32 >>, !spirv.matrix <2 x vector <3 xf32 >> -> !spirv.matrix <2 x vector <3 xf32 >>
100
+ return
104
101
}
105
102
106
103
// -----
107
104
108
105
func.func @matrix_times_matrix_inputs_shape_mismatch (%arg0 : !spirv.matrix <3 x vector <2 xf32 >>, %arg1 : !spirv.matrix <2 x vector <2 xf32 >>){
109
106
// expected-error @+1 {{left matrix columns' count must be equal to the right matrix rows' count}}
110
107
%result = spirv.MatrixTimesMatrix %arg0 , %arg1 : !spirv.matrix <3 x vector <2 xf32 >>, !spirv.matrix <2 x vector <2 xf32 >> -> !spirv.matrix <2 x vector <2 xf32 >>
108
+ return
111
109
}
112
110
113
111
// -----
114
112
115
113
func.func @matrix_times_matrix_component_type_mismatch_1 (%arg0 : !spirv.matrix <3 x vector <3 xf32 >>, %arg1 : !spirv.matrix <3 x vector <3 xf32 >>){
116
114
// expected-error @+1 {{right and result matrices' component type must be the same}}
117
115
%result = spirv.MatrixTimesMatrix %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf32 >>, !spirv.matrix <3 x vector <3 xf32 >> -> !spirv.matrix <3 x vector <3 xf64 >>
116
+ return
118
117
}
119
118
120
119
@@ -123,4 +122,5 @@ func.func @matrix_times_matrix_component_type_mismatch_1(%arg0 : !spirv.matrix<3
123
122
func.func @matrix_times_matrix_component_type_mismatch_2 (%arg0 : !spirv.matrix <3 x vector <3 xf64 >>, %arg1 : !spirv.matrix <3 x vector <3 xf32 >>){
124
123
// expected-error @+1 {{left and result matrices' component type must be the same}}
125
124
%result = spirv.MatrixTimesMatrix %arg0 , %arg1 : !spirv.matrix <3 x vector <3 xf64 >>, !spirv.matrix <3 x vector <3 xf32 >> -> !spirv.matrix <3 x vector <3 xf32 >>
125
+ return
126
126
}
0 commit comments