Skip to content

Commit c1c88b0

Browse files
committed
[ExecuTorch] support BF16 in LLM runner & sampler
Pull Request resolved: #4984 The LLM runner assumed that the data type could only be float or half. Suport bfloat16 and neaten up the code while we're at it. ghstack-source-id: 241050113 @exported-using-ghexport Differential Revision: [D61981354](https://our.internmc.facebook.com/intern/diff/D61981354/) **NOTE FOR REVIEWERS**: This PR has internal Meta-specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D61981354/)!
1 parent 6d1af72 commit c1c88b0

File tree

3 files changed

+51
-42
lines changed

3 files changed

+51
-42
lines changed

extension/llm/runner/text_decoder_runner.h

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -67,39 +67,31 @@ class TextDecoderRunner {
6767
* @return The next token.
6868
*/
6969
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
70-
switch (logits_tensor.scalar_type()) {
71-
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
72-
// vocab_size], get the last logits, sample and return. Else the model
73-
// outputs the last logit, directly sample and return.
74-
case exec_aten::ScalarType::Float: {
75-
float* logits = logits_tensor.mutable_data_ptr<float>();
76-
if (logits_tensor.dim() == 3) {
77-
auto num_tokens = logits_tensor.size(1);
78-
auto vocab_size = logits_tensor.size(2);
79-
float* logits_last = logits;
80-
logits_last += (num_tokens - 1) * vocab_size;
81-
return sampler_->sample(logits_last);
82-
}
83-
return sampler_->sample(logits);
84-
}
85-
case exec_aten::ScalarType::Half: {
86-
exec_aten::Half* logits =
87-
logits_tensor.mutable_data_ptr<exec_aten::Half>();
88-
if (logits_tensor.dim() == 3) {
89-
auto num_tokens = logits_tensor.size(1);
90-
auto vocab_size = logits_tensor.size(2);
91-
exec_aten::Half* logits_last = logits;
92-
logits_last += (num_tokens - 1) * vocab_size;
93-
return sampler_->sample(logits_last);
94-
}
95-
return sampler_->sample(logits);
96-
}
97-
default:
98-
ET_CHECK_MSG(
99-
false,
100-
"Unsupported dtype output %hhd",
101-
static_cast<int8_t>(logits_tensor.scalar_type()));
102-
}
70+
int32_t result = 0;
71+
ET_SWITCH_THREE_TYPES(
72+
Float,
73+
Half,
74+
BFloat16,
75+
logits_tensor.scalar_type(),
76+
unused,
77+
"logits_to_token",
78+
CTYPE,
79+
[&]() {
80+
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
81+
// vocab_size], get the last logits, sample and return. Else the model
82+
// outputs the last logit, directly sample and return.
83+
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
84+
if (logits_tensor.dim() == 3) {
85+
auto num_tokens = logits_tensor.size(1);
86+
auto vocab_size = logits_tensor.size(2);
87+
auto* logits_last = logits;
88+
logits_last += (num_tokens - 1) * vocab_size;
89+
result = sampler_->sample(logits_last);
90+
} else {
91+
result = sampler_->sample(logits);
92+
}
93+
});
94+
return result;
10395
}
10496

10597
protected:

extension/llm/sampler/sampler.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,8 @@ int32_t Sampler::sample(T* logits) {
192192

193193
template int32_t Sampler::sample<float>(float* logits);
194194
template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
195+
template int32_t Sampler::sample<exec_aten::BFloat16>(
196+
exec_aten::BFloat16* logits);
195197

196198
} // namespace llm
197199
} // namespace extension

runtime/core/exec_aten/util/scalar_type_util.h

Lines changed: 24 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -953,17 +953,19 @@ inline exec_aten::ScalarType promoteTypes(
953953
//
954954

955955
#ifdef ET_INTERNAL_CHECK_SELECTIVE_BUILD
956-
#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
957-
case enum_type: { \
958-
ET_INTERNAL_CHECK_SELECTIVE_BUILD(enum_type); \
959-
using CTYPE_ALIAS = ScalarTypeToCppType<enum_type>::type; \
960-
return __VA_ARGS__(); \
956+
#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
957+
case enum_type: { \
958+
ET_INTERNAL_CHECK_SELECTIVE_BUILD(enum_type); \
959+
using CTYPE_ALIAS = \
960+
::executorch::runtime::ScalarTypeToCppType<enum_type>::type; \
961+
return __VA_ARGS__(); \
961962
}
962963
#else
963-
#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
964-
case enum_type: { \
965-
using CTYPE_ALIAS = ScalarTypeToCppType<enum_type>::type; \
966-
return __VA_ARGS__(); \
964+
#define ET_INTERNAL_SWITCH_CASE(enum_type, CTYPE_ALIAS, ...) \
965+
case enum_type: { \
966+
using CTYPE_ALIAS = \
967+
::executorch::runtime::ScalarTypeToCppType<enum_type>::type; \
968+
return __VA_ARGS__(); \
967969
}
968970
#endif
969971

@@ -1343,6 +1345,19 @@ inline exec_aten::ScalarType promoteTypes(
13431345
ET_INTERNAL_SWITCH_CASE( \
13441346
exec_aten::ScalarType::T2, CTYPE_ALIAS, __VA_ARGS__))
13451347

1348+
#define ET_SWITCH_THREE_TYPES( \
1349+
T1, T2, T3, TYPE, CONTEXT, NAME, CTYPE_ALIAS, ...) \
1350+
ET_INTERNAL_SWITCH( \
1351+
TYPE, \
1352+
CONTEXT, \
1353+
NAME, \
1354+
ET_INTERNAL_SWITCH_CASE( \
1355+
exec_aten::ScalarType::T1, CTYPE_ALIAS, __VA_ARGS__) \
1356+
ET_INTERNAL_SWITCH_CASE( \
1357+
exec_aten::ScalarType::T2, CTYPE_ALIAS, __VA_ARGS__) \
1358+
ET_INTERNAL_SWITCH_CASE( \
1359+
exec_aten::ScalarType::T3, CTYPE_ALIAS, __VA_ARGS__))
1360+
13461361
} // namespace runtime
13471362
} // namespace executorch
13481363

0 commit comments

Comments
 (0)