Skip to content

Commit 49777d7

Browse files
authored
[mlir][spirv] Add atan and atan2 pattern to MathToSPIRV Conversion pass (#102633)
Add missing math.atan to spirv.CL.atan and math.atan2 to spirv.CL.atan2 in MathToSPIRV. Add math.atan to spirv.GL.atan too.
1 parent 82ee31f commit 49777d7

File tree

3 files changed

+97
-74
lines changed

3 files changed

+97
-74
lines changed

mlir/lib/Conversion/MathToSPIRV/MathToSPIRV.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -414,6 +414,7 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
414414
ExpM1OpPattern<spirv::GLExpOp>, PowFOpPattern, RoundOpPattern,
415415
CheckedElementwiseOpPattern<math::AbsFOp, spirv::GLFAbsOp>,
416416
CheckedElementwiseOpPattern<math::AbsIOp, spirv::GLSAbsOp>,
417+
CheckedElementwiseOpPattern<math::AtanOp, spirv::GLAtanOp>,
417418
CheckedElementwiseOpPattern<math::CeilOp, spirv::GLCeilOp>,
418419
CheckedElementwiseOpPattern<math::CosOp, spirv::GLCosOp>,
419420
CheckedElementwiseOpPattern<math::ExpOp, spirv::GLExpOp>,
@@ -431,6 +432,8 @@ void populateMathToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
431432
patterns.add<Log1pOpPattern<spirv::CLLogOp>, ExpM1OpPattern<spirv::CLExpOp>,
432433
CheckedElementwiseOpPattern<math::AbsFOp, spirv::CLFAbsOp>,
433434
CheckedElementwiseOpPattern<math::AbsIOp, spirv::CLSAbsOp>,
435+
CheckedElementwiseOpPattern<math::AtanOp, spirv::CLAtanOp>,
436+
CheckedElementwiseOpPattern<math::Atan2Op, spirv::CLAtan2Op>,
434437
CheckedElementwiseOpPattern<math::CeilOp, spirv::CLCeilOp>,
435438
CheckedElementwiseOpPattern<math::CosOp, spirv::CLCosOp>,
436439
CheckedElementwiseOpPattern<math::ErfOp, spirv::CLErfOp>,

mlir/test/Conversion/MathToSPIRV/math-to-gl-spirv.mlir

Lines changed: 43 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,65 +6,69 @@ module attributes {
66

77
// CHECK-LABEL: @float32_unary_scalar
88
func.func @float32_unary_scalar(%arg0: f32) {
9+
// CHECK: spirv.GL.Atan %{{.*}}: f32
10+
%0 = math.atan %arg0 : f32
911
// CHECK: spirv.GL.Cos %{{.*}}: f32
10-
%0 = math.cos %arg0 : f32
12+
%1 = math.cos %arg0 : f32
1113
// CHECK: spirv.GL.Exp %{{.*}}: f32
12-
%1 = math.exp %arg0 : f32
14+
%2 = math.exp %arg0 : f32
1315
// CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
1416
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
1517
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
16-
%2 = math.expm1 %arg0 : f32
18+
%3 = math.expm1 %arg0 : f32
1719
// CHECK: spirv.GL.Log %{{.*}}: f32
18-
%3 = math.log %arg0 : f32
20+
%4 = math.log %arg0 : f32
1921
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
2022
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
2123
// CHECK: spirv.GL.Log %[[ADDONE]]
22-
%4 = math.log1p %arg0 : f32
24+
%5 = math.log1p %arg0 : f32
2325
// CHECK: spirv.GL.RoundEven %{{.*}}: f32
24-
%5 = math.roundeven %arg0 : f32
26+
%6 = math.roundeven %arg0 : f32
2527
// CHECK: spirv.GL.InverseSqrt %{{.*}}: f32
26-
%6 = math.rsqrt %arg0 : f32
28+
%7 = math.rsqrt %arg0 : f32
2729
// CHECK: spirv.GL.Sqrt %{{.*}}: f32
28-
%7 = math.sqrt %arg0 : f32
30+
%8 = math.sqrt %arg0 : f32
2931
// CHECK: spirv.GL.Tanh %{{.*}}: f32
30-
%8 = math.tanh %arg0 : f32
32+
%9 = math.tanh %arg0 : f32
3133
// CHECK: spirv.GL.Sin %{{.*}}: f32
32-
%9 = math.sin %arg0 : f32
34+
%10 = math.sin %arg0 : f32
3335
// CHECK: spirv.GL.FAbs %{{.*}}: f32
34-
%10 = math.absf %arg0 : f32
36+
%11 = math.absf %arg0 : f32
3537
// CHECK: spirv.GL.Ceil %{{.*}}: f32
36-
%11 = math.ceil %arg0 : f32
38+
%12 = math.ceil %arg0 : f32
3739
// CHECK: spirv.GL.Floor %{{.*}}: f32
38-
%12 = math.floor %arg0 : f32
40+
%13 = math.floor %arg0 : f32
3941
return
4042
}
4143

4244
// CHECK-LABEL: @float32_unary_vector
4345
func.func @float32_unary_vector(%arg0: vector<3xf32>) {
46+
// CHECK: spirv.GL.Atan %{{.*}}: vector<3xf32>
47+
%0 = math.atan %arg0 : vector<3xf32>
4448
// CHECK: spirv.GL.Cos %{{.*}}: vector<3xf32>
45-
%0 = math.cos %arg0 : vector<3xf32>
49+
%1 = math.cos %arg0 : vector<3xf32>
4650
// CHECK: spirv.GL.Exp %{{.*}}: vector<3xf32>
47-
%1 = math.exp %arg0 : vector<3xf32>
51+
%2 = math.exp %arg0 : vector<3xf32>
4852
// CHECK: %[[EXP:.+]] = spirv.GL.Exp %arg0
4953
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
5054
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
51-
%2 = math.expm1 %arg0 : vector<3xf32>
55+
%3 = math.expm1 %arg0 : vector<3xf32>
5256
// CHECK: spirv.GL.Log %{{.*}}: vector<3xf32>
53-
%3 = math.log %arg0 : vector<3xf32>
57+
%4 = math.log %arg0 : vector<3xf32>
5458
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
5559
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
5660
// CHECK: spirv.GL.Log %[[ADDONE]]
57-
%4 = math.log1p %arg0 : vector<3xf32>
61+
%5 = math.log1p %arg0 : vector<3xf32>
5862
// CHECK: spirv.GL.RoundEven %{{.*}}: vector<3xf32>
59-
%5 = math.roundeven %arg0 : vector<3xf32>
63+
%6 = math.roundeven %arg0 : vector<3xf32>
6064
// CHECK: spirv.GL.InverseSqrt %{{.*}}: vector<3xf32>
61-
%6 = math.rsqrt %arg0 : vector<3xf32>
65+
%7 = math.rsqrt %arg0 : vector<3xf32>
6266
// CHECK: spirv.GL.Sqrt %{{.*}}: vector<3xf32>
63-
%7 = math.sqrt %arg0 : vector<3xf32>
67+
%8 = math.sqrt %arg0 : vector<3xf32>
6468
// CHECK: spirv.GL.Tanh %{{.*}}: vector<3xf32>
65-
%8 = math.tanh %arg0 : vector<3xf32>
69+
%9 = math.tanh %arg0 : vector<3xf32>
6670
// CHECK: spirv.GL.Sin %{{.*}}: vector<3xf32>
67-
%9 = math.sin %arg0 : vector<3xf32>
71+
%10 = math.sin %arg0 : vector<3xf32>
6872
return
6973
}
7074

@@ -229,18 +233,20 @@ module attributes {
229233

230234
// CHECK-LABEL: @vector_2d
231235
func.func @vector_2d(%arg0: vector<2x2xf32>) {
236+
// CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
237+
%0 = math.atan %arg0 : vector<2x2xf32>
232238
// CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
233-
%0 = math.cos %arg0 : vector<2x2xf32>
239+
%1 = math.cos %arg0 : vector<2x2xf32>
234240
// CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
235-
%1 = math.exp %arg0 : vector<2x2xf32>
241+
%2 = math.exp %arg0 : vector<2x2xf32>
236242
// CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
237-
%2 = math.absf %arg0 : vector<2x2xf32>
243+
%3 = math.absf %arg0 : vector<2x2xf32>
238244
// CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
239-
%3 = math.ceil %arg0 : vector<2x2xf32>
245+
%4 = math.ceil %arg0 : vector<2x2xf32>
240246
// CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
241-
%4 = math.floor %arg0 : vector<2x2xf32>
247+
%5 = math.floor %arg0 : vector<2x2xf32>
242248
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
243-
%5 = math.powf %arg0, %arg0 : vector<2x2xf32>
249+
%6 = math.powf %arg0, %arg0 : vector<2x2xf32>
244250
// CHECK-NEXT: return
245251
return
246252
}
@@ -249,18 +255,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
249255

250256
// CHECK-LABEL: @tensor_1d
251257
func.func @tensor_1d(%arg0: tensor<2xf32>) {
258+
// CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
259+
%0 = math.atan %arg0 : tensor<2xf32>
252260
// CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
253-
%0 = math.cos %arg0 : tensor<2xf32>
261+
%1 = math.cos %arg0 : tensor<2xf32>
254262
// CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
255-
%1 = math.exp %arg0 : tensor<2xf32>
263+
%2 = math.exp %arg0 : tensor<2xf32>
256264
// CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
257-
%2 = math.absf %arg0 : tensor<2xf32>
265+
%3 = math.absf %arg0 : tensor<2xf32>
258266
// CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
259-
%3 = math.ceil %arg0 : tensor<2xf32>
267+
%4 = math.ceil %arg0 : tensor<2xf32>
260268
// CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
261-
%4 = math.floor %arg0 : tensor<2xf32>
269+
%5 = math.floor %arg0 : tensor<2xf32>
262270
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
263-
%5 = math.powf %arg0, %arg0 : tensor<2xf32>
271+
%6 = math.powf %arg0, %arg0 : tensor<2xf32>
264272
// CHECK-NEXT: return
265273
return
266274
}

mlir/test/Conversion/MathToSPIRV/math-to-opencl-spirv.mlir

Lines changed: 51 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -4,83 +4,91 @@ module attributes { spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Kerne
44

55
// CHECK-LABEL: @float32_unary_scalar
66
func.func @float32_unary_scalar(%arg0: f32) {
7+
// CHECK: spirv.CL.atan %{{.*}}: f32
8+
%0 = math.atan %arg0 : f32
79
// CHECK: spirv.CL.cos %{{.*}}: f32
8-
%0 = math.cos %arg0 : f32
10+
%1 = math.cos %arg0 : f32
911
// CHECK: spirv.CL.exp %{{.*}}: f32
10-
%1 = math.exp %arg0 : f32
12+
%2 = math.exp %arg0 : f32
1113
// CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
1214
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
1315
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
14-
%2 = math.expm1 %arg0 : f32
16+
%3 = math.expm1 %arg0 : f32
1517
// CHECK: spirv.CL.log %{{.*}}: f32
16-
%3 = math.log %arg0 : f32
18+
%4 = math.log %arg0 : f32
1719
// CHECK: %[[ONE:.+]] = spirv.Constant 1.000000e+00 : f32
1820
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
1921
// CHECK: spirv.CL.log %[[ADDONE]]
20-
%4 = math.log1p %arg0 : f32
22+
%5 = math.log1p %arg0 : f32
2123
// CHECK: spirv.CL.rint %{{.*}}: f32
22-
%5 = math.roundeven %arg0 : f32
24+
%6 = math.roundeven %arg0 : f32
2325
// CHECK: spirv.CL.rsqrt %{{.*}}: f32
24-
%6 = math.rsqrt %arg0 : f32
26+
%7 = math.rsqrt %arg0 : f32
2527
// CHECK: spirv.CL.sqrt %{{.*}}: f32
26-
%7 = math.sqrt %arg0 : f32
28+
%8 = math.sqrt %arg0 : f32
2729
// CHECK: spirv.CL.tanh %{{.*}}: f32
28-
%8 = math.tanh %arg0 : f32
30+
%9 = math.tanh %arg0 : f32
2931
// CHECK: spirv.CL.sin %{{.*}}: f32
30-
%9 = math.sin %arg0 : f32
32+
%10 = math.sin %arg0 : f32
3133
// CHECK: spirv.CL.fabs %{{.*}}: f32
32-
%10 = math.absf %arg0 : f32
34+
%11 = math.absf %arg0 : f32
3335
// CHECK: spirv.CL.ceil %{{.*}}: f32
34-
%11 = math.ceil %arg0 : f32
36+
%12 = math.ceil %arg0 : f32
3537
// CHECK: spirv.CL.floor %{{.*}}: f32
36-
%12 = math.floor %arg0 : f32
38+
%13 = math.floor %arg0 : f32
3739
// CHECK: spirv.CL.erf %{{.*}}: f32
38-
%13 = math.erf %arg0 : f32
40+
%14 = math.erf %arg0 : f32
3941
// CHECK: spirv.CL.round %{{.*}}: f32
40-
%14 = math.round %arg0 : f32
42+
%15 = math.round %arg0 : f32
4143
return
4244
}
4345

4446
// CHECK-LABEL: @float32_unary_vector
4547
func.func @float32_unary_vector(%arg0: vector<3xf32>) {
48+
// CHECK: spirv.CL.atan %{{.*}}: vector<3xf32>
49+
%0 = math.atan %arg0 : vector<3xf32>
4650
// CHECK: spirv.CL.cos %{{.*}}: vector<3xf32>
47-
%0 = math.cos %arg0 : vector<3xf32>
51+
%1 = math.cos %arg0 : vector<3xf32>
4852
// CHECK: spirv.CL.exp %{{.*}}: vector<3xf32>
49-
%1 = math.exp %arg0 : vector<3xf32>
53+
%2 = math.exp %arg0 : vector<3xf32>
5054
// CHECK: %[[EXP:.+]] = spirv.CL.exp %arg0
5155
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
5256
// CHECK: spirv.FSub %[[EXP]], %[[ONE]]
53-
%2 = math.expm1 %arg0 : vector<3xf32>
57+
%3 = math.expm1 %arg0 : vector<3xf32>
5458
// CHECK: spirv.CL.log %{{.*}}: vector<3xf32>
55-
%3 = math.log %arg0 : vector<3xf32>
59+
%4 = math.log %arg0 : vector<3xf32>
5660
// CHECK: %[[ONE:.+]] = spirv.Constant dense<1.000000e+00> : vector<3xf32>
5761
// CHECK: %[[ADDONE:.+]] = spirv.FAdd %[[ONE]], %{{.+}}
5862
// CHECK: spirv.CL.log %[[ADDONE]]
59-
%4 = math.log1p %arg0 : vector<3xf32>
63+
%5 = math.log1p %arg0 : vector<3xf32>
6064
// CHECK: spirv.CL.rint %{{.*}}: vector<3xf32>
61-
%5 = math.roundeven %arg0 : vector<3xf32>
65+
%6 = math.roundeven %arg0 : vector<3xf32>
6266
// CHECK: spirv.CL.rsqrt %{{.*}}: vector<3xf32>
63-
%6 = math.rsqrt %arg0 : vector<3xf32>
67+
%7 = math.rsqrt %arg0 : vector<3xf32>
6468
// CHECK: spirv.CL.sqrt %{{.*}}: vector<3xf32>
65-
%7 = math.sqrt %arg0 : vector<3xf32>
69+
%8 = math.sqrt %arg0 : vector<3xf32>
6670
// CHECK: spirv.CL.tanh %{{.*}}: vector<3xf32>
67-
%8 = math.tanh %arg0 : vector<3xf32>
71+
%9 = math.tanh %arg0 : vector<3xf32>
6872
// CHECK: spirv.CL.sin %{{.*}}: vector<3xf32>
69-
%9 = math.sin %arg0 : vector<3xf32>
73+
%10 = math.sin %arg0 : vector<3xf32>
7074
return
7175
}
7276

7377
// CHECK-LABEL: @float32_binary_scalar
7478
func.func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
79+
// CHECK: spirv.CL.atan2 %{{.*}}: f32
80+
%0 = math.atan2 %lhs, %rhs : f32
7581
// CHECK: spirv.CL.pow %{{.*}}: f32
76-
%0 = math.powf %lhs, %rhs : f32
82+
%1 = math.powf %lhs, %rhs : f32
7783
return
7884
}
7985

8086
// CHECK-LABEL: @float32_binary_vector
8187
func.func @float32_binary_vector(%lhs: vector<4xf32>, %rhs: vector<4xf32>) {
88+
// CHECK: spirv.CL.atan2 %{{.*}}: vector<4xf32>
89+
%0 = math.atan2 %lhs, %rhs : vector<4xf32>
8290
// CHECK: spirv.CL.pow %{{.*}}: vector<4xf32>
83-
%0 = math.powf %lhs, %rhs : vector<4xf32>
91+
%1 = math.powf %lhs, %rhs : vector<4xf32>
8492
return
8593
}
8694

@@ -118,18 +126,20 @@ module attributes {
118126

119127
// CHECK-LABEL: @vector_2d
120128
func.func @vector_2d(%arg0: vector<2x2xf32>) {
129+
// CHECK-NEXT: math.atan {{.+}} : vector<2x2xf32>
130+
%0 = math.atan %arg0 : vector<2x2xf32>
121131
// CHECK-NEXT: math.cos {{.+}} : vector<2x2xf32>
122-
%0 = math.cos %arg0 : vector<2x2xf32>
132+
%1 = math.cos %arg0 : vector<2x2xf32>
123133
// CHECK-NEXT: math.exp {{.+}} : vector<2x2xf32>
124-
%1 = math.exp %arg0 : vector<2x2xf32>
134+
%2 = math.exp %arg0 : vector<2x2xf32>
125135
// CHECK-NEXT: math.absf {{.+}} : vector<2x2xf32>
126-
%2 = math.absf %arg0 : vector<2x2xf32>
136+
%3 = math.absf %arg0 : vector<2x2xf32>
127137
// CHECK-NEXT: math.ceil {{.+}} : vector<2x2xf32>
128-
%3 = math.ceil %arg0 : vector<2x2xf32>
138+
%4 = math.ceil %arg0 : vector<2x2xf32>
129139
// CHECK-NEXT: math.floor {{.+}} : vector<2x2xf32>
130-
%4 = math.floor %arg0 : vector<2x2xf32>
140+
%5 = math.floor %arg0 : vector<2x2xf32>
131141
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : vector<2x2xf32>
132-
%5 = math.powf %arg0, %arg0 : vector<2x2xf32>
142+
%6 = math.powf %arg0, %arg0 : vector<2x2xf32>
133143
// CHECK-NEXT: return
134144
return
135145
}
@@ -138,18 +148,20 @@ func.func @vector_2d(%arg0: vector<2x2xf32>) {
138148

139149
// CHECK-LABEL: @tensor_1d
140150
func.func @tensor_1d(%arg0: tensor<2xf32>) {
151+
// CHECK-NEXT: math.atan {{.+}} : tensor<2xf32>
152+
%0 = math.atan %arg0 : tensor<2xf32>
141153
// CHECK-NEXT: math.cos {{.+}} : tensor<2xf32>
142-
%0 = math.cos %arg0 : tensor<2xf32>
154+
%1 = math.cos %arg0 : tensor<2xf32>
143155
// CHECK-NEXT: math.exp {{.+}} : tensor<2xf32>
144-
%1 = math.exp %arg0 : tensor<2xf32>
156+
%2 = math.exp %arg0 : tensor<2xf32>
145157
// CHECK-NEXT: math.absf {{.+}} : tensor<2xf32>
146-
%2 = math.absf %arg0 : tensor<2xf32>
158+
%3 = math.absf %arg0 : tensor<2xf32>
147159
// CHECK-NEXT: math.ceil {{.+}} : tensor<2xf32>
148-
%3 = math.ceil %arg0 : tensor<2xf32>
160+
%4 = math.ceil %arg0 : tensor<2xf32>
149161
// CHECK-NEXT: math.floor {{.+}} : tensor<2xf32>
150-
%4 = math.floor %arg0 : tensor<2xf32>
162+
%5 = math.floor %arg0 : tensor<2xf32>
151163
// CHECK-NEXT: math.powf {{.+}}, {{%.+}} : tensor<2xf32>
152-
%5 = math.powf %arg0, %arg0 : tensor<2xf32>
164+
%6 = math.powf %arg0, %arg0 : tensor<2xf32>
153165
// CHECK-NEXT: return
154166
return
155167
}

0 commit comments

Comments
 (0)