Skip to content

Commit bb9bb68

Browse files
authored
[mlir][spirv] Handle vectors of integers of unsupported width (#118663)
Fixes: #118612
1 parent 4639a9a commit bb9bb68

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,13 +292,19 @@ convertScalarType(const spirv::TargetEnv &targetEnv,
292292
}
293293

294294
/// Converts a sub-byte integer `type` to i32 regardless of target environment.
295+
/// Returns a nullptr for unsupported integer types, including non sub-byte
296+
/// types.
295297
///
296298
/// Note that we don't recognize sub-byte types in `spirv::ScalarType` and use
297299
/// the above given that these sub-byte types are not supported at all in
298300
/// SPIR-V; there are no compute/storage capability for them like other
299301
/// supported integer types.
300302
static Type convertSubByteIntegerType(const SPIRVConversionOptions &options,
301303
IntegerType type) {
304+
if (type.getWidth() > 8) {
305+
LLVM_DEBUG(llvm::dbgs() << "not a subbyte type\n");
306+
return nullptr;
307+
}
302308
if (options.subByteTypeStorage != SPIRVSubByteTypeStorage::Packed) {
303309
LLVM_DEBUG(llvm::dbgs() << "unsupported sub-byte storage kind\n");
304310
return nullptr;
@@ -348,6 +354,9 @@ convertVectorType(const spirv::TargetEnv &targetEnv,
348354
}
349355

350356
Type elementType = convertSubByteIntegerType(options, intType);
357+
if (!elementType)
358+
return nullptr;
359+
351360
if (type.getRank() <= 1 && type.getNumElements() == 1)
352361
return elementType;
353362

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,14 @@ func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
6060
return
6161
}
6262

63+
// -----
64+
65+
func.func @int_vector_invalid_bitwidth(%arg0: vector<2xi12>) {
66+
// expected-error @+1 {{failed to legalize operation 'arith.addi'}}
67+
%0 = arith.addi %arg0, %arg0: vector<2xi12>
68+
return
69+
}
70+
6371
///===----------------------------------------------------------------------===//
6472
// Constant ops
6573
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)