File tree Expand file tree Collapse file tree 2 files changed +37
-0
lines changed
core/conversion/converters/impl
tests/core/conversion/converters Expand file tree Collapse file tree 2 files changed +37
-0
lines changed Original file line number Diff line number Diff line change @@ -121,11 +121,27 @@ auto acthardtanh TRTORCH_UNUSED =
121
121
out_tensor = out_shuffle->getOutput (0 );
122
122
}
123
123
124
+ out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
125
+ LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
126
+ return true ;
127
+ }})
128
+ .pattern(
129
+ {" aten::leaky_relu(Tensor self, Scalar negative_slope=0.01) -> (Tensor)" ,
130
+ [](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
131
+ auto self = args[0 ].ITensorOrFreeze (ctx);
132
+ auto negative_slopeScalar = args[1 ].unwrapToScalar ().to <float >();
133
+
134
+ auto new_layer = ctx->net ->addActivation (*self, nvinfer1::ActivationType::kLEAKY_RELU );
135
+ new_layer->setAlpha (negative_slopeScalar);
136
+
137
+ new_layer->setName (util::node_info (n).c_str ());
138
+ auto out_tensor = new_layer->getOutput (0 );
124
139
out_tensor = ctx->AssociateValueAndTensor (n->outputs ()[0 ], out_tensor);
125
140
LOG_DEBUG (" Output shape: " << out_tensor->getDimensions ());
126
141
return true ;
127
142
}});
128
143
144
+
129
145
} // namespace
130
146
} // namespace impl
131
147
} // namespace converters
Original file line number Diff line number Diff line change @@ -156,3 +156,24 @@ TEST(Converters, ATenPReLUMultiChannelConvertsCorrectly) {
156
156
157
157
ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
158
158
}
159
+
160
+ TEST (Converters, ATenLeakyReluConvertsCorrectly) {
161
+ const auto graph = R"IR(
162
+ graph(%0 : Tensor):
163
+ %1 : float = prim::Constant[value=0.15]()
164
+ %2 : Tensor = aten::leaky_relu(%0, %1)
165
+ return (%2))IR" ;
166
+
167
+ auto g = std::make_shared<torch::jit::Graph>();
168
+ torch::jit::parseIR (graph, &*g);
169
+
170
+ auto in = at::randint (-5 , 5 , {5 }, {at::kCUDA });
171
+ auto params = trtorch::core::conversion::get_named_params (g->inputs (), {});
172
+ auto jit_results = trtorch::tests::util::RunGraph (g, params, {in});
173
+
174
+ in = at::clone (in);
175
+ params = trtorch::core::conversion::get_named_params (g->inputs (), {});
176
+ auto trt_results = trtorch::tests::util::RunGraphEngine (g, params, {in});
177
+
178
+ ASSERT_TRUE (trtorch::tests::util::almostEqual (jit_results[0 ], trt_results[0 ], 2e-6 ));
179
+ }
You can’t perform that action at this time.
0 commit comments