-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][spirv] Linearize ND vectors. #80451
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
SPIR-V only supports 1D vectors, try to linearize vector in type converter.
@llvm/pr-subscribers-mlir-spirv Author: Ivan Butygin (Hardcode84) ChangesSPIR-V only supports 1D vectors, try to linearize vector in type converter. Not sure is this is a right approach or we should have a dedicated vector linearization pass. Full diff: https://github.com/llvm/llvm-project/pull/80451.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
if (!srcType || srcType.getNumElements() == 1)
return failure();
- // arith.constant should only have vector or tenor types.
- assert((isa<VectorType, RankedTensorType>(srcType)));
+ assert((isa<VectorType, RankedTensorType>(srcType) &&
+ "arith.constant should only have vector or tensor types"));
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
} else {
- // TODO: add support for large vectors.
- return failure();
+ dstAttrType =
+ VectorType::get(srcType.getNumElements(), srcType.getElementType());
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
if (type.getRank() <= 1 && type.getNumElements() == 1)
return convertScalarType(targetEnv, options, scalarType, storageClass);
+ // Linearize ND vectors
+ if (type.getRank() > 1)
+ type = VectorType::get(type.getNumElements(), scalarType);
+
if (!spirv::CompositeType::isValid(type)) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
// expected-error@+1 {{failed to legalize operation 'arith.muli'}}
- %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+ %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
return
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
%9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
// CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
%10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+ // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
return
}
|
@llvm/pr-subscribers-mlir Author: Ivan Butygin (Hardcode84) ChangesSPIR-V only supports 1D vectors, try to linearize vector in type converter. Not sure is this is a right approach or we should have a dedicated vector linearization pass. Full diff: https://github.com/llvm/llvm-project/pull/80451.diff 4 Files Affected:
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index edf81bd7a8f39..b55eda69f99ec 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -229,8 +229,8 @@ struct ConstantCompositeOpPattern final
if (!srcType || srcType.getNumElements() == 1)
return failure();
- // arith.constant should only have vector or tenor types.
- assert((isa<VectorType, RankedTensorType>(srcType)));
+ assert((isa<VectorType, RankedTensorType>(srcType) &&
+ "arith.constant should only have vector or tensor types"));
Type dstType = getTypeConverter()->convertType(srcType);
if (!dstType)
@@ -250,8 +250,9 @@ struct ConstantCompositeOpPattern final
srcType.getElementType());
dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
} else {
- // TODO: add support for large vectors.
- return failure();
+ dstAttrType =
+ VectorType::get(srcType.getNumElements(), srcType.getElementType());
+ dstElementsAttr = dstElementsAttr.reshape(dstAttrType);
}
}
diff --git a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
index 2b79c8022b8e8..1ce7dff8ff0e4 100644
--- a/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
+++ b/mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp
@@ -330,6 +330,10 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
if (type.getRank() <= 1 && type.getNumElements() == 1)
return convertScalarType(targetEnv, options, scalarType, storageClass);
+ // Linearize ND vectors
+ if (type.getRank() > 1)
+ type = VectorType::get(type.getNumElements(), scalarType);
+
if (!spirv::CompositeType::isValid(type)) {
LLVM_DEBUG(llvm::dbgs()
<< type << " illegal: not a valid composite type\n");
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
index 0d92a8e676d85..551b036ba85e5 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir
@@ -26,9 +26,9 @@ module attributes {
#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>
} {
-func.func @unsupported_2x2elem_vector(%arg0: vector<2x2xi32>) {
+func.func @unsupported_2x2elem_vector(%arg0: vector<3x5xi32>) {
// expected-error@+1 {{failed to legalize operation 'arith.muli'}}
- %2 = arith.muli %arg0, %arg0: vector<2x2xi32>
+ %2 = arith.muli %arg0, %arg0: vector<3x5xi32>
return
}
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index ae47ae36ca51c..4a2ef1f0275c6 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -556,6 +556,8 @@ func.func @constant() {
%9 = arith.constant dense<[[1, 2], [3, 4], [5, 6]]> : tensor<3x2xi32>
// CHECK: spirv.Constant dense<{{\[}}1, 2, 3, 4, 5, 6]> : tensor<6xi32> : !spirv.array<6 x i32>
%10 = arith.constant dense<[1, 2, 3, 4, 5, 6]> : tensor<6xi32>
+ // CHECK: spirv.Constant dense<[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]> : vector<4xf32>
+ %11 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : vector<2x2xf32>
return
}
|
Alternative approach #81159 |
Closed in favor of #81159 |
SPIR-V only supports 1D vectors, try to linearize vector in type converter.
Not sure is this is a right approach or we should have a dedicated vector linearization pass.