File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -334,10 +334,9 @@ def quantized_add_meta(
334
334
out_scale : float ,
335
335
out_zero_point : int ,
336
336
) -> torch .Tensor :
337
- out_size = X .size ()
338
- if list (X .size ()) == [1 ]:
339
- out_size = Y .size ()
340
337
338
+ # Determine output shape by broadcasting X and Y
339
+ out_size = torch .broadcast_shapes (X .size (), Y .size ())
341
340
return X .new_empty (out_size , dtype = X .dtype )
342
341
343
342
@@ -352,10 +351,8 @@ def quantized_add_per_tensor_meta(
352
351
out_scale : float ,
353
352
out_zero_point : int ,
354
353
) -> torch .Tensor :
355
- out_size = X .size ()
356
- if list (X .size ()) == [1 ]:
357
- out_size = Y .size ()
358
354
355
+ out_size = torch .broadcast_shapes (X .size (), Y .size ())
359
356
return X .new_empty (out_size , dtype = X .dtype )
360
357
361
358
You can’t perform that action at this time.
0 commit comments