@@ -57,76 +57,41 @@ ET_CHECK_MSG(
57
57
ET_CHECK_MSG (
58
58
b.scalar_type () == torch_to_executorch_scalar_type (a.options ().dtype ()),
59
59
" dtypes dont match a %hhd vs. b %hhd" ,
60
- torch_to_executorch_scalar_type (a.options ().dtype ()),
61
- b.scalar_type ());
60
+ static_cast < int8_t >( torch_to_executorch_scalar_type (a.options ().dtype () )),
61
+ static_cast < int8_t >( b.scalar_type () ));
62
62
}
63
63
} // namespace
64
64
65
- torch::executor ::ScalarType torch_to_executorch_scalar_type (
65
+ executorch::runtime::etensor ::ScalarType torch_to_executorch_scalar_type (
66
66
caffe2::TypeMeta type) {
67
- switch (c10::typeMetaToScalarType (type)) {
68
- case c10::ScalarType::Byte:
69
- return torch::executor::ScalarType::Byte;
70
- case c10::ScalarType::Char:
71
- return torch::executor::ScalarType::Char;
72
- case c10::ScalarType::Short:
73
- return torch::executor::ScalarType::Short;
74
- case c10::ScalarType::Half:
75
- return torch::executor::ScalarType::Half;
76
- case c10::ScalarType::BFloat16:
77
- return torch::executor::ScalarType::BFloat16;
78
- case c10::ScalarType::Int:
79
- return torch::executor::ScalarType::Int;
80
- case c10::ScalarType::Float:
81
- return torch::executor::ScalarType::Float;
82
- case c10::ScalarType::Long:
83
- return torch::executor::ScalarType::Long;
84
- case c10::ScalarType::Double:
85
- return torch::executor::ScalarType::Double;
86
- case c10::ScalarType::Bool:
87
- return torch::executor::ScalarType::Bool;
88
- case c10::ScalarType::QInt8:
89
- return torch::executor::ScalarType::QInt8;
90
- case c10::ScalarType::QUInt8:
91
- return torch::executor::ScalarType::QUInt8;
92
- default :
93
- ET_ASSERT_UNREACHABLE_MSG (
94
- " Unrecognized dtype: %hhd" ,
95
- static_cast <int8_t >(c10::typeMetaToScalarType (type)));
96
- }
67
+ const auto intermediate =
68
+ static_cast <std::underlying_type<c10::ScalarType>::type>(
69
+ c10::typeMetaToScalarType (type));
70
+
71
+ ET_CHECK_MSG (
72
+ intermediate >= 0 &&
73
+ intermediate <= static_cast <std::underlying_type<
74
+ executorch::runtime::etensor::ScalarType>::type>(
75
+ executorch::runtime::etensor::ScalarType::UInt64),
76
+ " ScalarType %d unsupported in Executorch" ,
77
+ intermediate);
78
+ return static_cast <executorch::runtime::etensor::ScalarType>(intermediate);
97
79
}
98
80
99
81
c10::ScalarType executorch_to_torch_scalar_type (
100
82
torch::executor::ScalarType type) {
101
- switch (type) {
102
- case torch::executor::ScalarType::Byte:
103
- return c10::ScalarType::Byte;
104
- case torch::executor::ScalarType::Char:
105
- return c10::ScalarType::Char;
106
- case torch::executor::ScalarType::Short:
107
- return c10::ScalarType::Short;
108
- case torch::executor::ScalarType::Half:
109
- return c10::ScalarType::Half;
110
- case torch::executor::ScalarType::BFloat16:
111
- return c10::ScalarType::BFloat16;
112
- case torch::executor::ScalarType::Int:
113
- return c10::ScalarType::Int;
114
- case torch::executor::ScalarType::Float:
115
- return c10::ScalarType::Float;
116
- case torch::executor::ScalarType::Long:
117
- return c10::ScalarType::Long;
118
- case torch::executor::ScalarType::Double:
119
- return c10::ScalarType::Double;
120
- case torch::executor::ScalarType::Bool:
121
- return c10::ScalarType::Bool;
122
- case torch::executor::ScalarType::QInt8:
123
- return c10::ScalarType::QInt8;
124
- case torch::executor::ScalarType::QUInt8:
125
- return c10::ScalarType::QUInt8;
126
- default :
127
- ET_ASSERT_UNREACHABLE_MSG (
128
- " Unrecognized dtype: %hhd" , static_cast <int8_t >(type));
129
- }
83
+ const auto intermediate = static_cast <
84
+ std::underlying_type<executorch::runtime::etensor::ScalarType>::type>(
85
+ type);
86
+
87
+ ET_CHECK_MSG (
88
+ intermediate >= 0 &&
89
+ intermediate <= static_cast <std::underlying_type<
90
+ executorch::runtime::etensor::ScalarType>::type>(
91
+ executorch::runtime::etensor::ScalarType::UInt64),
92
+ " ScalarType %d unsupported in Executorch" ,
93
+ intermediate);
94
+ return static_cast <c10::ScalarType>(intermediate);
130
95
}
131
96
132
97
/*
0 commit comments