@@ -741,7 +741,7 @@ def test_kron_input_dtype_matrix(self, dtype1, dtype2):
741
741
@pytest .mark .parametrize (
742
742
"stride" , [3 , - 1 , - 2 , - 4 ], ids = ["3" , "-1" , "-2" , "-4" ]
743
743
)
744
- def test_kron_strided (self , dtype , stride ):
744
+ def test_kron_strided1 (self , dtype , stride ):
745
745
a = numpy .arange (20 , dtype = dtype )
746
746
b = numpy .arange (20 , dtype = dtype )
747
747
ia = dpnp .array (a )
@@ -751,6 +751,32 @@ def test_kron_strided(self, dtype, stride):
751
751
expected = numpy .kron (a [::stride ], b [::stride ])
752
752
assert_dtype_allclose (result , expected )
753
753
754
+ @pytest .mark .parametrize ("stride" , [2 , - 1 , - 2 ], ids = ["2" , "-1" , "-2" ])
755
+ def test_kron_strided2 (self , stride ):
756
+ a = numpy .arange (48 ).reshape (6 , 8 )
757
+ b = numpy .arange (480 ).reshape (6 , 8 , 10 )
758
+ ia = dpnp .array (a )
759
+ ib = dpnp .array (b )
760
+
761
+ result = dpnp .kron (
762
+ ia [::stride , ::stride ], ib [::stride , ::stride , ::stride ]
763
+ )
764
+ expected = numpy .kron (
765
+ a [::stride , ::stride ], b [::stride , ::stride , ::stride ]
766
+ )
767
+ assert_dtype_allclose (result , expected )
768
+
769
+ @pytest .mark .parametrize ("order" , ["C" , "F" , "A" ])
770
+ def test_kron_order (self , order ):
771
+ a = numpy .arange (48 ).reshape (6 , 8 , order = order )
772
+ b = numpy .arange (480 ).reshape (6 , 8 , 10 , order = order )
773
+ ia = dpnp .array (a )
774
+ ib = dpnp .array (b )
775
+
776
+ result = dpnp .kron (ia , ib )
777
+ expected = numpy .kron (a , b )
778
+ assert_dtype_allclose (result , expected )
779
+
754
780
755
781
class TestMultiDot :
756
782
def setup_method (self ):
0 commit comments