Skip to content

Commit c064545

Browse files
committed
[mlir][spirv] Do not truncate i/f64 -> i/f32 in SPIRVConversion
This truncation can be unexpected and break program behavior. Dedicated emulation passes should be used instead. Also rename pass options to "emulate-lt-32-bit-scalar-types". Fixes: #57917 Reviewed By: antiagainst Differential Revision: https://reviews.llvm.org/D137115
1 parent 9a456b7 commit c064545

File tree

10 files changed

+138
-222
lines changed

10 files changed

+138
-222
lines changed

mlir/include/mlir/Conversion/Passes.td

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def ConvertArithToSPIRV : Pass<"convert-arith-to-spirv"> {
118118
let constructor = "mlir::arith::createConvertArithToSPIRVPass()";
119119
let dependentDialects = ["spirv::SPIRVDialect"];
120120
let options = [
121-
Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
121+
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
122122
"bool", /*default=*/"true",
123-
"Emulate non-32-bit scalar types with 32-bit ones if "
124-
"missing native support">,
123+
"Emulate narrower scalar types with 32-bit ones if not supported by "
124+
"the target">,
125125
Option<"enableFastMath", "enable-fast-math",
126126
"bool", /*default=*/"false",
127127
"Enable fast math mode (assuming no NaN and infinity for floating "
@@ -259,10 +259,10 @@ def ConvertControlFlowToSPIRV : Pass<"convert-cf-to-spirv"> {
259259
let constructor = "mlir::createConvertControlFlowToSPIRVPass()";
260260
let dependentDialects = ["spirv::SPIRVDialect"];
261261
let options = [
262-
Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
262+
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
263263
"bool", /*default=*/"true",
264-
"Emulate non-32-bit scalar types with 32-bit ones if "
265-
"missing native support">
264+
"Emulate narrower scalar types with 32-bit ones if not supported by"
265+
" the target">
266266
];
267267
}
268268

@@ -320,10 +320,10 @@ def ConvertFuncToSPIRV : Pass<"convert-func-to-spirv"> {
320320
let constructor = "mlir::createConvertFuncToSPIRVPass()";
321321
let dependentDialects = ["spirv::SPIRVDialect"];
322322
let options = [
323-
Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
323+
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
324324
"bool", /*default=*/"true",
325-
"Emulate non-32-bit scalar types with 32-bit ones if "
326-
"missing native support">
325+
"Emulate narrower scalar types with 32-bit ones if not supported by"
326+
" the target">
327327
];
328328
}
329329

@@ -815,10 +815,10 @@ def ConvertTensorToSPIRV : Pass<"convert-tensor-to-spirv"> {
815815
let constructor = "mlir::createConvertTensorToSPIRVPass()";
816816
let dependentDialects = ["spirv::SPIRVDialect"];
817817
let options = [
818-
Option<"emulateNon32BitScalarTypes", "emulate-non-32-bit-scalar-types",
818+
Option<"emulateLT32BitScalarTypes", "emulate-lt-32-bit-scalar-types",
819819
"bool", /*default=*/"true",
820-
"Emulate non-32-bit scalar types with 32-bit ones if "
821-
"missing native support">
820+
"Emulate narrower scalar types with 32-bit ones if not supported by"
821+
" the target">
822822
];
823823
}
824824

mlir/include/mlir/Dialect/SPIRV/Transforms/SPIRVConversion.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,21 @@ struct SPIRVConversionOptions {
3030
/// The number of bits to store a boolean value.
3131
unsigned boolNumBits{8};
3232

33-
/// Whether to emulate non-32-bit scalar types with 32-bit scalar types if
34-
/// no native support.
33+
/// Whether to emulate narrower scalar types with 32-bit scalar types if not
34+
/// supported by the target.
3535
///
3636
/// Non-32-bit scalar types require special hardware support that may not
3737
/// exist on all GPUs. This is reflected in SPIR-V as that non-32-bit scalar
3838
/// types require special capabilities or extensions. This option controls
39-
/// whether to use 32-bit types to emulate, if a scalar type of a certain
40-
/// bitwidth is not supported in the target environment. This requires the
41-
/// runtime to also feed in data with a matched bitwidth and layout for
42-
/// interface types. The runtime can do that by inspecting the SPIR-V
43-
/// module.
39+
/// whether to use 32-bit types to emulate < 32-bits-wide scalars, if a scalar
40+
/// type of a certain bitwidth is not supported in the target environment.
41+
/// This requires the runtime to also feed in data with a matched bitwidth and
42+
/// layout for interface types. The runtime can do that by inspecting the
43+
/// SPIR-V module.
4444
///
4545
/// If the original scalar type has less than 32-bit, a multiple of its
4646
/// values will be packed into one 32-bit value to be memory efficient.
47-
bool emulateNon32BitScalarTypes{true};
47+
bool emulateLT32BitScalarTypes{true};
4848

4949
/// Use 64-bit integers to convert index types.
5050
bool use64bitIndex{false};

mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1031,7 +1031,7 @@ struct ConvertArithToSPIRVPass
10311031
auto target = SPIRVConversionTarget::get(targetAttr);
10321032

10331033
SPIRVConversionOptions options;
1034-
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
1034+
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
10351035
options.enableFastMathMode = this->enableFastMath;
10361036
SPIRVTypeConverter typeConverter(targetAttr, options);
10371037

mlir/lib/Conversion/ControlFlowToSPIRV/ControlFlowToSPIRVPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ void ConvertControlFlowToSPIRVPass::runOnOperation() {
4141
SPIRVConversionTarget::get(targetAttr);
4242

4343
SPIRVConversionOptions options;
44-
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
44+
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
4545
SPIRVTypeConverter typeConverter(targetAttr, options);
4646

4747
RewritePatternSet patterns(context);

mlir/lib/Conversion/FuncToSPIRV/FuncToSPIRVPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ void ConvertFuncToSPIRVPass::runOnOperation() {
4040
SPIRVConversionTarget::get(targetAttr);
4141

4242
SPIRVConversionOptions options;
43-
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
43+
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
4444
SPIRVTypeConverter typeConverter(targetAttr, options);
4545

4646
RewritePatternSet patterns(context);

mlir/lib/Conversion/TensorToSPIRV/TensorToSPIRVPass.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class ConvertTensorToSPIRVPass
3838
SPIRVConversionTarget::get(targetAttr);
3939

4040
SPIRVConversionOptions options;
41-
options.emulateNon32BitScalarTypes = this->emulateNon32BitScalarTypes;
41+
options.emulateLT32BitScalarTypes = this->emulateLT32BitScalarTypes;
4242
SPIRVTypeConverter typeConverter(targetAttr, options);
4343

4444
RewritePatternSet patterns(context);

mlir/lib/Dialect/SPIRV/Transforms/SPIRVConversion.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,9 +220,16 @@ static Type convertScalarType(const spirv::TargetEnv &targetEnv,
220220

221221
// Otherwise we need to adjust the type, which really means adjusting the
222222
// bitwidth given this is a scalar type.
223+
if (!options.emulateLT32BitScalarTypes)
224+
return nullptr;
223225

224-
if (!options.emulateNon32BitScalarTypes)
226+
// We only emulate narrower scalar types here and do not truncate results.
227+
if (type.getIntOrFloatBitWidth() > 32) {
228+
LLVM_DEBUG(llvm::dbgs()
229+
<< type
230+
<< " not converted to 32-bit for SPIR-V to avoid truncation\n");
225231
return nullptr;
232+
}
226233

227234
if (auto floatType = type.dyn_cast<FloatType>()) {
228235
LLVM_DEBUG(llvm::dbgs() << type << " converted to 32-bit for SPIR-V\n");

mlir/test/Conversion/ArithToSPIRV/arith-to-spirv-unsupported.mlir

Lines changed: 65 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -49,24 +49,84 @@ func.func @int_vector4_invalid(%arg0: vector<2xi16>) {
4949

5050
// -----
5151

52-
func.func @unsupported_constant_0() {
52+
func.func @unsupported_constant_i64_0() {
53+
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
54+
%0 = arith.constant 0 : i64
55+
return
56+
}
57+
58+
// -----
59+
60+
func.func @unsupported_constant_i64_1() {
5361
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
5462
%0 = arith.constant 4294967296 : i64 // 2^32
5563
return
5664
}
5765

5866
// -----
5967

60-
func.func @unsupported_constant_1() {
68+
func.func @unsupported_constant_vector_2xi64_0() {
69+
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
70+
%1 = arith.constant dense<0> : vector<2xi64>
71+
return
72+
}
73+
74+
// -----
75+
76+
func.func @unsupported_constant_f64_0() {
6177
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
62-
%1 = arith.constant -2147483649 : i64 // -2^31 - 1
78+
%1 = arith.constant 0.0 : f64
6379
return
6480
}
6581

6682
// -----
6783

68-
func.func @unsupported_constant_2() {
84+
func.func @unsupported_constant_vector_2xf64_0() {
6985
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
70-
%2 = arith.constant -2147483649 : i64 // -2^31 - 1
86+
%1 = arith.constant dense<0.0> : vector<2xf64>
7187
return
7288
}
89+
90+
// -----
91+
92+
func.func @unsupported_constant_tensor_2xf64_0() {
93+
// expected-error @+1 {{failed to legalize operation 'arith.constant'}}
94+
%1 = arith.constant dense<0.0> : tensor<2xf64>
95+
return
96+
}
97+
98+
///===----------------------------------------------------------------------===//
99+
// Type emulation
100+
//===----------------------------------------------------------------------===//
101+
102+
// -----
103+
104+
module attributes {
105+
spirv.target_env = #spirv.target_env<
106+
#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
107+
} {
108+
109+
// Check that we do not emualte i64 by truncating to i32.
110+
func.func @unsupported_i64(%arg0: i64) {
111+
// expected-error@+1 {{failed to legalize operation 'arith.addi'}}
112+
%2 = arith.addi %arg0, %arg0: i64
113+
return
114+
}
115+
116+
} // end module
117+
118+
// -----
119+
120+
module attributes {
121+
spirv.target_env = #spirv.target_env<
122+
#spirv.vce<v1.0, [], []>, #spirv.resource_limits<>>
123+
} {
124+
125+
// Check that we do not emualte f64 by truncating to i32.
126+
func.func @unsupported_f64(%arg0: f64) {
127+
// expected-error@+1 {{failed to legalize operation 'arith.addf'}}
128+
%2 = arith.addf %arg0, %arg0: f64
129+
return
130+
}
131+
132+
} // end module

0 commit comments

Comments
 (0)