@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc():
247
247
assert xval .__getitem__ (slices ).shape == val .shape
248
248
249
249
250
- @ pytest . mark . parametrize (
251
- "x, s, idx, x_val, s_val" ,
252
- [
253
- (
254
- vector (),
255
- ( iscalar (), ),
256
- ( 1 ,),
257
- np . array ([ 1 , 2 ], dtype = config . floatX ),
258
- np .array ([2 ], dtype = np . int64 ),
259
- ),
260
- (
261
- matrix (),
262
- ( iscalar (), iscalar () ),
263
- ( 1 , ),
264
- np . array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
265
- np .array ([2 , 2 ], dtype = np . int64 ),
266
- ),
267
- (
268
- matrix (),
269
- ( iscalar (), iscalar () ),
270
- ( 0 , ),
271
- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
272
- np .array ([2 , 3 ], dtype = np . int64 ),
273
- ),
274
- (
275
- matrix (),
276
- ( iscalar (), iscalar () ),
277
- ( 1 , 1 ),
278
- np . array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
279
- np .array ([2 , 3 ], dtype = np . int64 ),
280
- ),
281
- (
282
- tensor3 (),
283
- ( iscalar (), iscalar (), iscalar () ),
284
- ( - 1 , ),
285
- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
286
- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
287
- ),
288
- (
289
- tensor3 (),
290
- ( iscalar (), iscalar (), iscalar () ),
291
- ( - 1 , 0 ),
292
- np . arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
293
- np .array ([ 2 , 3 , 5 ] , dtype = np . int64 ),
294
- ),
295
- ] ,
296
- )
297
- def test_local_subtensor_SpecifyShape_lift ( x , s , idx , x_val , s_val ):
298
- y = specify_shape ( x , s )[ idx ]
299
- assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
300
-
301
- rewrites = RewriteDatabaseQuery ( include = [ None ])
302
- no_rewrites_mode = Mode ( optimizer = rewrites )
303
-
304
- y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
305
- y_val = y_val_fn ( * ([ x_val , * s_val ]) )
306
-
307
- # This optimization should appear in the canonicalizations
308
- y_opt = rewrite_graph ( y , clone = False )
309
-
310
- if y . ndim == 0 :
311
- # SpecifyShape should be removed altogether
312
- assert isinstance ( y_opt . owner . op , Subtensor )
313
- assert y_opt .owner .inputs [ 0 ] is x
314
- else :
315
- assert isinstance ( y_opt . owner . op , SpecifyShape )
316
-
317
- y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
318
- y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]) )
319
-
320
- assert np . allclose ( y_val , y_opt_val )
321
-
322
-
323
- @pytest .mark .parametrize (
324
- "x, s, idx" ,
325
- [
326
- (
327
- matrix (),
328
- (iscalar (), iscalar ()),
329
- (slice (1 , None ),),
330
- ),
331
- (
332
- matrix (),
333
- (iscalar (), iscalar ()),
334
- (slicetype (),),
335
- ),
336
- (
337
- matrix (),
338
- (iscalar (), iscalar ()),
339
- (1 , 0 ),
340
- ),
341
- ],
342
- )
343
- def test_local_subtensor_SpecifyShape_lift_fail (x , s , idx ):
344
- y = specify_shape (x , s )[idx ]
345
-
346
- # This optimization should appear in the canonicalizations
347
- y_opt = rewrite_graph (y , clone = False )
348
-
349
- assert not isinstance (y_opt .owner .op , SpecifyShape )
250
+ class TestLocalSubtensorSpecifyShapeLift :
251
+ @ pytest . mark . parametrize (
252
+ "x, s, idx, x_val, s_val" ,
253
+ [
254
+ (
255
+ vector ( ),
256
+ ( iscalar () ,),
257
+ ( 1 , ),
258
+ np .array ([1 , 2 ], dtype = config . floatX ),
259
+ np . array ([ 2 ], dtype = np . int64 ),
260
+ ),
261
+ (
262
+ matrix ( ),
263
+ ( iscalar (), iscalar () ),
264
+ ( 1 , ),
265
+ np .array ([[ 1 , 2 ], [ 3 , 4 ]], dtype = config . floatX ),
266
+ np . array ([ 2 , 2 ], dtype = np . int64 ),
267
+ ),
268
+ (
269
+ matrix ( ),
270
+ ( iscalar (), iscalar () ),
271
+ ( 0 , ),
272
+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
273
+ np . array ([ 2 , 3 ], dtype = np . int64 ),
274
+ ),
275
+ (
276
+ matrix ( ),
277
+ ( iscalar (), iscalar () ),
278
+ ( 1 , 1 ),
279
+ np .array ([[ 1 , 2 , 3 ], [ 4 , 5 , 6 ]], dtype = config . floatX ),
280
+ np . array ([ 2 , 3 ], dtype = np . int64 ),
281
+ ),
282
+ (
283
+ tensor3 ( ),
284
+ ( iscalar (), iscalar (), iscalar () ),
285
+ ( - 1 , ),
286
+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
287
+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
288
+ ),
289
+ (
290
+ tensor3 ( ),
291
+ ( iscalar (), iscalar (), iscalar () ),
292
+ ( - 1 , 0 ),
293
+ np .arange ( 2 * 3 * 5 , dtype = config . floatX ). reshape (( 2 , 3 , 5 ) ),
294
+ np . array ([ 2 , 3 , 5 ], dtype = np . int64 ),
295
+ ) ,
296
+ ],
297
+ )
298
+ def test_local_subtensor_SpecifyShape_lift ( self , x , s , idx , x_val , s_val ):
299
+ y = specify_shape ( x , s )[ idx ]
300
+ assert isinstance ( y . owner . inputs [ 0 ]. owner . op , SpecifyShape )
301
+
302
+ rewrites = RewriteDatabaseQuery ( include = [ None ] )
303
+ no_rewrites_mode = Mode ( optimizer = rewrites )
304
+
305
+ y_val_fn = function ([ x , * s ], y , on_unused_input = "ignore" , mode = no_rewrites_mode )
306
+ y_val = y_val_fn ( * ([ x_val , * s_val ]))
307
+
308
+ # This optimization should appear in the canonicalizations
309
+ y_opt = rewrite_graph ( y , clone = False )
310
+
311
+ if y . ndim == 0 :
312
+ # SpecifyShape should be removed altogether
313
+ assert isinstance ( y_opt .owner .op , Subtensor )
314
+ assert y_opt . owner . inputs [ 0 ] is x
315
+ else :
316
+ assert isinstance ( y_opt . owner . op , SpecifyShape )
317
+
318
+ y_opt_fn = function ([ x , * s ], y_opt , on_unused_input = "ignore" )
319
+ y_opt_val = y_opt_fn ( * ([ x_val , * s_val ]))
320
+
321
+ assert np . allclose ( y_val , y_opt_val )
322
+
323
+ @pytest .mark .parametrize (
324
+ "x, s, idx" ,
325
+ [
326
+ (
327
+ matrix (),
328
+ (iscalar (), iscalar ()),
329
+ (slice (1 , None ),),
330
+ ),
331
+ (
332
+ matrix (),
333
+ (iscalar (), iscalar ()),
334
+ (slicetype (),),
335
+ ),
336
+ (
337
+ matrix (),
338
+ (iscalar (), iscalar ()),
339
+ (1 , 0 ),
340
+ ),
341
+ ],
342
+ )
343
+ def test_local_subtensor_SpecifyShape_lift_fail (self , x , s , idx ):
344
+ y = specify_shape (x , s )[idx ]
345
+
346
+ # This optimization should appear in the canonicalizations
347
+ y_opt = rewrite_graph (y , clone = False )
348
+
349
+ assert not isinstance (y_opt .owner .op , SpecifyShape )
350
350
351
351
352
352
class TestLocalSubtensorMakeVector :
0 commit comments