Skip to content

[mlir][spirv] Refactor image operations #128552

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 77 additions & 31 deletions mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
//===----------------------------------------------------------------------===//
//
// This file contains image ops for the SPIR-V dialect. It corresponds
// to "3.37.10. Image Instructions" of the SPIR-V specification.
// to "3.56.10. Image Instructions" of the SPIR-V specification.
//
//===----------------------------------------------------------------------===//

Expand All @@ -19,21 +19,65 @@ include "mlir/Interfaces/SideEffectInterfaces.td"

// -----

def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
class SPIRV_ValuesAreContained<string operand, list<string> values, string transform, string type, string getter> :
CPred<"::llvm::is_contained("
"{::mlir::spirv::" # type # "::" # !interleave(values, ", ::mlir::spirv::" # type # "::") # "},"
"::llvm::cast<::mlir::spirv::ImageType>(" # !subst("$_self", "$" # operand # ".getType()", transform) # ")." # getter # "()"
")"
>;

class SPIRV_SampledOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
"the sampled operand of the underlying image must be " # !interleave(values, " or "),
SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplerUseInfo", "getSamplerUseInfo">
>;

class SPIRV_MSOperandIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
"the MS operand of the underlying image type must be " # !interleave(values, " or "),
SPIRV_ValuesAreContained<operand, values, transform, "ImageSamplingInfo", "getSamplingInfo">
>;

class SPIRV_DimIs<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
"the Dim operand of the underlying image must be " # !interleave(values, " or "),
SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">
>;

class SPIRV_DimIsNot<string operand, list<string> values, string transform="$_self"> : PredOpTrait<
"the Dim operand of the underlying image must not be " # !interleave(values, " or "),
Neg<SPIRV_ValuesAreContained<operand, values, transform, "Dim", "getDim">>
>;

class SPIRV_NoneOrElementMatchImage<string operand, string image, string transform="$_self"> : PredOpTrait<
"the " # operand # " component type must match the image sampled type",
CPred<"::llvm::isa<NoneType>(cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType()) ||"
"(getElementTypeOrSelf($" # operand # ")"
"=="
"cast<ImageType>(" # !subst("$_self", "$" # image # ".getType()", transform) # ").getElementType())"
>
>;

def SPIRV_SampledImageTransform : StrFunc<"llvm::cast<spirv::SampledImageType>($_self).getImageType()">;

// -----

def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather",
[Pure,
SPIRV_DimIs<"sampled_image", ["Dim2D", "Cube", "Rect"], SPIRV_SampledImageTransform.result>,
SPIRV_MSOperandIs<"sampled_image", ["SingleSampled"], SPIRV_SampledImageTransform.result>,
SPIRV_NoneOrElementMatchImage<"result", "sampled_image", SPIRV_SampledImageTransform.result>]>{
let summary = "Gathers the requested depth-comparison from four texels.";

let description = [{
Result Type must be a vector of four components of floating-point type
or integer type. Its components must be the same as Sampled Type of the
or integer type. Its components must be the same as Sampled Type of the
underlying OpTypeImage (unless that underlying Sampled Type is
OpTypeVoid). It has one component per gathered texel.

Sampled Image must be an object whose type is OpTypeSampledImage. Its
OpTypeImage must have a Dim of 2D, Cube, or Rect. The MS operand of the
underlying OpTypeImage must be 0.

Coordinate must be a scalar or vector of floating-point type. It
contains (u[, v] [, array layer]) as needed by the definition of
Coordinate must be a scalar or vector of floating-point type. It
contains (u[, v] ... [, array layer]) as needed by the definition of
Sampled Image.

Dref is the depth-comparison reference value. It must be a 32-bit
Expand All @@ -44,8 +88,8 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
#### Example:

```mlir
%0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 -> vector<4xi32>
%0 = spirv.ImageDrefGather %1 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, %2 : vector<4xf32>, %3 : f32 ["NonPrivateTexel"] : f32, f32 -> vector<4xi32>
%0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 -> vector<4xi32>
%0 = spirv.ImageDrefGather %1, %2, %3 : !spirv.sampled_image<!spirv.image<i32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Unknown>>, vector<4xf32>, f32 ["NonPrivateTexel"] -> vector<4xi32>
```
}];

Expand All @@ -57,23 +101,24 @@ def SPIRV_ImageDrefGatherOp : SPIRV_Op<"ImageDrefGather", [Pure]> {
];

let arguments = (ins
SPIRV_AnySampledImage:$sampledimage,
SPIRV_AnySampledImage:$sampled_image,
SPIRV_ScalarOrVectorOf<SPIRV_Float>:$coordinate,
SPIRV_Float:$dref,
OptionalAttr<SPIRV_ImageOperandsAttr>:$imageoperands,
SPIRV_Float32:$dref,
OptionalAttr<SPIRV_ImageOperandsAttr>:$image_operands,
Variadic<SPIRV_Type>:$operand_arguments
);

let results = (outs
SPIRV_Vector:$result
AnyTypeOf<[SPIRV_Vec4<SPIRV_Integer>, SPIRV_Vec4<SPIRV_Float>]>:$result
);

let assemblyFormat = [{$sampledimage `:` type($sampledimage) `,`
$coordinate `:` type($coordinate) `,` $dref `:` type($dref)
custom<ImageOperands>($imageoperands)
( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
attr-dict
`->` type($result)}];

let assemblyFormat = [{
$sampled_image `,` $coordinate `,` $dref custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)` )? attr-dict
`:` type($sampled_image) `,` type($coordinate) `,` type($dref) ( `(` type($operand_arguments)^ `)` )?
`->` type($result)
}];

}

// -----
Expand All @@ -82,7 +127,7 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
let summary = "Query the dimensions of Image, with no level of detail.";

let description = [{
Result Type must be an integer type scalar or vector. The number of
Result Type must be an integer type scalar or vector. The number of
components must be:

1 for the 1D and Buffer dimensionalities,
Expand Down Expand Up @@ -130,12 +175,15 @@ def SPIRV_ImageQuerySizeOp : SPIRV_Op<"ImageQuerySize", [Pure]> {
SPIRV_ScalarOrVectorOf<SPIRV_Integer>:$result
);

let assemblyFormat = "attr-dict $image `:` type($image) `->` type($result)";
let assemblyFormat = "$image attr-dict `:` type($image) `->` type($result)";
}

// -----

def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite",
[SPIRV_SampledOperandIs<"image", ["SamplerUnknown", "NoSampler"]>,
SPIRV_DimIsNot<"image", ["SubpassData"]>,
SPIRV_NoneOrElementMatchImage<"texel", "image">]> {
let summary = "Write a texel to an image without a sampler.";

let description = [{
Expand Down Expand Up @@ -163,7 +211,7 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {
#### Example:

```mlir
spirv.ImageWrite %0 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, %1 : vector<2xsi32>, %2 : vector<4xf32>
spirv.ImageWrite %0, %1, %2 : !spirv.image<f32, Dim2D, NoDepth, NonArrayed, SingleSampled, NoSampler, Rgba16>, vector<2xsi32>, vector<4xf32>
```
}];

Expand All @@ -177,20 +225,18 @@ def SPIRV_ImageWriteOp : SPIRV_Op<"ImageWrite", []> {

let results = (outs);

let assemblyFormat = [{$image `:` type($image) `,`
$coordinate `:` type($coordinate) `,`
$texel `:` type($texel)
custom<ImageOperands>($image_operands)
( `(` $operand_arguments^ `:` type($operand_arguments) `)`)?
attr-dict}];
let assemblyFormat = [{
$image `,` $coordinate `,` $texel custom<ImageOperands>($image_operands) ( `(` $operand_arguments^ `)`)? attr-dict
`:` type($image) `,` type($coordinate) `,` type($texel) ( `(` type($operand_arguments)^ `)`)?
}];
}

// -----

def SPIRV_ImageOp : SPIRV_Op<"Image",
[Pure,
TypesMatchWith<"type of 'result' matches image type of 'sampledimage'",
"sampledimage", "result",
TypesMatchWith<"type of 'result' matches image type of 'sampled_image'",
"sampled_image", "result",
"::llvm::cast<spirv::SampledImageType>($_self).getImageType()">]> {
let summary = "Extract the image from a sampled image.";

Expand All @@ -210,14 +256,14 @@ def SPIRV_ImageOp : SPIRV_Op<"Image",
}];

let arguments = (ins
SPIRV_AnySampledImage:$sampledimage
SPIRV_AnySampledImage:$sampled_image
);

let results = (outs
SPIRV_AnyImage:$result
);

let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage)";
let assemblyFormat = "$sampled_image attr-dict `:` type($sampled_image)";

let hasVerifier = 0;
}
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/SPIRV/IR/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ add_mlir_dialect_library(MLIRSPIRVDialect
ControlFlowOps.cpp
CooperativeMatrixOps.cpp
GroupOps.cpp
ImageOps.cpp
IntegerDotProductOps.cpp
MemoryOps.cpp
MeshOps.cpp
Expand Down
137 changes: 137 additions & 0 deletions mlir/lib/Dialect/SPIRV/IR/ImageOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
//===- ImageOps.cpp - MLIR SPIR-V Image Ops ------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Defines the image operations in the SPIR-V dialect.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"

using namespace mlir;

//===----------------------------------------------------------------------===//
// Common utility functions
//===----------------------------------------------------------------------===//

static LogicalResult verifyImageOperands(Operation *imageOp,
spirv::ImageOperandsAttr attr,
Operation::operand_range operands) {
if (!attr) {
if (operands.empty())
return success();

return imageOp->emitError("the Image Operands should encode what operands "
"follow, as per Image Operands");
}

// TODO: Add the validation rules for the following Image Operands.
spirv::ImageOperands noSupportOperands =
spirv::ImageOperands::Bias | spirv::ImageOperands::Lod |
spirv::ImageOperands::Grad | spirv::ImageOperands::ConstOffset |
spirv::ImageOperands::Offset | spirv::ImageOperands::ConstOffsets |
spirv::ImageOperands::Sample | spirv::ImageOperands::MinLod |
spirv::ImageOperands::MakeTexelAvailable |
spirv::ImageOperands::MakeTexelVisible |
spirv::ImageOperands::SignExtend | spirv::ImageOperands::ZeroExtend;

assert(!spirv::bitEnumContainsAny(attr.getValue(), noSupportOperands) &&
"unimplemented operands of Image Operands");

return success();
}

//===----------------------------------------------------------------------===//
// spirv.ImageDrefGather
//===----------------------------------------------------------------------===//

LogicalResult spirv::ImageDrefGatherOp::verify() {
return verifyImageOperands(getOperation(), getImageOperandsAttr(),
getOperandArguments());
}

//===----------------------------------------------------------------------===//
// spirv.ImageWriteOp
//===----------------------------------------------------------------------===//

LogicalResult spirv::ImageWriteOp::verify() {
// TODO: Do we need check for: "If the Arrayed operand is 1, then additional
// capabilities may be required; e.g., ImageCubeArray, or ImageMSArray."?

// TODO: Ideally it should be somewhere verified that "The Image Format must
// not be Unknown, unless the StorageImageWriteWithoutFormat Capability was
// declared." This function however may not be the suitable place for such
// verification.

return verifyImageOperands(getOperation(), getImageOperandsAttr(),
getOperandArguments());
}

//===----------------------------------------------------------------------===//
// spirv.ImageQuerySize
//===----------------------------------------------------------------------===//

LogicalResult spirv::ImageQuerySizeOp::verify() {
spirv::ImageType imageType =
llvm::cast<spirv::ImageType>(getImage().getType());
Type resultType = getResult().getType();

spirv::Dim dim = imageType.getDim();
spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo();
spirv::ImageSamplerUseInfo samplerInfo = imageType.getSamplerUseInfo();
switch (dim) {
case spirv::Dim::Dim1D:
case spirv::Dim::Dim2D:
case spirv::Dim::Dim3D:
case spirv::Dim::Cube:
if (samplingInfo != spirv::ImageSamplingInfo::MultiSampled &&
samplerInfo != spirv::ImageSamplerUseInfo::SamplerUnknown &&
samplerInfo != spirv::ImageSamplerUseInfo::NoSampler)
return emitError(
"if Dim is 1D, 2D, 3D, or Cube, "
"it must also have either an MS of 1 or a Sampled of 0 or 2");
break;
case spirv::Dim::Buffer:
case spirv::Dim::Rect:
break;
default:
return emitError("the Dim operand of the image type must "
"be 1D, 2D, 3D, Buffer, Cube, or Rect");
}

unsigned componentNumber = 0;
switch (dim) {
case spirv::Dim::Dim1D:
case spirv::Dim::Buffer:
componentNumber = 1;
break;
case spirv::Dim::Dim2D:
case spirv::Dim::Cube:
case spirv::Dim::Rect:
componentNumber = 2;
break;
case spirv::Dim::Dim3D:
componentNumber = 3;
break;
default:
break;
}

if (imageType.getArrayedInfo() == spirv::ImageArrayedInfo::Arrayed)
componentNumber += 1;

unsigned resultComponentNumber = 1;
if (auto resultVectorType = llvm::dyn_cast<VectorType>(resultType))
resultComponentNumber = resultVectorType.getNumElements();

if (componentNumber != resultComponentNumber)
return emitError("expected the result to have ")
<< componentNumber << " component(s), but found "
<< resultComponentNumber << " component(s)";

return success();
}
Loading