Skip to content

Commit 004f29c

Browse files
committed
[mlir][spirv] Timely fail type conversion
Per the TypeConverter API contract, returning `llvm:None` means other conversion rules should be tried. But we only have one rule per input type. So there is no need to try others and we can just directly fail, which should return `nullptr`. This avoids unnecessary checks. Differential Revision: https://reviews.llvm.org/D100058
1 parent 94a6fe4 commit 004f29c

File tree

1 file changed

+32
-35
lines changed

1 file changed

+32
-35
lines changed

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 32 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -235,9 +235,9 @@ Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
235235
}
236236

237237
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
238-
static Optional<Type>
239-
convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
240-
Optional<spirv::StorageClass> storageClass = {}) {
238+
static Type convertScalarType(const spirv::TargetEnv &targetEnv,
239+
spirv::ScalarType type,
240+
Optional<spirv::StorageClass> storageClass = {}) {
241241
// Get extension and capability requirements for the given type.
242242
SmallVector<ArrayRef<spirv::Extension>, 1> extensions;
243243
SmallVector<ArrayRef<spirv::Capability>, 2> capabilities;
@@ -271,17 +271,17 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
271271
}
272272

273273
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
274-
static Optional<Type>
275-
convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
276-
Optional<spirv::StorageClass> storageClass = {}) {
274+
static Type convertVectorType(const spirv::TargetEnv &targetEnv,
275+
VectorType type,
276+
Optional<spirv::StorageClass> storageClass = {}) {
277277
if (type.getRank() == 1 && type.getNumElements() == 1)
278278
return type.getElementType();
279279

280280
if (!spirv::CompositeType::isValid(type)) {
281281
// TODO: Vector types with more than four elements can be translated into
282282
// array types.
283283
LLVM_DEBUG(llvm::dbgs() << type << " illegal: > 4-element unimplemented\n");
284-
return llvm::None;
284+
return nullptr;
285285
}
286286

287287
// Get extension and capability requirements for the given type.
@@ -298,8 +298,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
298298
auto elementType = convertScalarType(
299299
targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
300300
if (elementType)
301-
return VectorType::get(type.getShape(), *elementType);
302-
return llvm::None;
301+
return VectorType::get(type.getShape(), elementType);
302+
return nullptr;
303303
}
304304

305305
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
@@ -308,56 +308,56 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
308308
/// create composite constants with OpConstantComposite to embed relative large
309309
/// constant values and use OpCompositeExtract and OpCompositeInsert to
310310
/// manipulate, like what we do for vectors.
311-
static Optional<Type> convertTensorType(const spirv::TargetEnv &targetEnv,
312-
TensorType type) {
311+
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
312+
TensorType type) {
313313
// TODO: Handle dynamic shapes.
314314
if (!type.hasStaticShape()) {
315315
LLVM_DEBUG(llvm::dbgs()
316316
<< type << " illegal: dynamic shape unimplemented\n");
317-
return llvm::None;
317+
return nullptr;
318318
}
319319

320320
auto scalarType = type.getElementType().dyn_cast<spirv::ScalarType>();
321321
if (!scalarType) {
322322
LLVM_DEBUG(llvm::dbgs()
323323
<< type << " illegal: cannot convert non-scalar element type\n");
324-
return llvm::None;
324+
return nullptr;
325325
}
326326

327327
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
328328
Optional<int64_t> tensorSize = getTypeNumBytes(type);
329329
if (!scalarSize || !tensorSize) {
330330
LLVM_DEBUG(llvm::dbgs()
331331
<< type << " illegal: cannot deduce element count\n");
332-
return llvm::None;
332+
return nullptr;
333333
}
334334

335335
auto arrayElemCount = *tensorSize / *scalarSize;
336336
auto arrayElemType = convertScalarType(targetEnv, scalarType);
337337
if (!arrayElemType)
338-
return llvm::None;
339-
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
338+
return nullptr;
339+
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
340340
if (!arrayElemSize) {
341341
LLVM_DEBUG(llvm::dbgs()
342342
<< type << " illegal: cannot deduce converted element size\n");
343-
return llvm::None;
343+
return nullptr;
344344
}
345345

346-
return spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
346+
return spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
347347
}
348348

349-
static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
350-
MemRefType type) {
349+
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
350+
MemRefType type) {
351351
Optional<spirv::StorageClass> storageClass =
352352
SPIRVTypeConverter::getStorageClassForMemorySpace(
353353
type.getMemorySpaceAsInt());
354354
if (!storageClass) {
355355
LLVM_DEBUG(llvm::dbgs()
356356
<< type << " illegal: cannot convert memory space\n");
357-
return llvm::None;
357+
return nullptr;
358358
}
359359

360-
Optional<Type> arrayElemType;
360+
Type arrayElemType;
361361
Type elementType = type.getElementType();
362362
if (auto vecType = elementType.dyn_cast<VectorType>()) {
363363
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
@@ -368,20 +368,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
368368
llvm::dbgs()
369369
<< type
370370
<< " unhandled: can only convert scalar or vector element type\n");
371-
return llvm::None;
371+
return nullptr;
372372
}
373373
if (!arrayElemType)
374-
return llvm::None;
374+
return nullptr;
375375

376376
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
377377
if (!elementSize) {
378378
LLVM_DEBUG(llvm::dbgs()
379379
<< type << " illegal: cannot deduce element size\n");
380-
return llvm::None;
380+
return nullptr;
381381
}
382382

383383
if (!type.hasStaticShape()) {
384-
auto arrayType = spirv::RuntimeArrayType::get(*arrayElemType, *elementSize);
384+
auto arrayType = spirv::RuntimeArrayType::get(arrayElemType, *elementSize);
385385
// Wrap in a struct to satisfy Vulkan interface requirements.
386386
auto structType = spirv::StructType::get(arrayType, 0);
387387
return spirv::PointerType::get(structType, *storageClass);
@@ -391,20 +391,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
391391
if (!memrefSize) {
392392
LLVM_DEBUG(llvm::dbgs()
393393
<< type << " illegal: cannot deduce element count\n");
394-
return llvm::None;
394+
return nullptr;
395395
}
396396

397397
auto arrayElemCount = *memrefSize / *elementSize;
398398

399-
Optional<int64_t> arrayElemSize = getTypeNumBytes(*arrayElemType);
399+
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
400400
if (!arrayElemSize) {
401401
LLVM_DEBUG(llvm::dbgs()
402402
<< type << " illegal: cannot deduce converted element size\n");
403-
return llvm::None;
403+
return nullptr;
404404
}
405405

406406
auto arrayType =
407-
spirv::ArrayType::get(*arrayElemType, arrayElemCount, *arrayElemSize);
407+
spirv::ArrayType::get(arrayElemType, arrayElemCount, *arrayElemSize);
408408

409409
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
410410
// workgroup storage class do not need the struct to be laid out explicitly.
@@ -418,9 +418,6 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
418418
: targetEnv(targetAttr) {
419419
// Add conversions. The order matters here: later ones will be tried earlier.
420420

421-
// All other cases failed. Then we cannot convert this type.
422-
addConversion([](Type type) { return llvm::None; });
423-
424421
// Allow all SPIR-V dialect specific types. This assumes all builtin types
425422
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
426423
// were tried before.
@@ -438,13 +435,13 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
438435
addConversion([this](IntegerType intType) -> Optional<Type> {
439436
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
440437
return convertScalarType(targetEnv, scalarType);
441-
return llvm::None;
438+
return Type();
442439
});
443440

444441
addConversion([this](FloatType floatType) -> Optional<Type> {
445442
if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
446443
return convertScalarType(targetEnv, scalarType);
447-
return llvm::None;
444+
return Type();
448445
});
449446

450447
addConversion([this](VectorType vectorType) {

0 commit comments

Comments
 (0)