Skip to content

Commit 7add8ff

Browse files
authored
Improve implementation of dpnp.kron to avoid unnecessary copy for non-contiguous arrays (#2059) (#2064)
* remove continuity check in dpnp.kron * add additional assert check * update CHANGELOG.md
1 parent e2a6c6d commit 7add8ff

File tree

3 files changed

+31
-6
lines changed

3 files changed

+31
-6
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ In addition, this release completes implementation of `dpnp.fft` module and adds
104104
* `dpnp` uses pybind11 2.13.6 [#2041](https://github.com/IntelPython/dpnp/pull/2041)
105105
* Updated `dpnp.fft` backend to depend on `INTEL_MKL_VERSION` flag to ensures that the appropriate code segment is executed based on the version of OneMKL [#2035](https://github.com/IntelPython/dpnp/pull/2035)
106106
* Use `dpctl::tensor::alloc_utils::sycl_free_noexcept` instead of `sycl::free` in `host_task` tasks associated with life-time management of temporary USM allocations [#2058](https://github.com/IntelPython/dpnp/pull/2058)
107+
* Improved implementation of `dpnp.kron` to avoid unnecessary copy for non-contiguous arrays [#2059](https://github.com/IntelPython/dpnp/pull/2059)
107108

108109
### Fixed
109110

dpnp/dpnp_utils/dpnp_utils_linearalgebra.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -676,10 +676,6 @@ def dpnp_kron(a, b, a_ndim, b_ndim):
676676

677677
a_shape = a.shape
678678
b_shape = b.shape
679-
if not a.flags.contiguous:
680-
a = dpnp.reshape(a, a_shape)
681-
if not b.flags.contiguous:
682-
b = dpnp.reshape(b, b_shape)
683679

684680
# Equalise the shapes by prepending smaller one with 1s
685681
a_shape = (1,) * max(0, b_ndim - a_ndim) + a_shape
@@ -693,7 +689,7 @@ def dpnp_kron(a, b, a_ndim, b_ndim):
693689
ndim = max(b_ndim, a_ndim)
694690
a_arr = dpnp.expand_dims(a_arr, axis=tuple(range(1, 2 * ndim, 2)))
695691
b_arr = dpnp.expand_dims(b_arr, axis=tuple(range(0, 2 * ndim, 2)))
696-
result = dpnp.multiply(a_arr, b_arr, order="C")
692+
result = dpnp.multiply(a_arr, b_arr)
697693

698694
# Reshape back
699695
return result.reshape(tuple(numpy.multiply(a_shape, b_shape)))

tests/test_product.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ def test_kron_input_dtype_matrix(self, dtype1, dtype2):
741741
@pytest.mark.parametrize(
742742
"stride", [3, -1, -2, -4], ids=["3", "-1", "-2", "-4"]
743743
)
744-
def test_kron_strided(self, dtype, stride):
744+
def test_kron_strided1(self, dtype, stride):
745745
a = numpy.arange(20, dtype=dtype)
746746
b = numpy.arange(20, dtype=dtype)
747747
ia = dpnp.array(a)
@@ -751,6 +751,34 @@ def test_kron_strided(self, dtype, stride):
751751
expected = numpy.kron(a[::stride], b[::stride])
752752
assert_dtype_allclose(result, expected)
753753

754+
@pytest.mark.parametrize("stride", [2, -1, -2], ids=["2", "-1", "-2"])
755+
def test_kron_strided2(self, stride):
756+
a = numpy.arange(48).reshape(6, 8)
757+
b = numpy.arange(480).reshape(6, 8, 10)
758+
ia = dpnp.array(a)
759+
ib = dpnp.array(b)
760+
761+
result = dpnp.kron(
762+
ia[::stride, ::stride], ib[::stride, ::stride, ::stride]
763+
)
764+
expected = numpy.kron(
765+
a[::stride, ::stride], b[::stride, ::stride, ::stride]
766+
)
767+
assert_dtype_allclose(result, expected)
768+
769+
@pytest.mark.parametrize("order", ["C", "F", "A"])
770+
def test_kron_order(self, order):
771+
a = numpy.arange(48).reshape(6, 8, order=order)
772+
b = numpy.arange(480).reshape(6, 8, 10, order=order)
773+
ia = dpnp.array(a)
774+
ib = dpnp.array(b)
775+
776+
result = dpnp.kron(ia, ib)
777+
expected = numpy.kron(a, b)
778+
assert result.flags["C_CONTIGUOUS"] == expected.flags["C_CONTIGUOUS"]
779+
assert result.flags["F_CONTIGUOUS"] == expected.flags["F_CONTIGUOUS"]
780+
assert_dtype_allclose(result, expected)
781+
754782

755783
class TestMultiDot:
756784
def setup_method(self):

0 commit comments

Comments
 (0)