Skip to content

Commit fce33e1

Browse files
committed
[mlir][spirv] Consider target when converting one-element vector
Vectors with just one element will be converted into scalars. However, we cannot just return the element types and assume it is supported in the target environment; we need to conver the element type again factoring in those considerations. Reviewed By: kuhar Differential Revision: https://reviews.llvm.org/D136226
1 parent 3ee58e2 commit fce33e1

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "mlir/Dialect/SPIRV/IR/SPIRVDialect.h"
1616
#include "mlir/Dialect/SPIRV/IR/SPIRVEnums.h"
1717
#include "mlir/Dialect/SPIRV/IR/SPIRVOps.h"
18+
#include "mlir/Dialect/SPIRV/IR/SPIRVTypes.h"
1819
#include "mlir/Transforms/DialectConversion.h"
1920
#include "llvm/ADT/Sequence.h"
2021
#include "llvm/ADT/StringExtras.h"
@@ -239,8 +240,9 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
239240
const SPIRVConversionOptions &options,
240241
VectorType type,
241242
Optional<spirv::StorageClass> storageClass = {}) {
243+
auto scalarType = type.getElementType().cast<spirv::ScalarType>();
242244
if (type.getRank() <= 1 && type.getNumElements() == 1)
243-
return type.getElementType();
245+
return convertScalarType(targetEnv, options, scalarType, storageClass);
244246

245247
if (!spirv::CompositeType::isValid(type)) {
246248
// TODO: Vector types with more than four elements can be translated into
@@ -260,9 +262,8 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
260262
succeeded(checkExtensionRequirements(type, targetEnv, extensions)))
261263
return type;
262264

263-
auto elementType = convertScalarType(
264-
targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
265-
storageClass);
265+
auto elementType =
266+
convertScalarType(targetEnv, options, scalarType, storageClass);
266267
if (elementType)
267268
return VectorType::get(type.getShape(), elementType);
268269
return nullptr;

mlir/test/Conversion/FuncToSPIRV/types-to-spirv.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,10 @@ func.func @float_vector(
207207
%arg1: vector<3xf64>
208208
) { return }
209209

210+
// CHECK-LABEL: spirv.func @one_element_vector
211+
// CHECK-SAME: %{{.+}}: i32
212+
func.func @one_element_vector(%arg0: vector<1xi8>) { return }
213+
210214
} // end module
211215

212216
// -----

0 commit comments

Comments
 (0)