Skip to content

Commit 7b3a353

Browse files
authored
[mlir][spirv] Add common SPIRV Extended Ops for Vectors (#122322)
Support for the following SPIR-V Extended Ops: * 67: Distance * 68: Cross * 69: Normalize * 71: Reflect (Found here: https://registry.khronos.org/SPIR-V/specs/1.0/GLSL.std.450.html)
1 parent dce5d1f commit 7b3a353

File tree

3 files changed

+258
-0
lines changed

3 files changed

+258
-0
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLOps.td

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,6 +1029,122 @@ def SPIRV_GLFMixOp :
10291029
let hasVerifier = 0;
10301030
}
10311031

1032+
// -----
1033+
1034+
def SPIRV_GLDistanceOp : SPIRV_GLOp<"Distance", 67, [
1035+
Pure,
1036+
AllTypesMatch<["p0", "p1"]>,
1037+
TypesMatchWith<"result type must match operand element type",
1038+
"p0", "result",
1039+
"::mlir::getElementTypeOrSelf($_self)">
1040+
]> {
1041+
let summary = "Return distance between two points";
1042+
1043+
let description = [{
1044+
Result is the distance between p0 and p1, i.e., length(p0 - p1).
1045+
1046+
The operands must all be a scalar or vector whose component type is floating-point.
1047+
1048+
Result Type must be a scalar of the same type as the component type of the operands.
1049+
1050+
#### Example:
1051+
1052+
```mlir
1053+
%2 = spirv.GL.Distance %0, %1 : vector<3xf32>, vector<3xf32> -> f32
1054+
```
1055+
}];
1056+
1057+
let arguments = (ins
1058+
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p0,
1059+
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$p1
1060+
);
1061+
1062+
let results = (outs
1063+
SPIRV_Float:$result
1064+
);
1065+
1066+
let assemblyFormat = [{
1067+
operands attr-dict `:` type($p0) `,` type($p1) `->` type($result)
1068+
}];
1069+
1070+
let hasVerifier = 0;
1071+
}
1072+
1073+
// -----
1074+
1075+
def SPIRV_GLCrossOp : SPIRV_GLBinaryArithmeticOp<"Cross", 68, SPIRV_Float> {
1076+
let summary = "Return the cross product of two 3-component vectors";
1077+
1078+
let description = [{
1079+
Result is the cross product of x and y, i.e., the resulting components are, in order:
1080+
1081+
x[1] * y[2] - y[1] * x[2]
1082+
1083+
x[2] * y[0] - y[2] * x[0]
1084+
1085+
x[0] * y[1] - y[0] * x[1]
1086+
1087+
All the operands must be vectors of 3 components of a floating-point type.
1088+
1089+
Result Type and the type of all operands must be the same type.
1090+
1091+
#### Example:
1092+
1093+
```mlir
1094+
%2 = spirv.GL.Cross %0, %1 : vector<3xf32>
1095+
%3 = spirv.GL.Cross %0, %1 : vector<3xf16>
1096+
```
1097+
}];
1098+
}
1099+
1100+
// -----
1101+
1102+
def SPIRV_GLNormalizeOp : SPIRV_GLUnaryArithmeticOp<"Normalize", 69, SPIRV_Float> {
1103+
let summary = "Normalizes a vector operand";
1104+
1105+
let description = [{
1106+
Result is the vector in the same direction as x but with a length of 1.
1107+
1108+
The operand x must be a scalar or vector whose component type is floating-point.
1109+
1110+
Result Type and the type of x must be the same type.
1111+
1112+
#### Example:
1113+
1114+
```mlir
1115+
%2 = spirv.GL.Normalize %0 : vector<3xf32>
1116+
%3 = spirv.GL.Normalize %1 : vector<4xf16>
1117+
```
1118+
}];
1119+
}
1120+
1121+
// -----
1122+
1123+
def SPIRV_GLReflectOp : SPIRV_GLBinaryArithmeticOp<"Reflect", 71, SPIRV_Float> {
1124+
let summary = "Calculate reflection direction vector";
1125+
1126+
let description = [{
1127+
For the incident vector I and surface orientation N, the result is the reflection direction:
1128+
1129+
I - 2 * dot(N, I) * N
1130+
1131+
N must already be normalized in order to achieve the desired result.
1132+
1133+
The operands must all be a scalar or vector whose component type is floating-point.
1134+
1135+
Result Type and the type of all operands must be the same type.
1136+
1137+
#### Example:
1138+
1139+
```mlir
1140+
%2 = spirv.GL.Reflect %0, %1 : f32
1141+
%3 = spirv.GL.Reflect %0, %1 : vector<3xf32>
1142+
```
1143+
}];
1144+
}
1145+
1146+
// ----
1147+
10321148
def SPIRV_GLFindUMsbOp : SPIRV_GLUnaryArithmeticOp<"FindUMsb", 75, SPIRV_Int32> {
10331149
let summary = "Unsigned-integer most-significant bit";
10341150

mlir/test/Dialect/SPIRV/IR/gl-ops.mlir

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -541,3 +541,125 @@ func.func @findumsb(%arg0 : i64) -> () {
541541
%2 = spirv.GL.FindUMsb %arg0 : i64
542542
return
543543
}
544+
545+
// -----
546+
547+
//===----------------------------------------------------------------------===//
548+
// spirv.GL.Distance
549+
//===----------------------------------------------------------------------===//
550+
551+
func.func @distance_scalar(%arg0 : f32, %arg1 : f32) {
552+
// CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
553+
%0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> f32
554+
return
555+
}
556+
557+
func.func @distance_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
558+
// CHECK: spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
559+
%0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<3xf32> -> f32
560+
return
561+
}
562+
563+
// -----
564+
565+
func.func @distance_invalid_type(%arg0 : i32, %arg1 : i32) {
566+
// expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
567+
%0 = spirv.GL.Distance %arg0, %arg1 : i32, i32 -> f32
568+
return
569+
}
570+
571+
// -----
572+
573+
func.func @distance_arg_mismatch(%arg0 : vector<3xf32>, %arg1 : vector<4xf32>) {
574+
// expected-error @+1 {{'spirv.GL.Distance' op failed to verify that all of {p0, p1} have same type}}
575+
%0 = spirv.GL.Distance %arg0, %arg1 : vector<3xf32>, vector<4xf32> -> f32
576+
return
577+
}
578+
579+
// -----
580+
581+
func.func @distance_invalid_vector_size(%arg0 : vector<5xf32>, %arg1 : vector<5xf32>) {
582+
// expected-error @+1 {{'spirv.GL.Distance' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16}}
583+
%0 = spirv.GL.Distance %arg0, %arg1 : vector<5xf32>, vector<5xf32> -> f32
584+
return
585+
}
586+
587+
// -----
588+
589+
func.func @distance_invalid_result(%arg0 : f32, %arg1 : f32) {
590+
// expected-error @+1 {{'spirv.GL.Distance' op result #0 must be 16/32/64-bit float}}
591+
%0 = spirv.GL.Distance %arg0, %arg1 : f32, f32 -> i32
592+
return
593+
}
594+
595+
// -----
596+
597+
//===----------------------------------------------------------------------===//
598+
// spirv.GL.Cross
599+
//===----------------------------------------------------------------------===//
600+
601+
func.func @cross(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
602+
%2 = spirv.GL.Cross %arg0, %arg1 : vector<3xf32>
603+
// CHECK: %{{.+}} = spirv.GL.Cross %{{.+}}, %{{.+}} : vector<3xf32>
604+
return
605+
}
606+
607+
// -----
608+
609+
func.func @cross_invalid_type(%arg0 : vector<3xi32>, %arg1 : vector<3xi32>) {
610+
// expected-error @+1 {{'spirv.GL.Cross' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values of length 2/3/4/8/16, but got 'vector<3xi32>'}}
611+
%0 = spirv.GL.Cross %arg0, %arg1 : vector<3xi32>
612+
return
613+
}
614+
615+
// -----
616+
617+
//===----------------------------------------------------------------------===//
618+
// spirv.GL.Normalize
619+
//===----------------------------------------------------------------------===//
620+
621+
func.func @normalize_scalar(%arg0 : f32) {
622+
%2 = spirv.GL.Normalize %arg0 : f32
623+
// CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : f32
624+
return
625+
}
626+
627+
func.func @normalize_vector(%arg0 : vector<3xf32>) {
628+
%2 = spirv.GL.Normalize %arg0 : vector<3xf32>
629+
// CHECK: %{{.+}} = spirv.GL.Normalize %{{.+}} : vector<3xf32>
630+
return
631+
}
632+
633+
// -----
634+
635+
func.func @normalize_invalid_type(%arg0 : i32) {
636+
// expected-error @+1 {{'spirv.GL.Normalize' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
637+
%0 = spirv.GL.Normalize %arg0 : i32
638+
return
639+
}
640+
641+
// -----
642+
643+
//===----------------------------------------------------------------------===//
644+
// spirv.GL.Reflect
645+
//===----------------------------------------------------------------------===//
646+
647+
func.func @reflect_scalar(%arg0 : f32, %arg1 : f32) {
648+
%2 = spirv.GL.Reflect %arg0, %arg1 : f32
649+
// CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : f32
650+
return
651+
}
652+
653+
func.func @reflect_vector(%arg0 : vector<3xf32>, %arg1 : vector<3xf32>) {
654+
%2 = spirv.GL.Reflect %arg0, %arg1 : vector<3xf32>
655+
// CHECK: %{{.+}} = spirv.GL.Reflect %{{.+}}, %{{.+}} : vector<3xf32>
656+
return
657+
}
658+
659+
// -----
660+
661+
func.func @reflect_invalid_type(%arg0 : i32, %arg1 : i32) {
662+
// expected-error @+1 {{'spirv.GL.Reflect' op operand #0 must be 16/32/64-bit float or vector of 16/32/64-bit float values}}
663+
%0 = spirv.GL.Reflect %arg0, %arg1 : i32
664+
return
665+
}

mlir/test/Target/SPIRV/gl-ops.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,4 +81,24 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
8181
%2 = spirv.GL.FindUMsb %arg0 : i32
8282
spirv.Return
8383
}
84+
85+
spirv.func @vector(%arg0 : f32, %arg1 : vector<3xf32>, %arg2 : vector<3xf32>) "None" {
86+
// CHECK: {{%.*}} = spirv.GL.Cross {{%.*}}, {{%.*}} : vector<3xf32>
87+
%0 = spirv.GL.Cross %arg1, %arg2 : vector<3xf32>
88+
// CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : f32
89+
%1 = spirv.GL.Normalize %arg0 : f32
90+
// CHECK: {{%.*}} = spirv.GL.Normalize {{%.*}} : vector<3xf32>
91+
%2 = spirv.GL.Normalize %arg1 : vector<3xf32>
92+
// CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : f32
93+
%3 = spirv.GL.Reflect %arg0, %arg0 : f32
94+
// CHECK: {{%.*}} = spirv.GL.Reflect {{%.*}}, {{%.*}} : vector<3xf32>
95+
%4 = spirv.GL.Reflect %arg1, %arg2 : vector<3xf32>
96+
// CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : f32, f32 -> f32
97+
%5 = spirv.GL.Distance %arg0, %arg0 : f32, f32 -> f32
98+
// CHECK: {{%.*}} = spirv.GL.Distance {{%.*}}, {{%.*}} : vector<3xf32>, vector<3xf32> -> f32
99+
%6 = spirv.GL.Distance %arg1, %arg2 : vector<3xf32>, vector<3xf32> -> f32
100+
spirv.Return
101+
}
102+
103+
84104
}

0 commit comments

Comments
 (0)