Skip to content

Commit 5299843

Browse files
committed
[mlir][spirv] Add control for non-32-bit scalar type emulation
Non-32-bit scalar types require special hardware support that may not exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types require special capabilities or extensions. Previously when there is a non-32-bit type and no native support, we unconditionally emulate it with 32-bit ones. This isn't good given that it can have implications over ABI and data layout consistency. This commit introduces an option to control whether to use 32-bit types to emulate. Differential Revision: https://reviews.llvm.org/D100059
1 parent 004f29c commit 5299843

File tree

5 files changed

+169
-79
lines changed

5 files changed

+169
-79
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,6 +425,12 @@ def ConvertStandardToSPIRV : Pass<"convert-std-to-spirv", "ModuleOp"> {
425425
let summary = "Convert Standard dialect to SPIR-V dialect";
426426
let constructor = "mlir::createConvertStandardToSPIRVPass()";
427427
let dependentDialects = ["spirv::SPIRVDialect"];
428+
let options = [
429+
Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
430+
"bool", /*default=*/"true",
431+
"Emulate non-32-bit scalar types with 32-bit ones if "
432+
"missing native support">
433+
];
428434
}
429435

430436
//===----------------------------------------------------------------------===//

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -27,29 +27,35 @@ namespace mlir {
2727

2828
/// Type conversion from builtin types to SPIR-V types for shader interface.
2929
///
30-
/// Non-32-bit scalar types require special hardware support that may not exist
31-
/// on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar types
32-
/// require special capabilities or extensions. Right now if a scalar type of a
33-
/// certain bitwidth is not supported in the target environment, we use 32-bit
34-
/// ones unconditionally. This requires the runtime to also feed in data with
35-
/// a matched bitwidth and layout for interface types. The runtime can do that
36-
/// by inspecting the SPIR-V module.
37-
///
3830
/// For memref types, this converter additionally performs type wrapping to
3931
/// satisfy shader interface requirements: shader interface types must be
4032
/// pointers to structs.
41-
///
42-
/// TODO: We might want to introduce a way to control how unsupported bitwidth
43-
/// are handled and explicitly fail if wanted.
4433
class SPIRVTypeConverter : public TypeConverter {
4534
public:
46-
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr);
47-
48-
/// Gets the number of bytes used for a type when converted to SPIR-V
49-
/// type. Note that it doesnt account for whether the type is legal for a
50-
/// SPIR-V target (described by spirv::TargetEnvAttr). Returns None on
51-
/// failure.
52-
static Optional<int64_t> getConvertedTypeNumBytes(Type);
35+
struct Options {
36+
/// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
37+
/// no native support.
38+
///
39+
/// Non-32-bit scalar types require special hardware support that may not
40+
/// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
41+
/// types require special capabilities or extensions. This option controls
42+
/// whether to use 32-bit types to emulate, if a scalar type of a certain
43+
/// bitwidth is not supported in the target environment. This requires the
44+
/// runtime to also feed in data with a matched bitwidth and layout for
45+
/// interface types. The runtime can do that by inspecting the SPIR-V
46+
/// module.
47+
///
48+
/// If the original scalar type has less than 32-bit, a multiple of its
49+
/// values will be packed into one 32-bit value to be memory efficient.
50+
bool emulateNon32BitScalarTypes;
51+
52+
// Note: we need this instead of inline initializers becuase of
53+
// https://bugs.llvm.org/show_bug.cgi?id=36684
54+
Options() : emulateNon32BitScalarTypes(true) {}
55+
};
56+
57+
explicit SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
58+
Options options = {});
5359

5460
/// Gets the SPIR-V correspondence for the standard index type.
5561
static Type getIndexType(MLIRContext *context);
@@ -63,8 +69,12 @@ class SPIRVTypeConverter : public TypeConverter {
6369
static Optional<spirv::StorageClass>
6470
getStorageClassForMemorySpace(unsigned space);
6571

72+
/// Returns the options controlling the SPIR-V type converter.
73+
const Options &getOptions() const;
74+
6675
private:
6776
spirv::TargetEnv targetEnv;
77+
Options options;
6878
};
6979

7080
//===----------------------------------------------------------------------===//

mlir/lib/Conversion/StandardToSPIRV/StandardToSPIRVPass.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,10 @@ void ConvertStandardToSPIRVPass::runOnOperation() {
3434
std::unique_ptr<ConversionTarget> target =
3535
SPIRVConversionTarget::get(targetAttr);
3636

37-
SPIRVTypeConverter typeConverter(targetAttr);
37+
SPIRVTypeConverter::Options options;
38+
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
39+
SPIRVTypeConverter typeConverter(targetAttr, options);
40+
3841
RewritePatternSet patterns(context);
3942
populateStandardToSPIRVPatterns(typeConverter, patterns);
4043
populateTensorToSPIRVPatterns(typeConverter,

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 63 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -155,87 +155,84 @@ SPIRVTypeConverter::getStorageClassForMemorySpace(unsigned space) {
155155

156156
#undef STORAGE_SPACE_MAP_LIST
157157

158-
// TODO: This is a utility function that should probably be
159-
// exposed by the SPIR-V dialect. Keeping it local till the use case arises.
160-
static Optional<int64_t> getTypeNumBytes(Type t) {
161-
if (t.isa<spirv::ScalarType>()) {
162-
auto bitWidth = t.getIntOrFloatBitWidth();
158+
// TODO: This is a utility function that should probably be exposed by the
159+
// SPIR-V dialect. Keeping it local till the use case arises.
160+
static Optional<int64_t>
161+
getTypeNumBytes(const SPIRVTypeConverter::Options &options, Type type) {
162+
if (type.isa<spirv::ScalarType>()) {
163+
auto bitWidth = type.getIntOrFloatBitWidth();
163164
// According to the SPIR-V spec:
164165
// "There is no physical size or bit pattern defined for values with boolean
165166
// type. If they are stored (in conjunction with OpVariable), they can only
166167
// be used with logical addressing operations, not physical, and only with
167168
// non-externally visible shader Storage Classes: Workgroup, CrossWorkgroup,
168169
// Private, Function, Input, and Output."
169-
if (bitWidth == 1) {
170+
if (bitWidth == 1)
170171
return llvm::None;
171-
}
172172
return bitWidth / 8;
173173
}
174174

175-
if (auto vecType = t.dyn_cast<VectorType>()) {
176-
auto elementSize = getTypeNumBytes(vecType.getElementType());
175+
if (auto vecType = type.dyn_cast<VectorType>()) {
176+
auto elementSize = getTypeNumBytes(options, vecType.getElementType());
177177
if (!elementSize)
178178
return llvm::None;
179-
return vecType.getNumElements() * *elementSize;
179+
return vecType.getNumElements() * elementSize.getValue();
180180
}
181181

182-
if (auto memRefType = t.dyn_cast<MemRefType>()) {
182+
if (auto memRefType = type.dyn_cast<MemRefType>()) {
183183
// TODO: Layout should also be controlled by the ABI attributes. For now
184184
// using the layout from MemRef.
185185
int64_t offset;
186186
SmallVector<int64_t, 4> strides;
187187
if (!memRefType.hasStaticShape() ||
188-
failed(getStridesAndOffset(memRefType, strides, offset))) {
188+
failed(getStridesAndOffset(memRefType, strides, offset)))
189189
return llvm::None;
190-
}
190+
191191
// To get the size of the memref object in memory, the total size is the
192192
// max(stride * dimension-size) computed for all dimensions times the size
193193
// of the element.
194-
auto elementSize = getTypeNumBytes(memRefType.getElementType());
195-
if (!elementSize) {
194+
auto elementSize = getTypeNumBytes(options, memRefType.getElementType());
195+
if (!elementSize)
196196
return llvm::None;
197-
}
198-
if (memRefType.getRank() == 0) {
197+
198+
if (memRefType.getRank() == 0)
199199
return elementSize;
200-
}
200+
201201
auto dims = memRefType.getShape();
202202
if (llvm::is_contained(dims, ShapedType::kDynamicSize) ||
203203
offset == MemRefType::getDynamicStrideOrOffset() ||
204-
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset())) {
204+
llvm::is_contained(strides, MemRefType::getDynamicStrideOrOffset()))
205205
return llvm::None;
206-
}
206+
207207
int64_t memrefSize = -1;
208-
for (auto shape : enumerate(dims)) {
208+
for (auto shape : enumerate(dims))
209209
memrefSize = std::max(memrefSize, shape.value() * strides[shape.index()]);
210-
}
210+
211211
return (offset + memrefSize) * elementSize.getValue();
212212
}
213213

214-
if (auto tensorType = t.dyn_cast<TensorType>()) {
215-
if (!tensorType.hasStaticShape()) {
214+
if (auto tensorType = type.dyn_cast<TensorType>()) {
215+
if (!tensorType.hasStaticShape())
216216
return llvm::None;
217-
}
218-
auto elementSize = getTypeNumBytes(tensorType.getElementType());
219-
if (!elementSize) {
217+
218+
auto elementSize = getTypeNumBytes(options, tensorType.getElementType());
219+
if (!elementSize)
220220
return llvm::None;
221-
}
221+
222222
int64_t size = elementSize.getValue();
223-
for (auto shape : tensorType.getShape()) {
223+
for (auto shape : tensorType.getShape())
224224
size *= shape;
225-
}
225+
226226
return size;
227227
}
228228

229229
// TODO: Add size computation for other types.
230230
return llvm::None;
231231
}
232232

233-
Optional<int64_t> SPIRVTypeConverter::getConvertedTypeNumBytes(Type t) {
234-
return getTypeNumBytes(t);
235-
}
236-
237233
/// Converts a scalar `type` to a suitable type under the given `targetEnv`.
238234
static Type convertScalarType(const spirv::TargetEnv &targetEnv,
235+
const SPIRVTypeConverter::Options &options,
239236
spirv::ScalarType type,
240237
Optional<spirv::StorageClass> storageClass = {}) {
241238
// Get extension and capability requirements for the given type.
@@ -251,13 +248,9 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
251248

252249
// Otherwise we need to adjust the type, which really means adjusting the
253250
// bitwidth given this is a scalar type.
254-
// TODO: We are unconditionally converting the bitwidth here,
255-
// this might be okay for non-interface types (i.e., types used in
256-
// Private/Function storage classes), but not for interface types (i.e.,
257-
// types used in StorageBuffer/Uniform/PushConstant/etc. storage classes).
258-
// This is because the later actually affects the ABI contract with the
259-
// runtime. So we may want to expose a control on SPIRVTypeConverter to fail
260-
// conversion if we cannot change there.
251+
252+
if (!options.emulateNon32BitScalarTypes)
253+
return nullptr;
261254

262255
if (auto floatType = type.dyn_cast<FloatType>()) {
263256
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");
@@ -272,6 +265,7 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
272265

273266
/// Converts a vector `type` to a suitable type under the given `targetEnv`.
274267
static Type convertVectorType(const spirv::TargetEnv &targetEnv,
268+
const SPIRVTypeConverter::Options &options,
275269
VectorType type,
276270
Optional<spirv::StorageClass> storageClass = {}) {
277271
if (type.getRank() == 1 && type.getNumElements() == 1)
@@ -296,19 +290,21 @@ static Type convertVectorType(const spirv::TargetEnv &targetEnv,
296290
return type;
297291

298292
auto elementType = convertScalarType(
299-
targetEnv, type.getElementType().cast<spirv::ScalarType>(), storageClass);
293+
targetEnv, options, type.getElementType().cast<spirv::ScalarType>(),
294+
storageClass);
300295
if (elementType)
301296
return VectorType::get(type.getShape(), elementType);
302297
return nullptr;
303298
}
304299

305300
/// Converts a tensor `type` to a suitable type under the given `targetEnv`.
306301
///
307-
/// Note that this is mainly for lowering constant tensors.In SPIR-V one can
302+
/// Note that this is mainly for lowering constant tensors. In SPIR-V one can
308303
/// create composite constants with OpConstantComposite to embed relative large
309304
/// constant values and use OpCompositeExtract and OpCompositeInsert to
310305
/// manipulate, like what we do for vectors.
311306
static Type convertTensorType(const spirv::TargetEnv &targetEnv,
307+
const SPIRVTypeConverter::Options &options,
312308
TensorType type) {
313309
// TODO: Handle dynamic shapes.
314310
if (!type.hasStaticShape()) {
@@ -324,19 +320,19 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
324320
return nullptr;
325321
}
326322

327-
Optional<int64_t> scalarSize = getTypeNumBytes(scalarType);
328-
Optional<int64_t> tensorSize = getTypeNumBytes(type);
323+
Optional<int64_t> scalarSize = getTypeNumBytes(options, scalarType);
324+
Optional<int64_t> tensorSize = getTypeNumBytes(options, type);
329325
if (!scalarSize || !tensorSize) {
330326
LLVM_DEBUG(llvm::dbgs()
331327
<< type << " illegal: cannot deduce element count\n");
332328
return nullptr;
333329
}
334330

335331
auto arrayElemCount = *tensorSize / *scalarSize;
336-
auto arrayElemType = convertScalarType(targetEnv, scalarType);
332+
auto arrayElemType = convertScalarType(targetEnv, options, scalarType);
337333
if (!arrayElemType)
338334
return nullptr;
339-
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
335+
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
340336
if (!arrayElemSize) {
341337
LLVM_DEBUG(llvm::dbgs()
342338
<< type << " illegal: cannot deduce converted element size\n");
@@ -347,6 +343,7 @@ static Type convertTensorType(const spirv::TargetEnv &targetEnv,
347343
}
348344

349345
static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
346+
const SPIRVTypeConverter::Options &options,
350347
MemRefType type) {
351348
Optional<spirv::StorageClass> storageClass =
352349
SPIRVTypeConverter::getStorageClassForMemorySpace(
@@ -360,9 +357,11 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
360357
Type arrayElemType;
361358
Type elementType = type.getElementType();
362359
if (auto vecType = elementType.dyn_cast<VectorType>()) {
363-
arrayElemType = convertVectorType(targetEnv, vecType, storageClass);
360+
arrayElemType =
361+
convertVectorType(targetEnv, options, vecType, storageClass);
364362
} else if (auto scalarType = elementType.dyn_cast<spirv::ScalarType>()) {
365-
arrayElemType = convertScalarType(targetEnv, scalarType, storageClass);
363+
arrayElemType =
364+
convertScalarType(targetEnv, options, scalarType, storageClass);
366365
} else {
367366
LLVM_DEBUG(
368367
llvm::dbgs()
@@ -373,7 +372,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
373372
if (!arrayElemType)
374373
return nullptr;
375374

376-
Optional<int64_t> elementSize = getTypeNumBytes(elementType);
375+
Optional<int64_t> elementSize = getTypeNumBytes(options, elementType);
377376
if (!elementSize) {
378377
LLVM_DEBUG(llvm::dbgs()
379378
<< type << " illegal: cannot deduce element size\n");
@@ -387,7 +386,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
387386
return spirv::PointerType::get(structType, *storageClass);
388387
}
389388

390-
Optional<int64_t> memrefSize = getTypeNumBytes(type);
389+
Optional<int64_t> memrefSize = getTypeNumBytes(options, type);
391390
if (!memrefSize) {
392391
LLVM_DEBUG(llvm::dbgs()
393392
<< type << " illegal: cannot deduce element count\n");
@@ -396,7 +395,7 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
396395

397396
auto arrayElemCount = *memrefSize / *elementSize;
398397

399-
Optional<int64_t> arrayElemSize = getTypeNumBytes(arrayElemType);
398+
Optional<int64_t> arrayElemSize = getTypeNumBytes(options, arrayElemType);
400399
if (!arrayElemSize) {
401400
LLVM_DEBUG(llvm::dbgs()
402401
<< type << " illegal: cannot deduce converted element size\n");
@@ -414,8 +413,9 @@ static Type convertMemrefType(const spirv::TargetEnv &targetEnv,
414413
return spirv::PointerType::get(structType, *storageClass);
415414
}
416415

417-
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
418-
: targetEnv(targetAttr) {
416+
SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr,
417+
Options options)
418+
: targetEnv(targetAttr), options(options) {
419419
// Add conversions. The order matters here: later ones will be tried earlier.
420420

421421
// Allow all SPIR-V dialect specific types. This assumes all builtin types
@@ -434,26 +434,26 @@ SPIRVTypeConverter::SPIRVTypeConverter(spirv::TargetEnvAttr targetAttr)
434434

435435
addConversion([this](IntegerType intType) -> Optional<Type> {
436436
if (auto scalarType = intType.dyn_cast<spirv::ScalarType>())
437-
return convertScalarType(targetEnv, scalarType);
437+
return convertScalarType(this->targetEnv, this->options, scalarType);
438438
return Type();
439439
});
440440

441441
addConversion([this](FloatType floatType) -> Optional<Type> {
442442
if (auto scalarType = floatType.dyn_cast<spirv::ScalarType>())
443-
return convertScalarType(targetEnv, scalarType);
443+
return convertScalarType(this->targetEnv, this->options, scalarType);
444444
return Type();
445445
});
446446

447447
addConversion([this](VectorType vectorType) {
448-
return convertVectorType(targetEnv, vectorType);
448+
return convertVectorType(this->targetEnv, this->options, vectorType);
449449
});
450450

451451
addConversion([this](TensorType tensorType) {
452-
return convertTensorType(targetEnv, tensorType);
452+
return convertTensorType(this->targetEnv, this->options, tensorType);
453453
});
454454

455455
addConversion([this](MemRefType memRefType) {
456-
return convertMemrefType(targetEnv, memRefType);
456+
return convertMemrefType(this->targetEnv, this->options, memRefType);
457457
});
458458
}
459459

@@ -490,8 +490,11 @@ FuncOpConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value> operands,
490490
}
491491

492492
Type resultType;
493-
if (fnType.getNumResults() == 1)
493+
if (fnType.getNumResults() == 1) {
494494
resultType = getTypeConverter()->convertType(fnType.getResult(0));
495+
if (!resultType)
496+
return failure();
497+
}
495498

496499
// Create the converted spv.func op.
497500
auto newFuncOp = rewriter.create<spirv::FuncOp>(

0 commit comments

Comments
 (0)