Skip to content

Commit 23f8e9d

Browse files
authored
Merge pull request #310 from inocsin/div_scalar
support div.Scalar(Tensor self, Scalar other)
2 parents b228bf2 + 071c1d6 commit 23f8e9d

File tree

5 files changed

+84
-2
lines changed

5 files changed

+84
-2
lines changed

core/conversion/converters/impl/element_wise.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,20 @@ auto element_wise_registrations TRTORCH_UNUSED =
213213
div->setName(util::node_info(n).c_str());
214214
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
215215

216+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
217+
return true;
218+
}})
219+
.pattern({"aten::div.Scalar(Tensor self, Scalar other) -> (Tensor)",
220+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
221+
auto self = args[0].ITensorOrFreeze(ctx);
222+
auto otherScalar = args[1].unwrapToScalar().to<float>();
223+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
224+
auto div =
225+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
226+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
227+
228+
div->setName(util::node_info(n).c_str());
229+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
216230
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
217231
return true;
218232
}})
@@ -229,6 +243,20 @@ auto element_wise_registrations TRTORCH_UNUSED =
229243
div->setName(util::node_info(n).c_str());
230244
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
231245

246+
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
247+
return true;
248+
}})
249+
.pattern({"aten::div_.Scalar(Tensor self, Scalar other) -> (Tensor)",
250+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
251+
auto self = args[0].ITensorOrFreeze(ctx);
252+
auto otherScalar = args[1].unwrapToScalar().to<float>();
253+
auto other = tensor_to_const(ctx, torch::tensor({otherScalar}));
254+
auto div =
255+
add_elementwise(ctx, nvinfer1::ElementWiseOperation::kDIV, self, other, util::node_info(n));
256+
TRTORCH_CHECK(div, "Unable to create div layer from node: " << *n);
257+
258+
div->setName(util::node_info(n).c_str());
259+
auto out = ctx->AssociateValueAndTensor(n->outputs()[0], div->getOutput(0));
232260
LOG_DEBUG("Output tensor shape: " << out->getDimensions());
233261
return true;
234262
}})

core/conversion/evaluators/aten.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@ auto aten_registrations TRTORCH_UNUSED =
426426
}
427427
},
428428
EvalOptions().validSchemas({
429-
"aten::div.Scalar(Scalar a, Scalar b) -> (float)",
429+
"aten::div.float(float a, float b) -> (float)",
430+
"aten::div.int(int a, int b) -> (float)",
430431
})})
431432
.evaluator({c10::Symbol::fromQualString("aten::floordiv"),
432433
[](const torch::jit::Node* n, kwargs& args) -> c10::optional<torch::jit::IValue> {

tests/core/conversion/converters/test_element_wise.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,15 @@ TEST(Converters, ATenDivConvertsCorrectly) {
123123
pointwise_test_helper(graph, false, true, {4, 3}, {3, 4, 3});
124124
}
125125

126+
TEST(Converters, ATenDivWithScalarConvertsCorrectly) {
127+
const auto graph = R"IR(
128+
graph(%0 : Tensor):
129+
%scalar : float = prim::Constant[value=2.4]()
130+
%1 : Tensor = aten::div(%0, %scalar)
131+
return (%1))IR";
132+
pointwise_test_helper(graph, true);
133+
}
134+
126135
TEST(Converters, ATenPowTensorConvertsCorrectly) {
127136
const auto graph = R"IR(
128137
graph(%x.1 : Tensor, %x2.1 : Tensor):

tests/core/conversion/evaluators/BUILD

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,14 @@ evaluator_test(
1111
name = "test_prim_evaluators",
1212
)
1313

14+
evaluator_test(
15+
name = "test_aten_evaluators",
16+
)
17+
1418
test_suite(
1519
name = "evaluator_tests",
1620
tests = [
17-
":test_prim_evaluators"
21+
":test_prim_evaluators",
22+
":test_aten_evaluators"
1823
]
1924
)
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
#include <string>
2+
#include "core/compiler.h"
3+
#include "gtest/gtest.h"
4+
#include "tests/util/util.h"
5+
#include "torch/csrc/jit/ir/irparser.h"
6+
7+
TEST(Evaluators, DivIntEvaluatesCorrectly) {
8+
const auto graph = R"IR(
9+
graph():
10+
%1 : int = prim::Constant[value=9]()
11+
%2 : int = prim::Constant[value=4]()
12+
%3 : float = aten::div(%1, %2)
13+
return (%3))IR";
14+
15+
auto g = std::make_shared<torch::jit::Graph>();
16+
torch::jit::parseIR(graph, &*g);
17+
18+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
19+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
20+
21+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
22+
}
23+
24+
TEST(Evaluators, DivFloatEvaluatesCorrectly) {
25+
const auto graph = R"IR(
26+
graph():
27+
%1 : float = prim::Constant[value=9.1]()
28+
%2 : float = prim::Constant[value=4.2]()
29+
%3 : float = aten::div(%1, %2)
30+
return (%3))IR";
31+
32+
auto g = std::make_shared<torch::jit::Graph>();
33+
torch::jit::parseIR(graph, &*g);
34+
35+
auto jit_results = trtorch::tests::util::EvaluateGraphJIT(g, {});
36+
auto trt_results = trtorch::tests::util::EvaluateGraph(g->block(), {});
37+
38+
ASSERT_TRUE(jit_results[0] == trt_results[0]);
39+
}

0 commit comments

Comments
 (0)