Skip to content

Commit cc9e2d1

Browse files
jorgep31415facebook-github-bot
authored andcommitted
Generalize MeanToSumDiv to any dtype (#3794)
Summary: Pull Request resolved: #3794 This change is required for fp16 models. ghstack-source-id: 228674119 Reviewed By: copyrightly Differential Revision: D58040777 fbshipit-source-id: e28f5d285da7b9c0b639671745a341981ac683a8
1 parent bb4f761 commit cc9e2d1

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

backends/transforms/mean_to_sum_div.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
87
from executorch.exir.dialects._ops import ops as exir_ops
98

109
from executorch.exir.pass_base import ExportPass
@@ -19,6 +18,7 @@ def call_operator(self, op, args, kwargs, meta):
1918
)
2019
# args[0] is the input tensor
2120
shape = args[0].node.meta["val"].shape
21+
dtype = args[0].node.meta["val"].dtype
2222
dims_to_reduce = args[1]
2323
size = 1.0
2424
for dim in dims_to_reduce:
@@ -32,7 +32,7 @@ def call_operator(self, op, args, kwargs, meta):
3232
],
3333
size,
3434
),
35-
{"dtype": torch.float32},
35+
{"dtype": dtype},
3636
meta,
3737
)
3838

0 commit comments

Comments
 (0)