Skip to content

Commit efc4869

Browse files
committed
[ET-VK] Generalize MeanToSumDiv to any dtype
This change is required for fp16 models. Differential Revision: [D58040777](https://our.internmc.facebook.com/intern/diff/D58040777/) [ghstack-poisoned]
1 parent 68c822f commit efc4869

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

backends/transforms/mean_to_sum_div.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ def call_operator(self, op, args, kwargs, meta):
1919
)
2020
# args[0] is the input tensor
2121
shape = args[0].node.meta["val"].shape
22+
dtype = args[0].node.meta["val"].dtype
2223
dims_to_reduce = args[1]
2324
size = 1.0
2425
for dim in dims_to_reduce:
@@ -32,7 +33,7 @@ def call_operator(self, op, args, kwargs, meta):
3233
],
3334
size,
3435
),
35-
{"dtype": torch.float32},
36+
{"dtype": dtype},
3637
meta,
3738
)
3839

0 commit comments

Comments
 (0)