@@ -155,87 +155,84 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
155
155
156
156
#undef STORAGE_SPACE_MAP_LIST
157
157
158
- // TODO: This is a utility function that should probably be
159
- // exposed by the SPIR-V dialect. Keeping it local till the use case arises.
160
- static Optional<int64_t > getTypeNumBytes (Type t) {
161
- if (t.isa <spirv::ScalarType>()) {
162
- auto bitWidth = t.getIntOrFloatBitWidth ();
158
+ // TODO: This is a utility function that should probably be exposed by the
159
+ // SPIR-V dialect. Keeping it local till the use case arises.
160
+ static Optional<int64_t >
161
+ getTypeNumBytes (const SPIRVTypeConverter::Options &options, Type type) {
162
+ if (type.isa <spirv::ScalarType>()) {
163
+ auto bitWidth = type.getIntOrFloatBitWidth ();
163
164
// According to the SPIR-V spec:
164
165
// "There is no physical size or bit pattern defined for values with boolean
165
166
// type. If they are stored (in conjunction with OpVariable), they can only
166
167
// be used with logical addressing operations, not physical, and only with
167
168
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
168
169
// Private, Function, Input, and Output."
169
- if (bitWidth == 1 ) {
170
+ if (bitWidth == 1 )
170
171
return llvm::None;
171
- }
172
172
return bitWidth / 8 ;
173
173
}
174
174
175
- if (auto vecType = t .dyn_cast <VectorType>()) {
176
- auto elementSize = getTypeNumBytes (vecType.getElementType ());
175
+ if (auto vecType = type .dyn_cast <VectorType>()) {
176
+ auto elementSize = getTypeNumBytes (options, vecType.getElementType ());
177
177
if (!elementSize)
178
178
return llvm::None;
179
- return vecType.getNumElements () * * elementSize;
179
+ return vecType.getNumElements () * elementSize. getValue () ;
180
180
}
181
181
182
- if (auto memRefType = t .dyn_cast <MemRefType>()) {
182
+ if (auto memRefType = type .dyn_cast <MemRefType>()) {
183
183
// TODO: Layout should also be controlled by the ABI attributes. For now
184
184
// using the layout from MemRef.
185
185
int64_t offset;
186
186
SmallVector<int64_t , 4 > strides;
187
187
if (!memRefType.hasStaticShape () ||
188
- failed (getStridesAndOffset (memRefType, strides, offset))) {
188
+ failed (getStridesAndOffset (memRefType, strides, offset)))
189
189
return llvm::None;
190
- }
190
+
191
191
// To get the size of the memref object in memory, the total size is the
192
192
// max(stride * dimension-size) computed for all dimensions times the size
193
193
// of the element.
194
- auto elementSize = getTypeNumBytes (memRefType.getElementType ());
195
- if (!elementSize) {
194
+ auto elementSize = getTypeNumBytes (options, memRefType.getElementType ());
195
+ if (!elementSize)
196
196
return llvm::None;
197
- }
198
- if (memRefType.getRank () == 0 ) {
197
+
198
+ if (memRefType.getRank () == 0 )
199
199
return elementSize;
200
- }
200
+
201
201
auto dims = memRefType.getShape ();
202
202
if (llvm::is_contained (dims, ShapedType::kDynamicSize ) ||
203
203
offset == MemRefType::getDynamicStrideOrOffset () ||
204
- llvm::is_contained (strides, MemRefType::getDynamicStrideOrOffset ())) {
204
+ llvm::is_contained (strides, MemRefType::getDynamicStrideOrOffset ()))
205
205
return llvm::None;
206
- }
206
+
207
207
int64_t memrefSize = -1 ;
208
- for (auto shape : enumerate(dims)) {
208
+ for (auto shape : enumerate(dims))
209
209
memrefSize = std::max (memrefSize, shape.value () * strides[shape.index ()]);
210
- }
210
+
211
211
return (offset + memrefSize) * elementSize.getValue ();
212
212
}
213
213
214
- if (auto tensorType = t .dyn_cast <TensorType>()) {
215
- if (!tensorType.hasStaticShape ()) {
214
+ if (auto tensorType = type .dyn_cast <TensorType>()) {
215
+ if (!tensorType.hasStaticShape ())
216
216
return llvm::None;
217
- }
218
- auto elementSize = getTypeNumBytes (tensorType.getElementType ());
219
- if (!elementSize) {
217
+
218
+ auto elementSize = getTypeNumBytes (options, tensorType.getElementType ());
219
+ if (!elementSize)
220
220
return llvm::None;
221
- }
221
+
222
222
int64_t size = elementSize.getValue ();
223
- for (auto shape : tensorType.getShape ()) {
223
+ for (auto shape : tensorType.getShape ())
224
224
size *= shape;
225
- }
225
+
226
226
return size;
227
227
}
228
228
229
229
// TODO: Add size computation for other types.
230
230
return llvm::None;
231
231
}
232
232
233
- Optional<int64_t > SPIRVTypeConverter::getConvertedTypeNumBytes (Type t) {
234
- return getTypeNumBytes (t);
235
- }
236
-
237
233
// / Converts a scalar `type` to a suitable type under the given `targetEnv`.
238
234
static Type convertScalarType (const spirv::TargetEnv &targetEnv,
235
+ const SPIRVTypeConverter::Options &options,
239
236
spirv::ScalarType type,
240
237
Optional<spirv::StorageClass> storageClass = {}) {
241
238
// Get extension and capability requirements for the given type.
@@ -251,13 +248,9 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
251
248
252
249
// Otherwise we need to adjust the type, which really means adjusting the
253
250
// bitwidth given this is a scalar type.
254
- // TODO: We are unconditionally converting the bitwidth here,
255
- // this might be okay for non-interface types (i.e., types used in
256
- // Private/Function storage classes), but not for interface types (i.e.,
257
- // types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
258
- // This is because the later actually affects the ABI contract with the
259
- // runtime. So we may want to expose a control on SPIRVTypeConverter to fail
260
- // conversion if we cannot change there.
251
+
252
+ if (!options.emulateNon32BitScalarTypes )
253
+ return nullptr ;
261
254
262
255
if (auto floatType = type.dyn_cast <FloatType>()) {
263
256
LLVM_DEBUG (llvm::dbgs () << type << " converted to 32-bit for SPIR-V\n " );
@@ -272,6 +265,7 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
272
265
273
266
// / Converts a vector `type` to a suitable type under the given `targetEnv`.
274
267
static Type convertVectorType (const spirv::TargetEnv &targetEnv,
268
+ const SPIRVTypeConverter::Options &options,
275
269
VectorType type,
276
270
Optional<spirv::StorageClass> storageClass = {}) {
277
271
if (type.getRank () == 1 && type.getNumElements () == 1 )
@@ -296,19 +290,21 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
296
290
return type;
297
291
298
292
auto elementType = convertScalarType (
299
- targetEnv, type.getElementType ().cast <spirv::ScalarType>(), storageClass);
293
+ targetEnv, options, type.getElementType ().cast <spirv::ScalarType>(),
294
+ storageClass);
300
295
if (elementType)
301
296
return VectorType::get (type.getShape (), elementType);
302
297
return nullptr ;
303
298
}
304
299
305
300
// / Converts a tensor `type` to a suitable type under the given `targetEnv`.
306
301
// /
307
- // / Note that this is mainly for lowering constant tensors.In SPIR-V one can
302
+ // / Note that this is mainly for lowering constant tensors. In SPIR-V one can
308
303
// / create composite constants with OpConstantComposite to embed relative large
309
304
// / constant values and use OpCompositeExtract and OpCompositeInsert to
310
305
// / manipulate, like what we do for vectors.
311
306
static Type convertTensorType (const spirv::TargetEnv &targetEnv,
307
+ const SPIRVTypeConverter::Options &options,
312
308
TensorType type) {
313
309
// TODO: Handle dynamic shapes.
314
310
if (!type.hasStaticShape ()) {
@@ -324,19 +320,19 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
324
320
return nullptr ;
325
321
}
326
322
327
- Optional<int64_t > scalarSize = getTypeNumBytes (scalarType);
328
- Optional<int64_t > tensorSize = getTypeNumBytes (type);
323
+ Optional<int64_t > scalarSize = getTypeNumBytes (options, scalarType);
324
+ Optional<int64_t > tensorSize = getTypeNumBytes (options, type);
329
325
if (!scalarSize || !tensorSize) {
330
326
LLVM_DEBUG (llvm::dbgs ()
331
327
<< type << " illegal: cannot deduce element count\n " );
332
328
return nullptr ;
333
329
}
334
330
335
331
auto arrayElemCount = *tensorSize / *scalarSize;
336
- auto arrayElemType = convertScalarType (targetEnv, scalarType);
332
+ auto arrayElemType = convertScalarType (targetEnv, options, scalarType);
337
333
if (!arrayElemType)
338
334
return nullptr ;
339
- Optional<int64_t > arrayElemSize = getTypeNumBytes (arrayElemType);
335
+ Optional<int64_t > arrayElemSize = getTypeNumBytes (options, arrayElemType);
340
336
if (!arrayElemSize) {
341
337
LLVM_DEBUG (llvm::dbgs ()
342
338
<< type << " illegal: cannot deduce converted element size\n " );
@@ -347,6 +343,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
347
343
}
348
344
349
345
static Type convertMemrefType (const spirv::TargetEnv &targetEnv,
346
+ const SPIRVTypeConverter::Options &options,
350
347
MemRefType type) {
351
348
Optional<spirv::StorageClass> storageClass =
352
349
SPIRVTypeConverter::getStorageClassForMemorySpace (
@@ -360,9 +357,11 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
360
357
Type arrayElemType;
361
358
Type elementType = type.getElementType ();
362
359
if (auto vecType = elementType.dyn_cast <VectorType>()) {
363
- arrayElemType = convertVectorType (targetEnv, vecType, storageClass);
360
+ arrayElemType =
361
+ convertVectorType (targetEnv, options, vecType, storageClass);
364
362
} else if (auto scalarType = elementType.dyn_cast <spirv::ScalarType>()) {
365
- arrayElemType = convertScalarType (targetEnv, scalarType, storageClass);
363
+ arrayElemType =
364
+ convertScalarType (targetEnv, options, scalarType, storageClass);
366
365
} else {
367
366
LLVM_DEBUG (
368
367
llvm::dbgs ()
@@ -373,7 +372,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
373
372
if (!arrayElemType)
374
373
return nullptr ;
375
374
376
- Optional<int64_t > elementSize = getTypeNumBytes (elementType);
375
+ Optional<int64_t > elementSize = getTypeNumBytes (options, elementType);
377
376
if (!elementSize) {
378
377
LLVM_DEBUG (llvm::dbgs ()
379
378
<< type << " illegal: cannot deduce element size\n " );
@@ -387,7 +386,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
387
386
return spirv::PointerType::get (structType, *storageClass);
388
387
}
389
388
390
- Optional<int64_t > memrefSize = getTypeNumBytes (type);
389
+ Optional<int64_t > memrefSize = getTypeNumBytes (options, type);
391
390
if (!memrefSize) {
392
391
LLVM_DEBUG (llvm::dbgs ()
393
392
<< type << " illegal: cannot deduce element count\n " );
@@ -396,7 +395,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
396
395
397
396
auto arrayElemCount = *memrefSize / *elementSize;
398
397
399
- Optional<int64_t > arrayElemSize = getTypeNumBytes (arrayElemType);
398
+ Optional<int64_t > arrayElemSize = getTypeNumBytes (options, arrayElemType);
400
399
if (!arrayElemSize) {
401
400
LLVM_DEBUG (llvm::dbgs ()
402
401
<< type << " illegal: cannot deduce converted element size\n " );
@@ -414,8 +413,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
414
413
return spirv::PointerType::get (structType, *storageClass);
415
414
}
416
415
417
- SPIRVTypeConverter::SPIRVTypeConverter (spirv::TargetEnvAttr targetAttr)
418
- : targetEnv(targetAttr) {
416
+ SPIRVTypeConverter::SPIRVTypeConverter (spirv::TargetEnvAttr targetAttr,
417
+ Options options)
418
+ : targetEnv(targetAttr), options(options) {
419
419
// Add conversions. The order matters here: later ones will be tried earlier.
420
420
421
421
// Allow all SPIR-V dialect specific types. This assumes all builtin types
@@ -434,26 +434,26 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
434
434
435
435
addConversion ([this ](IntegerType intType) -> Optional<Type> {
436
436
if (auto scalarType = intType.dyn_cast <spirv::ScalarType>())
437
- return convertScalarType (targetEnv, scalarType);
437
+ return convertScalarType (this -> targetEnv , this -> options , scalarType);
438
438
return Type ();
439
439
});
440
440
441
441
addConversion ([this ](FloatType floatType) -> Optional<Type> {
442
442
if (auto scalarType = floatType.dyn_cast <spirv::ScalarType>())
443
- return convertScalarType (targetEnv, scalarType);
443
+ return convertScalarType (this -> targetEnv , this -> options , scalarType);
444
444
return Type ();
445
445
});
446
446
447
447
addConversion ([this ](VectorType vectorType) {
448
- return convertVectorType (targetEnv, vectorType);
448
+ return convertVectorType (this -> targetEnv , this -> options , vectorType);
449
449
});
450
450
451
451
addConversion ([this ](TensorType tensorType) {
452
- return convertTensorType (targetEnv, tensorType);
452
+ return convertTensorType (this -> targetEnv , this -> options , tensorType);
453
453
});
454
454
455
455
addConversion ([this ](MemRefType memRefType) {
456
- return convertMemrefType (targetEnv, memRefType);
456
+ return convertMemrefType (this -> targetEnv , this -> options , memRefType);
457
457
});
458
458
}
459
459
@@ -490,8 +490,11 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
490
490
}
491
491
492
492
Type resultType;
493
- if (fnType.getNumResults () == 1 )
493
+ if (fnType.getNumResults () == 1 ) {
494
494
resultType = getTypeConverter ()->convertType (fnType.getResult (0 ));
495
+ if (!resultType)
496
+ return failure ();
497
+ }
495
498
496
499
// Create the converted spv.func op.
497
500
auto newFuncOp = rewriter.create <spirv::FuncOp>(
0 commit comments