Skip to content

Commit e7a429a

Browse files
chuntlfacebook-github-bot
authored andcommitted
Qualcomm AI Engine Direct - Adapt to new IR capture flow (#2431)
Summary: - Change existent IR capture flow (exir.capture) to torch.export.export - Add custom decomposition table for mitigating maintaining effort - Fix breakages encountered and make sure all tests passed as well Pull Request resolved: #2431 Reviewed By: mergennachin Differential Revision: D55353449 Pulled By: cccclai fbshipit-source-id: aa2e27d0ae93aa62208fd03ec39b3891a70b954e
1 parent 3cf9f22 commit e7a429a

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

44 files changed

+76
-67
lines changed

backends/qualcomm/builders/node_visitor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -396,7 +396,8 @@ def register_node_visitor(visitor):
396396
and issubclass(visitor, NodeVisitor)
397397
and hasattr(visitor, "target")
398398
), f"Illformed NodeVisitor subclass, can't register!, got: {visitor}"
399-
_node_visitor_dict[visitor.target] = visitor
399+
for target in visitor.target:
400+
_node_visitor_dict[target] = visitor
400401

401402

402403
def generate_node_to_external_map(

backends/qualcomm/builders/op_add.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Add(NodeVisitor):
18-
target = "aten.add.Tensor"
18+
target = ["aten.add.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_avg_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class AvgPool2d(NodeVisitor):
19-
target = "aten.avg_pool2d.default"
19+
target = ["aten.avg_pool2d.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_batch_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class BatchNorm(NodeVisitor):
19-
target = "aten._native_batch_norm_legit_no_training.default"
19+
target = ["aten._native_batch_norm_legit_no_training.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_bmm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class BMM(NodeVisitor):
18-
target = "aten.bmm.default"
18+
target = ["aten.bmm.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_cast.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Cast(NodeVisitor):
18-
target = "aten._to_copy.default"
18+
target = ["aten._to_copy.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Cat(NodeVisitor):
19-
target = "aten.cat.default"
19+
target = ["aten.cat.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_ceil.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Ceil(NodeVisitor):
18-
target = "aten.ceil.default"
18+
target = ["aten.ceil.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_clamp.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Clamp(NodeVisitor):
19-
target = "aten.clamp.default"
19+
target = ["aten.clamp.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
@register_node_visitor
2626
class Conv2d(NodeVisitor):
27-
target = "aten.convolution.default"
27+
target = ["aten.convolution.default"]
2828

2929
def __init__(self, *args) -> None:
3030
super().__init__(*args)

backends/qualcomm/builders/op_depth_to_space.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class DepthToSpaceVisitor(NodeVisitor):
20-
target = "aten.pixel_shuffle.default"
20+
target = ["aten.pixel_shuffle.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_dequantize.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,19 +55,19 @@ def define_node(
5555

5656
@register_node_visitor
5757
class PerTensorDequantizeDefault(DequantizeOpBase):
58-
target = "quantized_decomposed.dequantize_per_tensor.default"
58+
target = ["quantized_decomposed.dequantize_per_tensor.default"]
5959

6060

6161
@register_node_visitor
6262
class PerTensorDequantizeTensor(DequantizeOpBase):
63-
target = "quantized_decomposed.dequantize_per_tensor.tensor"
63+
target = ["quantized_decomposed.dequantize_per_tensor.tensor"]
6464

6565

6666
@register_node_visitor
6767
class PerChannelDequantizeDefault(DequantizeOpBase):
68-
target = "quantized_decomposed.dequantize_per_channel.default"
68+
target = ["quantized_decomposed.dequantize_per_channel.default"]
6969

7070

7171
@register_node_visitor
7272
class PerChannelDequantizeTensor(DequantizeOpBase):
73-
target = "quantized_decomposed.dequantize_per_channel.tensor"
73+
target = ["quantized_decomposed.dequantize_per_channel.tensor"]

backends/qualcomm/builders/op_div.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Div(NodeVisitor):
18-
target = "aten.div.Tensor"
18+
target = ["aten.div.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_embedding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class Embedding(NodeVisitor):
20-
target = "aten.embedding.default"
20+
target = ["aten.embedding.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_expand.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Expand(NodeVisitor):
19-
target = "aten.expand_copy.default"
19+
target = ["aten.expand_copy.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_gelu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class GeluVisitor(NodeVisitor):
19-
target = "aten.gelu.default"
19+
target = ["aten.gelu.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_hardswish.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class HardSwishVisitor(NodeVisitor):
19-
target = "aten.hardswish.default"
19+
target = ["aten.hardswish.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_hardtanh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class HardTanhVisitor(NodeVisitor):
20-
target = "aten.hardtanh.default"
20+
target = ["aten.hardtanh.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_layer_norm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
@register_node_visitor
2020
class LayerNormVisitor(NodeVisitor):
21-
target = "aten.native_layer_norm.default"
21+
target = ["aten.native_layer_norm.default"]
2222

2323
def __init__(self, *args) -> None:
2424
super().__init__(*args)

backends/qualcomm/builders/op_linear.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class LinearVisitor(NodeVisitor):
20-
target = "aten.linear.default"
20+
target = ["aten.linear.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_log_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class LogSoftmax(NodeVisitor):
19-
target = "aten._log_softmax.default"
19+
target = ["aten._log_softmax.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_matmul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Matmul(NodeVisitor):
18-
target = "aten.matmul.default"
18+
target = ["aten.matmul.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_max_pool2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class MaxPool2d(NodeVisitor):
19-
target = "aten.max_pool2d_with_indices.default"
19+
target = ["aten.max_pool2d_with_indices.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_mean_dim.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class MeanDim(NodeVisitor):
20-
target = "aten.mean.dim"
20+
target = ["aten.mean.dim"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Mul(NodeVisitor):
18-
target = "aten.mul.Tensor"
18+
target = ["aten.mul.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_pad.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Pad(NodeVisitor):
19-
target = "aten.constant_pad_nd.default"
19+
target = ["aten.constant_pad_nd.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_pow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# TODO Add more class Like PowTensorTensor if needed
1818
@register_node_visitor
1919
class PowTensorScalar(NodeVisitor):
20-
target = "aten.pow.Tensor_Scalar"
20+
target = ["aten.pow.Tensor_Scalar"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_quantize.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,9 @@ def define_node(
6161

6262
@register_node_visitor
6363
class PerTensorQuantize(QuantizeOpBase):
64-
target = "quantized_decomposed.quantize_per_tensor.default"
64+
target = ["quantized_decomposed.quantize_per_tensor.default"]
6565

6666

6767
@register_node_visitor
6868
class PerChannelQuantize(QuantizeOpBase):
69-
target = "quantized_decomposed.quantize_per_channel.default"
69+
target = ["quantized_decomposed.quantize_per_channel.default"]

backends/qualcomm/builders/op_relu.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Relu(NodeVisitor):
18-
target = "aten.relu.default"
18+
target = ["aten.relu.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_reshape.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Reshape(NodeVisitor):
18-
target = "aten.view_copy.default"
18+
target = ["aten.view_copy.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_rsqrt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Rsqrt(NodeVisitor):
18-
target = "aten.rsqrt.default"
18+
target = ["aten.rsqrt.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_select_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class SelectCopy(NodeVisitor):
20-
target = "aten.select_copy.int"
20+
target = ["aten.select_copy.int", "aten.select.int"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_sigmoid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Sigmoid(NodeVisitor):
18-
target = "aten.sigmoid.default"
18+
target = ["aten.sigmoid.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_skip_ops.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class OpGetItem(OpSkipOps):
3535
do nothing if node is getitem
3636
"""
3737

38-
target = "getitem"
38+
target = ["getitem"]
3939

4040
def define_node(
4141
self,

backends/qualcomm/builders/op_slice_copy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class StrideSlice(NodeVisitor):
19-
target = "aten.slice_copy.Tensor"
19+
target = ["aten.slice_copy.Tensor"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_softmax.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Softmax(NodeVisitor):
19-
target = "aten._softmax.default"
19+
target = ["aten._softmax.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_squeeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Squeeze(NodeVisitor):
18-
target = "aten.squeeze_copy.dims"
18+
target = ["aten.squeeze_copy.dims", "aten.squeeze.dims"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_sub.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Sub(NodeVisitor):
18-
target = "aten.sub.Tensor"
18+
target = ["aten.sub.Tensor"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_tanh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
@register_node_visitor
1818
class Tanh(NodeVisitor):
19-
target = "aten.tanh.default"
19+
target = ["aten.tanh.default"]
2020

2121
def __init__(self, *args) -> None:
2222
super().__init__(*args)

backends/qualcomm/builders/op_transpose.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
@register_node_visitor
1919
class TransposeVisitor(NodeVisitor):
20-
target = "aten.permute_copy.default"
20+
target = ["aten.permute_copy.default"]
2121

2222
def __init__(self, *args) -> None:
2323
super().__init__(*args)

backends/qualcomm/builders/op_unsqueeze.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class Unsqueeze(NodeVisitor):
18-
target = "aten.unsqueeze_copy.default"
18+
target = ["aten.unsqueeze_copy.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

backends/qualcomm/builders/op_upsample_bilinear2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
@register_node_visitor
1717
class ResizeBilinear(NodeVisitor):
18-
target = "aten.upsample_bilinear2d.default"
18+
target = ["aten.upsample_bilinear2d.default"]
1919

2020
def __init__(self, *args) -> None:
2121
super().__init__(*args)

0 commit comments

Comments
 (0)