Skip to content

Commit b308744

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Move to dtype_utils (#6016)
Summary: Pull Request resolved: #6016 ghstack-source-id: 246985129 exported-using-ghexport Reviewed By: swolchok Differential Revision: D63994875 fbshipit-source-id: 54cab1c27e7a18271426a9bc26f2ee3df57b74ec
1 parent 5937f4a commit b308744

File tree

5 files changed

+334
-275
lines changed

5 files changed

+334
-275
lines changed

kernels/portable/cpu/util/elementwise_util.cpp renamed to kernels/portable/cpu/util/dtype_util.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
* LICENSE file in the root directory of this source tree.
77
*/
88

9-
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
9+
#include <executorch/kernels/portable/cpu/util/dtype_util.h>
1010

1111
namespace torch {
1212
namespace executor {
Lines changed: 292 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,292 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#include <executorch/runtime/kernel/kernel_includes.h>
12+
13+
namespace torch {
14+
namespace executor {
15+
namespace native {
16+
namespace utils {
17+
namespace internal {
18+
19+
template <typename To, typename From>
20+
To load_and_convert(const void* fromPtr) {
21+
return static_cast<To>(*reinterpret_cast<const From*>(fromPtr));
22+
}
23+
24+
template <typename To, typename From>
25+
void convert_and_store(From f, void* dst) {
26+
*reinterpret_cast<To*>(dst) = static_cast<To>(f);
27+
}
28+
29+
template <typename CTYPE_COMMON>
30+
using load_to_common_fn = CTYPE_COMMON (*)(const void*);
31+
32+
template <typename CTYPE_COMMON, const char* op_name>
33+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbbf16(
34+
const Tensor& t) {
35+
CTYPE_COMMON (*result)(const void*) = nullptr;
36+
ET_SWITCH_REALHBBF16_TYPES(
37+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
38+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
39+
});
40+
return result;
41+
}
42+
43+
template <typename CTYPE_COMMON, const char* op_name>
44+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_realhbf16(
45+
const Tensor& t) {
46+
CTYPE_COMMON (*result)(const void*) = nullptr;
47+
ET_SWITCH_REALHBF16_TYPES(
48+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
49+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
50+
});
51+
return result;
52+
}
53+
54+
template <typename CTYPE_COMMON, const char* op_name>
55+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_floathbf16(
56+
const Tensor& t) {
57+
CTYPE_COMMON (*result)(const void*) = nullptr;
58+
ET_SWITCH_FLOATHBF16_TYPES(
59+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
60+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
61+
});
62+
return result;
63+
}
64+
65+
template <typename CTYPE_COMMON, const char* op_name>
66+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_intb(const Tensor& t) {
67+
CTYPE_COMMON (*result)(const void*) = nullptr;
68+
ET_SWITCH_INT_TYPES_AND(
69+
Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
70+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
71+
});
72+
return result;
73+
}
74+
75+
template <typename CTYPE_COMMON, const char* op_name>
76+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_bool_or_byte(
77+
const Tensor& t) {
78+
CTYPE_COMMON (*result)(const void*) = nullptr;
79+
ET_SWITCH_TWO_TYPES(
80+
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
81+
result = internal::load_and_convert<CTYPE_COMMON, TENSOR_CTYPE>;
82+
});
83+
return result;
84+
}
85+
86+
template <typename CTYPE_COMMON, const char* op_name>
87+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_compute(
88+
const Tensor& t) {
89+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
90+
ET_CHECK_MSG(
91+
t.scalar_type() == common_scalar_type,
92+
"Unhandled dtype %s for %s",
93+
::executorch::runtime::toString(common_scalar_type),
94+
op_name);
95+
return internal::load_and_convert<CTYPE_COMMON, CTYPE_COMMON>;
96+
}
97+
98+
template <
99+
typename CTYPE_COMMON,
100+
const char* op_name,
101+
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
102+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
103+
const Tensor& t) {
104+
CTYPE_COMMON (*result)(const void*) = nullptr;
105+
ET_SWITCH_THREE_TYPES(
106+
Float, Half, BFloat16, t.scalar_type(), unused, op_name, T, [&]() {
107+
result = internal::load_and_convert<CTYPE_COMMON, T>;
108+
});
109+
return result;
110+
}
111+
112+
template <
113+
typename CTYPE_COMMON,
114+
const char* op_name,
115+
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
116+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn_same_as_common(
117+
const Tensor& t) {
118+
return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
119+
}
120+
121+
template <typename CTYPE_COMMON>
122+
using store_common_to_tensor_fn = void (*)(CTYPE_COMMON, void*);
123+
124+
template <typename CTYPE_COMMON, const char* op_name>
125+
store_common_to_tensor_fn<CTYPE_COMMON>
126+
get_store_common_to_tensor_fn_realhbbf16(const Tensor& t) {
127+
void (*result)(CTYPE_COMMON, void*) = nullptr;
128+
ET_SWITCH_REALHBBF16_TYPES(
129+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
130+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
131+
});
132+
return result;
133+
}
134+
135+
template <typename CTYPE_COMMON, const char* op_name>
136+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_realhbf16(
137+
const Tensor& t) {
138+
void (*result)(CTYPE_COMMON, void*) = nullptr;
139+
ET_SWITCH_REALHBF16_TYPES(
140+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
141+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
142+
});
143+
return result;
144+
}
145+
146+
template <typename CTYPE_COMMON, const char* op_name>
147+
store_common_to_tensor_fn<CTYPE_COMMON>
148+
get_store_common_to_tensor_fn_floathbf16(const Tensor& t) {
149+
void (*result)(CTYPE_COMMON, void*) = nullptr;
150+
ET_SWITCH_FLOATHBF16_TYPES(
151+
t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
152+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
153+
});
154+
return result;
155+
}
156+
157+
template <typename CTYPE_COMMON, const char* op_name>
158+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn_intb(
159+
const Tensor& t) {
160+
void (*result)(CTYPE_COMMON, void*) = nullptr;
161+
ET_SWITCH_INT_TYPES_AND(
162+
Bool, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
163+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
164+
});
165+
return result;
166+
}
167+
168+
template <typename CTYPE_COMMON, const char* op_name>
169+
store_common_to_tensor_fn<CTYPE_COMMON>
170+
get_store_common_to_tensor_fn_bool_or_byte(const Tensor& t) {
171+
void (*result)(CTYPE_COMMON, void*) = nullptr;
172+
ET_SWITCH_TWO_TYPES(
173+
Bool, Byte, t.scalar_type(), unused, op_name, TENSOR_CTYPE, [&]() {
174+
result = internal::convert_and_store<TENSOR_CTYPE, CTYPE_COMMON>;
175+
});
176+
return result;
177+
}
178+
179+
template <typename CTYPE_COMMON, const char* op_name>
180+
store_common_to_tensor_fn<CTYPE_COMMON>
181+
get_store_common_to_tensor_fn_same_as_compute(const Tensor& t) {
182+
constexpr auto common_scalar_type = CppTypeToScalarType<CTYPE_COMMON>::value;
183+
ET_CHECK_MSG(
184+
t.scalar_type() == common_scalar_type,
185+
"Unhandled dtype %s for %s",
186+
::executorch::runtime::toString(common_scalar_type),
187+
op_name);
188+
return internal::convert_and_store<CTYPE_COMMON, CTYPE_COMMON>;
189+
}
190+
191+
template <
192+
typename CTYPE_COMMON,
193+
const char* op_name,
194+
std::enable_if_t<std::is_same_v<CTYPE_COMMON, float>, bool> = true>
195+
store_common_to_tensor_fn<CTYPE_COMMON>
196+
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
197+
void (*result)(CTYPE_COMMON, void*) = nullptr;
198+
ET_SWITCH_THREE_TYPES(
199+
Float, Half, BFloat16, t.scalar_type(), unused, op_name, CTYPE, [&]() {
200+
result = internal::convert_and_store<CTYPE, CTYPE_COMMON>;
201+
});
202+
return result;
203+
}
204+
205+
template <
206+
typename CTYPE_COMMON,
207+
const char* op_name,
208+
std::enable_if_t<!std::is_same_v<CTYPE_COMMON, float>, bool> = true>
209+
store_common_to_tensor_fn<CTYPE_COMMON>
210+
get_store_common_to_tensor_fn_same_as_common(const Tensor& t) {
211+
return get_store_common_to_tensor_fn_same_as_compute<CTYPE_COMMON, op_name>(
212+
t);
213+
}
214+
215+
} // namespace internal
216+
217+
enum class SupportedTensorDtypes {
218+
REALHBBF16,
219+
REALHBF16,
220+
FLOATHBF16,
221+
INTB,
222+
BOOL_OR_BYTE,
223+
SAME_AS_COMPUTE,
224+
SAME_AS_COMMON,
225+
};
226+
227+
namespace internal {
228+
229+
template <typename CTYPE_COMMON, const char* op_name>
230+
load_to_common_fn<CTYPE_COMMON> get_load_to_common_fn(
231+
const Tensor& t,
232+
SupportedTensorDtypes dtypes) {
233+
switch (dtypes) {
234+
case SupportedTensorDtypes::REALHBBF16:
235+
return get_load_to_common_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
236+
case SupportedTensorDtypes::REALHBF16:
237+
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
238+
case SupportedTensorDtypes::FLOATHBF16:
239+
return get_load_to_common_fn_realhbf16<CTYPE_COMMON, op_name>(t);
240+
case SupportedTensorDtypes::INTB:
241+
return get_load_to_common_fn_intb<CTYPE_COMMON, op_name>(t);
242+
case SupportedTensorDtypes::BOOL_OR_BYTE:
243+
return get_load_to_common_fn_bool_or_byte<CTYPE_COMMON, op_name>(t);
244+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
245+
return get_load_to_common_fn_same_as_compute<CTYPE_COMMON, op_name>(t);
246+
case SupportedTensorDtypes::SAME_AS_COMMON:
247+
return get_load_to_common_fn_same_as_common<CTYPE_COMMON, op_name>(t);
248+
}
249+
ET_CHECK(false);
250+
return nullptr;
251+
}
252+
253+
template <typename CTYPE_COMMON, const char* op_name>
254+
store_common_to_tensor_fn<CTYPE_COMMON> get_store_common_to_tensor_fn(
255+
const Tensor& t,
256+
SupportedTensorDtypes dtypes) {
257+
switch (dtypes) {
258+
case SupportedTensorDtypes::REALHBBF16:
259+
return get_store_common_to_tensor_fn_realhbbf16<CTYPE_COMMON, op_name>(t);
260+
case SupportedTensorDtypes::REALHBF16:
261+
return get_store_common_to_tensor_fn_realhbf16<CTYPE_COMMON, op_name>(t);
262+
case SupportedTensorDtypes::FLOATHBF16:
263+
return get_store_common_to_tensor_fn_floathbf16<CTYPE_COMMON, op_name>(t);
264+
case SupportedTensorDtypes::INTB:
265+
return get_store_common_to_tensor_fn_intb<CTYPE_COMMON, op_name>(t);
266+
case SupportedTensorDtypes::BOOL_OR_BYTE:
267+
return get_store_common_to_tensor_fn_bool_or_byte<CTYPE_COMMON, op_name>(
268+
t);
269+
case SupportedTensorDtypes::SAME_AS_COMPUTE:
270+
return get_store_common_to_tensor_fn_same_as_compute<
271+
CTYPE_COMMON,
272+
op_name>(t);
273+
case SupportedTensorDtypes::SAME_AS_COMMON: {
274+
return get_store_common_to_tensor_fn_same_as_common<
275+
CTYPE_COMMON,
276+
op_name>(t);
277+
}
278+
}
279+
ET_CHECK(false);
280+
return nullptr;
281+
}
282+
283+
bool check_tensor_dtype(
284+
const Tensor t,
285+
SupportedTensorDtypes dtypes,
286+
const ScalarType compute_type);
287+
288+
} // namespace internal
289+
} // namespace utils
290+
} // namespace native
291+
} // namespace executor
292+
} // namespace torch

0 commit comments

Comments
 (0)