Skip to content

Commit 8b12acd

Browse files
authored
[mlir][vector][spirv] Handle 1-element vector.{load|store} lowering. (#126294)
Add support for single element vector{load|store} lowering to SPIR-V. Since, SPIR-V converts single element vector to scalars, it needs special attention for vector{load|store} lowering to spirv{load|store}.
1 parent 2e0c093 commit 8b12acd

File tree

2 files changed

+62
-6
lines changed

2 files changed

+62
-6
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -770,10 +770,20 @@ struct VectorLoadOpConverter final
770770

771771
spirv::StorageClass storageClass = attr.getValue();
772772
auto vectorType = loadOp.getVectorType();
773-
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
774-
Value castedAccessChain =
775-
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
776-
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, vectorType,
773+
// Use the converted vector type instead of original (single element vector
774+
// would get converted to scalar).
775+
auto spirvVectorType = typeConverter.convertType(vectorType);
776+
auto vectorPtrType = spirv::PointerType::get(spirvVectorType, storageClass);
777+
778+
// For single element vectors, we don't need to bitcast the access chain to
779+
// the original vector type. Both is going to be the same, a pointer
780+
// to a scalar.
781+
Value castedAccessChain = (vectorType.getNumElements() == 1)
782+
? accessChain
783+
: rewriter.create<spirv::BitcastOp>(
784+
loc, vectorPtrType, accessChain);
785+
786+
rewriter.replaceOpWithNewOp<spirv::LoadOp>(loadOp, spirvVectorType,
777787
castedAccessChain);
778788

779789
return success();
@@ -806,8 +816,15 @@ struct VectorStoreOpConverter final
806816
spirv::StorageClass storageClass = attr.getValue();
807817
auto vectorType = storeOp.getVectorType();
808818
auto vectorPtrType = spirv::PointerType::get(vectorType, storageClass);
809-
Value castedAccessChain =
810-
rewriter.create<spirv::BitcastOp>(loc, vectorPtrType, accessChain);
819+
820+
// For single element vectors, we don't need to bitcast the access chain to
821+
// the original vector type. Both is going to be the same, a pointer
822+
// to a scalar.
823+
Value castedAccessChain = (vectorType.getNumElements() == 1)
824+
? accessChain
825+
: rewriter.create<spirv::BitcastOp>(
826+
loc, vectorPtrType, accessChain);
827+
811828
rewriter.replaceOpWithNewOp<spirv::StoreOp>(storeOp, castedAccessChain,
812829
adaptor.getValueToStore());
813830

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1004,6 +1004,27 @@ func.func @vector_load(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>
10041004
return %0: vector<4xf32>
10051005
}
10061006

1007+
1008+
// CHECK-LABEL: @vector_load_single_elem
1009+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>)
1010+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
1011+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
1012+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
1013+
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
1014+
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
1015+
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
1016+
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S1]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32
1017+
// CHECK: %[[S5:.+]] = spirv.Load "StorageBuffer" %[[S4]] : f32
1018+
// CHECK: %[[R0:.+]] = builtin.unrealized_conversion_cast %[[S5]] : f32 to vector<1xf32>
1019+
// CHECK: return %[[R0]] : vector<1xf32>
1020+
func.func @vector_load_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<1xf32> {
1021+
%idx = arith.constant 0 : index
1022+
%cst_0 = arith.constant 0.000000e+00 : f32
1023+
%0 = vector.load %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
1024+
return %0: vector<1xf32>
1025+
}
1026+
1027+
10071028
// CHECK-LABEL: @vector_load_2d
10081029
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>) -> vector<4xf32> {
10091030
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4x4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<16 x f32, stride=4> [0])>, StorageBuffer>
@@ -1046,6 +1067,24 @@ func.func @vector_store(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer
10461067
return
10471068
}
10481069

1070+
// CHECK-LABEL: @vector_store_single_elem
1071+
// CHECK-SAME: (%[[ARG0:.*]]: memref<4xf32, #spirv.storage_class<StorageBuffer>>
1072+
// CHECK-SAME: %[[ARG1:.*]]: vector<1xf32>
1073+
// CHECK: %[[S0:.+]] = builtin.unrealized_conversion_cast %[[ARG0]] : memref<4xf32, #spirv.storage_class<StorageBuffer>> to !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>
1074+
// CHECK: %[[S1:.+]] = builtin.unrealized_conversion_cast %[[ARG1]] : vector<1xf32> to f32
1075+
// CHECK: %[[C0:.+]] = arith.constant 0 : index
1076+
// CHECK: %[[S2:.+]] = builtin.unrealized_conversion_cast %[[C0]] : index to i32
1077+
// CHECK: %[[CST1:.+]] = spirv.Constant 0 : i32
1078+
// CHECK: %[[CST2:.+]] = spirv.Constant 0 : i32
1079+
// CHECK: %[[CST3:.+]] = spirv.Constant 1 : i32
1080+
// CHECK: %[[S4:.+]] = spirv.AccessChain %[[S0]][%[[CST1]], %[[S2]]] : !spirv.ptr<!spirv.struct<(!spirv.array<4 x f32, stride=4> [0])>, StorageBuffer>, i32, i32 -> !spirv.ptr<f32, StorageBuffer>
1081+
// CHECK: spirv.Store "StorageBuffer" %[[S4]], %[[S1]] : f32
1082+
func.func @vector_store_single_elem(%arg0 : memref<4xf32, #spirv.storage_class<StorageBuffer>>, %arg1 : vector<1xf32>) {
1083+
%idx = arith.constant 0 : index
1084+
vector.store %arg1, %arg0[%idx] : memref<4xf32, #spirv.storage_class<StorageBuffer>>, vector<1xf32>
1085+
return
1086+
}
1087+
10491088
// CHECK-LABEL: @vector_store_2d
10501089
// CHECK-SAME: (%[[ARG0:.*]]: memref<4x4xf32, #spirv.storage_class<StorageBuffer>>
10511090
// CHECK-SAME: %[[ARG1:.*]]: vector<4xf32>

0 commit comments

Comments
 (0)