Skip to content

[ExecuTorch] Allow setting dtype to bf16 in export_llama #4985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 23 commits into from
Sep 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
4fb8d62
[ExecuTorch] Implement BFloat16 and hook it up to scalar_type_util
swolchok Aug 29, 2024
4e8b86b
[ExecuTorch] support BF16 in op_to_copy
swolchok Aug 29, 2024
ff88974
[ExecuTorch] support BF16 in op_mul
swolchok Aug 29, 2024
46579e4
[ExecuTorch] support BF16 in op_mm
swolchok Aug 29, 2024
8388e56
[ExecuTorch] support BF16 in op_copy
swolchok Aug 29, 2024
8e0b9d3
[ExecuTorch] support BF16 in op_slice_scatter
swolchok Aug 29, 2024
741f777
[ExecuTorch] support BF16 in op_scalar_tensor
swolchok Aug 29, 2024
ebdde77
[ExecuTorch] support BF16 in op_where
swolchok Aug 29, 2024
1addb7d
[ExecuTorch] support BF16 in op_add
swolchok Aug 29, 2024
7127237
[ExecuTorch] support BF16 in LLM runner & sampler
swolchok Aug 29, 2024
79021ef
[ExecuTorch] Allow setting dtype to bf16 in export_llama
swolchok Aug 29, 2024
9f13154
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Aug 30, 2024
eac4e38
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Aug 30, 2024
5917de2
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Aug 30, 2024
d67b816
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Aug 30, 2024
4f1d89b
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Sep 3, 2024
8f466a8
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Sep 3, 2024
fc9bf07
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Sep 3, 2024
413a847
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Sep 3, 2024
b88b3f3
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Sep 4, 2024
b8003af
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Sep 4, 2024
862cd85
Update base for Update on "[ExecuTorch] Allow setting dtype to bf16 i…
swolchok Sep 6, 2024
b211c00
Update on "[ExecuTorch] Allow setting dtype to bf16 in export_llama"
swolchok Sep 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions examples/models/llama2/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,9 @@ def build_args_parser() -> argparse.ArgumentParser:
"--dtype-override",
default="fp32",
type=str,
choices=["fp32", "fp16"],
choices=["fp32", "fp16", "bf16"],
help="Override the dtype of the model (default is the checkpoint dtype)."
"Options: fp32, fp16. Please be aware that only some backends support fp16.",
"Options: fp32, fp16, bf16. Please be aware that only some backends support fp16 and bf16.",
)

parser.add_argument(
Expand Down
1 change: 1 addition & 0 deletions extension/llm/export/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def to_torch_dtype(self) -> torch.dtype:
mapping = {
DType.fp32: torch.float32,
DType.fp16: torch.float16,
DType.bf16: torch.bfloat16,
}
if self not in mapping:
raise ValueError(f"Unsupported dtype {self}")
Expand Down
58 changes: 25 additions & 33 deletions extension/llm/runner/text_decoder_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,39 +67,31 @@ class TextDecoderRunner {
* @return The next token.
*/
inline int32_t logits_to_token(const exec_aten::Tensor& logits_tensor) {
switch (logits_tensor.scalar_type()) {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
case exec_aten::ScalarType::Float: {
float* logits = logits_tensor.mutable_data_ptr<float>();
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
float* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
case exec_aten::ScalarType::Half: {
exec_aten::Half* logits =
logits_tensor.mutable_data_ptr<exec_aten::Half>();
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
exec_aten::Half* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
return sampler_->sample(logits_last);
}
return sampler_->sample(logits);
}
default:
ET_CHECK_MSG(
false,
"Unsupported dtype output %hhd",
static_cast<int8_t>(logits_tensor.scalar_type()));
}
int32_t result = 0;
ET_SWITCH_THREE_TYPES(
Float,
Half,
BFloat16,
logits_tensor.scalar_type(),
unused,
"logits_to_token",
CTYPE,
[&]() {
// If the logit_tensor rank is 3, the shape is [batch, seq_length,
// vocab_size], get the last logits, sample and return. Else the model
// outputs the last logit, directly sample and return.
auto* logits = logits_tensor.mutable_data_ptr<CTYPE>();
if (logits_tensor.dim() == 3) {
auto num_tokens = logits_tensor.size(1);
auto vocab_size = logits_tensor.size(2);
auto* logits_last = logits;
logits_last += (num_tokens - 1) * vocab_size;
result = sampler_->sample(logits_last);
} else {
result = sampler_->sample(logits);
}
});
return result;
}

protected:
Expand Down
2 changes: 2 additions & 0 deletions extension/llm/sampler/sampler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,8 @@ int32_t Sampler::sample(T* logits) {

template int32_t Sampler::sample<float>(float* logits);
template int32_t Sampler::sample<exec_aten::Half>(exec_aten::Half* logits);
template int32_t Sampler::sample<exec_aten::BFloat16>(
exec_aten::BFloat16* logits);

} // namespace llm
} // namespace extension
Expand Down
17 changes: 9 additions & 8 deletions kernels/optimized/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ Tensor& opt_add_out(
ScalarType out_type = out.scalar_type();

if (b.numel() == 1) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half) {
if (a_type == b_type && a_type == out_type && a_type != ScalarType::Half &&
a_type != ScalarType::BFloat16) {
auto error = resize_tensor(out, a.sizes());
ET_KERNEL_CHECK_MSG(
ctx,
Expand Down Expand Up @@ -186,12 +187,12 @@ Tensor& opt_add_out(
InvalidArgument,
out);

ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, "add.out", CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
ET_SWITCH_REALHB_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, "add.out", CTYPE_OUT, [&]() {
CTYPE_IN alpha_val;
ET_KERNEL_CHECK(
ctx, utils::extract_scalar(alpha, &alpha_val), InvalidArgument, );
Expand Down Expand Up @@ -226,7 +227,7 @@ Tensor& opt_add_scalar_out(

ET_CHECK(common_type == out_type);

if (common_type == ScalarType::Half) {
if (common_type == ScalarType::Half || common_type == ScalarType::BFloat16) {
common_type = ScalarType::Float;
}

Expand All @@ -235,7 +236,7 @@ Tensor& opt_add_scalar_out(
ET_CHECK_MSG(error == Error::Ok, "Failed to resize output tensor.");

if (a_type == common_type && a_type == out_type &&
a_type != ScalarType::Half) {
a_type != ScalarType::Half && a_type != ScalarType::BFloat16) {
ET_SWITCH_REALB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
CTYPE_B b_val;
Expand All @@ -255,11 +256,11 @@ Tensor& opt_add_scalar_out(
});
});
} else {
ET_SWITCH_REALHB_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, "add.Scalar_out", CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, "add.Scalar_out", CTYPE_B, [&]() {
ET_SWITCH_REALB_TYPES(
common_type, ctx, "add.Scalar_out", CTYPE_IN, [&]() {
ET_SWITCH_REALHB_TYPES(
ET_SWITCH_REALHBBF16_TYPES(
out_type, ctx, "add.Scalar_out", CTYPE_OUT, [&]() {
CTYPE_B b_val;
ET_EXTRACT_SCALAR(b, b_val);
Expand Down
20 changes: 14 additions & 6 deletions kernels/portable/cpu/op_add.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,11 @@ Tensor& add_out(
InvalidArgument,
out);

ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out);

Expand All @@ -94,15 +98,15 @@ Tensor& add_out(

constexpr auto name = "add.out";

ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename torch::executor::
promote_types<CTYPE_A, CTYPE_B, /*half_to_float*/ true>::type;
ET_DCHECK(CppTypeToScalarType<CTYPE_IN>::value == common_type);
CTYPE_IN alpha_val;
utils::extract_scalar(alpha, &alpha_val);

ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
ET_SWITCH_REALHBBF16_TYPES(out_type, ctx, name, CTYPE_OUT, [&]() {
AddInner<
can_cast<CTYPE_IN, CTYPE_OUT>::value,
CTYPE_A,
Expand Down Expand Up @@ -132,7 +136,11 @@ Tensor& add_scalar_out(
out,
"Failed to resize output tensor.");

ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out);
ET_KERNEL_CHECK(
ctx,
executorch::runtime::tensor_is_realhbbf16_type(out),
InvalidArgument,
out);
ET_KERNEL_CHECK(
ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out);

Expand All @@ -153,7 +161,7 @@ Tensor& add_scalar_out(

constexpr auto name = "add.Scalar_out";

ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_IN = typename utils::promote_type_with_scalar_type<
CTYPE_A,
Expand Down
8 changes: 4 additions & 4 deletions kernels/portable/cpu/op_copy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ Tensor& copy_out(
ScalarType in_type = in.scalar_type();
ScalarType src_type = src.scalar_type();

ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() {
ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy.out", CTYPE, [&]() {
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy.out", CTYPE_SRC, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
[](const CTYPE val_in, const CTYPE_SRC val_src) {
return convert<CTYPE, CTYPE_SRC>(val_src);
Expand Down Expand Up @@ -75,8 +75,8 @@ copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
ScalarType in_type = in.scalar_type();
ScalarType src_type = src.scalar_type();

ET_SWITCH_REALHB_TYPES(in_type, ctx, "copy_", CTYPE, [&]() {
ET_SWITCH_REALHB_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() {
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "copy_", CTYPE, [&]() {
ET_SWITCH_REALHBBF16_TYPES(src_type, ctx, "copy_", CTYPE_SRC, [&]() {
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
[](const CTYPE val_in, const CTYPE_SRC val_src) {
return convert<CTYPE, CTYPE_SRC>(val_src);
Expand Down
27 changes: 14 additions & 13 deletions kernels/portable/cpu/op_mm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,20 @@ mm_out(RuntimeContext& ctx, const Tensor& in, const Tensor& mat2, Tensor& out) {

ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out);

ET_SWITCH_REAL_TYPES_AND(Half, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
size_t m = in.size(0);
size_t n = in.size(1);
size_t p = mat2.size(1);

vec_matmul<CTYPE>(
out.mutable_data_ptr<CTYPE>(),
in.const_data_ptr<CTYPE>(),
mat2.const_data_ptr<CTYPE>(),
m,
n,
p);
});
ET_SWITCH_REAL_TYPES_AND2(
Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() {
size_t m = in.size(0);
size_t n = in.size(1);
size_t p = mat2.size(1);

vec_matmul<CTYPE>(
out.mutable_data_ptr<CTYPE>(),
in.const_data_ptr<CTYPE>(),
mat2.const_data_ptr<CTYPE>(),
m,
n,
p);
});

return out;
}
Expand Down
15 changes: 8 additions & 7 deletions kernels/portable/cpu/op_scalar_tensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,14 @@ Tensor& scalar_tensor_out(RuntimeContext& ctx, const Scalar& s, Tensor& out) {

constexpr auto name = "scalar_tensor.out";

ET_SWITCH_REALHB_TYPES(out_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
CTYPE_S val_s;
utils::extract_scalar(s, &val_s);
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
});
});
ET_SWITCH_REAL_TYPES_AND3(
Half, Bool, BFloat16, out_type, ctx, name, CTYPE, [&]() {
ET_SWITCH_SCALAR_OBJ_TYPES(s_type, ctx, name, CTYPE_S, [&]() {
CTYPE_S val_s;
utils::extract_scalar(s, &val_s);
out.mutable_data_ptr<CTYPE>()[0] = convert<CTYPE, CTYPE_S>(val_s);
});
});

return out;
}
Expand Down
4 changes: 2 additions & 2 deletions kernels/portable/cpu/op_slice_scatter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,8 @@ Tensor& slice_scatter_out(
ScalarType in_type = input.scalar_type();
ScalarType src_type = src.scalar_type();

ET_SWITCH_REALHB_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() {
ET_SWITCH_REALHB_TYPES(
ET_SWITCH_REALHBBF16_TYPES(in_type, ctx, "slice_scatter.out", CTYPE, [&]() {
ET_SWITCH_REALHBBF16_TYPES(
src_type, ctx, "slice_scatter.out", CTYPE_SRC, [&]() {
CTYPE* out_data = out.mutable_data_ptr<CTYPE>();
const CTYPE_SRC* src_data = src.const_data_ptr<CTYPE_SRC>();
Expand Down
4 changes: 2 additions & 2 deletions kernels/portable/cpu/op_where.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ Tensor& where_out(
cond_type == ScalarType::Bool || cond_type == ScalarType::Byte,
"Unhandled dtype %s for where.self_out",
torch::executor::toString(cond_type));
ET_SWITCH_REALHB_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHB_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
ET_SWITCH_REALHBBF16_TYPES(a_type, ctx, name, CTYPE_A, [&]() {
ET_SWITCH_REALHBBF16_TYPES(b_type, ctx, name, CTYPE_B, [&]() {
using CTYPE_OUT =
typename torch::executor::promote_types<CTYPE_A, CTYPE_B>::type;
apply_ternary_elementwise_fn<CTYPE_A, CTYPE_B, uint8_t, CTYPE_OUT>(
Expand Down
23 changes: 17 additions & 6 deletions kernels/test/op_add_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class OpAddOutKernelTest : public OperatorTest {

template <ScalarType DTYPE_A, ScalarType DTYPE_B>
void test_add_enumerate_out_types() {
test_add<DTYPE_A, DTYPE_B, ScalarType::BFloat16>();
test_add<DTYPE_A, DTYPE_B, ScalarType::Half>();
test_add<DTYPE_A, DTYPE_B, ScalarType::Float>();
test_add<DTYPE_A, DTYPE_B, ScalarType::Double>();
Expand All @@ -73,7 +74,7 @@ class OpAddOutKernelTest : public OperatorTest {
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_add_enumerate_out_types<DTYPE_A, ScalarType::dtype>();

ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)

#undef ENUMERATE_TEST_ENTRY
}
Expand All @@ -82,7 +83,7 @@ class OpAddOutKernelTest : public OperatorTest {
#define ENUMERATE_TEST_ENTRY(ctype, dtype) \
test_add_enumerate_b_types<ScalarType::dtype>();

ET_FORALL_REAL_TYPES_AND(Half, ENUMERATE_TEST_ENTRY)
ET_FORALL_REALHBF16_TYPES(ENUMERATE_TEST_ENTRY)

#undef ENUMERATE_TEST_ENTRY
}
Expand All @@ -99,13 +100,15 @@ class OpAddOutKernelTest : public OperatorTest {

// Add two tensors.
op_add_out(
tf.make(sizes, /*data=*/{1.1, 2.2, 4.4, 8.8}),
tf.make(sizes, /*data=*/{1.25, 2.25, 4.5, 8.875}),
tf.ones(sizes),
/*alpha=*/1.1,
/*alpha=*/1.25,
out);

// Check that it matches the expected output.
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.2, 3.3, 5.5, 9.9}));
// Check that it matches the expected output. Values selected to
// be exactly representable to avoid throwing off half/bfloat16
// tests.
EXPECT_TENSOR_CLOSE(out, tf.make(sizes, /*data=*/{2.5, 3.5, 5.75, 10.125}));
}
};

Expand Down Expand Up @@ -136,6 +139,14 @@ TEST_F(OpAddOutKernelTest, DoubleTensors) {
test_floating_point_add_out<ScalarType::Double>();
}

TEST_F(OpAddOutKernelTest, HalfTensors) {
test_floating_point_add_out<ScalarType::Half>();
}

TEST_F(OpAddOutKernelTest, BFloat16Tensors) {
test_floating_point_add_out<ScalarType::BFloat16>();
}

TEST_F(OpAddOutKernelTest, BoolAndIntInputTensor) {
TensorFactory<ScalarType::Bool> tf;
TensorFactory<ScalarType::Int> tfi;
Expand Down
4 changes: 2 additions & 2 deletions kernels/test/op_copy_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,13 +125,13 @@ class OpCopyInplaceTest : public OperatorTest {
// regular test for copy.out
TEST_F(OpCopyTest, AllRealDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES(TEST_ENTRY);
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

TEST_F(OpCopyTest, EmptyInputSupported) {
#define TEST_ENTRY(ctype, dtype) test_empty_input<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES_AND(Bool, TEST_ENTRY);
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
}

Expand Down
2 changes: 1 addition & 1 deletion kernels/test/op_mm_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ TEST_F(OpMmOutTest, OutputDim) {
/// zeros().
TEST_F(OpMmOutTest, AllDtypesSupported) {
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
ET_FORALL_REAL_TYPES_AND(Half, TEST_ENTRY);
ET_FORALL_REALHBF16_TYPES(TEST_ENTRY);
#undef TEST_ENTRY
// TODO: Also add tests for half, complex, quantized, and other types. Easiest
// way to do that would be to make TensorFactory support zeros() and ones()
Expand Down
Loading
Loading