Skip to content

[mlir][spirv] Fix spirv dialect to support Specialization constants as GlobalVar initializer #75660

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 6, 2024

Conversation

drprajap
Copy link
Contributor

Changes include:

  • spirv serialization and deserialization needs handling in cases when GlobalVariableOp initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently even though it allows SpecConst, it only looked up in for GlobalVariable Map to find initializer symbol reference, change is fixing this and extending the support to SpecConstantComposite as an initializer.
  • Adds tests to make sure GlobalVariable can be initialized using specialized constants.

…n GlobalVar initializer

Changes include:
	- spirv serialization and deserialization needs handling in cases when GlobalVariableOp
	  initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently
	  even though it allows SpecConst, it only looked up in for GlobalVariable Map to find
	  initializer symbol reference, change is fixing this and extending the support to
	  SpecConstantComposite as an initializer.
	- Adds tests to make sure GlobalVariable can be initialzed using specialized constants.
@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2023

@llvm/pr-subscribers-mlir

Author: Dimple Prajapati (drprajap)

Changes

Changes include:

  • spirv serialization and deserialization needs handling in cases when GlobalVariableOp initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently even though it allows SpecConst, it only looked up in for GlobalVariable Map to find initializer symbol reference, change is fixing this and extending the support to SpecConstantComposite as an initializer.
  • Adds tests to make sure GlobalVariable can be initialized using specialized constants.

Full diff: https://github.com/llvm/llvm-project/pull/75660.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+3-2)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+13-6)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+16-6)
  • (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+14-1)
  • (modified) mlir/test/Target/SPIRV/global-variable.mlir (+24)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..66f1d6b2e12206 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1163,9 +1163,10 @@ LogicalResult spirv::GlobalVariableOp::verify() {
     // constants and other variables is supported. They could be normal
     // constants in the module scope as well.
     if (!initOp ||
-        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
+        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
       return emitOpError("initializer must be result of a "
-                         "spirv.SpecConstant or spirv.GlobalVariable op");
+                         "spirv.SpecConstant or spirv.GlobalVariable or "
+                         "spirv.SpecConstantCompositeOp op");
     }
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..ccea690a7c3ded 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -637,14 +637,21 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
 
   // Initializer.
   FlatSymbolRefAttr initializer = nullptr;
+  
   if (wordIndex < operands.size()) {
-    auto initializerOp = getGlobalVariable(operands[wordIndex]);
-    if (!initializerOp) {
-      return emitError(unknownLoc, "unknown <id> ")
-             << operands[wordIndex] << "used as initializer";
-    }
+    Operation *op = nullptr;
+
+    if((op = getGlobalVariable(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
+    else if ((op  = getSpecConstant(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
+    else if((op = getSpecConstantComposite(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+    else
+      return emitError(unknownLoc,
+                        "Unknown op used as initializer");
+
     wordIndex++;
-    initializer = SymbolRefAttr::get(initializerOp.getOperation());
   }
   if (wordIndex != operands.size()) {
     return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..bd1fc7a84fbd6a 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -383,12 +383,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
 
   // Encode initialization.
-  if (auto initializer = varOp.getInitializer()) {
-    auto initializerID = getVariableID(*initializer);
-    if (!initializerID) {
-      return emitError(varOp.getLoc(),
-                       "invalid usage of undefined variable as initializer");
-    }
+  if (auto initializerName = varOp.getInitializer()) {
+
+    uint32_t initializerID = 0;
+    auto init = varOp->getAttrOfType<FlatSymbolRefAttr>("initializer");
+    Operation *initOp = SymbolTable::lookupNearestSymbolFrom(varOp->getParentOp(), init.getAttr());
+
+    // Check if initializer is GlobalVariable or SpecConstant/SpecConstantComposite
+    if(isa<spirv::GlobalVariableOp>(initOp))
+      initializerID = getVariableID(*initializerName);
+    else
+      initializerID = getSpecConstID(*initializerName);
+
+    if (!initializerID)
+        return emitError(varOp.getLoc(),
+                      "invalid usage of undefined variable as initializer");
+   
     operands.push_back(initializerID);
     elidedAttrs.push_back("initializer");
   }
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..77b605050e1442 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
 // CHECK: spirv.GlobalVariable @var initializer(@sc)
 spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
 
+
+// -----
+// Allow SpecConstantComposite as initializer
+  spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc1 = 1 : i8
+  spirv.SpecConstant @sc2 = 2 : i8
+  spirv.SpecConstant @sc3 = 3 : i8
+  spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
+
+  // CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+}
+
 // -----
 
 spirv.module Logical GLSL450 {
@@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
 // -----
 
 spirv.module Logical GLSL450 {
-  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
+  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
   spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
 }
 
diff --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir
index 66d0782c205c7d..f22d2a9b3d14d9 100644
--- a/mlir/test/Target/SPIRV/global-variable.mlir
+++ b/mlir/test/Target/SPIRV/global-variable.mlir
@@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
 // -----
 
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstant @sc = 1 : i8
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+  spirv.SpecConstant @sc = 1 : i8
+
+  spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+  spirv.SpecConstant @sc0 = 1 : i8
+  spirv.SpecConstant @sc1 = 2 : i8
+  spirv.SpecConstant @sc2 = 3 : i8
+
+  spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
   spirv.func @foo() "None" {

@llvmbot
Copy link
Member

llvmbot commented Dec 15, 2023

@llvm/pr-subscribers-mlir-spirv

Author: Dimple Prajapati (drprajap)

Changes

Changes include:

  • spirv serialization and deserialization needs handling in cases when GlobalVariableOp initializer is defined using spirv SpecConstant or SpecConstantComposite op, currently even though it allows SpecConst, it only looked up in for GlobalVariable Map to find initializer symbol reference, change is fixing this and extending the support to SpecConstantComposite as an initializer.
  • Adds tests to make sure GlobalVariable can be initialized using specialized constants.

Full diff: https://github.com/llvm/llvm-project/pull/75660.diff

5 Files Affected:

  • (modified) mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp (+3-2)
  • (modified) mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp (+13-6)
  • (modified) mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp (+16-6)
  • (modified) mlir/test/Dialect/SPIRV/IR/structure-ops.mlir (+14-1)
  • (modified) mlir/test/Target/SPIRV/global-variable.mlir (+24)
diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
index 2a1d083308282a..66f1d6b2e12206 100644
--- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
+++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp
@@ -1163,9 +1163,10 @@ LogicalResult spirv::GlobalVariableOp::verify() {
     // constants and other variables is supported. They could be normal
     // constants in the module scope as well.
     if (!initOp ||
-        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp>(initOp)) {
+        !isa<spirv::GlobalVariableOp, spirv::SpecConstantOp, spirv::SpecConstantCompositeOp>(initOp)) {
       return emitOpError("initializer must be result of a "
-                         "spirv.SpecConstant or spirv.GlobalVariable op");
+                         "spirv.SpecConstant or spirv.GlobalVariable or "
+                         "spirv.SpecConstantCompositeOp op");
     }
   }
 
diff --git a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
index 89e2e7ad52fa7d..ccea690a7c3ded 100644
--- a/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
+++ b/mlir/lib/Target/SPIRV/Deserialization/Deserializer.cpp
@@ -637,14 +637,21 @@ spirv::Deserializer::processGlobalVariable(ArrayRef<uint32_t> operands) {
 
   // Initializer.
   FlatSymbolRefAttr initializer = nullptr;
+  
   if (wordIndex < operands.size()) {
-    auto initializerOp = getGlobalVariable(operands[wordIndex]);
-    if (!initializerOp) {
-      return emitError(unknownLoc, "unknown <id> ")
-             << operands[wordIndex] << "used as initializer";
-    }
+    Operation *op = nullptr;
+
+    if((op = getGlobalVariable(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::GlobalVariableOp>(op)).getOperation());
+    else if ((op  = getSpecConstant(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantOp>(op)).getOperation());
+    else if((op = getSpecConstantComposite(operands[wordIndex])))
+      initializer = SymbolRefAttr::get((dyn_cast<spirv::SpecConstantCompositeOp>(op)).getOperation());
+    else
+      return emitError(unknownLoc,
+                        "Unknown op used as initializer");
+
     wordIndex++;
-    initializer = SymbolRefAttr::get(initializerOp.getOperation());
   }
   if (wordIndex != operands.size()) {
     return emitError(unknownLoc,
diff --git a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
index 44538c38a41b83..bd1fc7a84fbd6a 100644
--- a/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
+++ b/mlir/lib/Target/SPIRV/Serialization/SerializeOps.cpp
@@ -383,12 +383,22 @@ Serializer::processGlobalVariableOp(spirv::GlobalVariableOp varOp) {
   operands.push_back(static_cast<uint32_t>(varOp.storageClass()));
 
   // Encode initialization.
-  if (auto initializer = varOp.getInitializer()) {
-    auto initializerID = getVariableID(*initializer);
-    if (!initializerID) {
-      return emitError(varOp.getLoc(),
-                       "invalid usage of undefined variable as initializer");
-    }
+  if (auto initializerName = varOp.getInitializer()) {
+
+    uint32_t initializerID = 0;
+    auto init = varOp->getAttrOfType<FlatSymbolRefAttr>("initializer");
+    Operation *initOp = SymbolTable::lookupNearestSymbolFrom(varOp->getParentOp(), init.getAttr());
+
+    // Check if initializer is GlobalVariable or SpecConstant/SpecConstantComposite
+    if(isa<spirv::GlobalVariableOp>(initOp))
+      initializerID = getVariableID(*initializerName);
+    else
+      initializerID = getSpecConstID(*initializerName);
+
+    if (!initializerID)
+        return emitError(varOp.getLoc(),
+                      "invalid usage of undefined variable as initializer");
+   
     operands.push_back(initializerID);
     elidedAttrs.push_back("initializer");
   }
diff --git a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
index 722e4434aeaf9f..77b605050e1442 100644
--- a/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
+++ b/mlir/test/Dialect/SPIRV/IR/structure-ops.mlir
@@ -349,6 +349,19 @@ spirv.SpecConstant @sc = 4.0 : f32
 // CHECK: spirv.GlobalVariable @var initializer(@sc)
 spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<f32, Private>
 
+
+// -----
+// Allow SpecConstantComposite as initializer
+  spirv.module Logical GLSL450 {
+  spirv.SpecConstant @sc1 = 1 : i8
+  spirv.SpecConstant @sc2 = 2 : i8
+  spirv.SpecConstant @sc3 = 3 : i8
+  spirv.SpecConstantComposite @scc (@sc1, @sc2, @sc3) : !spirv.array<3 x i8>
+
+  // CHECK: spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Private>
+}
+
 // -----
 
 spirv.module Logical GLSL450 {
@@ -410,7 +423,7 @@ spirv.module Logical GLSL450 {
 // -----
 
 spirv.module Logical GLSL450 {
-  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable op}}
+  // expected-error @+1 {{op initializer must be result of a spirv.SpecConstant or spirv.GlobalVariable or spirv.SpecConstantCompositeOp op}}
   spirv.GlobalVariable @var0 initializer(@var1) : !spirv.ptr<f32, Private>
 }
 
diff --git a/mlir/test/Target/SPIRV/global-variable.mlir b/mlir/test/Target/SPIRV/global-variable.mlir
index 66d0782c205c7d..f22d2a9b3d14d9 100644
--- a/mlir/test/Target/SPIRV/global-variable.mlir
+++ b/mlir/test/Target/SPIRV/global-variable.mlir
@@ -23,6 +23,30 @@ spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
 
 // -----
 
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstant @sc = 1 : i8
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+  spirv.SpecConstant @sc = 1 : i8
+
+  spirv.GlobalVariable @var initializer(@sc) : !spirv.ptr<i8, Uniform>
+}
+
+// -----
+
+spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
+  // CHECK:         spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+  // CHECK-NEXT:    spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+  spirv.SpecConstant @sc0 = 1 : i8
+  spirv.SpecConstant @sc1 = 2 : i8
+  spirv.SpecConstant @sc2 = 3 : i8
+
+  spirv.SpecConstantComposite @scc (@sc0, @sc1, @sc2) : !spirv.array<3 x i8>
+
+  spirv.GlobalVariable @var initializer(@scc) : !spirv.ptr<!spirv.array<3 x i8>, Uniform>
+}
+
+// -----
+
 spirv.module Logical GLSL450 requires #spirv.vce<v1.0, [Shader], []> {
   spirv.GlobalVariable @globalInvocationID built_in("GlobalInvocationId") : !spirv.ptr<vector<3xi32>, Input>
   spirv.func @foo() "None" {

@drprajap
Copy link
Contributor Author

@mshahneo , @silee2 - FYI

Copy link

github-actions bot commented Dec 15, 2023

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but please wait for @antiagainst to take a look to

@drprajap
Copy link
Contributor Author

LGTM but please wait for @antiagainst to take a look to

Thank you :), sure I will wait for 2nd approval.

@antiagainst
Copy link
Member

Thanks for the contribution! Sorry about the delay in reviews. I do have some style-ish comments but given this has been a while I just went ahead fixed them and will land this path directly. Thanks again!

@antiagainst antiagainst merged commit 5e54319 into llvm:main Jan 6, 2024
@drprajap
Copy link
Contributor Author

drprajap commented Jan 8, 2024

Thanks for the contribution! Sorry about the delay in reviews. I do have some style-ish comments but given this has been a while I just went ahead fixed them and will land this path directly. Thanks again!

Thanks for the review and merging it. No worries about the delay, it was holidays time, so understood..

justinfargnoli pushed a commit to justinfargnoli/llvm-project that referenced this pull request Jan 28, 2024
…5660)

Changes include:

- spirv serialization and deserialization needs handling in cases when
GlobalVariableOp initializer is defined using spirv SpecConstant or
SpecConstantComposite op, currently even though it allows SpecConst, it
only looked up in for GlobalVariable Map to find initializer symbol
reference, change is fixing this and extending the support to
SpecConstantComposite as an initializer.
- Adds tests to make sure GlobalVariable can be initialized using
specialized constants.

---------

Co-authored-by: Lei Zhang <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants