Skip to content

Commit 9cb0be6

Browse files
manuelcandalesfacebook-github-bot
authored andcommitted
Add copy_ (#2432)
Summary: Pull Request resolved: #2432 Reviewed By: JacobSzwejbka Differential Revision: D54909962 fbshipit-source-id: d3cca3e8c6d3406dc502ca62fa58f0bce2a3ae4d
1 parent 63a1fde commit 9cb0be6

File tree

3 files changed

+59
-0
lines changed

3 files changed

+59
-0
lines changed

kernels/portable/cpu/op_copy.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,33 @@ Tensor& copy_out(
5757
return out;
5858
}
5959

60+
Tensor&
61+
copy_(RuntimeContext& ctx, Tensor& in, const Tensor& src, bool non_blocking) {
62+
(void)ctx;
63+
// Right now we only support blocking data transfer
64+
ET_KERNEL_CHECK(ctx, non_blocking == false, InvalidArgument, in);
65+
66+
ET_KERNEL_CHECK(
67+
ctx, tensor_is_broadcastable_to(src, in), InvalidArgument, in);
68+
69+
ScalarType in_type = in.scalar_type();
70+
ScalarType src_type = src.scalar_type();
71+
72+
ET_SWITCH_REAL_TYPES_AND(Bool, in_type, ctx, "copy_", CTYPE, [&]() {
73+
ET_SWITCH_REAL_TYPES_AND(Bool, src_type, ctx, "copy_", CTYPE_SRC, [&]() {
74+
apply_binary_elementwise_fn<CTYPE, CTYPE_SRC, CTYPE>(
75+
[](const CTYPE val_in, const CTYPE_SRC val_src) {
76+
return convert<CTYPE, CTYPE_SRC>(val_src);
77+
},
78+
in,
79+
src,
80+
in);
81+
});
82+
});
83+
84+
return in;
85+
}
86+
6087
} // namespace native
6188
} // namespace executor
6289
} // namespace torch

kernels/portable/functions.yaml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,11 @@
253253
- arg_meta: null
254254
kernel_name: torch::executor::copy_out
255255

256+
- op: copy_
257+
kernels:
258+
- arg_meta: null
259+
kernel_name: torch::executor::copy_
260+
256261
- op: cos.out
257262
kernels:
258263
- arg_meta: null

kernels/test/op_copy_test.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,13 @@ class OpCopyTest : public OperatorTest {
115115
}
116116
};
117117

118+
class OpCopyInplaceTest : public OperatorTest {
119+
protected:
120+
Tensor& op_copy_(Tensor& self, const Tensor& src, bool non_blocking) {
121+
return torch::executor::aten::copy_(context_, self, src, non_blocking);
122+
}
123+
};
124+
118125
// regular test for copy.out
119126
TEST_F(OpCopyTest, AllRealDtypesSupported) {
120127
#define TEST_ENTRY(ctype, dtype) test_dtype<ctype, ScalarType::dtype>();
@@ -255,3 +262,23 @@ TEST_F(OpCopyTest, DynamicShapeUnbound) {
255262
test_dynamic_shape(
256263
{1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND);
257264
}
265+
266+
TEST_F(OpCopyInplaceTest, SmokeTest) {
267+
TensorFactory<ScalarType::Int> tf;
268+
Tensor in = tf.zeros({2, 2});
269+
Tensor src = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
270+
bool non_blocking = false;
271+
op_copy_(in, src, non_blocking);
272+
Tensor expected = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
273+
EXPECT_TENSOR_EQ(in, expected);
274+
}
275+
276+
TEST_F(OpCopyInplaceTest, BroadCastSrcSupported) {
277+
TensorFactory<ScalarType::Int> tf;
278+
Tensor in = tf.make(/*sizes=*/{2, 2}, /*data=*/{1, 2, 3, 4});
279+
Tensor src = tf.make(/*sizes=*/{1, 2}, /*data=*/{3, 3});
280+
bool non_blocking = false;
281+
op_copy_(in, src, non_blocking);
282+
Tensor expected = tf.make(/*sizes=*/{2, 2}, /*data=*/{3, 3, 3, 3});
283+
EXPECT_TENSOR_EQ(in, expected);
284+
}

0 commit comments

Comments
 (0)