Skip to content

Commit 9ee9f86

Browse files
committed
fix:avoid shape of out_tensor of ParametricReLU is all zeros
1 parent a1880d4 commit 9ee9f86

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

core/conversion/converters/impl/activation.cpp

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,11 +82,25 @@ auto acthardtanh TORCHTRT_UNUSED =
8282
.pattern(
8383
{"aten::prelu(Tensor self, Tensor weight) -> (Tensor)",
8484
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
85-
auto in = args[0].ITensor();
86-
auto slopes = args[1].unwrapToTensor();
85+
auto in = args[0].ITensor();
86+
auto slopes = args[1].unwrapToTensor();
87+
auto original_shape = in->getDimensions();
88+
89+
// when the input dim is not equal to the slopes dim,the line output of ParametricReLU will be all zeros.
90+
// since it necessary to avoid the input dim is not equal to the slopes dim.
91+
auto in_shape = util::toVec(original_shape );
92+
auto slopes_shape = slopes.sizes().vec();
93+
if(in_shape.size()!=slopes_shape.size() and slopes_shape.size()==1){
94+
std::vector<int64_t> slopes_new_shape ;
95+
for(size_t i = 0;i<in_shape.size();i++){
96+
slopes_new_shape.push_back(
97+
in_shape[i]==slopes_shape[0]?slopes_shape[0]:1
98+
);
99+
}
100+
slopes = slopes.reshape(slopes_new_shape);
101+
}
87102

88103
bool to_reshape = false;
89-
auto original_shape = in->getDimensions();
90104
if (slopes.numel() != 1 &&
91105
!util::broadcastable(
92106
in->getDimensions(),

0 commit comments

Comments
 (0)