Skip to content

Commit 6bbcecb

Browse files
inocsinnarendasan
authored andcommitted
support leaky_relu converter/test_case
Signed-off-by: inocsin <[email protected]>
1 parent 885439c commit 6bbcecb

File tree

2 files changed

+37
-0
lines changed

2 files changed

+37
-0
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,11 +121,27 @@ auto acthardtanh TRTORCH_UNUSED =
121121
out_tensor = out_shuffle->getOutput(0);
122122
}
123123

124+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
125+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
126+
return true;
127+
}})
128+
.pattern(
129+
{"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)",
130+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
131+
auto self = args[0].ITensorOrFreeze(ctx);
132+
auto negative_slopeScalar = args[1].unwrapToScalar().to<float>();
133+
134+
auto new_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kLEAKY_RELU);
135+
new_layer->setAlpha(negative_slopeScalar);
136+
137+
new_layer->setName(util::node_info(n).c_str());
138+
auto out_tensor = new_layer->getOutput(0);
124139
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
125140
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
126141
return true;
127142
}});
128143

144+
129145
} // namespace
130146
} // namespace impl
131147
} // namespace converters

tests/core/conversion/converters/test_activation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,24 @@ TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
156156

157157
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
158158
}
159+
160+
TEST(Converters, ATenLeakyReluConvertsCorrectly) {
161+
const auto graph = R"IR(
162+
graph(%0 : Tensor):
163+
%1 : float = prim::Constant[value=0.15]()
164+
%2 : Tensor = aten::leaky_relu(%0, %1)
165+
return (%2))IR";
166+
167+
auto g = std::make_shared<torch::jit::Graph>();
168+
torch::jit::parseIR(graph, &*g);
169+
170+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
171+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
172+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
173+
174+
in = at::clone(in);
175+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
176+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
177+
178+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
179+
}

0 commit comments

Comments
 (0)