Skip to content

Commit 1b31b50

Browse files
ergawyantiagainst
authored andcommitted
[MLIR][SPIRV] Extend _reference_of to support SpecConstantCompositeOp.
Adds support for SPIR-V composite speciailization constants to spv._reference_of. Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D88732
1 parent e00f189 commit 1b31b50

File tree

5 files changed

+114
-14
lines changed

5 files changed

+114
-14
lines changed

mlir/include/mlir/Dialect/SPIRV/SPIRVStructureOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ def SPV_ReferenceOfOp : SPV_Op<"_reference_of", [NoSideEffect]> {
472472
let summary = "Reference a specialization constant.";
473473

474474
let description = [{
475-
Specialization constant in module scope are defined using symbol names.
475+
Specialization constants in module scope are defined using symbol names.
476476
This op generates an SSA value that can be used to refer to the symbol
477477
within function scope for use in ops that expect an SSA value.
478478
This operation has no corresponding SPIR-V instruction; it's merely used

mlir/lib/Dialect/SPIRV/SPIRVOps.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2568,17 +2568,27 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
25682568
//===----------------------------------------------------------------------===//
25692569

25702570
static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) {
2571-
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(
2572-
SymbolTable::lookupNearestSymbolFrom(referenceOfOp.getParentOp(),
2573-
referenceOfOp.spec_const()));
2574-
if (!specConstOp) {
2575-
return referenceOfOp.emitOpError("expected spv.specConstant symbol");
2576-
}
2577-
if (referenceOfOp.reference().getType() !=
2578-
specConstOp.default_value().getType()) {
2571+
auto *specConstSym = SymbolTable::lookupNearestSymbolFrom(
2572+
referenceOfOp.getParentOp(), referenceOfOp.spec_const());
2573+
Type constType;
2574+
2575+
auto specConstOp = dyn_cast_or_null<spirv::SpecConstantOp>(specConstSym);
2576+
if (specConstOp)
2577+
constType = specConstOp.default_value().getType();
2578+
2579+
auto specConstCompositeOp =
2580+
dyn_cast_or_null<spirv::SpecConstantCompositeOp>(specConstSym);
2581+
if (specConstCompositeOp)
2582+
constType = specConstCompositeOp.type();
2583+
2584+
if (!specConstOp && !specConstCompositeOp)
2585+
return referenceOfOp.emitOpError(
2586+
"expected spv.specConstant or spv.SpecConstantComposite symbol");
2587+
2588+
if (referenceOfOp.reference().getType() != constType)
25792589
return referenceOfOp.emitOpError("result type mismatch with the referenced "
25802590
"specialization constant's type");
2581-
}
2591+
25822592
return success();
25832593
}
25842594

mlir/lib/Dialect/SPIRV/Serialization/Deserializer.cpp

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,11 @@ class Deserializer {
187187
return specConstMap.lookup(id);
188188
}
189189

190+
/// Gets the composite specialization constant with the given result <id>.
191+
spirv::SpecConstantCompositeOp getSpecConstantComposite(uint32_t id) {
192+
return specConstCompositeMap.lookup(id);
193+
}
194+
190195
/// Creates a spirv::SpecConstantOp.
191196
spirv::SpecConstantOp createSpecConstant(Location loc, uint32_t resultID,
192197
Attribute defaultValue);
@@ -461,9 +466,12 @@ class Deserializer {
461466
/// (and type) here. Later when it's used, we materialize the constant.
462467
DenseMap<uint32_t, std::pair<Attribute, Type>> constantMap;
463468

464-
// Result <id> to variable mapping.
469+
// Result <id> to spec constant mapping.
465470
DenseMap<uint32_t, spirv::SpecConstantOp> specConstMap;
466471

472+
// Result <id> to composite spec constant mapping.
473+
DenseMap<uint32_t, spirv::SpecConstantCompositeOp> specConstCompositeMap;
474+
467475
// Result <id> to variable mapping.
468476
DenseMap<uint32_t, spirv::GlobalVariableOp> globalVariableMap;
469477

@@ -1565,7 +1573,8 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
15651573
<< operands[0];
15661574
}
15671575

1568-
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(operands[1]));
1576+
auto resultID = operands[1];
1577+
auto symName = opBuilder.getStringAttr(getSpecConstantSymbol(resultID));
15691578

15701579
SmallVector<Attribute, 4> elements;
15711580
elements.reserve(operands.size() - 2);
@@ -1574,9 +1583,10 @@ Deserializer::processSpecConstantComposite(ArrayRef<uint32_t> operands) {
15741583
elements.push_back(opBuilder.getSymbolRefAttr(elementInfo));
15751584
}
15761585

1577-
opBuilder.create<spirv::SpecConstantCompositeOp>(
1586+
auto op = opBuilder.create<spirv::SpecConstantCompositeOp>(
15781587
unknownLoc, TypeAttr::get(resultType), symName,
15791588
opBuilder.getArrayAttr(elements));
1589+
specConstCompositeMap[resultID] = op;
15801590

15811591
return success();
15821592
}
@@ -2208,6 +2218,12 @@ Value Deserializer::getValue(uint32_t id) {
22082218
opBuilder.getSymbolRefAttr(constOp.getOperation()));
22092219
return referenceOfOp.reference();
22102220
}
2221+
if (auto constCompositeOp = getSpecConstantComposite(id)) {
2222+
auto referenceOfOp = opBuilder.create<spirv::ReferenceOfOp>(
2223+
unknownLoc, constCompositeOp.type(),
2224+
opBuilder.getSymbolRefAttr(constCompositeOp.getOperation()));
2225+
return referenceOfOp.reference();
2226+
}
22112227
if (auto undef = getUndefType(id)) {
22122228
return opBuilder.create<spirv::UndefOp>(unknownLoc, undef);
22132229
}

mlir/test/Dialect/SPIRV/Serialization/spec-constant.mlir

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
1212
// CHECK: spv.specConstant @sc_float spec_id(5) = 1.000000e+00 : f32
1313
spv.specConstant @sc_float spec_id(5) = 1. : f32
1414

15+
// CHECK: spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
16+
spv.specConstantComposite @scc (@sc_int, @sc_int) : !spv.array<2 x i32>
17+
1518
// CHECK-LABEL: @use
1619
spv.func @use() -> (i32) "None" {
1720
// We materialize a `spv._reference_of` op at every use of a
@@ -24,6 +27,43 @@ spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
2427
%1 = spv.IAdd %0, %0 : i32
2528
spv.ReturnValue %1 : i32
2629
}
30+
31+
// CHECK-LABEL: @use
32+
spv.func @use_composite() -> (i32) "None" {
33+
// We materialize a `spv._reference_of` op at every use of a
34+
// specialization constant in the deserializer. So two ops here.
35+
// CHECK: %[[USE1:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
36+
// CHECK: %[[ITM0:.*]] = spv.CompositeExtract %[[USE1]][0 : i32] : !spv.array<2 x i32>
37+
// CHECK: %[[USE2:.*]] = spv._reference_of @scc : !spv.array<2 x i32>
38+
// CHECK: %[[ITM1:.*]] = spv.CompositeExtract %[[USE2]][1 : i32] : !spv.array<2 x i32>
39+
// CHECK: spv.IAdd %[[ITM0]], %[[ITM1]]
40+
41+
%0 = spv._reference_of @scc : !spv.array<2 x i32>
42+
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<2 x i32>
43+
%2 = spv.CompositeExtract %0[1 : i32] : !spv.array<2 x i32>
44+
%3 = spv.IAdd %1, %2 : i32
45+
spv.ReturnValue %3 : i32
46+
}
47+
}
48+
49+
// -----
50+
51+
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], []> {
52+
53+
spv.specConstant @sc_f32_1 = 1.5 : f32
54+
spv.specConstant @sc_f32_2 = 2.5 : f32
55+
spv.specConstant @sc_f32_3 = 3.5 : f32
56+
57+
spv.specConstant @sc_i32_1 = 1 : i32
58+
59+
// CHECK: spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
60+
spv.specConstantComposite @scc_array (@sc_f32_1, @sc_f32_2, @sc_f32_3) : !spv.array<3 x f32>
61+
62+
// CHECK: spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
63+
spv.specConstantComposite @scc_struct (@sc_i32_1, @sc_f32_2, @sc_f32_3) : !spv.struct<i32, f32, f32>
64+
65+
// CHECK: spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3xf32>
66+
spv.specConstantComposite @scc_vector (@sc_f32_1, @sc_f32_2, @sc_f32_3) : vector<3 x f32>
2767
}
2868

2969
// -----

mlir/test/Dialect/SPIRV/structure-ops.mlir

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -496,13 +496,23 @@ spv.module Logical GLSL450 {
496496
spv.specConstant @sc2 = 42 : i64
497497
spv.specConstant @sc3 = 1.5 : f32
498498

499+
spv.specConstantComposite @scc (@sc1, @sc2, @sc3) : !spv.struct<i1, i64, f32>
500+
499501
// CHECK-LABEL: @reference
500502
spv.func @reference() -> i1 "None" {
501503
// CHECK: spv._reference_of @sc1 : i1
502504
%0 = spv._reference_of @sc1 : i1
503505
spv.ReturnValue %0 : i1
504506
}
505507

508+
// CHECK-LABEL: @reference_composite
509+
spv.func @reference_composite() -> i1 "None" {
510+
// CHECK: spv._reference_of @scc : !spv.struct<i1, i64, f32>
511+
%0 = spv._reference_of @scc : !spv.struct<i1, i64, f32>
512+
%1 = spv.CompositeExtract %0[0 : i32] : !spv.struct<i1, i64, f32>
513+
spv.ReturnValue %1 : i1
514+
}
515+
506516
// CHECK-LABEL: @initialize
507517
spv.func @initialize() -> i64 "None" {
508518
// CHECK: spv._reference_of @sc2 : i64
@@ -534,9 +544,21 @@ func @reference_of() {
534544

535545
// -----
536546

547+
spv.specConstant @sc = 5 : i32
548+
spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
549+
550+
func @reference_of_composite() {
551+
// CHECK: spv._reference_of @scc : !spv.array<1 x i32>
552+
%0 = spv._reference_of @scc : !spv.array<1 x i32>
553+
%1 = spv.CompositeExtract %0[0 : i32] : !spv.array<1 x i32>
554+
return
555+
}
556+
557+
// -----
558+
537559
spv.module Logical GLSL450 {
538560
spv.func @foo() -> () "None" {
539-
// expected-error @+1 {{expected spv.specConstant symbol}}
561+
// expected-error @+1 {{expected spv.specConstant or spv.SpecConstantComposite symbol}}
540562
%0 = spv._reference_of @sc : i32
541563
spv.Return
542564
}
@@ -555,6 +577,18 @@ spv.module Logical GLSL450 {
555577

556578
// -----
557579

580+
spv.module Logical GLSL450 {
581+
spv.specConstant @sc = 42 : i32
582+
spv.specConstantComposite @scc (@sc) : !spv.array<1 x i32>
583+
spv.func @foo() -> () "None" {
584+
// expected-error @+1 {{result type mismatch with the referenced specialization constant's type}}
585+
%0 = spv._reference_of @scc : f32
586+
spv.Return
587+
}
588+
}
589+
590+
// -----
591+
558592
//===----------------------------------------------------------------------===//
559593
// spv.specConstant
560594
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)