Skip to content

aten::pow support #918

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 16 additions & 7 deletions core/conversion/conversion.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ void AddLayer(ConversionCtx* ctx, const torch::jit::Node* n) {
// Node input has not been converted yet or is a prim op
TORCHTRT_THROW_ERROR(
"Unable to retrieve all node inputs for node: "
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: " << *input_node);
<< util::node_info(n) << " (ctx.AddLayer)\nSpecifically failed to retrieve value for input: %"
<< input->debugName());
}
}

Expand Down Expand Up @@ -534,18 +535,22 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
if (unsupported_ops.size() != 0) {
std::stringstream unsupported_msg;
unsupported_msg
<< "Method requested cannot be compiled by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
<< "Method requested cannot be compiled end to end by Torch-TensorRT.TorchScript.\nUnsupported operators listed below:"
<< std::endl;
for (auto s : unsupported_ops) {
unsupported_msg << " - " << s.second << std::endl;
}
unsupported_msg << "You can either implement converters for these ops in your application or request implementation"
<< std::endl;
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
unsupported_msg << std::endl << "In Module:" << std::endl;

if (!suppress_errors) {
unsupported_msg
<< "You can either implement converters for these ops in your application or request implementation"
<< std::endl;
unsupported_msg << "https://www.github.com/nvidia/Torch-TensorRT/issues" << std::endl;
unsupported_msg << std::endl << "In Module:" << std::endl;

LOG_ERROR(unsupported_msg.str());
} else {
LOG_INFO(unsupported_msg.str());
}

std::unordered_map<std::string, std::unordered_set<std::string>> unsupported_node_locations;
Expand All @@ -572,7 +577,11 @@ bool VerifyConverterSupportForBlock(const torch::jit::Block* b, bool suppress_er
traceback << str;
}
auto tb_str = traceback.str();
LOG_ERROR(tb_str);
if (!suppress_errors) {
LOG_ERROR(tb_str);
} else {
LOG_DEBUG(tb_str);
}
}

return false;
Expand Down
13 changes: 13 additions & 0 deletions core/conversion/evaluators/aten.cpp
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
#include <math.h>

#include "ATen/core/List.h"
#include "ATen/core/functional.h"
#include "ATen/core/ivalue.h"
Expand Down Expand Up @@ -98,6 +100,17 @@ DEFINE_GENERIC_TWO_INPUT_EVALUATOR(
"aten::ge.float_int(float a, int b) -> (bool)",
}));

DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(
pow,
"aten::pow",
pow(a, b),
std::set<std::string>({
"aten::pow.int(int a, int b) -> (float)",
"aten::pow.float(float a, float b) -> (float)",
"aten::pow.int_float(int a, float b) -> (float)",
"aten::pow.float_int(float a, int b) -> (float)",
}));

DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(
and,
"aten::__and__",
Expand Down
47 changes: 47 additions & 0 deletions core/conversion/evaluators/eval_macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,53 @@
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_ARITHMATIC_TWO_INPUT_EVALUATOR(name, node_kind, operation, schemas) \
auto name##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_kind), \
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> { \
if (args.at(n->input(0)).IValue()->isInt()) { \
auto a = args.at(n->input(0)).unwrapToInt(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else if (args.at(n->input(0)).IValue()->isDouble()) { \
auto a = args.at(n->input(0)).unwrapToDouble(); \
if (args.at(n->input(1)).IValue()->isInt()) { \
auto b = args.at(n->input(1)).unwrapToInt(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isDouble()) { \
auto b = args.at(n->input(1)).unwrapToDouble(); \
return operation; \
} else if (args.at(n->input(1)).IValue()->isBool()) { \
auto b = args.at(n->input(1)).unwrapToBool(); \
return operation; \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator b arg:" << args.at(n->input(1)).IValue()->type()->str()); \
return {}; \
} \
} else { \
TORCHTRT_THROW_ERROR( \
"Unimplemented data type for " \
<< node_kind << " evaluator a arg: " << args.at(n->input(0)).IValue()->type()->str()); \
return {}; \
} \
}, \
EvalOptions().validSchemas(schemas)});

#define DEFINE_TWO_INPUT_SIMPLE_EVALUATOR(node_kind, node_name, operation, type, schemas) \
auto node_kind##_registrations TORCHTRT_UNUSED = RegisterNodeEvaluators().evaluator( \
{c10::Symbol::fromQualString(node_name), \
Expand Down
5 changes: 5 additions & 0 deletions noxfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,3 +322,8 @@ def l2_multi_gpu_tests(session):
def l2_multi_gpu_tests_host_deps(session):
"""Makes sure that Torch-TensorRT can operate on multi-gpu systems using host dependencies"""
run_l2_multi_gpu_tests(session, use_host_env=True)

@nox.session(python=["3"], reuse_venv=True)
def download_test_models(session):
"""Grab all the models needed for testing"""
download_models(session, use_host_env=True)
68 changes: 68 additions & 0 deletions tests/core/conversion/evaluators/test_aten_evaluators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,5 +726,73 @@ TEST(Evaluators, RangeLengthNegEvaluatesCorrectly) {
auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : int = prim::Constant[value=4]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.5]()
%2 : float = prim::Constant[value=4.5]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowIntFloatEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : int = prim::Constant[value=9]()
%2 : float = prim::Constant[value=4.5]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}

TEST(Evaluators, PowFloatIntEvaluatesCorrectly) {
const auto graph = R"IR(
graph():
%1 : float = prim::Constant[value=9.5]()
%2 : int = prim::Constant[value=4]()
%3 : float = aten::pow(%1, %2)
return (%3))IR";

auto g = std::make_shared<torch::jit::Graph>();
torch::jit::parseIR(graph, g.get());

auto jit_results = torch_tensorrt::tests::util::EvaluateGraphJIT(g, {});
auto trt_results = torch_tensorrt::tests::util::EvaluateGraph(g->block(), {});

ASSERT_TRUE(jit_results[0] == trt_results[0]);
}