Skip to content

Commit 46a18cb

Browse files
authored
Implement native_dropout (#10567)
Yet another core ATen op. Test Plan: Comes with test. Imported to fbsource and ran test in ATen mode as well.
1 parent 1ae8c2c commit 46a18cb

File tree

6 files changed

+200
-0
lines changed

6 files changed

+200
-0
lines changed
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/kernels/portable/cpu/util/elementwise_util.h>
10+
#include <executorch/runtime/kernel/kernel_includes.h>
11+
12+
#include <random>
13+
#include <tuple>
14+
15+
namespace torch::executor::native {
16+
std::tuple<Tensor&, Tensor&> native_dropout_out(
17+
KernelRuntimeContext& ctx,
18+
const Tensor& input,
19+
double prob,
20+
torch::executor::optional<bool> train,
21+
Tensor& out,
22+
Tensor& mask) {
23+
std::tuple<Tensor&, Tensor&> ret(out, mask);
24+
ET_KERNEL_CHECK(
25+
ctx, tensors_have_same_dtype(input, out), InvalidArgument, ret);
26+
ET_KERNEL_CHECK(
27+
ctx, tensors_have_same_dim_order(input, out, mask), InvalidArgument, ret);
28+
ET_KERNEL_CHECK(
29+
ctx,
30+
resize_tensor(out, input.sizes()) == Error::Ok,
31+
InvalidArgument,
32+
ret);
33+
ET_KERNEL_CHECK(
34+
ctx,
35+
resize_tensor(mask, input.sizes()) == Error::Ok,
36+
InvalidArgument,
37+
ret);
38+
ET_KERNEL_CHECK(ctx, tensor_is_bool_type(mask), InvalidArgument, ret);
39+
ET_KERNEL_CHECK_MSG(
40+
ctx,
41+
prob >= 0 && prob <= 1,
42+
InvalidArgument,
43+
ret,
44+
"dropout probability has to be between 0 and 1 but got %f",
45+
prob);
46+
47+
// @lint-ignore CLANGTIDY facebook-hte-CArray
48+
static constexpr const char op_name[] = "native_dropout.out";
49+
if ((!train.has_value() || train.value()) && prob != 0) {
50+
{
51+
std::mt19937 gen((std::random_device())());
52+
std::uniform_real_distribution<double> dist;
53+
bool* const mask_data_ptr = mask.mutable_data_ptr<bool>();
54+
for (const auto ii : c10::irange(mask.numel())) {
55+
mask_data_ptr[ii] = dist(gen) >= prob;
56+
}
57+
}
58+
ET_SWITCH_FLOATHBF16_TYPES(
59+
input.scalar_type(), ctx, op_name, CTYPE_COMPUTE, [&]() {
60+
utils::apply_bitensor_elementwise_fn<CTYPE_COMPUTE, op_name>(
61+
[](const auto val, const auto mask_val) {
62+
if (!mask_val) {
63+
return static_cast<decltype(val)>(0);
64+
}
65+
return val;
66+
},
67+
ctx,
68+
input,
69+
utils::SupportedTensorDtypes::FLOATHBF16,
70+
mask,
71+
// TODO: should really be just BOOL
72+
utils::SupportedTensorDtypes::BOOL_OR_BYTE,
73+
out,
74+
utils::SupportedTensorDtypes::SAME_AS_COMMON);
75+
});
76+
} else if (input.numel() > 0) {
77+
std::memcpy(out.mutable_data_ptr(), input.data_ptr(), input.nbytes());
78+
std::memset(mask.mutable_data_ptr(), true, mask.nbytes());
79+
}
80+
return ret;
81+
}
82+
83+
} // namespace torch::executor::native

kernels/portable/functions.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,6 +627,12 @@
627627
- arg_meta: null
628628
kernel_name: torch::executor::narrow_copy_out
629629

630+
- op: native_dropout.out
631+
kernels:
632+
- arg_meta: null
633+
kernel_name: torch::executor::native_dropout_out
634+
tags: nondeterministic_seeded
635+
630636
- op: native_group_norm.out
631637
kernels:
632638
- arg_meta: null

kernels/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,7 @@ set(all_test_sources
186186
"op_mul_test.cpp"
187187
"op_pow_test.cpp"
188188
"op_native_batch_norm_test.cpp"
189+
"op_native_dropout_test.cpp"
189190
"op_native_group_norm_test.cpp"
190191
"op_native_layer_norm_test.cpp"
191192
"op_ne_test.cpp"
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <c10/util/irange.h>
10+
#include <executorch/kernels/test/FunctionHeaderWrapper.h> // Declares the operator
11+
#include <executorch/kernels/test/TestUtil.h>
12+
#include <executorch/kernels/test/supported_features.h>
13+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
14+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_factory.h>
15+
#include <executorch/runtime/core/exec_aten/testing_util/tensor_util.h>
16+
17+
#include <gtest/gtest.h>
18+
19+
using executorch::aten::ScalarType;
20+
using executorch::aten::Tensor;
21+
using torch::executor::testing::TensorFactory;
22+
23+
class OpNativeDropoutTest : public OperatorTest {
24+
protected:
25+
void op_native_dropout_out(
26+
const Tensor& self,
27+
double prob,
28+
executorch::aten::optional<bool> train,
29+
Tensor& out,
30+
Tensor& mask) {
31+
torch::executor::aten::native_dropout_outf(
32+
context_, self, prob, train, out, mask);
33+
}
34+
35+
template <typename CTYPE, ScalarType DTYPE>
36+
void test_dropout() {
37+
TensorFactory<DTYPE> tf;
38+
TensorFactory<ScalarType::Bool> tf_bool;
39+
const std::vector<int32_t> sizes = {3, 2};
40+
Tensor in = tf.make(sizes, {1, 2, 3, 4, 5, 6});
41+
Tensor out = tf.zeros(sizes);
42+
Tensor mask = tf_bool.zeros(sizes);
43+
44+
bool* const mask_data = mask.mutable_data_ptr<bool>();
45+
auto expect_no_drops = [&]() {
46+
EXPECT_TENSOR_CLOSE(out, in);
47+
for (const auto ii : c10::irange(mask.numel())) {
48+
EXPECT_TRUE(mask_data[ii]);
49+
mask_data[ii] = false;
50+
}
51+
};
52+
53+
op_native_dropout_out(in, 0, true, out, mask);
54+
expect_no_drops();
55+
56+
op_native_dropout_out(in, 0, false, out, mask);
57+
expect_no_drops();
58+
59+
op_native_dropout_out(in, 1, false, out, mask);
60+
expect_no_drops();
61+
62+
op_native_dropout_out(in, 1, true, out, mask);
63+
auto* const out_data = out.mutable_data_ptr<CTYPE>();
64+
for (const auto ii : c10::irange(out.numel())) {
65+
EXPECT_EQ(out_data[ii], CTYPE(0));
66+
}
67+
for (const auto ii : c10::irange(mask.numel())) {
68+
EXPECT_FALSE(mask_data[ii]);
69+
mask_data[ii] = 0;
70+
}
71+
}
72+
};
73+
74+
TEST_F(OpNativeDropoutTest, Basic) {
75+
#define TEST_ENTRY(ctype, dtype) test_dropout<ctype, ScalarType::dtype>();
76+
ET_FORALL_FLOATHBF16_TYPES(TEST_ENTRY);
77+
#undef TEST_ENTRY
78+
}
79+
80+
TEST_F(OpNativeDropoutTest, ProbabilityRangeCheck) {
81+
TensorFactory<ScalarType::Float> tf_float;
82+
TensorFactory<ScalarType::Bool> tf_bool;
83+
const std::vector<int32_t> sizes = {2, 3};
84+
Tensor a = tf_float.ones(sizes);
85+
Tensor out = tf_float.zeros(sizes);
86+
Tensor mask = tf_bool.zeros(sizes);
87+
ET_EXPECT_KERNEL_FAILURE(
88+
context_, op_native_dropout_out(a, -1, true, out, mask));
89+
}
90+
91+
TEST_F(OpNativeDropoutTest, MaskBoolCheck) {
92+
TensorFactory<ScalarType::Float> tf_float;
93+
TensorFactory<ScalarType::Byte> tf_byte;
94+
const std::vector<int32_t> sizes = {2, 3};
95+
Tensor a = tf_float.ones(sizes);
96+
Tensor out = tf_float.zeros(sizes);
97+
Tensor mask_byte = tf_byte.zeros(sizes);
98+
Tensor mask_float = tf_float.zeros(sizes);
99+
ET_EXPECT_KERNEL_FAILURE(
100+
context_, op_native_dropout_out(a, 0.5, true, out, mask_byte));
101+
ET_EXPECT_KERNEL_FAILURE(
102+
context_, op_native_dropout_out(a, 0.5, true, out, mask_float));
103+
}

kernels/test/targets.bzl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,7 @@ def define_common_targets():
272272
_common_op_test("op_mul_test", ["aten", "portable", "optimized"])
273273
_common_op_test("op_narrow_copy_test", ["aten", "portable"])
274274
_common_op_test("op_native_batch_norm_test", ["aten", "portable"])
275+
_common_op_test("op_native_dropout_test", ["aten", "portable"])
275276
_common_op_test("op_native_group_norm_test", ["aten", "portable"])
276277
_common_op_test("op_native_layer_norm_test", ["aten", "portable", "optimized"])
277278
_common_op_test("op_ne_test", ["aten", "portable"])

shim_et/xplat/executorch/kernels/portable/op_registration_util.bzl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -883,6 +883,12 @@ ATEN_OPS = (
883883
"//executorch/kernels/portable/cpu/util:normalization_ops_util",
884884
],
885885
),
886+
op_target(
887+
name = "op_native_dropout",
888+
deps = [
889+
"//executorch/kernels/portable/cpu/util:elementwise_util",
890+
],
891+
),
886892
op_target(
887893
name = "op_native_group_norm",
888894
deps = [

0 commit comments

Comments
 (0)