Skip to content

[mlir][spirv] Linearize ND vectors. #80451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from

Conversation

Hardcode84
Copy link
Contributor

SPIR-V only supports 1D vectors, try to linearize vector in type converter.

Not sure is this is a right approach or we should have a dedicated vector linearization pass.

SPIR-V only supports 1D vectors, try to linearize vector in type converter.
@llvmbot
Copy link
Member

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir-spirv

Author: Ivan Butygin (Hardcode84)

Changes

SPIR-V only supports 1D vectors, try to linearize vector in type converter.

Not sure is this is a right approach or we should have a dedicated vector linearization pass.


Full diff: https://github.com/llvm/llvm-project/pull/80451.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+5-4)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+4)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+2-2)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+2)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
     if (!srcType || srcType.getNumElements() == 1)
       return failure();
 
-    // arith.constant should only have vector or tenor types.
-    assert((isa<VectorType, RankedTensorType>(srcType)));
+    assert((isa<VectorType, RankedTensorType>(srcType) &&
+            "arith.constant should only have vector or tensor types"));
 
     Type dstType = getTypeConverter()->convertType(srcType);
     if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
                                             srcType.getElementType());
         dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       } else {
-        // TODO: add support for large vectors.
-        return failure();
+        dstAttrType =
+            VectorType::get(srcType.getNumElements(), srcType.getElementType());
+        dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       }
     }
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   if (type.getRank() <= 1 && type.getNumElements() == 1)
     return convertScalarType(targetEnv, options, scalarType, storageClass);
 
+  // Linearize ND vectors
+  if (type.getRank() > 1)
+    type = VectorType::get(type.getNumElements(), scalarType);
+
   if (!spirv::CompositeType::isValid(type)) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
   // expected-error@+1 {{failed to legalize operation 'arith.muli'}}
-  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+  %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
   %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
   // CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
   %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
   return
 }
 

@llvmbot
Copy link
Member

llvmbot commented Feb 2, 2024

@llvm/pr-subscribers-mlir

Author: Ivan Butygin (Hardcode84)

Changes

SPIR-V only supports 1D vectors, try to linearize vector in type converter.

Not sure is this is a right approach or we should have a dedicated vector linearization pass.


Full diff: https://github.com/llvm/llvm-project/pull/80451.diff

4 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+5-4)
  • (modified) mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp (+4)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir (+2-2)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+2)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
     if (!srcType || srcType.getNumElements() == 1)
       return failure();
 
-    // arith.constant should only have vector or tenor types.
-    assert((isa<VectorType, RankedTensorType>(srcType)));
+    assert((isa<VectorType, RankedTensorType>(srcType) &&
+            "arith.constant should only have vector or tensor types"));
 
     Type dstType = getTypeConverter()->convertType(srcType);
     if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
                                             srcType.getElementType());
         dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       } else {
-        // TODO: add support for large vectors.
-        return failure();
+        dstAttrType =
+            VectorType::get(srcType.getNumElements(), srcType.getElementType());
+        dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
       }
     }
 
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
   if (type.getRank() <= 1 && type.getNumElements() == 1)
     return convertScalarType(targetEnv, options, scalarType, storageClass);
 
+  // Linearize ND vectors
+  if (type.getRank() > 1)
+    type = VectorType::get(type.getNumElements(), scalarType);
+
   if (!spirv::CompositeType::isValid(type)) {
     LLVM_DEBUG(llvm::dbgs()
                << type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
     #spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
 } {
 
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
   // expected-error@+1 {{failed to legalize operation 'arith.muli'}}
-  %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+  %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
   return
 }
 
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
   %9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
   // CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
   %10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+  // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+  %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
   return
 }
 

@Hardcode84
Copy link
Contributor Author

Alternative approach #81159

@Hardcode84
Copy link
Contributor Author

Closed in favor of #81159

@Hardcode84 Hardcode84 closed this Feb 13, 2024
@Hardcode84 Hardcode84 deleted the spir-nd-vec branch February 13, 2024 12:31
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants