Skip to content

Commit 0770600

Browse files
committed
Group subtensor specify_shape lift tests in class
1 parent ccbab65 commit 0770600

File tree

1 file changed

+100
-100
lines changed

1 file changed

+100
-100
lines changed

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 100 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -247,106 +247,106 @@ def test_local_subtensor_of_alloc():
247247
assert xval.__getitem__(slices).shape == val.shape
248248

249249

250-
@pytest.mark.parametrize(
251-
"x, s, idx, x_val, s_val",
252-
[
253-
(
254-
vector(),
255-
(iscalar(),),
256-
(1,),
257-
np.array([1, 2], dtype=config.floatX),
258-
np.array([2], dtype=np.int64),
259-
),
260-
(
261-
matrix(),
262-
(iscalar(), iscalar()),
263-
(1,),
264-
np.array([[1, 2], [3, 4]], dtype=config.floatX),
265-
np.array([2, 2], dtype=np.int64),
266-
),
267-
(
268-
matrix(),
269-
(iscalar(), iscalar()),
270-
(0,),
271-
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
272-
np.array([2, 3], dtype=np.int64),
273-
),
274-
(
275-
matrix(),
276-
(iscalar(), iscalar()),
277-
(1, 1),
278-
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
279-
np.array([2, 3], dtype=np.int64),
280-
),
281-
(
282-
tensor3(),
283-
(iscalar(), iscalar(), iscalar()),
284-
(-1,),
285-
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
286-
np.array([2, 3, 5], dtype=np.int64),
287-
),
288-
(
289-
tensor3(),
290-
(iscalar(), iscalar(), iscalar()),
291-
(-1, 0),
292-
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
293-
np.array([2, 3, 5], dtype=np.int64),
294-
),
295-
],
296-
)
297-
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
298-
y = specify_shape(x, s)[idx]
299-
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
300-
301-
rewrites = RewriteDatabaseQuery(include=[None])
302-
no_rewrites_mode = Mode(optimizer=rewrites)
303-
304-
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
305-
y_val = y_val_fn(*([x_val, *s_val]))
306-
307-
# This optimization should appear in the canonicalizations
308-
y_opt = rewrite_graph(y, clone=False)
309-
310-
if y.ndim == 0:
311-
# SpecifyShape should be removed altogether
312-
assert isinstance(y_opt.owner.op, Subtensor)
313-
assert y_opt.owner.inputs[0] is x
314-
else:
315-
assert isinstance(y_opt.owner.op, SpecifyShape)
316-
317-
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
318-
y_opt_val = y_opt_fn(*([x_val, *s_val]))
319-
320-
assert np.allclose(y_val, y_opt_val)
321-
322-
323-
@pytest.mark.parametrize(
324-
"x, s, idx",
325-
[
326-
(
327-
matrix(),
328-
(iscalar(), iscalar()),
329-
(slice(1, None),),
330-
),
331-
(
332-
matrix(),
333-
(iscalar(), iscalar()),
334-
(slicetype(),),
335-
),
336-
(
337-
matrix(),
338-
(iscalar(), iscalar()),
339-
(1, 0),
340-
),
341-
],
342-
)
343-
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
344-
y = specify_shape(x, s)[idx]
345-
346-
# This optimization should appear in the canonicalizations
347-
y_opt = rewrite_graph(y, clone=False)
348-
349-
assert not isinstance(y_opt.owner.op, SpecifyShape)
250+
class TestLocalSubtensorSpecifyShapeLift:
251+
@pytest.mark.parametrize(
252+
"x, s, idx, x_val, s_val",
253+
[
254+
(
255+
vector(),
256+
(iscalar(),),
257+
(1,),
258+
np.array([1, 2], dtype=config.floatX),
259+
np.array([2], dtype=np.int64),
260+
),
261+
(
262+
matrix(),
263+
(iscalar(), iscalar()),
264+
(1,),
265+
np.array([[1, 2], [3, 4]], dtype=config.floatX),
266+
np.array([2, 2], dtype=np.int64),
267+
),
268+
(
269+
matrix(),
270+
(iscalar(), iscalar()),
271+
(0,),
272+
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
273+
np.array([2, 3], dtype=np.int64),
274+
),
275+
(
276+
matrix(),
277+
(iscalar(), iscalar()),
278+
(1, 1),
279+
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
280+
np.array([2, 3], dtype=np.int64),
281+
),
282+
(
283+
tensor3(),
284+
(iscalar(), iscalar(), iscalar()),
285+
(-1,),
286+
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
287+
np.array([2, 3, 5], dtype=np.int64),
288+
),
289+
(
290+
tensor3(),
291+
(iscalar(), iscalar(), iscalar()),
292+
(-1, 0),
293+
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
294+
np.array([2, 3, 5], dtype=np.int64),
295+
),
296+
],
297+
)
298+
def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val):
299+
y = specify_shape(x, s)[idx]
300+
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
301+
302+
rewrites = RewriteDatabaseQuery(include=[None])
303+
no_rewrites_mode = Mode(optimizer=rewrites)
304+
305+
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
306+
y_val = y_val_fn(*([x_val, *s_val]))
307+
308+
# This optimization should appear in the canonicalizations
309+
y_opt = rewrite_graph(y, clone=False)
310+
311+
if y.ndim == 0:
312+
# SpecifyShape should be removed altogether
313+
assert isinstance(y_opt.owner.op, Subtensor)
314+
assert y_opt.owner.inputs[0] is x
315+
else:
316+
assert isinstance(y_opt.owner.op, SpecifyShape)
317+
318+
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
319+
y_opt_val = y_opt_fn(*([x_val, *s_val]))
320+
321+
assert np.allclose(y_val, y_opt_val)
322+
323+
@pytest.mark.parametrize(
324+
"x, s, idx",
325+
[
326+
(
327+
matrix(),
328+
(iscalar(), iscalar()),
329+
(slice(1, None),),
330+
),
331+
(
332+
matrix(),
333+
(iscalar(), iscalar()),
334+
(slicetype(),),
335+
),
336+
(
337+
matrix(),
338+
(iscalar(), iscalar()),
339+
(1, 0),
340+
),
341+
],
342+
)
343+
def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx):
344+
y = specify_shape(x, s)[idx]
345+
346+
# This optimization should appear in the canonicalizations
347+
y_opt = rewrite_graph(y, clone=False)
348+
349+
assert not isinstance(y_opt.owner.op, SpecifyShape)
350350

351351

352352
class TestLocalSubtensorMakeVector:

0 commit comments

Comments
 (0)