Skip to content

Commit 589fa5e

Browse files
eigen-kfacebook-github-bot
authored andcommitted
Use GraphBuilder in unit tests for ops removal pytorch#2. (pytorch#11011)
Summary: Pull Request resolved: pytorch#11011 Reviewed By: zonglinpeng Differential Revision: D75034439
1 parent 936ac2b commit 589fa5e

File tree

1 file changed

+118
-116
lines changed

1 file changed

+118
-116
lines changed

backends/cadence/aot/tests/test_remove_ops_passes.py

Lines changed: 118 additions & 116 deletions
Original file line numberDiff line numberDiff line change
@@ -301,36 +301,28 @@ def forward(self, tensors):
301301
self.assertEqual(count_node(new_graph_module, torch.ops.aten.cat.out), 0)
302302

303303
def test_remove_clone(self):
304-
class Clone(torch.nn.Module):
305-
def forward(self, x, y):
306-
t1 = x.clone()
307-
t2 = y.clone()
308-
return t1 + t2
309-
310-
x = torch.ones(3, 5)
311-
y = torch.ones(3, 5)
312-
graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
313-
new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
314-
new_graph_module.graph.eliminate_dead_code()
315-
# Assert that t1 and t2 are optimized away
316-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
304+
builder = GraphBuilder()
305+
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
306+
clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
307+
builder.output([clone])
308+
original = builder.get_graph_module()
309+
graph_after_passes = RemoveCloneOpPass()(original).graph_module
310+
self.assertEqual(
311+
count_node(graph_after_passes, torch.ops.aten.clone.default), 0
312+
)
317313

318314
def test_remove_contiguous(self):
319-
class Contiguous(torch.nn.Module):
320-
def forward(self, x, y):
321-
t1 = x.contiguous()
322-
t2 = y.contiguous()
323-
return t1 + t2
324-
325-
x = torch.ones(3, 5)
326-
y = torch.ones(3, 5)
327-
graph_module = (
328-
export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
315+
builder = GraphBuilder()
316+
x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
317+
contiguous = builder.call_operator(
318+
op=exir_ops.edge.aten.contiguous.default, args=(x,)
319+
)
320+
builder.output([contiguous])
321+
original = builder.get_graph_module()
322+
graph_after_passes = RemoveContiguousOpPass()(original).graph_module
323+
self.assertEqual(
324+
count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0
329325
)
330-
new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
331-
new_graph_module.graph.eliminate_dead_code()
332-
# Assert that t1 and t2 are optimized away
333-
self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)
334326

335327
@parameterized.expand(
336328
[
@@ -340,119 +332,129 @@ def forward(self, x, y):
340332
)
341333
@torch.no_grad()
342334
def test_remove_nop_view(self, shape, new_shape):
343-
class View(torch.nn.Module):
344-
def __init__(self, new_shape):
345-
super().__init__()
346-
self.new_shape = new_shape
347-
348-
def forward(self, x: torch.Tensor):
349-
return x.view(self.new_shape)
350-
351-
model = View(new_shape)
352-
x = torch.randn(shape)
353-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
354-
p = RemoveNopSliceOrViewOpPass()
355-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
356-
graph_after_passes.graph.eliminate_dead_code()
357-
# Assert that view op was removed
335+
builder = GraphBuilder()
336+
x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
337+
view = builder.call_operator(
338+
op=exir_ops.edge.aten.view_copy.default, args=(x, new_shape)
339+
)
340+
builder.output([view])
341+
original = builder.get_graph_module()
342+
graph_after_passes = cast(
343+
PassResult, RemoveNopSliceOrViewOpPass()(original)
344+
).graph_module
358345
self.assertEqual(
359346
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
360347
)
361348

362349
def test_remove_nop_slice(self):
363-
class Slice(torch.nn.Module):
364-
def forward(self, x):
365-
return torch.slice_copy(x, dim=0, start=0, step=1)
366-
367-
x = torch.ones(3, 5)
368-
model = Slice()
369-
graph_module = export_to_edge(model, (x,)).exported_program().graph_module
370-
p = RemoveNopSliceOrViewOpPass()
371-
graph_after_passes = cast(PassResult, p(graph_module)).graph_module
372-
graph_after_passes.graph.eliminate_dead_code()
373-
# Assert that slice op was removed
350+
builder = GraphBuilder()
351+
x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
352+
slice_ = builder.call_operator(
353+
op=exir_ops.edge.aten.slice_copy.Tensor,
354+
args=(
355+
x,
356+
0, # dim
357+
0, # start
358+
3, # end
359+
),
360+
)
361+
builder.output([slice_])
362+
original = builder.get_graph_module()
363+
graph_after_passes = cast(
364+
PassResult, RemoveNopSliceOrViewOpPass()(original)
365+
).graph_module
374366
self.assertEqual(
375367
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
376368
)
377369

378-
def test_remove_nop_select(self):
379-
class SelectFeasible1(torch.nn.Module):
380-
def forward(self, x):
381-
y = x.select(0, 0)
382-
z = y.view([1, 5, 6])
383-
return z
384-
385-
x = torch.ones(1, 5, 6)
386-
graph_module = (
387-
export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
370+
def test_remove_nop_select_before_view(self):
371+
builder = GraphBuilder()
372+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
373+
select = builder.call_operator(
374+
op=exir_ops.edge.aten.select_copy.int,
375+
args=(
376+
x,
377+
0, # dim
378+
0, # index
379+
),
388380
)
389-
self.assertEqual(
390-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
381+
view = builder.call_operator(
382+
op=exir_ops.edge.aten.view_copy.default,
383+
args=(select, [1, 5, 6]), # new shape
391384
)
392-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
393-
# Assert that select op was removed
385+
builder.output([view])
386+
original = builder.get_graph_module()
387+
graph_after_passes = cast(
388+
PassResult, RemoveNopSelectOpPass()(original)
389+
).graph_module
394390
self.assertEqual(
395-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
391+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
396392
)
397393

398-
class SelectFeasible2(torch.nn.Module):
399-
def forward(self, x, y):
400-
x = x.select(0, 0)
401-
z = x + y
402-
return z
403-
404-
x = torch.ones(1, 5, 6)
405-
y = torch.ones(1, 5, 6)
406-
graph_module = (
407-
export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
408-
)
409-
self.assertEqual(
410-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
394+
def test_remove_nop_select_before_add(self):
395+
builder = GraphBuilder()
396+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
397+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
398+
select = builder.call_operator(
399+
op=exir_ops.edge.aten.select_copy.int,
400+
args=(
401+
x,
402+
0, # dim
403+
0, # index
404+
),
411405
)
412-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
413-
# Assert that select op was removed
406+
add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y))
407+
builder.output([add])
408+
original = builder.get_graph_module()
409+
graph_after_passes = cast(
410+
PassResult, RemoveNopSelectOpPass()(original)
411+
).graph_module
414412
self.assertEqual(
415-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
413+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
416414
)
417415

418-
class SelectFeasible3(torch.nn.Module):
419-
def forward(self, x, y):
420-
x = x.select(0, 0)
421-
z = x * y
422-
return z
423-
424-
x = torch.ones(1, 5, 6)
425-
y = torch.ones(1, 5, 6)
426-
graph_module = (
427-
export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
428-
)
429-
self.assertEqual(
430-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
416+
def test_remove_nop_select_before_mul(self):
417+
builder = GraphBuilder()
418+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
419+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
420+
select = builder.call_operator(
421+
op=exir_ops.edge.aten.select_copy.int,
422+
args=(
423+
x,
424+
0, # dim
425+
0, # index
426+
),
431427
)
432-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
433-
# Assert that select op was removed
428+
mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y))
429+
builder.output([mul])
430+
original = builder.get_graph_module()
431+
graph_after_passes = cast(
432+
PassResult, RemoveNopSelectOpPass()(original)
433+
).graph_module
434434
self.assertEqual(
435-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
435+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
436436
)
437437

438-
class SelectFeasible4(torch.nn.Module):
439-
def forward(self, x, y):
440-
x = x.select(0, 0)
441-
z = x / y
442-
return z
443-
444-
x = torch.ones(1, 5, 6)
445-
y = torch.ones(1, 5, 6)
446-
graph_module = (
447-
export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
448-
)
449-
self.assertEqual(
450-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
438+
def test_remove_nop_select_before_div(self):
439+
builder = GraphBuilder()
440+
x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
441+
y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
442+
select = builder.call_operator(
443+
op=exir_ops.edge.aten.select_copy.int,
444+
args=(
445+
x,
446+
0, # dim
447+
0, # index
448+
),
451449
)
452-
graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
453-
# Assert that select op was removed
450+
div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y))
451+
builder.output([div])
452+
original = builder.get_graph_module()
453+
graph_after_passes = cast(
454+
PassResult, RemoveNopSelectOpPass()(original)
455+
).graph_module
454456
self.assertEqual(
455-
count_node(graph_module, exir_ops.edge.aten.select_copy.int), 0
457+
count_node(graph_after_passes, exir_ops.edge.aten.select_copy.int), 0
456458
)
457459

458460
def test_remove_nop_quant_dequant(self):

0 commit comments

Comments
 (0)