@@ -57,73 +57,29 @@ ET_CHECK_MSG(
57
57
ET_CHECK_MSG (
58
58
b.scalar_type () == torch_to_executorch_scalar_type (a.options ().dtype ()),
59
59
" dtypes dont match a %hhd vs. b %hhd" ,
60
- torch_to_executorch_scalar_type (a.options ().dtype ()),
61
- b.scalar_type ());
60
+ static_cast < int8_t >( torch_to_executorch_scalar_type (a.options ().dtype () )),
61
+ static_cast < int8_t >( b.scalar_type () ));
62
62
}
63
63
} // namespace
64
64
65
- torch::executor ::ScalarType torch_to_executorch_scalar_type (
65
+ executorch::runtime::etensor ::ScalarType torch_to_executorch_scalar_type (
66
66
caffe2::TypeMeta type) {
67
- switch (c10::typeMetaToScalarType (type)) {
68
- case c10::ScalarType::Byte:
69
- return torch::executor::ScalarType::Byte;
70
- case c10::ScalarType::Char:
71
- return torch::executor::ScalarType::Char;
72
- case c10::ScalarType::Short:
73
- return torch::executor::ScalarType::Short;
74
- case c10::ScalarType::Half:
75
- return torch::executor::ScalarType::Half;
76
- case c10::ScalarType::BFloat16:
77
- return torch::executor::ScalarType::BFloat16;
78
- case c10::ScalarType::Int:
79
- return torch::executor::ScalarType::Int;
80
- case c10::ScalarType::Float:
81
- return torch::executor::ScalarType::Float;
82
- case c10::ScalarType::Long:
83
- return torch::executor::ScalarType::Long;
84
- case c10::ScalarType::Double:
85
- return torch::executor::ScalarType::Double;
86
- case c10::ScalarType::Bool:
87
- return torch::executor::ScalarType::Bool;
88
- case c10::ScalarType::QInt8:
89
- return torch::executor::ScalarType::QInt8;
90
- case c10::ScalarType::QUInt8:
91
- return torch::executor::ScalarType::QUInt8;
92
- default :
93
- ET_ASSERT_UNREACHABLE ();
94
- }
67
+ int8_t intermediate = static_cast <int8_t >(c10::typeMetaToScalarType (type));
68
+ // 29 is the latest scalartype entry ET added support for in scalar_type.h
69
+ ET_CHECK_MSG (
70
+ intermediate >= 0 && intermediate <= 29 ,
71
+ " ScalarType %d unsupported in Executorch" , intermediate);
72
+ return static_cast <executorch::runtime::etensor::ScalarType>(intermediate);
95
73
}
96
74
97
75
c10::ScalarType executorch_to_torch_scalar_type (
98
76
torch::executor::ScalarType type) {
99
- switch (type) {
100
- case torch::executor::ScalarType::Byte:
101
- return c10::ScalarType::Byte;
102
- case torch::executor::ScalarType::Char:
103
- return c10::ScalarType::Char;
104
- case torch::executor::ScalarType::Short:
105
- return c10::ScalarType::Short;
106
- case torch::executor::ScalarType::Half:
107
- return c10::ScalarType::Half;
108
- case torch::executor::ScalarType::BFloat16:
109
- return c10::ScalarType::BFloat16;
110
- case torch::executor::ScalarType::Int:
111
- return c10::ScalarType::Int;
112
- case torch::executor::ScalarType::Float:
113
- return c10::ScalarType::Float;
114
- case torch::executor::ScalarType::Long:
115
- return c10::ScalarType::Long;
116
- case torch::executor::ScalarType::Double:
117
- return c10::ScalarType::Double;
118
- case torch::executor::ScalarType::Bool:
119
- return c10::ScalarType::Bool;
120
- case torch::executor::ScalarType::QInt8:
121
- return c10::ScalarType::QInt8;
122
- case torch::executor::ScalarType::QUInt8:
123
- return c10::ScalarType::QUInt8;
124
- default :
125
- ET_ASSERT_UNREACHABLE ();
126
- }
77
+ int8_t intermediate = static_cast <int8_t >(type);
78
+ // 29 is the latest scalartype entry ET added support for in scalar_type.h
79
+ ET_CHECK_MSG (
80
+ intermediate >= 0 && intermediate <= 29 ,
81
+ " ScalarType %d unsupported in Executorch" , intermediate);
82
+ return static_cast <c10::ScalarType>(intermediate);
127
83
}
128
84
129
85
/*
0 commit comments