@@ -235,36 +235,28 @@ def test_remove_zero_arg_cat(self):
235
235
)
236
236
237
237
def test_remove_clone(self):
238
- class Clone(torch.nn.Module):
239
- def forward(self, x, y):
240
- t1 = x.clone()
241
- t2 = y.clone()
242
- return t1 + t2
243
-
244
- x = torch.ones(3, 5)
245
- y = torch.ones(3, 5)
246
- graph_module = export_to_edge(Clone(), (x, y)).exported_program().graph_module
247
- new_graph_module = RemoveCloneOpPass()(graph_module).graph_module
248
- new_graph_module.graph.eliminate_dead_code()
249
- # Assert that t1 and t2 are optimized away
250
- self.assertEqual(count_node(new_graph_module, torch.ops.aten.clone.out), 0)
238
+ builder = GraphBuilder()
239
+ x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
240
+ clone = builder.call_operator(op=exir_ops.edge.aten.clone.default, args=(x,))
241
+ builder.output([clone])
242
+ original = builder.get_graph_module()
243
+ graph_after_passes = RemoveCloneOpPass()(original).graph_module
244
+ self.assertEqual(
245
+ count_node(graph_after_passes, torch.ops.aten.clone.default), 0
246
+ )
251
247
252
248
def test_remove_contiguous(self):
253
- class Contiguous(torch.nn.Module):
254
- def forward(self, x, y):
255
- t1 = x.contiguous()
256
- t2 = y. contiguous( )
257
- return t1 + t2
258
-
259
- x = torch.ones(3, 5 )
260
- y = torch.ones(3, 5)
261
- graph_module = (
262
- export_to_edge(Contiguous(), (x, y)).exported_program().graph_module
249
+ builder = GraphBuilder()
250
+ x = builder.placeholder("x", torch.randn([3, 5], dtype=torch.float32))
251
+ contiguous = builder.call_operator(
252
+ op=exir_ops.edge.aten. contiguous.default, args=(x, )
253
+ )
254
+ builder.output([contiguous])
255
+ original = builder.get_graph_module( )
256
+ graph_after_passes = RemoveContiguousOpPass()(original).graph_module
257
+ self.assertEqual (
258
+ count_node(graph_after_passes, torch.ops.aten.contiguous.default), 0
263
259
)
264
- new_graph_module = RemoveContiguousOpPass()(graph_module).graph_module
265
- new_graph_module.graph.eliminate_dead_code()
266
- # Assert that t1 and t2 are optimized away
267
- self.assertEqual(count_node(new_graph_module, torch.ops.aten.contiguous.out), 0)
268
260
269
261
@parameterized.expand(
270
262
[
@@ -274,119 +266,129 @@ def forward(self, x, y):
274
266
)
275
267
@torch.no_grad()
276
268
def test_remove_nop_view(self, shape, new_shape):
277
- class View(torch.nn.Module):
278
- def __init__(self, new_shape):
279
- super().__init__()
280
- self.new_shape = new_shape
281
-
282
- def forward(self, x: torch.Tensor):
283
- return x.view(self.new_shape)
284
-
285
- model = View(new_shape)
286
- x = torch.randn(shape)
287
- graph_module = export_to_edge(model, (x,)).exported_program().graph_module
288
- p = RemoveNopSliceOrViewOpPass()
289
- graph_after_passes = cast(PassResult, p(graph_module)).graph_module
290
- graph_after_passes.graph.eliminate_dead_code()
291
- # Assert that view op was removed
269
+ builder = GraphBuilder()
270
+ x = builder.placeholder("x", torch.randn(*shape, dtype=torch.float32))
271
+ view = builder.call_operator(
272
+ op=exir_ops.edge.aten.view_copy.default, args=(x, new_shape)
273
+ )
274
+ builder.output([view])
275
+ original = builder.get_graph_module()
276
+ graph_after_passes = cast(
277
+ PassResult, RemoveNopSliceOrViewOpPass()(original)
278
+ ).graph_module
292
279
self.assertEqual(
293
280
count_node(graph_after_passes, exir_ops.edge.aten.view_copy.default), 0
294
281
)
295
282
296
283
def test_remove_nop_slice(self):
297
- class Slice(torch.nn.Module):
298
- def forward(self, x):
299
- return torch.slice_copy(x, dim=0, start=0, step=1)
300
-
301
- x = torch.ones(3, 5)
302
- model = Slice()
303
- graph_module = export_to_edge(model, (x,)).exported_program().graph_module
304
- p = RemoveNopSliceOrViewOpPass()
305
- graph_after_passes = cast(PassResult, p(graph_module)).graph_module
306
- graph_after_passes.graph.eliminate_dead_code()
307
- # Assert that slice op was removed
284
+ builder = GraphBuilder()
285
+ x = builder.placeholder("x", torch.randn(3, 5, dtype=torch.float32))
286
+ slice_ = builder.call_operator(
287
+ op=exir_ops.edge.aten.slice_copy.Tensor,
288
+ args=(
289
+ x,
290
+ 0, # dim
291
+ 0, # start
292
+ 3, # end
293
+ ),
294
+ )
295
+ builder.output([slice_])
296
+ original = builder.get_graph_module()
297
+ graph_after_passes = cast(
298
+ PassResult, RemoveNopSliceOrViewOpPass()(original)
299
+ ).graph_module
308
300
self.assertEqual(
309
301
count_node(graph_after_passes, exir_ops.edge.aten.slice_copy.Tensor), 0
310
302
)
311
303
312
- def test_remove_nop_select (self):
313
- class SelectFeasible1(torch.nn.Module):
314
- def forward(self, x):
315
- y = x.select(0, 0)
316
- z = y.view([1, 5, 6])
317
- return z
318
-
319
- x = torch.ones(1, 5, 6)
320
- graph_module = (
321
- export_to_edge(SelectFeasible1(), (x,)).exported_program().graph_module
304
+ def test_remove_nop_select_before_view (self):
305
+ builder = GraphBuilder()
306
+ x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
307
+ select = builder.call_operator(
308
+ op=exir_ops.edge.aten.select_copy.int,
309
+ args=(
310
+ x,
311
+ 0, # dim
312
+ 0, # index
313
+ ),
322
314
)
323
- self.assertEqual(
324
- count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
315
+ view = builder.call_operator(
316
+ op=exir_ops.edge.aten.view_copy.default,
317
+ args=(select, [1, 5, 6]), # new shape
325
318
)
326
- graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
327
- # Assert that select op was removed
319
+ builder.output([view])
320
+ original = builder.get_graph_module()
321
+ graph_after_passes = cast(
322
+ PassResult, RemoveNopSelectOpPass()(original)
323
+ ).graph_module
328
324
self.assertEqual(
329
- count_node(graph_module , exir_ops.edge.aten.select_copy.int), 0
325
+ count_node(graph_after_passes , exir_ops.edge.aten.select_copy.int), 0
330
326
)
331
327
332
- class SelectFeasible2(torch.nn.Module):
333
- def forward(self, x, y):
334
- x = x.select(0, 0)
335
- z = x + y
336
- return z
337
-
338
- x = torch.ones(1, 5, 6)
339
- y = torch.ones(1, 5, 6)
340
- graph_module = (
341
- export_to_edge(SelectFeasible2(), (x, y)).exported_program().graph_module
342
- )
343
- self.assertEqual(
344
- count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
328
+ def test_remove_nop_select_before_add(self):
329
+ builder = GraphBuilder()
330
+ x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
331
+ y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
332
+ select = builder.call_operator(
333
+ op=exir_ops.edge.aten.select_copy.int,
334
+ args=(
335
+ x,
336
+ 0, # dim
337
+ 0, # index
338
+ ),
345
339
)
346
- graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
347
- # Assert that select op was removed
340
+ add = builder.call_operator(op=exir_ops.edge.aten.add.Tensor, args=(select, y))
341
+ builder.output([add])
342
+ original = builder.get_graph_module()
343
+ graph_after_passes = cast(
344
+ PassResult, RemoveNopSelectOpPass()(original)
345
+ ).graph_module
348
346
self.assertEqual(
349
- count_node(graph_module , exir_ops.edge.aten.select_copy.int), 0
347
+ count_node(graph_after_passes , exir_ops.edge.aten.select_copy.int), 0
350
348
)
351
349
352
- class SelectFeasible3(torch.nn.Module):
353
- def forward(self, x, y):
354
- x = x.select(0, 0)
355
- z = x * y
356
- return z
357
-
358
- x = torch.ones(1, 5, 6)
359
- y = torch.ones(1, 5, 6)
360
- graph_module = (
361
- export_to_edge(SelectFeasible3(), (x, y)).exported_program().graph_module
362
- )
363
- self.assertEqual(
364
- count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
350
+ def test_remove_nop_select_before_mul(self):
351
+ builder = GraphBuilder()
352
+ x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
353
+ y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
354
+ select = builder.call_operator(
355
+ op=exir_ops.edge.aten.select_copy.int,
356
+ args=(
357
+ x,
358
+ 0, # dim
359
+ 0, # index
360
+ ),
365
361
)
366
- graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
367
- # Assert that select op was removed
362
+ mul = builder.call_operator(op=exir_ops.edge.aten.mul.Tensor, args=(select, y))
363
+ builder.output([mul])
364
+ original = builder.get_graph_module()
365
+ graph_after_passes = cast(
366
+ PassResult, RemoveNopSelectOpPass()(original)
367
+ ).graph_module
368
368
self.assertEqual(
369
- count_node(graph_module , exir_ops.edge.aten.select_copy.int), 0
369
+ count_node(graph_after_passes , exir_ops.edge.aten.select_copy.int), 0
370
370
)
371
371
372
- class SelectFeasible4(torch.nn.Module):
373
- def forward(self, x, y):
374
- x = x.select(0, 0)
375
- z = x / y
376
- return z
377
-
378
- x = torch.ones(1, 5, 6)
379
- y = torch.ones(1, 5, 6)
380
- graph_module = (
381
- export_to_edge(SelectFeasible4(), (x, y)).exported_program().graph_module
382
- )
383
- self.assertEqual(
384
- count_node(graph_module, exir_ops.edge.aten.select_copy.int), 1
372
+ def test_remove_nop_select_before_div(self):
373
+ builder = GraphBuilder()
374
+ x = builder.placeholder("x", torch.randn(1, 5, 6, dtype=torch.float32))
375
+ y = builder.placeholder("y", torch.randn(1, 5, 6, dtype=torch.float32))
376
+ select = builder.call_operator(
377
+ op=exir_ops.edge.aten.select_copy.int,
378
+ args=(
379
+ x,
380
+ 0, # dim
381
+ 0, # index
382
+ ),
385
383
)
386
- graph_module = RemoveNopSelectOpPass()(graph_module).graph_module
387
- # Assert that select op was removed
384
+ div = builder.call_operator(op=exir_ops.edge.aten.div.Tensor, args=(select, y))
385
+ builder.output([div])
386
+ original = builder.get_graph_module()
387
+ graph_after_passes = cast(
388
+ PassResult, RemoveNopSelectOpPass()(original)
389
+ ).graph_module
388
390
self.assertEqual(
389
- count_node(graph_module , exir_ops.edge.aten.select_copy.int), 0
391
+ count_node(graph_after_passes , exir_ops.edge.aten.select_copy.int), 0
390
392
)
391
393
392
394
def test_remove_nop_quant_dequant(self):
0 commit comments