Skip to content

Commit bbc7949

Browse files
authored
Merge pull request #543 from NVIDIA/yutec/clone_copy_evaluators
Feat: Add aten::clone and aten::copy_ support.
2 parents 2fbbbc1 + 962bf3b commit bbc7949

File tree

5 files changed

+157
-0
lines changed

5 files changed

+157
-0
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,41 @@ auto aten_registrations TRTORCH_UNUSED =
576576
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
577577
R"SIG(aten::arange.start_step(Scalar start, Scalar end, Scalar step, *, ScalarType? dtype=None,
578578
Layout? layout=None, Device? device=None, bool? pin_memory=None) -> (Tensor))SIG",
579+
})})
580+
.evaluator({c10::Symbol::fromQualString("aten::clone"),
581+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
582+
if (args.at(n->input(0)).isITensor()) {
583+
auto source_tensor = args.at(n->input(0)).ITensor();
584+
auto tensor_holder = TensorContainer();
585+
tensor_holder.hold_tensor(source_tensor);
586+
auto clone_tensor = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
587+
return std::move(clone_tensor);
588+
} else {
589+
auto source_tensor = args.at(n->input(0)).unwrapToTensor();
590+
auto clone_tensor = source_tensor.clone();
591+
return clone_tensor;
592+
}
593+
},
594+
EvalOptions().validSchemas({
595+
R"SIG(aten::clone(Tensor self, *, int? memory_format=None) -> (Tensor))SIG",
596+
})})
597+
.evaluator({c10::Symbol::fromQualString("aten::copy_"),
598+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
599+
if (args.at(n->input(1)).isITensor()) {
600+
auto source_tensor = args.at(n->input(1)).ITensor();
601+
auto tensor_holder = TensorContainer();
602+
tensor_holder.hold_tensor(source_tensor);
603+
auto clone_tensor = c10::IValue(std::move(c10::make_intrusive<TensorContainer>(tensor_holder)));
604+
return std::move(clone_tensor);
605+
} else {
606+
auto source_tensor = args.at(n->input(1)).unwrapToTensor();
607+
auto self_tensor = args.at(n->input(0)).unwrapToTensor();
608+
self_tensor.copy_(source_tensor);
609+
return self_tensor;
610+
}
611+
},
612+
EvalOptions().validSchemas({
613+
R"SIG(aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!)))SIG",
579614
})});
580615
} // namespace
581616
} // namespace evaluators

tests/core/conversion/converters/BUILD

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ converter_test(
1515
name = "test_batch_norm",
1616
)
1717

18+
converter_test(
19+
name = "test_clone",
20+
)
21+
1822
converter_test(
1923
name = "test_concat",
2024
)
@@ -27,6 +31,10 @@ converter_test(
2731
name = "test_conv_deconv",
2832
)
2933

34+
converter_test(
35+
name = "test_copy",
36+
)
37+
3038
converter_test(
3139
name = "test_cumsum"
3240
)
@@ -112,9 +120,11 @@ test_suite(
112120
tests = [
113121
":test_activation",
114122
":test_batch_norm",
123+
":test_clone",
115124
":test_concat",
116125
":test_constant_pad",
117126
":test_conv_deconv",
127+
":test_copy",
118128
":test_cumsum",
119129
":test_element_wise",
120130
":test_expand",
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenCloneConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%1 : Tensor = aten::relu(%0)
11+
%2 : None = prim::Constant()
12+
%3 : Tensor = aten::clone(%1, %2)
13+
%4 : Tensor = aten::relu(%3)
14+
%5 : int = prim::Constant[value=1]()
15+
%6 : Tensor = aten::add(%1, %4, %5)
16+
return (%6))IR";
17+
18+
auto g = std::make_shared<torch::jit::Graph>();
19+
torch::jit::parseIR(graph, g.get());
20+
21+
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
22+
23+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
24+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
25+
26+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
27+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
28+
29+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
30+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Converters, ATenCopyConvertsCorrectly) {
8+
const auto graph = R"IR(
9+
graph(%0 : Tensor):
10+
%0.1 : Tensor = aten::relu(%0)
11+
%1 : int = prim::Constant[value=1]()
12+
%2 : int = prim::Constant[value=3]()
13+
%3 : int = prim::Constant[value=10]()
14+
%4 : int = prim::Constant[value=10]()
15+
%5 : int[] = prim::ListConstruct(%1, %2, %3, %4)
16+
%6 : None = prim::Constant()
17+
%7 : Device = prim::Constant[value="cuda"]()
18+
%8 : Tensor = aten::ones(%5, %6, %6, %7, %6)
19+
%9 : bool = prim::Constant[value=0]()
20+
%10 : Tensor = aten::copy_(%8, %0.1, %9)
21+
%11 : Tensor = aten::relu(%10)
22+
%12 : int = prim::Constant[value=1]()
23+
%13 : Tensor = aten::add(%0.1, %11, %12)
24+
return (%13))IR";
25+
26+
auto g = std::make_shared<torch::jit::Graph>();
27+
torch::jit::parseIR(graph, g.get());
28+
29+
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
30+
31+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
32+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
33+
34+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
35+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
36+
37+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
38+
}

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,4 +355,48 @@ TEST(Evaluators, ATenAppendWithITensorAndTensorEvaluatesCorrectly) {
355355
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in0});
356356

357357
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
358+
}
359+
360+
TEST(Evaluators, ATenCloneEvaluatesCorrectly) {
361+
const auto graph = R"IR(
362+
graph(%0 : Tensor):
363+
%1 : None = prim::Constant()
364+
%2 : Tensor = aten::clone(%0, %1)
365+
return (%2))IR";
366+
367+
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
368+
369+
auto g = std::make_shared<torch::jit::Graph>();
370+
torch::jit::parseIR(graph, g.get());
371+
372+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
373+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
374+
375+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
376+
}
377+
378+
TEST(Evaluators, ATenCopyEvaluatesCorrectly) {
379+
const auto graph = R"IR(
380+
graph(%0 : Tensor):
381+
%1 : int = prim::Constant[value=1]()
382+
%2 : int = prim::Constant[value=3]()
383+
%3 : int = prim::Constant[value=10]()
384+
%4 : int = prim::Constant[value=10]()
385+
%5 : int[] = prim::ListConstruct(%1, %2, %3, %4)
386+
%6 : None = prim::Constant()
387+
%7 : Device = prim::Constant[value="cuda"]()
388+
%8 : Tensor = aten::ones(%5, %6, %6, %7, %6)
389+
%9 : bool = prim::Constant[value=0]()
390+
%10 : Tensor = aten::copy_(%8, %0, %9)
391+
return (%10))IR";
392+
393+
auto in = at::randint(1, 10, {1, 3, 10, 10}, {at::kCUDA});
394+
395+
auto g = std::make_shared<torch::jit::Graph>();
396+
torch::jit::parseIR(graph, g.get());
397+
398+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {in});
399+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {in});
400+
401+
ASSERT_TRUE(at::equal(jit_results[0].toTensor().to(at::kCUDA), trt_results[0].toTensor()));
358402
}

0 commit comments

Comments
 (0)