Skip to content

Commit 4839b11

Browse files
authored
Merge pull request #918 from NVIDIA/aten_pow
aten::pow support
2 parents de0e615 + 8b33dc0 commit 4839b11

File tree

5 files changed

+149
-7
lines changed

5 files changed

+149
-7
lines changed

core/conversion/conversion.cpp

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
105105
// Node input has not been converted yet or is a prim op
106106
TORCHTRT_THROW_ERROR(
107107
"Unable to retrieve all node inputs for node: "
108-
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: " << *input_node);
108+
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: %"
109+
<< input->debugName());
109110
}
110111
}
111112

@@ -534,18 +535,22 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
534535
if (unsupported_ops.size() != 0) {
535536
std::stringstream unsupported_msg;
536537
unsupported_msg
537-
<< "Method requested cannot be compiled by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
538+
<< "Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
538539
<< std::endl;
539540
for (auto s : unsupported_ops) {
540541
unsupported_msg << " - " << s.second << std::endl;
541542
}
542-
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
543-
<< std::endl;
544-
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
545-
unsupported_msg << std::endl << "In Module:" << std::endl;
546543

547544
if (!suppress_errors) {
545+
unsupported_msg
546+
<< "You can either implement converters for these ops in your application or request implementation"
547+
<< std::endl;
548+
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
549+
unsupported_msg << std::endl << "In Module:" << std::endl;
550+
548551
LOG_ERROR(unsupported_msg.str());
552+
} else {
553+
LOG_INFO(unsupported_msg.str());
549554
}
550555

551556
std::unordered_map<std::string, std::unordered_set<std::string>> unsupported_node_locations;
@@ -572,7 +577,11 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
572577
traceback << str;
573578
}
574579
auto tb_str = traceback.str();
575-
LOG_ERROR(tb_str);
580+
if (!suppress_errors) {
581+
LOG_ERROR(tb_str);
582+
} else {
583+
LOG_DEBUG(tb_str);
584+
}
576585
}
577586

578587
return false;

core/conversion/evaluators/aten.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
#include <math.h>
2+
13
#include "ATen/core/List.h"
24
#include "ATen/core/functional.h"
35
#include "ATen/core/ivalue.h"
@@ -98,6 +100,17 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
98100
"aten::ge.float_int(float a, int b) -> (bool)",
99101
}));
100102

103+
DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
104+
pow,
105+
"aten::pow",
106+
pow(a, b),
107+
std::set<std::string>({
108+
"aten::pow.int(int a, int b) -> (float)",
109+
"aten::pow.float(float a, float b) -> (float)",
110+
"aten::pow.int_float(int a, float b) -> (float)",
111+
"aten::pow.float_int(float a, int b) -> (float)",
112+
}));
113+
101114
DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
102115
and,
103116
"aten::__and__",

core/conversion/evaluators/eval_macros.h

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,53 @@
7777
}, \
7878
EvalOptions().validSchemas(schemas)});
7979

80+
#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
81+
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82+
{c10::Symbol::fromQualString(node_kind), \
83+
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
84+
if (args.at(n->input(0)).IValue()->isInt()) { \
85+
auto a = args.at(n->input(0)).unwrapToInt(); \
86+
if (args.at(n->input(1)).IValue()->isInt()) { \
87+
auto b = args.at(n->input(1)).unwrapToInt(); \
88+
return operation; \
89+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
90+
auto b = args.at(n->input(1)).unwrapToDouble(); \
91+
return operation; \
92+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
93+
auto b = args.at(n->input(1)).unwrapToBool(); \
94+
return operation; \
95+
} else { \
96+
TORCHTRT_THROW_ERROR( \
97+
"Unimplemented data type for " \
98+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
99+
return {}; \
100+
} \
101+
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
102+
auto a = args.at(n->input(0)).unwrapToDouble(); \
103+
if (args.at(n->input(1)).IValue()->isInt()) { \
104+
auto b = args.at(n->input(1)).unwrapToInt(); \
105+
return operation; \
106+
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
107+
auto b = args.at(n->input(1)).unwrapToDouble(); \
108+
return operation; \
109+
} else if (args.at(n->input(1)).IValue()->isBool()) { \
110+
auto b = args.at(n->input(1)).unwrapToBool(); \
111+
return operation; \
112+
} else { \
113+
TORCHTRT_THROW_ERROR( \
114+
"Unimplemented data type for " \
115+
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
116+
return {}; \
117+
} \
118+
} else { \
119+
TORCHTRT_THROW_ERROR( \
120+
"Unimplemented data type for " \
121+
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
122+
return {}; \
123+
} \
124+
}, \
125+
EvalOptions().validSchemas(schemas)});
126+
80127
#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
81128
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
82129
{c10::Symbol::fromQualString(node_name), \

noxfile.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,3 +322,8 @@ def l2_multi_gpu_tests(session):
322322
def l2_multi_gpu_tests_host_deps(session):
323323
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
324324
run_l2_multi_gpu_tests(session, use_host_env=True)
325+
326+
@nox.session(python=["3"], reuse_venv=True)
327+
def download_test_models(session):
328+
"""Grab all the models needed for testing"""
329+
download_models(session, use_host_env=True)

tests/core/conversion/evaluators/test_aten_evaluators.cpp

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,5 +726,73 @@ TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
726726
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
727727
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
728728

729+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
730+
}
731+
732+
TEST(Evaluators, PowIntEvaluatesCorrectly) {
733+
const auto graph = R"IR(
734+
graph():
735+
%1 : int = prim::Constant[value=9]()
736+
%2 : int = prim::Constant[value=4]()
737+
%3 : float = aten::pow(%1, %2)
738+
return (%3))IR";
739+
740+
auto g = std::make_shared<torch::jit::Graph>();
741+
torch::jit::parseIR(graph, g.get());
742+
743+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
744+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
745+
746+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
747+
}
748+
749+
TEST(Evaluators, PowFloatEvaluatesCorrectly) {
750+
const auto graph = R"IR(
751+
graph():
752+
%1 : float = prim::Constant[value=9.5]()
753+
%2 : float = prim::Constant[value=4.5]()
754+
%3 : float = aten::pow(%1, %2)
755+
return (%3))IR";
756+
757+
auto g = std::make_shared<torch::jit::Graph>();
758+
torch::jit::parseIR(graph, g.get());
759+
760+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
761+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
762+
763+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
764+
}
765+
766+
TEST(Evaluators, PowIntFloatEvaluatesCorrectly) {
767+
const auto graph = R"IR(
768+
graph():
769+
%1 : int = prim::Constant[value=9]()
770+
%2 : float = prim::Constant[value=4.5]()
771+
%3 : float = aten::pow(%1, %2)
772+
return (%3))IR";
773+
774+
auto g = std::make_shared<torch::jit::Graph>();
775+
torch::jit::parseIR(graph, g.get());
776+
777+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
778+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
779+
780+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
781+
}
782+
783+
TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
784+
const auto graph = R"IR(
785+
graph():
786+
%1 : float = prim::Constant[value=9.5]()
787+
%2 : int = prim::Constant[value=4]()
788+
%3 : float = aten::pow(%1, %2)
789+
return (%3))IR";
790+
791+
auto g = std::make_shared<torch::jit::Graph>();
792+
torch::jit::parseIR(graph, g.get());
793+
794+
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
795+
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});
796+
729797
ASSERT_TRUE(jit_results[0] == trt_results[0]);
730798
}

0 commit comments

Comments
 (0)