@@ -235,9 +235,9 @@ Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
235
235
}
236
236
237
237
// / Converts a scalar `type` to a suitable type under the given `targetEnv`.
238
- static Optional< Type>
239
- convertScalarType ( const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
240
- Optional<spirv::StorageClass> storageClass = {}) {
238
+ static Type convertScalarType ( const spirv::TargetEnv &targetEnv,
239
+ spirv::ScalarType type,
240
+ Optional<spirv::StorageClass> storageClass = {}) {
241
241
// Get extension and capability requirements for the given type.
242
242
SmallVector<ArrayRef<spirv::Extension>, 1 > extensions;
243
243
SmallVector<ArrayRef<spirv::Capability>, 2 > capabilities;
@@ -271,17 +271,17 @@ convertScalarType(const spirv::TargetEnv &targetEnv, spirv::ScalarType type,
271
271
}
272
272
273
273
// / Converts a vector `type` to a suitable type under the given `targetEnv`.
274
- static Optional< Type>
275
- convertVectorType ( const spirv::TargetEnv &targetEnv, VectorType type,
276
- Optional<spirv::StorageClass> storageClass = {}) {
274
+ static Type convertVectorType ( const spirv::TargetEnv &targetEnv,
275
+ VectorType type,
276
+ Optional<spirv::StorageClass> storageClass = {}) {
277
277
if (type.getRank () == 1 && type.getNumElements () == 1 )
278
278
return type.getElementType ();
279
279
280
280
if (!spirv::CompositeType::isValid (type)) {
281
281
// TODO: Vector types with more than four elements can be translated into
282
282
// array types.
283
283
LLVM_DEBUG (llvm::dbgs () << type << " illegal: > 4-element unimplemented\n " );
284
- return llvm::None ;
284
+ return nullptr ;
285
285
}
286
286
287
287
// Get extension and capability requirements for the given type.
@@ -298,8 +298,8 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
298
298
auto elementType = convertScalarType (
299
299
targetEnv, type.getElementType ().cast <spirv::ScalarType>(), storageClass);
300
300
if (elementType)
301
- return VectorType::get (type.getShape (), * elementType);
302
- return llvm::None ;
301
+ return VectorType::get (type.getShape (), elementType);
302
+ return nullptr ;
303
303
}
304
304
305
305
// / Converts a tensor `type` to a suitable type under the given `targetEnv`.
@@ -308,56 +308,56 @@ convertVectorType(const spirv::TargetEnv &targetEnv, VectorType type,
308
308
// / create composite constants with OpConstantComposite to embed relative large
309
309
// / constant values and use OpCompositeExtract and OpCompositeInsert to
310
310
// / manipulate, like what we do for vectors.
311
- static Optional< Type> convertTensorType (const spirv::TargetEnv &targetEnv,
312
- TensorType type) {
311
+ static Type convertTensorType (const spirv::TargetEnv &targetEnv,
312
+ TensorType type) {
313
313
// TODO: Handle dynamic shapes.
314
314
if (!type.hasStaticShape ()) {
315
315
LLVM_DEBUG (llvm::dbgs ()
316
316
<< type << " illegal: dynamic shape unimplemented\n " );
317
- return llvm::None ;
317
+ return nullptr ;
318
318
}
319
319
320
320
auto scalarType = type.getElementType ().dyn_cast <spirv::ScalarType>();
321
321
if (!scalarType) {
322
322
LLVM_DEBUG (llvm::dbgs ()
323
323
<< type << " illegal: cannot convert non-scalar element type\n " );
324
- return llvm::None ;
324
+ return nullptr ;
325
325
}
326
326
327
327
Optional<int64_t > scalarSize = getTypeNumBytes (scalarType);
328
328
Optional<int64_t > tensorSize = getTypeNumBytes (type);
329
329
if (!scalarSize || !tensorSize) {
330
330
LLVM_DEBUG (llvm::dbgs ()
331
331
<< type << " illegal: cannot deduce element count\n " );
332
- return llvm::None ;
332
+ return nullptr ;
333
333
}
334
334
335
335
auto arrayElemCount = *tensorSize / *scalarSize;
336
336
auto arrayElemType = convertScalarType (targetEnv, scalarType);
337
337
if (!arrayElemType)
338
- return llvm::None ;
339
- Optional<int64_t > arrayElemSize = getTypeNumBytes (* arrayElemType);
338
+ return nullptr ;
339
+ Optional<int64_t > arrayElemSize = getTypeNumBytes (arrayElemType);
340
340
if (!arrayElemSize) {
341
341
LLVM_DEBUG (llvm::dbgs ()
342
342
<< type << " illegal: cannot deduce converted element size\n " );
343
- return llvm::None ;
343
+ return nullptr ;
344
344
}
345
345
346
- return spirv::ArrayType::get (* arrayElemType, arrayElemCount, *arrayElemSize);
346
+ return spirv::ArrayType::get (arrayElemType, arrayElemCount, *arrayElemSize);
347
347
}
348
348
349
- static Optional< Type> convertMemrefType (const spirv::TargetEnv &targetEnv,
350
- MemRefType type) {
349
+ static Type convertMemrefType (const spirv::TargetEnv &targetEnv,
350
+ MemRefType type) {
351
351
Optional<spirv::StorageClass> storageClass =
352
352
SPIRVTypeConverter::getStorageClassForMemorySpace (
353
353
type.getMemorySpaceAsInt ());
354
354
if (!storageClass) {
355
355
LLVM_DEBUG (llvm::dbgs ()
356
356
<< type << " illegal: cannot convert memory space\n " );
357
- return llvm::None ;
357
+ return nullptr ;
358
358
}
359
359
360
- Optional< Type> arrayElemType;
360
+ Type arrayElemType;
361
361
Type elementType = type.getElementType ();
362
362
if (auto vecType = elementType.dyn_cast <VectorType>()) {
363
363
arrayElemType = convertVectorType (targetEnv, vecType, storageClass);
@@ -368,20 +368,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
368
368
llvm::dbgs ()
369
369
<< type
370
370
<< " unhandled: can only convert scalar or vector element type\n " );
371
- return llvm::None ;
371
+ return nullptr ;
372
372
}
373
373
if (!arrayElemType)
374
- return llvm::None ;
374
+ return nullptr ;
375
375
376
376
Optional<int64_t > elementSize = getTypeNumBytes (elementType);
377
377
if (!elementSize) {
378
378
LLVM_DEBUG (llvm::dbgs ()
379
379
<< type << " illegal: cannot deduce element size\n " );
380
- return llvm::None ;
380
+ return nullptr ;
381
381
}
382
382
383
383
if (!type.hasStaticShape ()) {
384
- auto arrayType = spirv::RuntimeArrayType::get (* arrayElemType, *elementSize);
384
+ auto arrayType = spirv::RuntimeArrayType::get (arrayElemType, *elementSize);
385
385
// Wrap in a struct to satisfy Vulkan interface requirements.
386
386
auto structType = spirv::StructType::get (arrayType, 0 );
387
387
return spirv::PointerType::get (structType, *storageClass);
@@ -391,20 +391,20 @@ static Optional<Type> convertMemrefType(const spirv::TargetEnv &targetEnv,
391
391
if (!memrefSize) {
392
392
LLVM_DEBUG (llvm::dbgs ()
393
393
<< type << " illegal: cannot deduce element count\n " );
394
- return llvm::None ;
394
+ return nullptr ;
395
395
}
396
396
397
397
auto arrayElemCount = *memrefSize / *elementSize;
398
398
399
- Optional<int64_t > arrayElemSize = getTypeNumBytes (* arrayElemType);
399
+ Optional<int64_t > arrayElemSize = getTypeNumBytes (arrayElemType);
400
400
if (!arrayElemSize) {
401
401
LLVM_DEBUG (llvm::dbgs ()
402
402
<< type << " illegal: cannot deduce converted element size\n " );
403
- return llvm::None ;
403
+ return nullptr ;
404
404
}
405
405
406
406
auto arrayType =
407
- spirv::ArrayType::get (* arrayElemType, arrayElemCount, *arrayElemSize);
407
+ spirv::ArrayType::get (arrayElemType, arrayElemCount, *arrayElemSize);
408
408
409
409
// Wrap in a struct to satisfy Vulkan interface requirements. Memrefs with
410
410
// workgroup storage class do not need the struct to be laid out explicitly.
@@ -418,9 +418,6 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
418
418
: targetEnv(targetAttr) {
419
419
// Add conversions. The order matters here: later ones will be tried earlier.
420
420
421
- // All other cases failed. Then we cannot convert this type.
422
- addConversion ([](Type type) { return llvm::None; });
423
-
424
421
// Allow all SPIR-V dialect specific types. This assumes all builtin types
425
422
// adopted in the SPIR-V dialect (i.e., IntegerType, FloatType, VectorType)
426
423
// were tried before.
@@ -438,13 +435,13 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
438
435
addConversion ([this ](IntegerType intType) -> Optional<Type> {
439
436
if (auto scalarType = intType.dyn_cast <spirv::ScalarType>())
440
437
return convertScalarType (targetEnv, scalarType);
441
- return llvm::None ;
438
+ return Type () ;
442
439
});
443
440
444
441
addConversion ([this ](FloatType floatType) -> Optional<Type> {
445
442
if (auto scalarType = floatType.dyn_cast <spirv::ScalarType>())
446
443
return convertScalarType (targetEnv, scalarType);
447
- return llvm::None ;
444
+ return Type () ;
448
445
});
449
446
450
447
addConversion ([this ](VectorType vectorType) {
0 commit comments