Skip to content

Commit 1a4047a

Browse files
authored
Merge pull request #719 from guoruoqian/support_aten_format
feat: support aten::format evaluator
2 parents 09afccb + df9dd6a commit 1a4047a

File tree

2 files changed

+103
-1
lines changed

2 files changed

+103
-1
lines changed

core/conversion/evaluators/aten.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -706,7 +706,23 @@ auto aten_registrations TORCHTRT_UNUSED =
706706
},
707707
EvalOptions().validSchemas({
708708
R"SIG(aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> (Tensor(a!)))SIG",
709-
})});
709+
})})
710+
.evaluator({c10::Symbol::fromQualString("aten::format"),
711+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {
712+
int64_t input_num = n->inputs().size();
713+
std::vector<torch::jit::IValue> stack;
714+
for (auto v : n->inputs()) {
715+
stack.push_back(*args.at(v).IValue());
716+
}
717+
stack.push_back(input_num);
718+
auto& ops = torch::jit::getAllOperatorsFor(c10::Symbol::fromQualString("aten::format"));
719+
auto& aten_format = ops.front();
720+
aten_format->getOperation()(stack);
721+
std::string output;
722+
torch::jit::pop(stack, output);
723+
return output;
724+
},
725+
EvalOptions().validSchemas({"aten::format(str self, ...) -> (str)"})});
710726
} // namespace
711727
} // namespace evaluators
712728
} // namespace conversion

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "gtest/gtest.h"
44
#include "tests/util/util.h"
55
#include "torch/csrc/jit/ir/irparser.h"
6+
#include "torch/csrc/jit/runtime/jit_exception.h"
67
#include "torch/torch.h"
78

89
TEST(Evaluators, DivIntEvaluatesCorrectly) {
@@ -579,4 +580,89 @@ TEST(Evaluators, AndBoolResultIsFalseEvaluatesCorrectly) {
579580
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
580581

581582
ASSERT_TRUE(jit_results[0] == trt_results[0]);
583+
}
584+
585+
TEST(Evaluators, AtenFormatEvaluatesCorrectly) {
586+
const auto graph = R"IR(
587+
graph(%x_1 : Tensor, %x_2 : Tensor):
588+
%0 : int = prim::Constant[value=1]()
589+
%1 : str = prim::Constant[value="res{}_{}_"]()
590+
%2 : int = prim::Constant[value=5]()
591+
%2.1 : int = prim::Constant[value=2]()
592+
%3 : str = prim::Constant[value="res5_2_"]()
593+
%4 : str = aten::format(%1, %2, %2.1)
594+
%5 : bool = aten::eq(%3, %4)
595+
%y : Tensor = prim::If(%5)
596+
block0():
597+
%194 : Tensor = aten::add(%x_1, %x_2, %0)
598+
-> (%194)
599+
block1():
600+
%195 : Tensor = aten::sub(%x_1, %x_2, %0)
601+
-> (%195)
602+
return (%y))IR";
603+
auto g = std::make_shared<torch::jit::Graph>();
604+
torch::jit::parseIR(graph, &*g);
605+
606+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
607+
auto in1 = in0.clone();
608+
609+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
610+
auto jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
611+
612+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
613+
auto trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
614+
615+
ASSERT_TRUE(
616+
torch_tensorrt::tests::util::almostEqual(jit_results[0], trt_results[0].reshape_as(jit_results[0]), 2e-6));
617+
}
618+
619+
TEST(Evaluators, AtenFormatRaiseExceptionEvaluatesCorrectly) {
620+
const auto graph = R"IR(
621+
graph(%x_1 : Tensor, %x_2 : Tensor):
622+
%0 : int = prim::Constant[value=1]()
623+
%1 : str = prim::Constant[value="res5_1"]()
624+
%2 : str = prim::Constant[value="{} is not equal to {}"]()
625+
%3 : str = prim::Constant[value="res5_2"]()
626+
%5713 : Tensor = prim::Uninitialized()
627+
%4 : str = aten::format(%2, %1, %3)
628+
%5 : bool = aten::eq(%1, %3)
629+
%y : Tensor = prim::If(%5)
630+
block0():
631+
%194 : Tensor = aten::add(%x_1, %x_2, %0)
632+
-> (%194)
633+
block1():
634+
prim::RaiseException(%4)
635+
-> (%5713)
636+
return (%y))IR";
637+
auto g = std::make_shared<torch::jit::Graph>();
638+
torch::jit::parseIR(graph, &*g);
639+
640+
auto in0 = at::randint(1, 10, {3, 4}, {at::kCUDA});
641+
auto in1 = in0.clone();
642+
643+
auto params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
644+
std::vector<at::Tensor> jit_results, trt_results;
645+
std::string error_jit, error_torch_trt;
646+
try {
647+
jit_results = torch_tensorrt::tests::util::RunGraph(g, params, {in0, in1});
648+
} catch (const torch::jit::JITException& error) {
649+
error_jit = error.what();
650+
}
651+
652+
params = torch_tensorrt::core::ir::get_static_params(g->inputs(), {});
653+
try {
654+
trt_results = torch_tensorrt::tests::util::RunGraphEngine(g, params, {in0, in1});
655+
} catch (const torch_tensorrt::Error& error) {
656+
error_torch_trt = error.what();
657+
}
658+
659+
auto position1 = error_jit.find("RuntimeError:");
660+
auto position2 = error_torch_trt.find("Error from TorchScript:");
661+
std::string jit_msg = error_jit.substr(position1 + 13);
662+
std::string torch_trt_msg = error_torch_trt.substr(position2 + 23);
663+
if (jit_msg == torch_trt_msg) {
664+
ASSERT_TRUE(true);
665+
} else {
666+
ASSERT_TRUE(false);
667+
}
582668
}

0 commit comments

Comments
 (0)