@@ -159,39 +159,50 @@ def test_diff_no_op():
159
159
def test_diff_prepend_append_py_scalars (sh , axis ):
160
160
get_queue_or_skip ()
161
161
162
- arrs = [
163
- dpt .ones (sh , dtype = "?" ),
164
- dpt .ones (sh , dtype = "i4" ),
165
- dpt .ones (sh , dtype = "f4" ),
166
- dpt .ones (sh , dtype = "c8" ),
167
- ]
168
-
169
- py_zeros = [
170
- False ,
171
- 0 ,
172
- 0.0 ,
173
- complex (0 , 0 ),
174
- ]
175
-
176
- py_ones = [
177
- True ,
178
- 1 ,
179
- 1.0 ,
180
- complex (1 , 0 ),
181
- ]
182
-
183
- for zero , one , arr in zip (py_zeros , py_ones , arrs ):
184
- n = 1
185
- r = dpt .diff (arr , n = n , axis = axis , prepend = zero , append = one )
186
- assert isinstance (r , dpt .usm_ndarray )
187
- assert all (
188
- r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
189
- )
190
- assert r .shape [axis ] == arr .shape [axis ] + 2 - n
191
-
192
- r = dpt .diff (arr , n = n , axis = axis , prepend = zero )
193
- assert isinstance (r , dpt .usm_ndarray )
194
- assert all (
195
- r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis
196
- )
197
- assert r .shape [axis ] == arr .shape [axis ] + 1 - n
162
+ n = 1
163
+
164
+ arr = dpt .ones (sh , dtype = "i4" )
165
+ zero = 0
166
+
167
+ # first and last elements along axis
168
+ # will be checked for correctness
169
+ sl1 = [slice (None )] * arr .ndim
170
+ sl1 [axis ] = slice (1 )
171
+ sl1 = tuple (sl1 )
172
+
173
+ sl2 = [slice (None )] * arr .ndim
174
+ sl2 [axis ] = slice (- 1 , None , None )
175
+ sl2 = tuple (sl2 )
176
+
177
+ r = dpt .diff (arr , axis = axis , prepend = zero , append = zero )
178
+ assert isinstance (r , dpt .usm_ndarray )
179
+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
180
+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
181
+ assert dpt .all (r [sl1 ] == 1 )
182
+ assert dpt .all (r [sl2 ] == - 1 )
183
+
184
+ r = dpt .diff (arr , axis = axis , prepend = zero )
185
+ assert isinstance (r , dpt .usm_ndarray )
186
+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
187
+ assert r .shape [axis ] == arr .shape [axis ] + 1 - n
188
+ assert dpt .all (r [sl1 ] == 1 )
189
+
190
+ r = dpt .diff (arr , axis = axis , append = zero )
191
+ assert isinstance (r , dpt .usm_ndarray )
192
+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
193
+ assert r .shape [axis ] == arr .shape [axis ] + 1 - n
194
+ assert dpt .all (r [sl2 ] == - 1 )
195
+
196
+ r = dpt .diff (arr , axis = axis , prepend = dpt .asarray (zero ), append = zero )
197
+ assert isinstance (r , dpt .usm_ndarray )
198
+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
199
+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
200
+ assert dpt .all (r [sl1 ] == 1 )
201
+ assert dpt .all (r [sl2 ] == - 1 )
202
+
203
+ r = dpt .diff (arr , axis = axis , prepend = zero , append = dpt .asarray (zero ))
204
+ assert isinstance (r , dpt .usm_ndarray )
205
+ assert all (r .shape [i ] == arr .shape [i ] for i in range (arr .ndim ) if i != axis )
206
+ assert r .shape [axis ] == arr .shape [axis ] + 2 - n
207
+ assert dpt .all (r [sl1 ] == 1 )
208
+ assert dpt .all (r [sl2 ] == - 1 )
0 commit comments