Skip to content

Commit 8953279

Browse files
authored
Broadcast implementation in quantized_add
Differential Revision: D74773433 Pull Request resolved: #10903
1 parent fd87e98 commit 8953279

File tree

1 file changed

+3
-6
lines changed

1 file changed

+3
-6
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -334,10 +334,9 @@ def quantized_add_meta(
334334
out_scale: float,
335335
out_zero_point: int,
336336
) -> torch.Tensor:
337-
out_size = X.size()
338-
if list(X.size()) == [1]:
339-
out_size = Y.size()
340337

338+
# Determine output shape by broadcasting X and Y
339+
out_size = torch.broadcast_shapes(X.size(), Y.size())
341340
return X.new_empty(out_size, dtype=X.dtype)
342341

343342

@@ -352,10 +351,8 @@ def quantized_add_per_tensor_meta(
352351
out_scale: float,
353352
out_zero_point: int,
354353
) -> torch.Tensor:
355-
out_size = X.size()
356-
if list(X.size()) == [1]:
357-
out_size = Y.size()
358354

355+
out_size = torch.broadcast_shapes(X.size(), Y.size())
359356
return X.new_empty(out_size, dtype=X.dtype)
360357

361358

0 commit comments

Comments
 (0)