Skip to content

Commit c3cdd32

Browse files
authored
Merge pull request #299 from NVIDIA/leaky_relu
support leaky_relu converter/test_case
2 parents 83cf1bf + bc53411 commit c3cdd32

File tree

2 files changed

+50
-1
lines changed

2 files changed

+50
-1
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,35 @@ auto acthardtanh TRTORCH_UNUSED =
124124
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
125125
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
126126
return true;
127-
}});
127+
}})
128+
.pattern({"aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)",
129+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
130+
auto self = args[0].ITensorOrFreeze(ctx);
131+
auto negative_slopeScalar = args[1].unwrapToScalar().to<float>();
132+
133+
auto new_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kLEAKY_RELU);
134+
new_layer->setAlpha(negative_slopeScalar);
135+
136+
new_layer->setName(util::node_info(n).c_str());
137+
auto out_tensor = new_layer->getOutput(0);
138+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
139+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
140+
return true;
141+
}})
142+
.pattern({"aten::leaky_relu_(Tensor(a!) self, Scalar negative_slope=0.01) -> Tensor(a!)",
143+
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
144+
auto self = args[0].ITensorOrFreeze(ctx);
145+
auto negative_slopeScalar = args[1].unwrapToScalar().to<float>();
146+
147+
auto new_layer = ctx->net->addActivation(*self, nvinfer1::ActivationType::kLEAKY_RELU);
148+
new_layer->setAlpha(negative_slopeScalar);
149+
150+
new_layer->setName(util::node_info(n).c_str());
151+
auto out_tensor = new_layer->getOutput(0);
152+
out_tensor = ctx->AssociateValueAndTensor(n->outputs()[0], out_tensor);
153+
LOG_DEBUG("Output shape: " << out_tensor->getDimensions());
154+
return true;
155+
}});
128156

129157
} // namespace
130158
} // namespace impl

tests/core/conversion/converters/test_activation.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,3 +156,24 @@ TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
156156

157157
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
158158
}
159+
160+
TEST(Converters, ATenLeakyReluConvertsCorrectly) {
161+
const auto graph = R"IR(
162+
graph(%0 : Tensor):
163+
%1 : float = prim::Constant[value=0.15]()
164+
%2 : Tensor = aten::leaky_relu(%0, %1)
165+
return (%2))IR";
166+
167+
auto g = std::make_shared<torch::jit::Graph>();
168+
torch::jit::parseIR(graph, &*g);
169+
170+
auto in = at::randint(-5, 5, {5}, {at::kCUDA});
171+
auto params = trtorch::core::conversion::get_named_params(g->inputs(), {});
172+
auto jit_results = trtorch::tests::util::RunGraph(g, params, {in});
173+
174+
in = at::clone(in);
175+
params = trtorch::core::conversion::get_named_params(g->inputs(), {});
176+
auto trt_results = trtorch::tests::util::RunGraphEngine(g, params, {in});
177+
178+
ASSERT_TRUE(trtorch::tests::util::almostEqual(jit_results[0], trt_results[0], 2e-6));
179+
}

0 commit comments

Comments
 (0)