Skip to content

Commit f13050d

Browse files
JacobSzwejbkafacebook-github-bot
authored andcommitted
expand dtype conversion support in aten_bridge
Summary: The dtype tables by necessity have to match exaclty so just casting to the int and then recasting to the other enum is safe Differential Revision: D65897501
1 parent ecdc007 commit f13050d

File tree

1 file changed

+15
-59
lines changed

1 file changed

+15
-59
lines changed

extension/aten_util/aten_bridge.cpp

Lines changed: 15 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -57,73 +57,29 @@ ET_CHECK_MSG(
5757
ET_CHECK_MSG(
5858
b.scalar_type() == torch_to_executorch_scalar_type(a.options().dtype()),
5959
"dtypes dont match a %hhd vs. b %hhd",
60-
torch_to_executorch_scalar_type(a.options().dtype()),
61-
b.scalar_type());
60+
static_cast<int8_t>(torch_to_executorch_scalar_type(a.options().dtype())),
61+
static_cast<int8_t>(b.scalar_type()));
6262
}
6363
} // namespace
6464

65-
torch::executor::ScalarType torch_to_executorch_scalar_type(
65+
executorch::runtime::etensor::ScalarType torch_to_executorch_scalar_type(
6666
caffe2::TypeMeta type) {
67-
switch (c10::typeMetaToScalarType(type)) {
68-
case c10::ScalarType::Byte:
69-
return torch::executor::ScalarType::Byte;
70-
case c10::ScalarType::Char:
71-
return torch::executor::ScalarType::Char;
72-
case c10::ScalarType::Short:
73-
return torch::executor::ScalarType::Short;
74-
case c10::ScalarType::Half:
75-
return torch::executor::ScalarType::Half;
76-
case c10::ScalarType::BFloat16:
77-
return torch::executor::ScalarType::BFloat16;
78-
case c10::ScalarType::Int:
79-
return torch::executor::ScalarType::Int;
80-
case c10::ScalarType::Float:
81-
return torch::executor::ScalarType::Float;
82-
case c10::ScalarType::Long:
83-
return torch::executor::ScalarType::Long;
84-
case c10::ScalarType::Double:
85-
return torch::executor::ScalarType::Double;
86-
case c10::ScalarType::Bool:
87-
return torch::executor::ScalarType::Bool;
88-
case c10::ScalarType::QInt8:
89-
return torch::executor::ScalarType::QInt8;
90-
case c10::ScalarType::QUInt8:
91-
return torch::executor::ScalarType::QUInt8;
92-
default:
93-
ET_ASSERT_UNREACHABLE();
94-
}
67+
int8_t intermediate = static_cast<int8_t>(c10::typeMetaToScalarType(type));
68+
// 29 is the latest scalartype entry ET added support for in scalar_type.h
69+
ET_CHECK_MSG(
70+
intermediate >= 0 && intermediate <= 29,
71+
"ScalarType %d unsupported in Executorch", intermediate);
72+
return static_cast<executorch::runtime::etensor::ScalarType>(intermediate);
9573
}
9674

9775
c10::ScalarType executorch_to_torch_scalar_type(
9876
torch::executor::ScalarType type) {
99-
switch (type) {
100-
case torch::executor::ScalarType::Byte:
101-
return c10::ScalarType::Byte;
102-
case torch::executor::ScalarType::Char:
103-
return c10::ScalarType::Char;
104-
case torch::executor::ScalarType::Short:
105-
return c10::ScalarType::Short;
106-
case torch::executor::ScalarType::Half:
107-
return c10::ScalarType::Half;
108-
case torch::executor::ScalarType::BFloat16:
109-
return c10::ScalarType::BFloat16;
110-
case torch::executor::ScalarType::Int:
111-
return c10::ScalarType::Int;
112-
case torch::executor::ScalarType::Float:
113-
return c10::ScalarType::Float;
114-
case torch::executor::ScalarType::Long:
115-
return c10::ScalarType::Long;
116-
case torch::executor::ScalarType::Double:
117-
return c10::ScalarType::Double;
118-
case torch::executor::ScalarType::Bool:
119-
return c10::ScalarType::Bool;
120-
case torch::executor::ScalarType::QInt8:
121-
return c10::ScalarType::QInt8;
122-
case torch::executor::ScalarType::QUInt8:
123-
return c10::ScalarType::QUInt8;
124-
default:
125-
ET_ASSERT_UNREACHABLE();
126-
}
77+
int8_t intermediate = static_cast<int8_t>(type);
78+
// 29 is the latest scalartype entry ET added support for in scalar_type.h
79+
ET_CHECK_MSG(
80+
intermediate >= 0 && intermediate <= 29,
81+
"ScalarType %d unsupported in Executorch", intermediate);
82+
return static_cast<c10::ScalarType>(intermediate);
12783
}
12884

12985
/*

0 commit comments

Comments
 (0)