Skip to content

Commit 94189e1

Browse files
authored
[mlir][spirv] Implement missing validation rules for ptr variables (#67871)
Variables that point to physical storage buffer require aliasing decorations. This is specified by the `SPV_KHR_physical_storage_buffer` extension. Also add an example of a variable with a decoration attribute.
1 parent 4b13c86 commit 94189e1

File tree

3 files changed

+121
-13
lines changed

3 files changed

+121
-13
lines changed

mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,9 @@ def SPIRV_PtrAccessChainOp : SPIRV_Op<"PtrAccessChain", [Pure]> {
289289
MinVersion<SPIRV_V_1_0>,
290290
MaxVersion<SPIRV_V_1_6>,
291291
Extension<[]>,
292-
Capability<[SPIRV_C_Addresses, SPIRV_C_PhysicalStorageBufferAddresses, SPIRV_C_VariablePointers, SPIRV_C_VariablePointersStorageBuffer]>
292+
Capability<[
293+
SPIRV_C_Addresses, SPIRV_C_PhysicalStorageBufferAddresses,
294+
SPIRV_C_VariablePointers, SPIRV_C_VariablePointersStorageBuffer]>
293295
];
294296

295297
let arguments = (ins
@@ -375,11 +377,16 @@ def SPIRV_VariableOp : SPIRV_Op<"Variable", []> {
375377
must be the `Function` Storage Class.
376378

377379
Initializer is optional. If Initializer is present, it will be the
378-
initial value of the variables memory content. Initializer must be an
380+
initial value of the variable's memory content. Initializer must be an
379381
<id> from a constant instruction or a global (module scope) OpVariable
380382
instruction. Initializer must have the same type as the type pointed to
381383
by Result Type.
382384

385+
From `SPV_KHR_physical_storage_buffer`:
386+
If an OpVariable's pointee type is a pointer (or array of pointers) in
387+
PhysicalStorageBuffer storage class, then the variable must be decorated
388+
with exactly one of AliasedPointer or RestrictPointer.
389+
383390
<!-- End of AutoGen section -->
384391

385392
```
@@ -396,6 +403,9 @@ def SPIRV_VariableOp : SPIRV_Op<"Variable", []> {
396403

397404
%1 = spirv.Variable : !spirv.ptr<f32, Function>
398405
%2 = spirv.Variable init(%0): !spirv.ptr<f32, Function>
406+
407+
%3 = spirv.Variable {aliased_pointer} :
408+
!spirv.ptr<!spirv.ptr<f32, PhysicalStorageBuffer>, Function>
399409
```
400410
}];
401411

@@ -407,6 +417,15 @@ def SPIRV_VariableOp : SPIRV_Op<"Variable", []> {
407417
let results = (outs
408418
SPIRV_AnyPtr:$pointer
409419
);
420+
421+
let extraClassDeclaration = [{
422+
::mlir::spirv::PointerType getPointerType() {
423+
return ::llvm::cast<::mlir::spirv::PointerType>(getType());
424+
}
425+
::mlir::Type getPointeeType() {
426+
return getPointerType().getPointeeType();
427+
}
428+
}];
410429
}
411430

412431
#endif // MLIR_DIALECT_SPIRV_IR_MEMORY_OPS

mlir/lib/Dialect/SPIRV/IR/MemoryOps.cpp

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,12 +10,16 @@
1010
//
1111
//===----------------------------------------------------------------------===//
1212

13+
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1314
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
1415

1516
#include "SPIRVOpUtils.h"
1617
#include "SPIRVParsingUtils.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
19+
#include "mlir/IR/Diagnostics.h"
1720

1821
#include "llvm/ADT/StringExtras.h"
22+
#include "llvm/Support/Casting.h"
1923

2024
using namespace mlir::spirv::AttrNames;
2125

@@ -730,19 +734,49 @@ LogicalResult VariableOp::verify() {
730734
"constant or spirv.GlobalVariable op");
731735
}
732736

737+
auto getDecorationAttr = [op = getOperation()](spirv::Decoration decoration) {
738+
return op->getAttr(
739+
llvm::convertToSnakeFromCamelCase(stringifyDecoration(decoration)));
740+
};
741+
733742
// TODO: generate these strings using ODS.
734-
auto *op = getOperation();
735-
auto descriptorSetName = llvm::convertToSnakeFromCamelCase(
736-
stringifyDecoration(spirv::Decoration::DescriptorSet));
737-
auto bindingName = llvm::convertToSnakeFromCamelCase(
738-
stringifyDecoration(spirv::Decoration::Binding));
739-
auto builtInName = llvm::convertToSnakeFromCamelCase(
740-
stringifyDecoration(spirv::Decoration::BuiltIn));
741-
742-
for (const auto &attr : {descriptorSetName, bindingName, builtInName}) {
743-
if (op->getAttr(attr))
743+
for (auto decoration :
744+
{spirv::Decoration::DescriptorSet, spirv::Decoration::Binding,
745+
spirv::Decoration::BuiltIn}) {
746+
if (auto attr = getDecorationAttr(decoration))
744747
return emitOpError("cannot have '")
745-
<< attr << "' attribute (only allowed in spirv.GlobalVariable)";
748+
<< llvm::convertToSnakeFromCamelCase(
749+
stringifyDecoration(decoration))
750+
<< "' attribute (only allowed in spirv.GlobalVariable)";
751+
}
752+
753+
// From SPV_KHR_physical_storage_buffer:
754+
// > If an OpVariable's pointee type is a pointer (or array of pointers) in
755+
// > PhysicalStorageBuffer storage class, then the variable must be decorated
756+
// > with exactly one of AliasedPointer or RestrictPointer.
757+
auto pointeePtrType = dyn_cast<spirv::PointerType>(getPointeeType());
758+
if (!pointeePtrType) {
759+
if (auto pointeeArrayType = dyn_cast<spirv::ArrayType>(getPointeeType())) {
760+
pointeePtrType =
761+
dyn_cast<spirv::PointerType>(pointeeArrayType.getElementType());
762+
}
763+
}
764+
765+
if (pointeePtrType && pointeePtrType.getStorageClass() ==
766+
spirv::StorageClass::PhysicalStorageBuffer) {
767+
bool hasAliasedPtr =
768+
getDecorationAttr(spirv::Decoration::AliasedPointer) != nullptr;
769+
bool hasRestrictPtr =
770+
getDecorationAttr(spirv::Decoration::RestrictPointer) != nullptr;
771+
772+
if (!hasAliasedPtr && !hasRestrictPtr)
773+
return emitOpError() << " with physical buffer pointer must be decorated "
774+
"either 'AliasedPointer' or 'RestrictPointer'";
775+
776+
if (hasAliasedPtr && hasRestrictPtr)
777+
return emitOpError()
778+
<< " with physical buffer pointer must have exactly one "
779+
"aliasing decoration";
746780
}
747781

748782
return success();

mlir/test/Dialect/SPIRV/IR/memory-ops.mlir

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,61 @@ spirv.module Logical GLSL450 {
525525

526526
// -----
527527

528+
func.func @variable_ptr_physical_buffer() -> () {
529+
%0 = spirv.Variable {aliased_pointer} :
530+
!spirv.ptr<!spirv.ptr<f32, PhysicalStorageBuffer>, Function>
531+
%1 = spirv.Variable {restrict_pointer} :
532+
!spirv.ptr<!spirv.ptr<f32, PhysicalStorageBuffer>, Function>
533+
return
534+
}
535+
536+
// -----
537+
538+
func.func @variable_ptr_physical_buffer_no_decoration() -> () {
539+
// expected-error @+1 {{must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
540+
%0 = spirv.Variable : !spirv.ptr<!spirv.ptr<f32, PhysicalStorageBuffer>, Function>
541+
return
542+
}
543+
544+
// -----
545+
546+
func.func @variable_ptr_physical_buffer_two_alias_decorations() -> () {
547+
// expected-error @+1 {{must have exactly one aliasing decoration}}
548+
%0 = spirv.Variable {aliased_pointer, restrict_pointer} :
549+
!spirv.ptr<!spirv.ptr<f32, PhysicalStorageBuffer>, Function>
550+
return
551+
}
552+
553+
// -----
554+
555+
func.func @variable_ptr_array_physical_buffer() -> () {
556+
%0 = spirv.Variable {aliased_pointer} :
557+
!spirv.ptr<!spirv.array<4x!spirv.ptr<f32, PhysicalStorageBuffer>>, Function>
558+
%1 = spirv.Variable {restrict_pointer} :
559+
!spirv.ptr<!spirv.array<4x!spirv.ptr<f32, PhysicalStorageBuffer>>, Function>
560+
return
561+
}
562+
563+
// -----
564+
565+
func.func @variable_ptr_array_physical_buffer_no_decoration() -> () {
566+
// expected-error @+1 {{must be decorated either 'AliasedPointer' or 'RestrictPointer'}}
567+
%0 = spirv.Variable :
568+
!spirv.ptr<!spirv.array<4x!spirv.ptr<f32, PhysicalStorageBuffer>>, Function>
569+
return
570+
}
571+
572+
// -----
573+
574+
func.func @variable_ptr_array_physical_buffer_two_alias_decorations() -> () {
575+
// expected-error @+1 {{must have exactly one aliasing decoration}}
576+
%0 = spirv.Variable {aliased_pointer, restrict_pointer} :
577+
!spirv.ptr<!spirv.array<4x!spirv.ptr<f32, PhysicalStorageBuffer>>, Function>
578+
return
579+
}
580+
581+
// -----
582+
528583
func.func @variable_bind() -> () {
529584
// expected-error @+1 {{cannot have 'descriptor_set' attribute (only allowed in spirv.GlobalVariable)}}
530585
%0 = spirv.Variable bind(1, 2) : !spirv.ptr<f32, Function>

0 commit comments

Comments
 (0)