File tree Expand file tree Collapse file tree 1 file changed +17
-3
lines changed
core/conversion/converters/impl Expand file tree Collapse file tree 1 file changed +17
-3
lines changed Original file line number Diff line number Diff line change @@ -82,11 +82,25 @@ auto acthardtanh TORCHTRT_UNUSED =
82
82
.pattern(
83
83
{" aten::prelu(Tensor self, Tensor weight) -> (Tensor)" ,
84
84
[](ConversionCtx* ctx, const torch::jit::Node* n, args& args) -> bool {
85
- auto in = args[0 ].ITensor ();
86
- auto slopes = args[1 ].unwrapToTensor ();
85
+ auto in = args[0 ].ITensor ();
86
+ auto slopes = args[1 ].unwrapToTensor ();
87
+ auto original_shape = in->getDimensions ();
88
+
89
+ // when the input dim is not equal to the slopes dim,the line output of ParametricReLU will be all zeros.
90
+ // since it necessary to avoid the input dim is not equal to the slopes dim.
91
+ auto in_shape = util::toVec (original_shape );
92
+ auto slopes_shape = slopes.sizes ().vec ();
93
+ if (in_shape.size ()!=slopes_shape.size () and slopes_shape.size ()==1 ){
94
+ std::vector<int64_t > slopes_new_shape ;
95
+ for (size_t i = 0 ;i<in_shape.size ();i++){
96
+ slopes_new_shape.push_back (
97
+ in_shape[i]==slopes_shape[0 ]?slopes_shape[0 ]:1
98
+ );
99
+ }
100
+ slopes = slopes.reshape (slopes_new_shape);
101
+ }
87
102
88
103
bool to_reshape = false ;
89
- auto original_shape = in->getDimensions ();
90
104
if (slopes.numel () != 1 &&
91
105
!util::broadcastable (
92
106
in->getDimensions (),
You can’t perform that action at this time.
0 commit comments