Skip to content

Commit 21eecff

Browse files
add up to Uint64
Differential Revision: D65846964 Pull Request resolved: #6825
1 parent 6f63893 commit 21eecff

File tree

4 files changed

+133
-26
lines changed

4 files changed

+133
-26
lines changed

kernels/portable/cpu/test/scalar_utils_test.cpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,12 @@ struct promote_type_with_scalar_type_is_valid
1818
std::is_same<T2, torch::executor::internal::F8>::value) &&
1919
!std::is_same<T1, exec_aten::BFloat16>::value &&
2020
!torch::executor::is_qint_type<T1>::value &&
21-
!torch::executor::is_bits_type<T1>::value> {};
21+
!torch::executor::is_bits_type<T1>::value &&
22+
!executorch::runtime::is_bits_type<T2>::value &&
23+
!executorch::runtime::is_float8_type<T1>::value &&
24+
!executorch::runtime::is_float8_type<T2>::value &&
25+
!executorch::runtime::is_barebones_unsigned_type<T1>::value &&
26+
!executorch::runtime::is_barebones_unsigned_type<T2>::value> {};
2227

2328
template <typename T1, bool half_to_float>
2429
struct CompileTimePromoteTypeWithScalarTypeTestCase {

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,33 @@ struct is_qint_type
503503
: std::integral_constant<bool, isQIntType(CppTypeToScalarType<T>::value)> {
504504
};
505505

506+
constexpr bool isFloat8Type(::executorch::aten::ScalarType t) {
507+
// Don't forget to extend this when adding new QInt types
508+
return t == ::executorch::aten::ScalarType::Float8_e5m2 ||
509+
t == ::executorch::aten::ScalarType::Float8_e4m3fn ||
510+
t == ::executorch::aten::ScalarType::Float8_e5m2fnuz ||
511+
t == ::executorch::aten::ScalarType::Float8_e4m3fnuz;
512+
}
513+
514+
template <typename T>
515+
struct is_float8_type
516+
: std::
517+
integral_constant<bool, isFloat8Type(CppTypeToScalarType<T>::value)> {
518+
};
519+
520+
constexpr bool isBarebonesUnsignedType(::executorch::aten::ScalarType t) {
521+
// Don't forget to extend this when adding new QInt types
522+
return t == ::executorch::aten::ScalarType::UInt16 ||
523+
t == ::executorch::aten::ScalarType::UInt32 ||
524+
t == ::executorch::aten::ScalarType::UInt64;
525+
}
526+
527+
template <typename T>
528+
struct is_barebones_unsigned_type
529+
: std::integral_constant<
530+
bool,
531+
isBarebonesUnsignedType(CppTypeToScalarType<T>::value)> {};
532+
506533
inline ::executorch::aten::ScalarType toQIntType(
507534
::executorch::aten::ScalarType t) {
508535
switch (t) {
@@ -883,6 +910,15 @@ struct promote_types {
883910
std::is_same<T1, T2>::value ||
884911
(!is_bits_type<T1>::value && !is_bits_type<T2>::value),
885912
"promote_types not valid for bits dtypes");
913+
static_assert(
914+
std::is_same<T1, T2>::value ||
915+
(!is_float8_type<T1>::value && !is_float8_type<T2>::value),
916+
"promote_types not valid for float8 dtypes");
917+
static_assert(
918+
std::is_same<T1, T2>::value ||
919+
(!is_barebones_unsigned_type<T1>::value &&
920+
!is_barebones_unsigned_type<T2>::value),
921+
"promote_types not valid for barebones unsigned dtypes");
886922

887923
using promoted_type_not_respecting_half_to_float =
888924
typename internal::promote_types_lookup<T1, T2>::type;
@@ -945,6 +981,24 @@ inline ::executorch::aten::ScalarType promoteTypes(
945981
ET_CHECK_MSG(false, "promoteTypes not valid for bits dtypes");
946982
}
947983

984+
// For Float8 types, only allow exact match
985+
if (::executorch::runtime::isFloat8Type(a) && a == b) {
986+
return a;
987+
}
988+
if (::executorch::runtime::isFloat8Type(a) ||
989+
::executorch::runtime::isFloat8Type(b)) {
990+
ET_CHECK_MSG(false, "promoteTypes not valid for float8 dtypes");
991+
}
992+
993+
// For barebones uint types, only allow exact match
994+
if (::executorch::runtime::isBarebonesUnsignedType(a) && a == b) {
995+
return a;
996+
}
997+
if (::executorch::runtime::isBarebonesUnsignedType(a) ||
998+
::executorch::runtime::isBarebonesUnsignedType(b)) {
999+
ET_CHECK_MSG(false, "promoteTypes not valid for barebone unsigned dtypes");
1000+
}
1001+
9481002
// 12 types are handled by this function, see the constexpr definitions above
9491003
const int NUM_PROMOTE_TYPES = 13;
9501004

@@ -1433,8 +1487,10 @@ using ::executorch::runtime::canCast;
14331487
using ::executorch::runtime::convert;
14341488
using ::executorch::runtime::CppTypeToScalarType;
14351489
using ::executorch::runtime::elementSize;
1490+
using ::executorch::runtime::is_barebones_unsigned_type;
14361491
using ::executorch::runtime::is_bits_type;
14371492
using ::executorch::runtime::is_complex_type;
1493+
using ::executorch::runtime::is_float8_type;
14381494
using ::executorch::runtime::is_integral_type;
14391495
using ::executorch::runtime::is_qint_type;
14401496
using ::executorch::runtime::isBitsType;

runtime/core/exec_aten/util/test/scalar_type_util_test.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,11 @@ struct promote_types_is_valid
170170
(!executorch::runtime::is_qint_type<T1>::value &&
171171
!executorch::runtime::is_qint_type<T2>::value &&
172172
!executorch::runtime::is_bits_type<T1>::value &&
173-
!executorch::runtime::is_bits_type<T2>::value))> {};
173+
!executorch::runtime::is_bits_type<T2>::value &&
174+
!executorch::runtime::is_float8_type<T1>::value &&
175+
!executorch::runtime::is_float8_type<T2>::value &&
176+
!executorch::runtime::is_barebones_unsigned_type<T1>::value &&
177+
!executorch::runtime::is_barebones_unsigned_type<T2>::value))> {};
174178

175179
template <typename T1, bool half_to_float>
176180
struct CompileTimePromoteTypesTestCase {

runtime/core/portable_type/scalar_type.h

Lines changed: 66 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,36 @@ namespace executorch {
4747
namespace runtime {
4848
namespace etensor {
4949

50+
// Placing a bunch of unused dtypes here as our macros don't make it easy
51+
// to skip scalar types defined in aten that we dont have.
52+
namespace unused_dtype {
53+
struct alignas(1) Float8_e5m2 {
54+
uint8_t x;
55+
using underlying = uint8_t;
56+
Float8_e5m2() = default;
57+
explicit Float8_e5m2(uint8_t val) : x(val) {}
58+
};
59+
struct alignas(1) Float8_e4m3fn {
60+
uint8_t x;
61+
using underlying = uint8_t;
62+
Float8_e4m3fn() = default;
63+
explicit Float8_e4m3fn(uint8_t val) : x(val) {}
64+
};
65+
struct alignas(1) Float8_e5m2fnuz {
66+
uint8_t x;
67+
using underlying = uint8_t;
68+
Float8_e5m2fnuz() = default;
69+
explicit Float8_e5m2fnuz(uint8_t val) : x(val) {}
70+
};
71+
struct alignas(1) Float8_e4m3fnuz {
72+
uint8_t x;
73+
using underlying = uint8_t;
74+
Float8_e4m3fnuz() = default;
75+
explicit Float8_e4m3fnuz(uint8_t val) : x(val) {}
76+
};
77+
78+
} // namespace unused_dtype
79+
5080
/**
5181
* Calls the provided macro on every ScalarType, providing the C type and the
5282
* ScalarType name to each call.
@@ -59,30 +89,42 @@ namespace etensor {
5989
* @param _ A macro that takes two parameters: the name of a C type, and the
6090
* name of the corresponding ScalarType enumerator.
6191
*/
62-
#define ET_FORALL_SCALAR_TYPES(_) \
63-
_(uint8_t, Byte) /* 0 */ \
64-
_(int8_t, Char) /* 1 */ \
65-
_(int16_t, Short) /* 2 */ \
66-
_(int32_t, Int) /* 3 */ \
67-
_(int64_t, Long) /* 4 */ \
68-
_(::torch::executor::Half, Half) /* 5 */ \
69-
_(float, Float) /* 6 */ \
70-
_(double, Double) /* 7 */ \
71-
_(::torch::executor::complex<::torch::executor::Half>, ComplexHalf) /* 8 */ \
72-
_(::torch::executor::complex<float>, ComplexFloat) /* 9 */ \
73-
_(::torch::executor::complex<double>, ComplexDouble) /* 10 */ \
74-
_(bool, Bool) /* 11 */ \
75-
_(::torch::executor::qint8, QInt8) /* 12 */ \
76-
_(::torch::executor::quint8, QUInt8) /* 13 */ \
77-
_(::torch::executor::qint32, QInt32) /* 14 */ \
78-
_(::torch::executor::BFloat16, BFloat16) /* 15 */ \
79-
_(::torch::executor::quint4x2, QUInt4x2) /* 16 */ \
80-
_(::torch::executor::quint2x4, QUInt2x4) /* 17 */ \
81-
_(::torch::executor::bits1x8, Bits1x8) /* 18 */ \
82-
_(::torch::executor::bits2x4, Bits2x4) /* 19 */ \
83-
_(::torch::executor::bits4x2, Bits4x2) /* 20 */ \
84-
_(::torch::executor::bits8, Bits8) /* 21 */ \
85-
_(::torch::executor::bits16, Bits16) /* 22 */
92+
#define ET_FORALL_SCALAR_TYPES(_) \
93+
_(uint8_t, Byte) /* 0 */ \
94+
_(int8_t, Char) /* 1 */ \
95+
_(int16_t, Short) /* 2 */ \
96+
_(int32_t, Int) /* 3 */ \
97+
_(int64_t, Long) /* 4 */ \
98+
_(::executorch::runtime::etensor::Half, Half) /* 5 */ \
99+
_(float, Float) /* 6 */ \
100+
_(double, Double) /* 7 */ \
101+
_(::executorch::runtime::etensor::complex<::torch::executor::Half>, \
102+
ComplexHalf) /* 8 */ \
103+
_(::executorch::runtime::etensor::complex<float>, ComplexFloat) /* 9 */ \
104+
_(::executorch::runtime::etensor::complex<double>, ComplexDouble) /* 10 */ \
105+
_(bool, Bool) /* 11 */ \
106+
_(::executorch::runtime::etensor::qint8, QInt8) /* 12 */ \
107+
_(::executorch::runtime::etensor::quint8, QUInt8) /* 13 */ \
108+
_(::executorch::runtime::etensor::qint32, QInt32) /* 14 */ \
109+
_(::executorch::runtime::etensor::BFloat16, BFloat16) /* 15 */ \
110+
_(::executorch::runtime::etensor::quint4x2, QUInt4x2) /* 16 */ \
111+
_(::executorch::runtime::etensor::quint2x4, QUInt2x4) /* 17 */ \
112+
_(::executorch::runtime::etensor::bits1x8, Bits1x8) /* 18 */ \
113+
_(::executorch::runtime::etensor::bits2x4, Bits2x4) /* 19 */ \
114+
_(::executorch::runtime::etensor::bits4x2, Bits4x2) /* 20 */ \
115+
_(::executorch::runtime::etensor::bits8, Bits8) /* 21 */ \
116+
_(::executorch::runtime::etensor::bits16, Bits16) /* 22 */ \
117+
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2, \
118+
Float8_e5m2) /* 23 */ \
119+
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fn, \
120+
Float8_e4m3fn) /* 24 */ \
121+
_(::executorch::runtime::etensor::unused_dtype::Float8_e5m2fnuz, \
122+
Float8_e5m2fnuz) /* 25 */ \
123+
_(::executorch::runtime::etensor::unused_dtype::Float8_e4m3fnuz, \
124+
Float8_e4m3fnuz) /* 26 */ \
125+
_(uint16_t, UInt16) /* 27 */ \
126+
_(uint32_t, UInt32) /* 28 */ \
127+
_(uint64_t, UInt64) /* 29 */
86128

87129
/**
88130
* Data types (dtypes) that can be used as element types in ETensors.

0 commit comments

Comments
 (0)