Skip to content

Commit 7ec95a9

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
expand dtype conversion support in aten_bridge (#6845)
Summary: Pull Request resolved: #6845 The dtype tables by necessity have to match exaclty so just casting to the int and then recasting to the other enum is safe Reviewed By: dulinriley Differential Revision: D65897501
1 parent e229d7c commit 7ec95a9

File tree

1 file changed

+27
-62
lines changed

1 file changed

+27
-62
lines changed

extension/aten_util/aten_bridge.cpp

Lines changed: 27 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -57,76 +57,41 @@ ET_CHECK_MSG(
5757
ET_CHECK_MSG(
5858
b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()),
5959
"dtypes dont match a %hhd vs. b %hhd",
60-
torch_to_executorch_scalar_type(a.options().dtype()),
61-
b.scalar_type());
60+
static_cast<int8_t>(torch_to_executorch_scalar_type(a.options().dtype())),
61+
static_cast<int8_t>(b.scalar_type()));
6262
}
6363
} // namespace
6464

65-
torch::executor::ScalarType torch_to_executorch_scalar_type(
65+
executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type(
6666
caffe2::TypeMeta type) {
67-
switch (c10::typeMetaToScalarType(type)) {
68-
case c10::ScalarType::Byte:
69-
return torch::executor::ScalarType::Byte;
70-
case c10::ScalarType::Char:
71-
return torch::executor::ScalarType::Char;
72-
case c10::ScalarType::Short:
73-
return torch::executor::ScalarType::Short;
74-
case c10::ScalarType::Half:
75-
return torch::executor::ScalarType::Half;
76-
case c10::ScalarType::BFloat16:
77-
return torch::executor::ScalarType::BFloat16;
78-
case c10::ScalarType::Int:
79-
return torch::executor::ScalarType::Int;
80-
case c10::ScalarType::Float:
81-
return torch::executor::ScalarType::Float;
82-
case c10::ScalarType::Long:
83-
return torch::executor::ScalarType::Long;
84-
case c10::ScalarType::Double:
85-
return torch::executor::ScalarType::Double;
86-
case c10::ScalarType::Bool:
87-
return torch::executor::ScalarType::Bool;
88-
case c10::ScalarType::QInt8:
89-
return torch::executor::ScalarType::QInt8;
90-
case c10::ScalarType::QUInt8:
91-
return torch::executor::ScalarType::QUInt8;
92-
default:
93-
ET_ASSERT_UNREACHABLE_MSG(
94-
"Unrecognized dtype: %hhd",
95-
static_cast<int8_t>(c10::typeMetaToScalarType(type)));
96-
}
67+
const auto intermediate =
68+
static_cast<std::underlying_type<c10::ScalarType>::type>(
69+
c10::typeMetaToScalarType(type));
70+
71+
ET_CHECK_MSG(
72+
intermediate >= 0 &&
73+
intermediate <= static_cast<std::underlying_type<
74+
executorch::runtime::etensor::ScalarType>::type>(
75+
executorch::runtime::etensor::ScalarType::UInt64),
76+
"ScalarType %d unsupported in Executorch",
77+
intermediate);
78+
return static_cast<executorch::runtime::etensor::ScalarType>(intermediate);
9779
}
9880

9981
c10::ScalarType executorch_to_torch_scalar_type(
10082
torch::executor::ScalarType type) {
101-
switch (type) {
102-
case torch::executor::ScalarType::Byte:
103-
return c10::ScalarType::Byte;
104-
case torch::executor::ScalarType::Char:
105-
return c10::ScalarType::Char;
106-
case torch::executor::ScalarType::Short:
107-
return c10::ScalarType::Short;
108-
case torch::executor::ScalarType::Half:
109-
return c10::ScalarType::Half;
110-
case torch::executor::ScalarType::BFloat16:
111-
return c10::ScalarType::BFloat16;
112-
case torch::executor::ScalarType::Int:
113-
return c10::ScalarType::Int;
114-
case torch::executor::ScalarType::Float:
115-
return c10::ScalarType::Float;
116-
case torch::executor::ScalarType::Long:
117-
return c10::ScalarType::Long;
118-
case torch::executor::ScalarType::Double:
119-
return c10::ScalarType::Double;
120-
case torch::executor::ScalarType::Bool:
121-
return c10::ScalarType::Bool;
122-
case torch::executor::ScalarType::QInt8:
123-
return c10::ScalarType::QInt8;
124-
case torch::executor::ScalarType::QUInt8:
125-
return c10::ScalarType::QUInt8;
126-
default:
127-
ET_ASSERT_UNREACHABLE_MSG(
128-
"Unrecognized dtype: %hhd", static_cast<int8_t>(type));
129-
}
83+
const auto intermediate = static_cast<
84+
std::underlying_type<executorch::runtime::etensor::ScalarType>::type>(
85+
type);
86+
87+
ET_CHECK_MSG(
88+
intermediate >= 0 &&
89+
intermediate <= static_cast<std::underlying_type<
90+
executorch::runtime::etensor::ScalarType>::type>(
91+
executorch::runtime::etensor::ScalarType::UInt64),
92+
"ScalarType %d unsupported in Executorch",
93+
intermediate);
94+
return static_cast<c10::ScalarType>(intermediate);
13095
}
13196

13297
/*

0 commit comments

Comments
 (0)