Skip to content

Commit a4a35bf

Browse files
committed
[mlir][spirv] Refactor image operations
This patch makes multiple changes to images ops: 1) The assembly format is unified with the rest of the dialect to use `%0 = spirv.op %1, %2, %3 : f32, f32, f32` rather than having each type directly attached to each argument. 2) The verification is moved from `SPIRVOps.cpp` to a new file so the ops can be easier maintained. 3) Majority of C++ verification is removed and moved into ODS. Verification of `ImageQuerySizeOp` is left in C++ due to the complexity of rules. 4) `spirv::bitEnumContainsAll` is replaced by `spirv::bitEnumContainsAny` in `verifyImageOperands`. In this context `...Any` seems to be the correct function, as we want to check whether unsupported operand is being used - in opposite to checking if all unsupported operands are being used. 5) Simplify target tests by removing entry points and adding `Linkage` capability to the modules.
1 parent d3d2ea6 commit a4a35bf

File tree

6 files changed

+245
-241
lines changed

6 files changed

+245
-241
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td

Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
//===----------------------------------------------------------------------===//
88
//
99
// This file contains image ops for the SPIR-V dialect. It corresponds
10-
// to "3.37.10. Image Instructions" of the SPIR-V specification.
10+
// to "3.56.10. Image Instructions" of the SPIR-V specification.
1111
//
1212
//===----------------------------------------------------------------------===//
1313

@@ -19,21 +19,65 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
1919

2020
// -----
2121

22-
def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
22+
class SPIRV_ValuesAreContained<string operand, list<string> values, string transform, string type, string getter> :
23+
CPred<"::llvm::is_contained("
24+
"{::mlir::spirv::" # type # "::" # !interleave(values, ", ::mlir::spirv::" # type # "::") # "},"
25+
"::llvm::cast<::mlir::spirv::ImageType>(" # !subst("$_self", "$" # operand # ".getType()", transform) # ")." # getter # "()"
26+
")"
27+
>;
28+
29+
class SPIRV_SampledOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
30+
"the sampled operand of the underlying image must be " # !interleave(values, " or "),
31+
SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplerUseInfo", "getSamplerUseInfo">
32+
>;
33+
34+
class SPIRV_MSOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
35+
"the MS operand of the underlying image type must be " # !interleave(values, " or "),
36+
SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplingInfo", "getSamplingInfo">
37+
>;
38+
39+
class SPIRV_DimIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
40+
"the Dim operand of the underlying image must be " # !interleave(values, " or "),
41+
SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">
42+
>;
43+
44+
class SPIRV_DimIsNot<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
45+
"the Dim operand of the underlying image must not be " # !interleave(values, " or "),
46+
Neg<SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">>
47+
>;
48+
49+
class SPIRV_NoneOrElementMatchImage<string operand, string image, string transform="$_self"> : PredOpTrait<
50+
"the " # operand # " component type must match the image sampled type",
51+
CPred<"::llvm::isa<NoneType>(cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType()) ||"
52+
"(getElementTypeOrSelf($" # operand # ")"
53+
"=="
54+
"cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType())"
55+
>
56+
>;
57+
58+
def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<spirv::SampledImageType>($_self).getImageType()">;
59+
60+
// -----
61+
62+
def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather",
63+
[Pure,
64+
SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>,
65+
SPIRV_MSOperandIs<"sampled_image", ["SingleSampled"], SPIRV_SampledImageTransform.result>,
66+
SPIRV_NoneOrElementMatchImage<"result", "sampled_image", SPIRV_SampledImageTransform.result>]>{
2367
let summary = "Gathers the requested depth-comparison from four texels.";
2468

2569
let description = [{
2670
Result Type must be a vector of four components of floating-point type
27-
or integer type. Its components must be the same as Sampled Type of the
71+
or integer type. Its components must be the same as Sampled Type of the
2872
underlying OpTypeImage (unless that underlying Sampled Type is
2973
OpTypeVoid). It has one component per gathered texel.
3074

3175
Sampled Image must be an object whose type is OpTypeSampledImage. Its
3276
OpTypeImage must have a Dim of 2D, Cube, or Rect. The MS operand of the
3377
underlying OpTypeImage must be 0.
3478

35-
Coordinate must be a scalar or vector of floating-point type. It
36-
contains (u[, v] [, array layer]) as needed by the definition of
79+
Coordinate must be a scalar or vector of floating-point type. It
80+
contains (u[, v] ... [, array layer]) as needed by the definition of
3781
Sampled Image.
3882

3983
Dref is the depth-comparison reference value. It must be a 32-bit
@@ -44,8 +88,8 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
4488
#### Example:
4589

4690
```mlir
47-
%0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 -> vector<4xi32>
48-
%0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 ["NonPrivateTexel"] : f32, f32 -> vector<4xi32>
91+
%0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
92+
%0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 ["NonPrivateTexel"] -> vector<4xi32>
4993
```
5094
}];
5195

@@ -57,23 +101,24 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
57101
];
58102

59103
let arguments = (ins
60-
SPIRV_AnySampledImage:$sampledimage,
104+
SPIRV_AnySampledImage:$sampled_image,
61105
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$coordinate,
62-
SPIRV_Float:$dref,
63-
OptionalAttr<SPIRV_ImageOperandsAttr>:$imageoperands,
106+
SPIRV_Float32:$dref,
107+
OptionalAttr<SPIRV_ImageOperandsAttr>:$image_operands,
64108
Variadic<SPIRV_Type>:$operand_arguments
65109
);
66110

67111
let results = (outs
68-
SPIRV_Vector:$result
112+
AnyTypeOf<[SPIRV_Vec4<SPIRV_Integer>, SPIRV_Vec4<SPIRV_Float>]>:$result
69113
);
70114

71-
let assemblyFormat = [{$sampledimage `:` type($sampledimage) `,`
72-
$coordinate `:` type($coordinate) `,` $dref `:` type($dref)
73-
custom<ImageOperands>($imageoperands)
74-
( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
75-
attr-dict
76-
`->` type($result)}];
115+
116+
let assemblyFormat = [{
117+
$sampled_image `,` $coordinate `,` $dref custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)` )? attr-dict
118+
`:` type($sampled_image) `,` type($coordinate) `,` type($dref) ( `(` type($operand_arguments)^ `)` )?
119+
`->` type($result)
120+
}];
121+
77122
}
78123

79124
// -----
@@ -82,7 +127,7 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
82127
let summary = "Query the dimensions of Image, with no level of detail.";
83128

84129
let description = [{
85-
Result Type must be an integer type scalar or vector. The number of
130+
Result Type must be an integer type scalar or vector. The number of
86131
components must be:
87132

88133
1 for the 1D and Buffer dimensionalities,
@@ -130,12 +175,15 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
130175
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$result
131176
);
132177

133-
let assemblyFormat = "attr-dict $image `:` type($image) `->` type($result)";
178+
let assemblyFormat = "$image attr-dict `:` type($image) `->` type($result)";
134179
}
135180

136181
// -----
137182

138-
def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
183+
def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite",
184+
[SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NoSampler"]>,
185+
SPIRV_DimIsNot<"image", ["SubpassData"]>,
186+
SPIRV_NoneOrElementMatchImage<"texel", "image">]> {
139187
let summary = "Write a texel to an image without a sampler.";
140188

141189
let description = [{
@@ -163,7 +211,7 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
163211
#### Example:
164212

165213
```mlir
166-
spirv.ImageWrite %0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %1 : vector<2xsi32>, %2 : vector<4xf32>
214+
spirv.ImageWrite %0, %1, %2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
167215
```
168216
}];
169217

@@ -177,20 +225,18 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
177225

178226
let results = (outs);
179227

180-
let assemblyFormat = [{$image `:` type($image) `,`
181-
$coordinate `:` type($coordinate) `,`
182-
$texel `:` type($texel)
183-
custom<ImageOperands>($image_operands)
184-
( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
185-
attr-dict}];
228+
let assemblyFormat = [{
229+
$image `,` $coordinate `,` $texel custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)`)? attr-dict
230+
`:` type($image) `,` type($coordinate) `,` type($texel) ( `(` type($operand_arguments)^ `)`)?
231+
}];
186232
}
187233

188234
// -----
189235

190236
def SPIRV_ImageOp : SPIRV_Op<"Image",
191237
[Pure,
192-
TypesMatchWith<"type of 'result' matches image type of 'sampledimage'",
193-
"sampledimage", "result",
238+
TypesMatchWith<"type of 'result' matches image type of 'sampled_image'",
239+
"sampled_image", "result",
194240
"::llvm::cast<spirv::SampledImageType>($_self).getImageType()">]> {
195241
let summary = "Extract the image from a sampled image.";
196242

@@ -210,14 +256,14 @@ def SPIRV_ImageOp : SPIRV_Op<"Image",
210256
}];
211257

212258
let arguments = (ins
213-
SPIRV_AnySampledImage:$sampledimage
259+
SPIRV_AnySampledImage:$sampled_image
214260
);
215261

216262
let results = (outs
217263
SPIRV_AnyImage:$result
218264
);
219265

220-
let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage)";
266+
let assemblyFormat = "$sampled_image attr-dict `:` type($sampled_image)";
221267

222268
let hasVerifier = 0;
223269
}

mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
88
ControlFlowOps.cpp
99
CooperativeMatrixOps.cpp
1010
GroupOps.cpp
11+
ImageOps.cpp
1112
IntegerDotProductOps.cpp
1213
MemoryOps.cpp
1314
MeshOps.cpp
Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
//===- ImageOps.cpp - MLIR SPIR-V Image Ops ------------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// Defines the image operations in the SPIR-V dialect.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
14+
15+
using namespace mlir;
16+
17+
//===----------------------------------------------------------------------===//
18+
// Common utility functions
19+
//===----------------------------------------------------------------------===//
20+
21+
static LogicalResult verifyImageOperands(Operation *imageOp,
22+
spirv::ImageOperandsAttr attr,
23+
Operation::operand_range operands) {
24+
if (!attr) {
25+
if (operands.empty())
26+
return success();
27+
28+
return imageOp->emitError("the Image Operands should encode what operands "
29+
"follow, as per Image Operands");
30+
}
31+
32+
// TODO: Add the validation rules for the following Image Operands.
33+
spirv::ImageOperands noSupportOperands =
34+
spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
35+
spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
36+
spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
37+
spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
38+
spirv::ImageOperands::MakeTexelAvailable |
39+
spirv::ImageOperands::MakeTexelVisible |
40+
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;
41+
42+
assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) &&
43+
"unimplemented operands of Image Operands");
44+
45+
return success();
46+
}
47+
48+
//===----------------------------------------------------------------------===//
49+
// spirv.ImageDrefGather
50+
//===----------------------------------------------------------------------===//
51+
52+
LogicalResult spirv::ImageDrefGatherOp::verify() {
53+
return verifyImageOperands(getOperation(), getImageOperandsAttr(),
54+
getOperandArguments());
55+
}
56+
57+
//===----------------------------------------------------------------------===//
58+
// spirv.ImageWriteOp
59+
//===----------------------------------------------------------------------===//
60+
61+
LogicalResult spirv::ImageWriteOp::verify() {
62+
// TODO: Do we need check for: "If the Arrayed operand is 1, then additional
63+
// capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?
64+
65+
// TODO: Ideally it should be somewhere verified that "The Image Format must
66+
// not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
67+
// declared." This function however may not be the suitable place for such
68+
// verification.
69+
70+
return verifyImageOperands(getOperation(), getImageOperandsAttr(),
71+
getOperandArguments());
72+
}
73+
74+
//===----------------------------------------------------------------------===//
75+
// spirv.ImageQuerySize
76+
//===----------------------------------------------------------------------===//
77+
78+
LogicalResult spirv::ImageQuerySizeOp::verify() {
79+
spirv::ImageType imageType =
80+
llvm::cast<spirv::ImageType>(getImage().getType());
81+
Type resultType = getResult().getType();
82+
83+
spirv::Dim dim = imageType.getDim();
84+
spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
85+
spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
86+
switch (dim) {
87+
case spirv::Dim::Dim1D:
88+
case spirv::Dim::Dim2D:
89+
case spirv::Dim::Dim3D:
90+
case spirv::Dim::Cube:
91+
if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
92+
samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
93+
samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
94+
return emitError(
95+
"if Dim is 1D, 2D, 3D, or Cube, "
96+
"it must also have either an MS of 1 or a Sampled of 0 or 2");
97+
break;
98+
case spirv::Dim::Buffer:
99+
case spirv::Dim::Rect:
100+
break;
101+
default:
102+
return emitError("the Dim operand of the image type must "
103+
"be 1D, 2D, 3D, Buffer, Cube, or Rect");
104+
}
105+
106+
unsigned componentNumber = 0;
107+
switch (dim) {
108+
case spirv::Dim::Dim1D:
109+
case spirv::Dim::Buffer:
110+
componentNumber = 1;
111+
break;
112+
case spirv::Dim::Dim2D:
113+
case spirv::Dim::Cube:
114+
case spirv::Dim::Rect:
115+
componentNumber = 2;
116+
break;
117+
case spirv::Dim::Dim3D:
118+
componentNumber = 3;
119+
break;
120+
default:
121+
break;
122+
}
123+
124+
if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
125+
componentNumber += 1;
126+
127+
unsigned resultComponentNumber = 1;
128+
if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
129+
resultComponentNumber = resultVectorType.getNumElements();
130+
131+
if (componentNumber != resultComponentNumber)
132+
return emitError("expected the result to have ")
133+
<< componentNumber << " component(s), but found "
134+
<< resultComponentNumber << " component(s)";
135+
136+
return success();
137+
}

0 commit comments

Comments
 (0)