-
Notifications
You must be signed in to change notification settings - Fork 14.3k
[mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. #126294
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.
@llvm/pr-subscribers-mlir Author: Md Abdullah Shahneous Bari (mshahneo) ChangesAdd support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}. Full diff: https://github.com/llvm/llvm-project/pull/126294.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = loadOp.getVectorType();
- auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+ // Use the converted vector type instead of original (single element vector
+ // would get converted to scalar).
+ auto spirvVectorType = typeConverter.convertType(vectorType);
+ auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
return %0: vector<4xf32>
}
+
+// CHECK-LABEL: @vector_load_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+// CHECK: return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return %0: vector<1xf32>
+}
+
+
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+ %idx = arith.constant 0 : index
+ vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
|
@llvm/pr-subscribers-mlir-spirv Author: Md Abdullah Shahneous Bari (mshahneo) ChangesAdd support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}. Full diff: https://github.com/llvm/llvm-project/pull/126294.diff 2 Files Affected:
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = loadOp.getVectorType();
- auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
- rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+ // Use the converted vector type instead of original (single element vector
+ // would get converted to scalar).
+ auto spirvVectorType = typeConverter.convertType(vectorType);
+ auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
+ rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
castedAccessChain);
return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
spirv::StorageClass storageClass = attr.getValue();
auto vectorType = storeOp.getVectorType();
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
- Value castedAccessChain =
- rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+ // For single element vectors, we don't need to bitcast the access chain to
+ // the original vector type. Both is going to be the same, a pointer
+ // to a scalar.
+ Value castedAccessChain = (vectorType.getNumElements() == 1)
+ ? accessChain
+ : rewriter.create<spirv::BitcastOp>(
+ loc, vectorPtrType, accessChain);
+
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
adaptor.getValueToStore());
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
return %0: vector<4xf32>
}
+
+// CHECK-LABEL: @vector_load_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+// CHECK: %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+// CHECK: return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+ %idx = arith.constant 0 : index
+ %cst_0 = arith.constant 0.000000e+00 : f32
+ %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return %0: vector<1xf32>
+}
+
+
// CHECK-LABEL: @vector_load_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
return
}
+// CHECK-LABEL: @vector_store_single_elem
+// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
+// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+// CHECK: %[[C0:.+]] = arith.constant 0 : index
+// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
+// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
+// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+// CHECK: spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+ %idx = arith.constant 0 : index
+ vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+ return
+}
+
// CHECK-LABEL: @vector_store_2d
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>
|
…lvm#126294) Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.
Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.