Skip to content

Commit 22d9ea1

Browse files
authored
[mlir][spirv] Add definition for GL Length (#144041)
A canonicalization pattern from `spirv.GL.Length` to `spirv.GL.FAbs` for scalar operands is also added.
1 parent 1bd4f97 commit 22d9ea1

File tree

6 files changed

+142
-2
lines changed

6 files changed

+142
-2
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,6 +1160,46 @@ def SPIRV_GLFMixOp :
11601160

11611161
// -----
11621162

1163+
def SPIRV_GLLengthOp : SPIRV_GLOp<"Length", 66, [
1164+
Pure,
1165+
TypesMatchWith<"result type must match operand element type",
1166+
"operand", "result",
1167+
"::mlir::getElementTypeOrSelf($_self)">
1168+
]> {
1169+
let summary = "Return the length of a vector x";
1170+
1171+
let description = [{
1172+
Result is the length of vector x, i.e., sqrt(x[0]**2 + x[1]**2 + ...).
1173+
1174+
The operand x must be a scalar or vector whose component type is floating-point.
1175+
1176+
Result Type must be a scalar of the same type as the component type of x.
1177+
1178+
#### Example:
1179+
1180+
```mlir
1181+
%2 = spirv.GL.Length %0 : vector<3xf32> -> f32
1182+
%3 = spirv.GL.Length %1 : f32 -> f32
1183+
```
1184+
}];
1185+
1186+
let arguments = (ins
1187+
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$operand
1188+
);
1189+
1190+
let results = (outs
1191+
SPIRV_Float:$result
1192+
);
1193+
1194+
let assemblyFormat = [{
1195+
$operand attr-dict `:` type($operand) `->` type($result)
1196+
}];
1197+
1198+
let hasVerifier = 0;
1199+
}
1200+
1201+
// -----
1202+
11631203
def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
11641204
Pure,
11651205
AllTypesMatch<["p0", "p1"]>,

mlir/lib/Dialect/SPIRV/IR/SPIRVCanonicalization.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,3 +75,11 @@ def ConvertComparisonIntoClamp2_#CmpClampPair[0] : Pat<
7575
)),
7676
(CmpClampPair[1] $input, $min, $max)>;
7777
}
78+
79+
//===----------------------------------------------------------------------===//
80+
// spirv.GL.Length -> spirv.GL.FAbs
81+
//===----------------------------------------------------------------------===//
82+
83+
def ConvertGLLengthToGLFAbs : Pat<
84+
(SPIRV_GLLengthOp SPIRV_Float:$operand),
85+
(SPIRV_GLFAbsOp $operand)>;

mlir/lib/Dialect/SPIRV/IR/SPIRVGLCanonicalization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,8 @@ void populateSPIRVGLCanonicalizationPatterns(RewritePatternSet &results) {
3434
ConvertComparisonIntoClamp2_SPIRV_SLessThanOp,
3535
ConvertComparisonIntoClamp2_SPIRV_SLessThanEqualOp,
3636
ConvertComparisonIntoClamp2_SPIRV_ULessThanOp,
37-
ConvertComparisonIntoClamp2_SPIRV_ULessThanEqualOp>(
38-
results.getContext());
37+
ConvertComparisonIntoClamp2_SPIRV_ULessThanEqualOp,
38+
ConvertGLLengthToGLFAbs>(results.getContext());
3939
}
4040
} // namespace spirv
4141
} // namespace mlir

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1000,3 +1000,69 @@ func.func @unpack_half_2x16_scalar_out(%arg0 : i32) -> () {
10001000
%0 = spirv.GL.UnpackHalf2x16 %arg0 : i32 -> f32
10011001
return
10021002
}
1003+
1004+
// -----
1005+
1006+
//===----------------------------------------------------------------------===//
1007+
// spirv.GL.Length
1008+
//===----------------------------------------------------------------------===//
1009+
1010+
func.func @length(%arg0 : f32) -> () {
1011+
// CHECK: spirv.GL.Length {{%.*}} : f32 -> f32
1012+
%0 = spirv.GL.Length %arg0 : f32 -> f32
1013+
return
1014+
}
1015+
1016+
func.func @lengthvec(%arg0 : vector<3xf32>) -> () {
1017+
// CHECK: spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
1018+
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> f32
1019+
return
1020+
}
1021+
1022+
// -----
1023+
1024+
func.func @length_i32_in(%arg0 : i32) -> () {
1025+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'i32'}}
1026+
%0 = spirv.GL.Length %arg0 : i32 -> f32
1027+
return
1028+
}
1029+
1030+
// -----
1031+
1032+
func.func @length_f16_in(%arg0 : f16) -> () {
1033+
// expected-error @+1 {{op failed to verify that result type must match operand element type}}
1034+
%0 = spirv.GL.Length %arg0 : f16 -> f32
1035+
return
1036+
}
1037+
1038+
// -----
1039+
1040+
func.func @length_i32vec_in(%arg0 : vector<3xi32>) -> () {
1041+
// expected-error @+1 {{op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
1042+
%0 = spirv.GL.Length %arg0 : vector<3xi32> -> f32
1043+
return
1044+
}
1045+
1046+
// -----
1047+
1048+
func.func @length_f16vec_in(%arg0 : vector<3xf16>) -> () {
1049+
// expected-error @+1 {{op failed to verify that result type must match operand element type}}
1050+
%0 = spirv.GL.Length %arg0 : vector<3xf16> -> f32
1051+
return
1052+
}
1053+
1054+
// -----
1055+
1056+
func.func @length_i32_out(%arg0 : vector<3xf32>) -> () {
1057+
// expected-error @+1 {{op result #0 must be 16/32/64-bit float, but got 'i32'}}
1058+
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> i32
1059+
return
1060+
}
1061+
1062+
// -----
1063+
1064+
func.func @length_vec_out(%arg0 : vector<3xf32>) -> () {
1065+
// expected-error @+1 {{op result #0 must be 16/32/64-bit float, but got 'vector<3xf32>'}}
1066+
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> vector<3xf32>
1067+
return
1068+
}

mlir/test/Dialect/SPIRV/Transforms/gl-canonicalize.mlir

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,25 @@ func.func @clamp_ulessthanequal(%input: i32, %min: i32, %max: i32) -> i32 {
177177
// CHECK-NEXT: spirv.ReturnValue [[RES]]
178178
spirv.ReturnValue %2 : i32
179179
}
180+
181+
// -----
182+
183+
//===----------------------------------------------------------------------===//
184+
// spirv.GL.Length
185+
//===----------------------------------------------------------------------===//
186+
187+
// CHECK-LABEL: @convert_length_into_fabs_scalar
188+
func.func @convert_length_into_fabs_scalar(%arg0 : f32) -> f32 {
189+
//CHECK: spirv.GL.FAbs {{%.*}} : f32
190+
//CHECK-NOT: spirv.GL.Length
191+
%0 = spirv.GL.Length %arg0 : f32 -> f32
192+
spirv.ReturnValue %0 : f32
193+
}
194+
195+
// CHECK-LABEL: @dont_convert_length_into_fabs_vec
196+
func.func @dont_convert_length_into_fabs_vec(%arg0 : vector<3xf32>) -> f32 {
197+
//CHECK: spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
198+
//CHECK-NOT: spirv.GL.FAbs
199+
%0 = spirv.GL.Length %arg0 : vector<3xf32> -> f32
200+
spirv.ReturnValue %0 : f32
201+
}

mlir/test/Target/SPIRV/gl-ops.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,10 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
128128
%8 = spirv.GL.FindSMsb %arg3 : vector<3xi32>
129129
// CHECK: {{%.*}} = spirv.GL.FindUMsb {{%.*}} : vector<3xi32>
130130
%9 = spirv.GL.FindUMsb %arg3 : vector<3xi32>
131+
// CHECK: {{%.*}} = spirv.GL.Length {{%.*}} : f32 -> f32
132+
%10 = spirv.GL.Length %arg0 : f32 -> f32
133+
// CHECK: {{%.*}} = spirv.GL.Length {{%.*}} : vector<3xf32> -> f32
134+
%11 = spirv.GL.Length %arg1 : vector<3xf32> -> f32
131135
spirv.Return
132136
}
133137

0 commit comments

Comments
 (0)