@@ -198,13 +198,14 @@ def _define_dim_flags(x, axis):
198
198
"""
199
199
Define useful flags for the calculations in dpnp_matmul and dpnp_vecdot.
200
200
x_is_1D: `x` is 1D array or inherently 1D (all dimensions are equal to one
201
- except for one of them ), for instance, if x.shape = (1, 1, 1, 2),
202
- then x_is_1D = True
201
+ except for dimension at `axis` ), for instance, if x.shape = (1, 1, 1, 2),
202
+ and axis=-1, then x_is_1D = True.
203
203
x_is_2D: `x` is 2D array or inherently 2D (all dimensions are equal to one
204
204
except for the last two of them), for instance, if x.shape = (1, 1, 3, 2),
205
- then x_is_2D = True
205
+ then x_is_2D = True.
206
206
x_base_is_1D: `x` is 1D considering only its last two dimensions, for instance,
207
- if x.shape = (3, 4, 1, 2), then x_base_is_1D = True
207
+ if x.shape = (3, 4, 1, 2), then x_base_is_1D = True.
208
+
208
209
"""
209
210
210
211
x_shape = x .shape
@@ -331,11 +332,11 @@ def _get_result_shape_vecdot(x1, x2, x1_ndim, x2_ndim):
331
332
332
333
if x1_ndim == 1 and x2_ndim == 1 :
333
334
result_shape = ()
334
- elif x1_is_1D :
335
+ elif x1_is_1D and not x2_is_1D :
335
336
result_shape = x2_shape [:- 1 ]
336
- elif x2_is_1D :
337
+ elif x2_is_1D and not x1_is_1D :
337
338
result_shape = x1_shape [:- 1 ]
338
- else : # at least 2D
339
+ else :
339
340
if x1_ndim != x2_ndim :
340
341
diff = abs (x1_ndim - x2_ndim )
341
342
if x1_ndim < x2_ndim :
0 commit comments