Skip to content

Commit 8239b74

Browse files
committed
Aligns vecdot with array API spec changes
Only negative values for `axis` are permitted to avoid ambiguity Now separately checks that the `axis` parameter is valid for each array before broadcasting occurs
1 parent da065b5 commit 8239b74

File tree

1 file changed

+20
-17
lines changed

1 file changed

+20
-17
lines changed

dpctl/tensor/_linear_algebra_functions.py

Lines changed: 20 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -351,17 +351,19 @@ def vecdot(x1, x2, axis=-1):
351351
x2_nd = x2.ndim
352352
x1_shape = x1.shape
353353
x2_shape = x2.shape
354+
if axis >= 0:
355+
raise ValueError("`axis` must be negative")
356+
axis = operator.index(axis)
357+
x1_axis = normalize_axis_index(axis, x1_nd)
358+
x2_axis = normalize_axis_index(axis, x2_nd)
359+
if x1_shape[x1_axis] != x2_shape[x2_axis]:
360+
raise ValueError(
361+
"given axis must have the same shape for `x1` and `x2`"
362+
)
354363
if x1_nd > x2_nd:
355364
x2_shape = (1,) * (x1_nd - x2_nd) + x2_shape
356-
x2_nd = len(x2_shape)
357365
elif x2_nd > x1_nd:
358366
x1_shape = (1,) * (x2_nd - x1_nd) + x1_shape
359-
x1_nd = len(x1_shape)
360-
axis = normalize_axis_index(operator.index(axis), min(x1_nd, x2_nd))
361-
if x1_shape[axis] != x2_shape[axis]:
362-
raise ValueError(
363-
"given axis must have the same shape for `x1` and `x2`"
364-
)
365367
try:
366368
broadcast_sh = _broadcast_shape_impl(
367369
[
@@ -371,8 +373,10 @@ def vecdot(x1, x2, axis=-1):
371373
)
372374
except ValueError:
373375
raise ValueError("mismatch in `vecdot` dimensions")
376+
broadcast_nd = len(broadcast_sh)
377+
contracted_axis = normalize_axis_index(axis, broadcast_nd)
374378
res_sh = tuple(
375-
[broadcast_sh[i] for i in range(len(broadcast_sh)) if i != axis]
379+
[broadcast_sh[i] for i in range(broadcast_nd) if i != contracted_axis]
376380
)
377381
# type validation
378382
sycl_dev = exec_q.sycl_device
@@ -410,9 +414,8 @@ def vecdot(x1, x2, axis=-1):
410414
x1 = dpt.broadcast_to(x1, broadcast_sh)
411415
if x2.shape != broadcast_sh:
412416
x2 = dpt.broadcast_to(x2, broadcast_sh)
413-
x1 = dpt.moveaxis(x1, axis, -1)
414-
x2 = dpt.moveaxis(x2, axis, -1)
415-
417+
x1 = dpt.moveaxis(x1, contracted_axis, -1)
418+
x2 = dpt.moveaxis(x2, contracted_axis, -1)
416419
out = dpt.empty(
417420
res_sh,
418421
dtype=res_dt,
@@ -455,8 +458,8 @@ def vecdot(x1, x2, axis=-1):
455458
x1 = dpt.broadcast_to(x1, broadcast_sh)
456459
if buf2.shape != broadcast_sh:
457460
buf2 = dpt.broadcast_to(buf2, broadcast_sh)
458-
x1 = dpt.moveaxis(x1, axis, -1)
459-
buf2 = dpt.moveaxis(buf2, axis, -1)
461+
x1 = dpt.moveaxis(x1, contracted_axis, -1)
462+
buf2 = dpt.moveaxis(buf2, contracted_axis, -1)
460463
out = dpt.empty(
461464
res_sh,
462465
dtype=res_dt,
@@ -497,8 +500,8 @@ def vecdot(x1, x2, axis=-1):
497500
buf1 = dpt.broadcast_to(buf1, broadcast_sh)
498501
if x2.shape != broadcast_sh:
499502
x2 = dpt.broadcast_to(x2, broadcast_sh)
500-
buf1 = dpt.moveaxis(buf1, axis, -1)
501-
x2 = dpt.moveaxis(x2, axis, -1)
503+
buf1 = dpt.moveaxis(buf1, contracted_axis, -1)
504+
x2 = dpt.moveaxis(x2, contracted_axis, -1)
502505
out = dpt.empty(
503506
res_sh,
504507
dtype=res_dt,
@@ -544,8 +547,8 @@ def vecdot(x1, x2, axis=-1):
544547
buf1 = dpt.broadcast_to(buf1, broadcast_sh)
545548
if buf2.shape != broadcast_sh:
546549
buf2 = dpt.broadcast_to(buf2, broadcast_sh)
547-
buf1 = dpt.moveaxis(buf1, axis, -1)
548-
buf2 = dpt.moveaxis(buf2, axis, -1)
550+
buf1 = dpt.moveaxis(buf1, contracted_axis, -1)
551+
buf2 = dpt.moveaxis(buf2, contracted_axis, -1)
549552
out = dpt.empty(
550553
res_sh,
551554
dtype=res_dt,

0 commit comments

Comments
 (0)