Skip to content

Commit 804a367

Browse files
committed
Align in-place operation with explicit ufunc call
1 parent 49d4339 commit 804a367

File tree

2 files changed

+26
-23
lines changed

2 files changed

+26
-23
lines changed

dpnp/dpnp_algo/dpnp_elementwise_common.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,20 @@ def __call__(
335335
"as an argument, but both were provided."
336336
)
337337

338+
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
339+
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
340+
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
341+
342+
if (
343+
isinstance(x1, dpnp_array)
344+
and x1 is out
345+
and order == "K"
346+
and dtype is None
347+
):
348+
# in-place operation
349+
super()._inplace_op(x1_usm, x2_usm)
350+
return x1
351+
338352
if order is None:
339353
order = "K"
340354
elif order in "afkcAFKC":
@@ -344,9 +358,6 @@ def __call__(
344358
"order must be one of 'C', 'F', 'A', or 'K' (got '{order}')"
345359
)
346360

347-
x1_usm = dpnp.get_usm_ndarray_or_scalar(x1)
348-
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
349-
350361
if dtype is not None:
351362
if dpnp.isscalar(x1):
352363
x1_usm = dpt.asarray(
@@ -368,20 +379,12 @@ def __call__(
368379
x1_usm = dpt.astype(x1_usm, dtype, copy=False)
369380
x2_usm = dpt.astype(x2_usm, dtype, copy=False)
370381

371-
out_usm = None if out is None else dpnp.get_usm_ndarray(out)
372382
res_usm = super().__call__(x1_usm, x2_usm, out=out_usm, order=order)
373383

374384
if out is not None and isinstance(out, dpnp_array):
375385
return out
376386
return dpnp_array._create_from_usm_ndarray(res_usm)
377387

378-
def _inplace_op(self, x1, x2):
379-
x1_usm = dpnp.get_usm_ndarray(x1)
380-
x2_usm = dpnp.get_usm_ndarray_or_scalar(x2)
381-
382-
super()._inplace_op(x1_usm, x2_usm)
383-
return x1
384-
385388
def outer(
386389
self,
387390
x1,

dpnp/dpnp_array.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -347,22 +347,22 @@ def __gt__(self, other):
347347

348348
def __iadd__(self, other):
349349
"""Return ``self+=value``."""
350-
dpnp.add._inplace_op(self, other)
350+
dpnp.add(self, other, out=self)
351351
return self
352352

353353
def __iand__(self, other):
354354
"""Return ``self&=value``."""
355-
dpnp.bitwise_and._inplace_op(self, other)
355+
dpnp.bitwise_and(self, other, out=self)
356356
return self
357357

358358
def __ifloordiv__(self, other):
359359
"""Return ``self//=value``."""
360-
dpnp.floor_divide._inplace_op(self, other)
360+
dpnp.floor_divide(self, other, out=self)
361361
return self
362362

363363
def __ilshift__(self, other):
364364
"""Return ``self<<=value``."""
365-
dpnp.left_shift._inplace_op(self, other)
365+
dpnp.left_shift(self, other, out=self)
366366
return self
367367

368368
def __imatmul__(self, other):
@@ -393,12 +393,12 @@ def __imatmul__(self, other):
393393

394394
def __imod__(self, other):
395395
"""Return ``self%=value``."""
396-
dpnp.remainder._inplace_op(self, other)
396+
dpnp.remainder(self, other, out=self)
397397
return self
398398

399399
def __imul__(self, other):
400400
"""Return ``self*=value``."""
401-
dpnp.multiply._inplace_op(self, other)
401+
dpnp.multiply(self, other, out=self)
402402
return self
403403

404404
def __index__(self):
@@ -416,22 +416,22 @@ def __invert__(self):
416416

417417
def __ior__(self, other):
418418
"""Return ``self|=value``."""
419-
dpnp.bitwise_or._inplace_op(self, other)
419+
dpnp.bitwise_or(self, other, out=self)
420420
return self
421421

422422
def __ipow__(self, other):
423423
"""Return ``self**=value``."""
424-
dpnp.power._inplace_op(self, other)
424+
dpnp.power(self, other, out=self)
425425
return self
426426

427427
def __irshift__(self, other):
428428
"""Return ``self>>=value``."""
429-
dpnp.right_shift._inplace_op(self, other)
429+
dpnp.right_shift(self, other, out=self)
430430
return self
431431

432432
def __isub__(self, other):
433433
"""Return ``self-=value``."""
434-
dpnp.subtract._inplace_op(self, other)
434+
dpnp.subtract(self, other, out=self)
435435
return self
436436

437437
def __iter__(self):
@@ -442,12 +442,12 @@ def __iter__(self):
442442

443443
def __itruediv__(self, other):
444444
"""Return ``self/=value``."""
445-
dpnp.true_divide._inplace_op(self, other)
445+
dpnp.true_divide(self, other, out=self)
446446
return self
447447

448448
def __ixor__(self, other):
449449
"""Return ``self^=value``."""
450-
dpnp.bitwise_xor._inplace_op(self, other)
450+
dpnp.bitwise_xor(self, other, out=self)
451451
return self
452452

453453
def __le__(self, other):

0 commit comments

Comments
 (0)