Skip to content

Commit 726bd35

Browse files
committed
[ET-VK] Generalize MeanToSumDiv to any dtype
Pull Request resolved: #3794 This change is required for fp16 models. Differential Revision: [D58040777](https://our.internmc.facebook.com/intern/diff/D58040777/) ghstack-source-id: 228674119
1 parent 13ba3a7 commit 726bd35

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

backends/transforms/mean_to_sum_div.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import torch
87
from executorch.exir.dialects._ops import ops as exir_ops
98

109
from executorch.exir.pass_base import ExportPass
@@ -19,6 +18,7 @@ def call_operator(self, op, args, kwargs, meta):
1918
)
2019
# args[0] is the input tensor
2120
shape = args[0].node.meta["val"].shape
21+
dtype = args[0].node.meta["val"].dtype
2222
dims_to_reduce = args[1]
2323
size = 1.0
2424
for dim in dims_to_reduce:
@@ -32,7 +32,7 @@ def call_operator(self, op, args, kwargs, meta):
3232
],
3333
size,
3434
),
35-
{"dtype": torch.float32},
35+
{"dtype": dtype},
3636
meta,
3737
)
3838

0 commit comments

Comments
 (0)