Skip to content

[mlir][spirv] Fix spirv dialect to support Specialization constants as GlobalVar initializer #75660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1162,10 +1162,11 @@ LogicalResult spirv::GlobalVariableOp::verify() {
// TODO: Currently only variable initialization with specialization
// constants and other variables is supported. They could be normal
// constants in the module scope as well.
if (!initOp ||
!isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
spirv::SpecConstantCompositeOp>(initOp)) {
return emitOpError("initializer must be result of a "
"spirv.SpecConstant or spirv.GlobalVariable op");
"spirv.SpecConstant or spirv.GlobalVariable or "
"spirv.SpecConstantCompositeOp op");
}
}

Expand Down
16 changes: 12 additions & 4 deletions mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -637,14 +637,22 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {

// Initializer.
FlatSymbolRefAttr initializer = nullptr;

if (wordIndex < operands.size()) {
auto initializerOp = getGlobalVariable(operands[wordIndex]);
if (!initializerOp) {
Operation *op = nullptr;

if (auto initOp = getGlobalVariable(operands[wordIndex]))
op = initOp;
else if (auto initOp = getSpecConstant(operands[wordIndex]))
op = initOp;
else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
op = initOp;
else
return emitError(unknownLoc, "unknown <id> ")
<< operands[wordIndex] << "used as initializer";
}

initializer = SymbolRefAttr::get(op);
wordIndex++;
initializer = SymbolRefAttr::get(initializerOp.getOperation());
}
if (wordIndex != operands.size()) {
return emitError(unknownLoc,
Expand Down
23 changes: 17 additions & 6 deletions mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -383,20 +383,31 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));

// Encode initialization.
if (auto initializer = varOp.getInitializer()) {
auto initializerID = getVariableID(*initializer);
if (!initializerID) {
StringRef initAttrName = varOp.getInitializerAttrName().getValue();
if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
uint32_t initializerID = 0;
auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
varOp->getParentOp(), initRef.getAttr());

// Check if initializer is GlobalVariable or SpecConstant* cases.
if (isa<spirv::GlobalVariableOp>(initOp))
initializerID = getVariableID(*initSymbolName);
else
initializerID = getSpecConstID(*initSymbolName);

if (!initializerID)
return emitError(varOp.getLoc(),
"invalid usage of undefined variable as initializer");
}

operands.push_back(initializerID);
elidedAttrs.push_back("initializer");
elidedAttrs.push_back(initAttrName);
}

if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
return failure();
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
elidedAttrs.push_back("initializer");
elidedAttrs.push_back(initAttrName);

// Encode decorations.
for (auto attr : varOp->getAttrs()) {
Expand Down
15 changes: 14 additions & 1 deletion mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
// CHECK: spirv.GlobalVariable @var initializer(@sc)
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>


// -----
// Allow SpecConstantComposite as initializer
spirv.module Logical GLSL450 {
spirv.SpecConstant @sc1 = 1 : i8
spirv.SpecConstant @sc2 = 2 : i8
spirv.SpecConstant @sc3 = 3 : i8
spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>

// CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
}

// -----

spirv.module Logical GLSL450 {
Expand Down Expand Up @@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
// -----

spirv.module Logical GLSL450 {
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
}

Expand Down
24 changes: 24 additions & 0 deletions mlir/test/Target/SPIRV/global-variable.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: spirv.SpecConstant @sc = 1 : i8
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
spirv.SpecConstant @sc = 1 : i8

spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
// CHECK: spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
spirv.SpecConstant @sc0 = 1 : i8
spirv.SpecConstant @sc1 = 2 : i8
spirv.SpecConstant @sc2 = 3 : i8

spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>

spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
}

// -----

spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
spirv.func @foo() "None" {
Expand Down