Skip to content

[mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. #126294

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 7, 2025

Conversation

mshahneo
Copy link
Contributor

@mshahneo mshahneo commented Feb 7, 2025

Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.

Add support for single element vector{load|store} lowering to SPIR-V.
Since, SPIR-V converts single element vector to scalars, it needs special
attention for vector{load|store} lowering to spirv{load|store}.
@mshahneo mshahneo requested review from kuhar and removed request for kuhar February 7, 2025 19:30
@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-mlir

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.


Full diff: https://github.com/llvm/llvm-project/pull/126294.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+23-6)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+39)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
 
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = loadOp.getVectorType();
-    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+    // Use the converted vector type instead of original (single element vector
+    // would get converted to scalar).
+    auto spirvVectorType = typeConverter.convertType(vectorType);
+    auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
                                                castedAccessChain);
 
     return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
                                                 adaptor.getValueToStore());
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
   return %0: vector<4xf32>
 }
 
+
+// CHECK-LABEL: @vector_load_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+//       CHECK:   %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+//       CHECK:   return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
 //       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>
+//       CHECK:  %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:  %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+//       CHECK:  %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:  %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:  %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:  %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+//       CHECK:  spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+  %idx = arith.constant 0 : index
+  vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>

@llvmbot
Copy link
Member

llvmbot commented Feb 7, 2025

@llvm/pr-subscribers-mlir-spirv

Author: Md Abdullah Shahneous Bari (mshahneo)

Changes

Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.


Full diff: https://github.com/llvm/llvm-project/pull/126294.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp (+23-6)
  • (modified) mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir (+39)
diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
index 1ecb892a4ea9297..bca77ba68fbd181 100644
--- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
+++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
 
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = loadOp.getVectorType();
-    auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
-    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
+    // Use the converted vector type instead of original (single element vector
+    // would get converted to scalar).
+    auto spirvVectorType = typeConverter.convertType(vectorType);
+    auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
+    rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
                                                castedAccessChain);
 
     return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
     spirv::StorageClass storageClass = attr.getValue();
     auto vectorType = storeOp.getVectorType();
     auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
-    Value castedAccessChain =
-        rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
+
+    // For single element vectors, we don't need to bitcast the access chain to
+    // the original vector type. Both is going to be the same, a pointer
+    // to a scalar.
+    Value castedAccessChain = (vectorType.getNumElements() == 1)
+                                  ? accessChain
+                                  : rewriter.create<spirv::BitcastOp>(
+                                        loc, vectorPtrType, accessChain);
+
     rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
                                                 adaptor.getValueToStore());
 
diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
index 3f0bf1962e299b0..4701ac5d960096d 100644
--- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
+++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
   return %0: vector<4xf32>
 }
 
+
+// CHECK-LABEL: @vector_load_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
+//       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:   %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:   %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:   %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:   %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:   %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
+//       CHECK:   %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
+//       CHECK:   %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
+//       CHECK:   return %[[R0]] : vector<1xf32>
+func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
+  %idx = arith.constant 0 : index
+  %cst_0 = arith.constant 0.000000e+00 : f32
+  %0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return %0: vector<1xf32>
+}
+
+
 // CHECK-LABEL: @vector_load_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
 //       CHECK:   %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
   return
 }
 
+// CHECK-LABEL: @vector_store_single_elem
+//  CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
+//  CHECK-SAME:  %[[ARG1:.*]]: vector<1xf32>
+//       CHECK:  %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
+//       CHECK:  %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
+//       CHECK:  %[[C0:.+]] = arith.constant 0 : index
+//       CHECK:  %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
+//       CHECK:  %[[CST1:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST2:.+]] = spirv.Constant 0 : i32
+//       CHECK:  %[[CST3:.+]] = spirv.Constant 1 : i32
+//       CHECK:  %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
+//       CHECK:  spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
+func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
+  %idx = arith.constant 0 : index
+  vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
+  return
+}
+
 // CHECK-LABEL: @vector_store_2d
 //  CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
 //  CHECK-SAME:  %[[ARG1:.*]]: vector<4xf32>

@kuhar kuhar requested a review from andfau-amd February 7, 2025 20:27
@mshahneo mshahneo merged commit 8b12acd into llvm:main Feb 7, 2025
11 checks passed
Icohedron pushed a commit to Icohedron/llvm-project that referenced this pull request Feb 11, 2025
…lvm#126294)

Add support for single element vector{load|store} lowering to SPIR-V.
Since, SPIR-V converts single element vector to scalars, it needs
special attention for vector{load|store} lowering to spirv{load|store}.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants