Skip to content

Commit bebff52

Browse files
Jerry-Gefacebook-github-bot
authored andcommitted
Add lowering for aten.clone and aten.view_copy (#602)
Summary: Pull Request resolved: #602 Reviewed By: cccclai Differential Revision: D49897893 Pulled By: digantdesai fbshipit-source-id: d08d3e90e45b7741162ca083ca3da5e8ca4251ee
1 parent 7c68ec6 commit bebff52

File tree

2 files changed

+38
-0
lines changed

2 files changed

+38
-0
lines changed

backends/arm/arm_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def is_node_supported(self, submodules, node: torch.fx.Node) -> bool:
7474
exir_ops.edge.aten._native_batch_norm_legit_no_training.default,
7575
exir_ops.edge.aten.avg_pool2d.default,
7676
exir_ops.edge.aten._softmax.default,
77+
exir_ops.edge.aten.view_copy.default,
78+
exir_ops.edge.aten.clone.default,
7779
operator.getitem,
7880
exir_ops.edge.quantized_decomposed.quantize_per_tensor.default,
7981
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
@@ -762,10 +764,18 @@ def preprocess( # noqa: C901
762764
[outp.name],
763765
attr_mul,
764766
)
767+
elif exir_ops.edge.aten.view_copy.default == node.target:
768+
attr = ts.TosaSerializerAttribute()
769+
new_shape = inputs[1].special
770+
attr.ReshapeAttribute(new_shape)
771+
tosa_fb.addOperator(
772+
TosaOp.Op().RESHAPE, [inputs[0].name], [outp.name], attr
773+
)
765774
elif node.target in [
766775
operator.getitem,
767776
tosa_quant_utils.q_op,
768777
tosa_quant_utils.dq_op,
778+
exir_ops.edge.aten.clone.default,
769779
]:
770780
item_name = inputs[0].name
771781
## Simply add an identityOp

backends/arm/test/test_models.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,34 @@ class TorchBuilder:
3434
def __init__(self):
3535
pass
3636

37+
@register_test
38+
class simple_clone(torch.nn.Module):
39+
inputs = {
40+
TosaProfile.BI: (torch.ones(10),),
41+
TosaProfile.MI: (torch.ones(10),),
42+
}
43+
44+
def __init__(self):
45+
super().__init__()
46+
47+
def forward(self, x):
48+
x = x.clone()
49+
return x
50+
51+
@register_test
52+
class simple_view(torch.nn.Module):
53+
inputs = {
54+
TosaProfile.BI: (torch.ones(10),),
55+
TosaProfile.MI: (torch.ones(10),),
56+
}
57+
58+
def __init__(self):
59+
super().__init__()
60+
61+
def forward(self, x):
62+
x = x.view(2, 5)
63+
return x
64+
3765
@register_test
3866
class simple_add(torch.nn.Module):
3967
inputs = {

0 commit comments

Comments
 (0)