@@ -351,17 +351,19 @@ def vecdot(x1, x2, axis=-1):
351
351
x2_nd = x2 .ndim
352
352
x1_shape = x1 .shape
353
353
x2_shape = x2 .shape
354
+ if axis >= 0 :
355
+ raise ValueError ("`axis` must be negative" )
356
+ axis = operator .index (axis )
357
+ x1_axis = normalize_axis_index (axis , x1_nd )
358
+ x2_axis = normalize_axis_index (axis , x2_nd )
359
+ if x1_shape [x1_axis ] != x2_shape [x2_axis ]:
360
+ raise ValueError (
361
+ "given axis must have the same shape for `x1` and `x2`"
362
+ )
354
363
if x1_nd > x2_nd :
355
364
x2_shape = (1 ,) * (x1_nd - x2_nd ) + x2_shape
356
- x2_nd = len (x2_shape )
357
365
elif x2_nd > x1_nd :
358
366
x1_shape = (1 ,) * (x2_nd - x1_nd ) + x1_shape
359
- x1_nd = len (x1_shape )
360
- axis = normalize_axis_index (operator .index (axis ), min (x1_nd , x2_nd ))
361
- if x1_shape [axis ] != x2_shape [axis ]:
362
- raise ValueError (
363
- "given axis must have the same shape for `x1` and `x2`"
364
- )
365
367
try :
366
368
broadcast_sh = _broadcast_shape_impl (
367
369
[
@@ -371,8 +373,10 @@ def vecdot(x1, x2, axis=-1):
371
373
)
372
374
except ValueError :
373
375
raise ValueError ("mismatch in `vecdot` dimensions" )
376
+ broadcast_nd = len (broadcast_sh )
377
+ contracted_axis = normalize_axis_index (axis , broadcast_nd )
374
378
res_sh = tuple (
375
- [broadcast_sh [i ] for i in range (len ( broadcast_sh )) if i != axis ]
379
+ [broadcast_sh [i ] for i in range (broadcast_nd ) if i != contracted_axis ]
376
380
)
377
381
# type validation
378
382
sycl_dev = exec_q .sycl_device
@@ -410,9 +414,8 @@ def vecdot(x1, x2, axis=-1):
410
414
x1 = dpt .broadcast_to (x1 , broadcast_sh )
411
415
if x2 .shape != broadcast_sh :
412
416
x2 = dpt .broadcast_to (x2 , broadcast_sh )
413
- x1 = dpt .moveaxis (x1 , axis , - 1 )
414
- x2 = dpt .moveaxis (x2 , axis , - 1 )
415
-
417
+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
418
+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
416
419
out = dpt .empty (
417
420
res_sh ,
418
421
dtype = res_dt ,
@@ -455,8 +458,8 @@ def vecdot(x1, x2, axis=-1):
455
458
x1 = dpt .broadcast_to (x1 , broadcast_sh )
456
459
if buf2 .shape != broadcast_sh :
457
460
buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
458
- x1 = dpt .moveaxis (x1 , axis , - 1 )
459
- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
461
+ x1 = dpt .moveaxis (x1 , contracted_axis , - 1 )
462
+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
460
463
out = dpt .empty (
461
464
res_sh ,
462
465
dtype = res_dt ,
@@ -497,8 +500,8 @@ def vecdot(x1, x2, axis=-1):
497
500
buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
498
501
if x2 .shape != broadcast_sh :
499
502
x2 = dpt .broadcast_to (x2 , broadcast_sh )
500
- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
501
- x2 = dpt .moveaxis (x2 , axis , - 1 )
503
+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
504
+ x2 = dpt .moveaxis (x2 , contracted_axis , - 1 )
502
505
out = dpt .empty (
503
506
res_sh ,
504
507
dtype = res_dt ,
@@ -544,8 +547,8 @@ def vecdot(x1, x2, axis=-1):
544
547
buf1 = dpt .broadcast_to (buf1 , broadcast_sh )
545
548
if buf2 .shape != broadcast_sh :
546
549
buf2 = dpt .broadcast_to (buf2 , broadcast_sh )
547
- buf1 = dpt .moveaxis (buf1 , axis , - 1 )
548
- buf2 = dpt .moveaxis (buf2 , axis , - 1 )
550
+ buf1 = dpt .moveaxis (buf1 , contracted_axis , - 1 )
551
+ buf2 = dpt .moveaxis (buf2 , contracted_axis , - 1 )
549
552
out = dpt .empty (
550
553
res_sh ,
551
554
dtype = res_dt ,
0 commit comments