Skip to content

Commit 5e54319

Browse files
[mlir][spirv] Support spec constants as GlobalVar initializer (#75660)
Changes include: - spirv serialization and deserialization needs handling in cases when GlobalVariableOp initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently even though it allows SpecConst, it only looked up in for GlobalVariable Map to find initializer symbol reference, change is fixing this and extending the support to SpecConstantComposite as an initializer. - Adds tests to make sure GlobalVariable can be initialized using specialized constants. --------- Co-authored-by: Lei Zhang <[email protected]>
1 parent 651a42f commit 5e54319

File tree

5 files changed

+71
-14
lines changed

5 files changed

+71
-14
lines changed

mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1162,10 +1162,11 @@ LogicalResult spirv::GlobalVariableOp::verify() {
11621162
// TODO: Currently only variable initialization with specialization
11631163
// constants and other variables is supported. They could be normal
11641164
// constants in the module scope as well.
1165-
if (!initOp ||
1166-
!isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
1165+
if (!initOp || !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp,
1166+
spirv::SpecConstantCompositeOp>(initOp)) {
11671167
return emitOpError("initializer must be result of a "
1168-
"spirv.SpecConstant or spirv.GlobalVariable op");
1168+
"spirv.SpecConstant or spirv.GlobalVariable or "
1169+
"spirv.SpecConstantCompositeOp op");
11691170
}
11701171
}
11711172

mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -637,14 +637,22 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
637637

638638
// Initializer.
639639
FlatSymbolRefAttr initializer = nullptr;
640+
640641
if (wordIndex < operands.size()) {
641-
auto initializerOp = getGlobalVariable(operands[wordIndex]);
642-
if (!initializerOp) {
642+
Operation *op = nullptr;
643+
644+
if (auto initOp = getGlobalVariable(operands[wordIndex]))
645+
op = initOp;
646+
else if (auto initOp = getSpecConstant(operands[wordIndex]))
647+
op = initOp;
648+
else if (auto initOp = getSpecConstantComposite(operands[wordIndex]))
649+
op = initOp;
650+
else
643651
return emitError(unknownLoc, "unknown <id> ")
644652
<< operands[wordIndex] << "used as initializer";
645-
}
653+
654+
initializer = SymbolRefAttr::get(op);
646655
wordIndex++;
647-
initializer = SymbolRefAttr::get(initializerOp.getOperation());
648656
}
649657
if (wordIndex != operands.size()) {
650658
return emitError(unknownLoc,

mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -383,20 +383,31 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
383383
operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
384384

385385
// Encode initialization.
386-
if (auto initializer = varOp.getInitializer()) {
387-
auto initializerID = getVariableID(*initializer);
388-
if (!initializerID) {
386+
StringRef initAttrName = varOp.getInitializerAttrName().getValue();
387+
if (std::optional<StringRef> initSymbolName = varOp.getInitializer()) {
388+
uint32_t initializerID = 0;
389+
auto initRef = varOp->getAttrOfType<FlatSymbolRefAttr>(initAttrName);
390+
Operation *initOp = SymbolTable::lookupNearestSymbolFrom(
391+
varOp->getParentOp(), initRef.getAttr());
392+
393+
// Check if initializer is GlobalVariable or SpecConstant* cases.
394+
if (isa<spirv::GlobalVariableOp>(initOp))
395+
initializerID = getVariableID(*initSymbolName);
396+
else
397+
initializerID = getSpecConstID(*initSymbolName);
398+
399+
if (!initializerID)
389400
return emitError(varOp.getLoc(),
390401
"invalid usage of undefined variable as initializer");
391-
}
402+
392403
operands.push_back(initializerID);
393-
elidedAttrs.push_back("initializer");
404+
elidedAttrs.push_back(initAttrName);
394405
}
395406

396407
if (failed(emitDebugLine(typesGlobalValues, varOp.getLoc())))
397408
return failure();
398409
encodeInstructionInto(typesGlobalValues, spirv::Opcode::OpVariable, operands);
399-
elidedAttrs.push_back("initializer");
410+
elidedAttrs.push_back(initAttrName);
400411

401412
// Encode decorations.
402413
for (auto attr : varOp->getAttrs()) {

mlir/test/Dialect/SPIRV/IR/structure-ops.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
349349
// CHECK: spirv.GlobalVariable @var initializer(@sc)
350350
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
351351

352+
353+
// -----
354+
// Allow SpecConstantComposite as initializer
355+
spirv.module Logical GLSL450 {
356+
spirv.SpecConstant @sc1 = 1 : i8
357+
spirv.SpecConstant @sc2 = 2 : i8
358+
spirv.SpecConstant @sc3 = 3 : i8
359+
spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
360+
361+
// CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
362+
spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
363+
}
364+
352365
// -----
353366

354367
spirv.module Logical GLSL450 {
@@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
410423
// -----
411424

412425
spirv.module Logical GLSL450 {
413-
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
426+
// expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
414427
spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
415428
}
416429

mlir/test/Target/SPIRV/global-variable.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
2323

2424
// -----
2525

26+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
27+
// CHECK: spirv.SpecConstant @sc = 1 : i8
28+
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
29+
spirv.SpecConstant @sc = 1 : i8
30+
31+
spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
32+
}
33+
34+
// -----
35+
36+
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
37+
// CHECK: spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
38+
// CHECK-NEXT: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
39+
spirv.SpecConstant @sc0 = 1 : i8
40+
spirv.SpecConstant @sc1 = 2 : i8
41+
spirv.SpecConstant @sc2 = 3 : i8
42+
43+
spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
44+
45+
spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
46+
}
47+
48+
// -----
49+
2650
spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
2751
spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
2852
spirv.func @foo() "None" {

0 commit comments

Comments
 (0)