Skip to content

Commit 7a89444

Browse files
committed
[mlir][spirv] Add ops and patterns for lowering standard max/min ops
Reviewed By: ThomasRaoux Differential Revision: https://reviews.llvm.org/D111143
1 parent cc1d13f commit 7a89444

File tree

5 files changed

+123
-12
lines changed

5 files changed

+123
-12
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,10 @@ def SPV_GLSLTanOp : SPV_GLSLUnaryArithmeticOp<"Tan", 15, SPV_Float16or32> {
261261
let description = [{
262262
The standard trigonometric tangent of x radians.
263263

264-
The operand x must be a scalar or vector whose component type is 16-bit or
264+
The operand x must be a scalar or vector whose component type is 16-bit or
265265
32-bit floating-point.
266266

267-
Result Type and the type of x must be the same type. Results are computed
267+
Result Type and the type of x must be the same type. Results are computed
268268
per component.
269269

270270
<!-- End of AutoGen section -->
@@ -576,6 +576,36 @@ def SPV_GLSLFMaxOp : SPV_GLSLBinaryArithmeticOp<"FMax", 40, SPV_Float> {
576576

577577
// -----
578578

579+
def SPV_GLSLUMaxOp : SPV_GLSLBinaryArithmeticOp<"UMax", 41, SPV_Integer> {
580+
let summary = "Return maximum of two unsigned integer operands";
581+
582+
let description = [{
583+
Result is y if x < y; otherwise result is x, where x and y are interpreted
584+
as unsigned integers.
585+
586+
Result Type and the type of x and y must both be integer scalar or integer
587+
vector types. Result Type and operand types must have the same number of
588+
components with the same component width. Results are computed per
589+
component.
590+
591+
<!-- End of AutoGen section -->
592+
```
593+
integer-scalar-vector-type ::= integer-type |
594+
`vector<` integer-literal `x` integer-type `>`
595+
smax-op ::= ssa-id `=` `spv.GLSL.UMax` ssa-use `:`
596+
integer-scalar-vector-type
597+
```
598+
#### Example:
599+
600+
```mlir
601+
%2 = spv.GLSL.UMax %0, %1 : i32
602+
%3 = spv.GLSL.UMax %0, %1 : vector<3xi16>
603+
```
604+
}];
605+
}
606+
607+
// -----
608+
579609
def SPV_GLSLSMaxOp : SPV_GLSLBinaryArithmeticOp<"SMax", 42, SPV_Integer> {
580610
let summary = "Return maximum of two signed integer operands";
581611

@@ -637,6 +667,36 @@ def SPV_GLSLFMinOp : SPV_GLSLBinaryArithmeticOp<"FMin", 37, SPV_Float> {
637667

638668
// -----
639669

670+
def SPV_GLSLUMinOp : SPV_GLSLBinaryArithmeticOp<"UMin", 38, SPV_Integer> {
671+
let summary = "Return minimum of two unsigned integer operands";
672+
673+
let description = [{
674+
Result is y if y < x; otherwise result is x, where x and y are interpreted
675+
as unsigned integers.
676+
677+
Result Type and the type of x and y must both be integer scalar or integer
678+
vector types. Result Type and operand types must have the same number of
679+
components with the same component width. Results are computed per
680+
component.
681+
682+
<!-- End of AutoGen section -->
683+
```
684+
integer-scalar-vector-type ::= integer-type |
685+
`vector<` integer-literal `x` integer-type `>`
686+
smin-op ::= ssa-id `=` `spv.GLSL.UMin` ssa-use `:`
687+
integer-scalar-vector-type
688+
```
689+
#### Example:
690+
691+
```mlir
692+
%2 = spv.GLSL.UMin %0, %1 : i32
693+
%3 = spv.GLSL.UMin %0, %1 : vector<3xi16>
694+
```
695+
}];
696+
}
697+
698+
// -----
699+
640700
def SPV_GLSLSMinOp : SPV_GLSLBinaryArithmeticOp<"SMin", 39, SPV_Integer> {
641701
let summary = "Return minimum of two signed integer operands";
642702

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRV.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -889,6 +889,12 @@ void populateStandardToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
889889
UnaryAndBinaryOpPattern<CeilFOp, spirv::GLSLCeilOp>,
890890
UnaryAndBinaryOpPattern<DivFOp, spirv::FDivOp>,
891891
UnaryAndBinaryOpPattern<FloorFOp, spirv::GLSLFloorOp>,
892+
UnaryAndBinaryOpPattern<MaxFOp, spirv::GLSLFMaxOp>,
893+
UnaryAndBinaryOpPattern<MaxSIOp, spirv::GLSLSMaxOp>,
894+
UnaryAndBinaryOpPattern<MaxUIOp, spirv::GLSLUMaxOp>,
895+
UnaryAndBinaryOpPattern<MinFOp, spirv::GLSLFMinOp>,
896+
UnaryAndBinaryOpPattern<MinSIOp, spirv::GLSLSMinOp>,
897+
UnaryAndBinaryOpPattern<MinUIOp, spirv::GLSLUMinOp>,
892898
UnaryAndBinaryOpPattern<MulFOp, spirv::FMulOp>,
893899
UnaryAndBinaryOpPattern<MulIOp, spirv::IMulOp>,
894900
UnaryAndBinaryOpPattern<NegFOp, spirv::FNegateOp>,

mlir/test/Conversion/StandardToSPIRV/std-ops-to-spirv.mlir

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ func @int32_scalar(%lhs: i32, %rhs: i32) {
2424
%4 = divi_unsigned %lhs, %rhs: i32
2525
// CHECK: spv.UMod %{{.*}}, %{{.*}}: i32
2626
%5 = remi_unsigned %lhs, %rhs: i32
27+
// CHECK: spv.GLSL.SMax %{{.*}}, %{{.*}}: i32
28+
%6 = maxsi %lhs, %rhs : i32
29+
// CHECK: spv.GLSL.UMax %{{.*}}, %{{.*}}: i32
30+
%7 = maxui %lhs, %rhs : i32
31+
// CHECK: spv.GLSL.SMin %{{.*}}, %{{.*}}: i32
32+
%8 = minsi %lhs, %rhs : i32
33+
// CHECK: spv.GLSL.UMin %{{.*}}, %{{.*}}: i32
34+
%9 = minui %lhs, %rhs : i32
2735
return
2836
}
2937

@@ -67,6 +75,10 @@ func @float32_binary_scalar(%lhs: f32, %rhs: f32) {
6775
%3 = divf %lhs, %rhs: f32
6876
// CHECK: spv.FRem %{{.*}}, %{{.*}}: f32
6977
%4 = remf %lhs, %rhs: f32
78+
// CHECK: spv.GLSL.FMax %{{.*}}, %{{.*}}: f32
79+
%5 = maxf %lhs, %rhs: f32
80+
// CHECK: spv.GLSL.FMin %{{.*}}, %{{.*}}: f32
81+
%6 = minf %lhs, %rhs: f32
7082
return
7183
}
7284

mlir/test/Dialect/SPIRV/IR/glsl-ops.mlir

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,24 +51,42 @@ func @exp(%arg0 : i32) -> () {
5151
// -----
5252

5353
//===----------------------------------------------------------------------===//
54-
// spv.GLSL.FMax
54+
// spv.GLSL.{F|S|U}{Max|Min}
5555
//===----------------------------------------------------------------------===//
5656

57-
func @fmax(%arg0 : f32, %arg1 : f32) -> () {
57+
func @fmaxmin(%arg0 : f32, %arg1 : f32) {
5858
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : f32
59-
%2 = spv.GLSL.FMax %arg0, %arg1 : f32
59+
%1 = spv.GLSL.FMax %arg0, %arg1 : f32
60+
// CHECK: spv.GLSL.FMin {{%.*}}, {{%.*}} : f32
61+
%2 = spv.GLSL.FMin %arg0, %arg1 : f32
6062
return
6163
}
6264

63-
func @fmaxvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) -> () {
65+
func @fmaxminvec(%arg0 : vector<3xf16>, %arg1 : vector<3xf16>) {
6466
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : vector<3xf16>
65-
%2 = spv.GLSL.FMax %arg0, %arg1 : vector<3xf16>
67+
%1 = spv.GLSL.FMax %arg0, %arg1 : vector<3xf16>
68+
// CHECK: spv.GLSL.FMin {{%.*}}, {{%.*}} : vector<3xf16>
69+
%2 = spv.GLSL.FMin %arg0, %arg1 : vector<3xf16>
6670
return
6771
}
6872

69-
func @fmaxf64(%arg0 : f64, %arg1 : f64) -> () {
73+
func @fmaxminf64(%arg0 : f64, %arg1 : f64) {
7074
// CHECK: spv.GLSL.FMax {{%.*}}, {{%.*}} : f64
71-
%2 = spv.GLSL.FMax %arg0, %arg1 : f64
75+
%1 = spv.GLSL.FMax %arg0, %arg1 : f64
76+
// CHECK: spv.GLSL.FMin {{%.*}}, {{%.*}} : f64
77+
%2 = spv.GLSL.FMin %arg0, %arg1 : f64
78+
return
79+
}
80+
81+
func @iminmax(%arg0: i32, %arg1: i32) {
82+
// CHECK: spv.GLSL.SMax {{%.*}}, {{%.*}} : i32
83+
%1 = spv.GLSL.SMax %arg0, %arg1 : i32
84+
// CHECK: spv.GLSL.UMax {{%.*}}, {{%.*}} : i32
85+
%2 = spv.GLSL.UMax %arg0, %arg1 : i32
86+
// CHECK: spv.GLSL.SMin {{%.*}}, {{%.*}} : i32
87+
%3 = spv.GLSL.SMin %arg0, %arg1 : i32
88+
// CHECK: spv.GLSL.UMin {{%.*}}, {{%.*}} : i32
89+
%4 = spv.GLSL.UMin %arg0, %arg1 : i32
7290
return
7391
}
7492

mlir/test/Target/SPIRV/glsl-ops.mlir

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
// RUN: mlir-translate -test-spirv-roundtrip %s | FileCheck %s
22

33
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
4-
spv.func @fmul(%arg0 : f32, %arg1 : f32, %arg2 : i32) "None" {
4+
spv.func @math(%arg0 : f32, %arg1 : f32, %arg2 : i32) "None" {
55
// CHECK: {{%.*}} = spv.GLSL.Exp {{%.*}} : f32
66
%0 = spv.GLSL.Exp %arg0 : f32
7-
// CHECK: {{%.*}} = spv.GLSL.FMax {{%.*}}, {{%.*}} : f32
8-
%1 = spv.GLSL.FMax %arg0, %arg1 : f32
97
// CHECK: {{%.*}} = spv.GLSL.Sqrt {{%.*}} : f32
108
%2 = spv.GLSL.Sqrt %arg0 : f32
119
// CHECK: {{%.*}} = spv.GLSL.Cos {{%.*}} : f32
@@ -37,6 +35,23 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
3735
spv.Return
3836
}
3937

38+
spv.func @maxmin(%arg0 : f32, %arg1 : f32, %arg2 : i32, %arg3 : i32) "None" {
39+
// CHECK: {{%.*}} = spv.GLSL.FMax {{%.*}}, {{%.*}} : f32
40+
%1 = spv.GLSL.FMax %arg0, %arg1 : f32
41+
// CHECK: {{%.*}} = spv.GLSL.SMax {{%.*}}, {{%.*}} : i32
42+
%2 = spv.GLSL.SMax %arg2, %arg3 : i32
43+
// CHECK: {{%.*}} = spv.GLSL.UMax {{%.*}}, {{%.*}} : i32
44+
%3 = spv.GLSL.UMax %arg2, %arg3 : i32
45+
46+
// CHECK: {{%.*}} = spv.GLSL.FMin {{%.*}}, {{%.*}} : f32
47+
%4 = spv.GLSL.FMin %arg0, %arg1 : f32
48+
// CHECK: {{%.*}} = spv.GLSL.SMin {{%.*}}, {{%.*}} : i32
49+
%5 = spv.GLSL.SMin %arg2, %arg3 : i32
50+
// CHECK: {{%.*}} = spv.GLSL.UMin {{%.*}}, {{%.*}} : i32
51+
%6 = spv.GLSL.UMin %arg2, %arg3 : i32
52+
spv.Return
53+
}
54+
4055
spv.func @fclamp(%arg0 : f32, %arg1 : f32, %arg2 : f32) "None" {
4156
// CHECK: spv.GLSL.FClamp {{%[^,]*}}, {{%[^,]*}}, {{%[^,]*}} : f32
4257
%13 = spv.GLSL.FClamp %arg0, %arg1, %arg2 : f32

0 commit comments

Comments
 (0)