Skip to content

Commit 1f9dc41

Browse files
mcremon-metafacebook-github-bot
authored andcommitted
Add vision transformer (#4077)
Summary: As titled. It will be useful for a few Cadence teams to be able to look at least at the AoT graph. Differential Revision: D59097944
1 parent 4fcd903 commit 1f9dc41

File tree

3 files changed

+66
-5
lines changed

3 files changed

+66
-5
lines changed

backends/cadence/aot/ops_registrations.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from executorch.exir.scalar_type import ScalarType
1111
from torch.library import impl, Library
1212

13-
from .utils import get_conv1d_output_size
13+
from .utils import get_conv1d_output_size, get_conv2d_output_size
1414

1515
lib = Library("cadence", "DEF")
1616

@@ -122,16 +122,22 @@ def quantized_conv_meta(
122122
out_multiplier: torch.Tensor,
123123
out_shift: torch.Tensor,
124124
channel_last: bool = False,
125-
):
125+
) -> torch.Tensor:
126126
out_channels, _in_channels, *kernel_size = weight.shape
127127
in_size = input.shape
128128
# Assert that the input tensor has at least 3 dimensions, and at most 6
129129
assert len(in_size) > 2
130130
assert len(in_size) < 6
131131

132132
# Compute the output tensor size
133-
output_size = get_conv1d_output_size(
134-
in_size, out_channels, stride[0], padding[0], dilation[0], kernel_size[0]
133+
output_size = (
134+
get_conv1d_output_size(
135+
in_size, out_channels, stride[1], padding[1], dilation[1], kernel_size[0]
136+
)
137+
if len(in_size) == 3
138+
else get_conv2d_output_size(
139+
in_size, out_channels, stride, padding, dilation, kernel_size, channel_last
140+
)
135141
)
136142

137143
return input.new_empty(output_size, dtype=input.dtype)

backends/cadence/aot/utils.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
import operator
9-
from typing import Dict
9+
from typing import Dict, List, Tuple
1010

1111
import torch
1212
from executorch.exir import memory
@@ -49,6 +49,33 @@ def get_conv1d_output_size(
4949
return torch.Size((in_size[0], out_channels, lout))
5050

5151

52+
# Get the output size of a 2D convolution given the input size and parameters
53+
def get_conv2d_output_size(
54+
in_size: torch.Size,
55+
out_channels: int,
56+
stride: Tuple[int],
57+
padding: Tuple[int],
58+
dilation: Tuple[int],
59+
kernel_size: List[int],
60+
channel_last: bool,
61+
) -> torch.Size:
62+
assert len(in_size) == 4
63+
if channel_last:
64+
N, H, W, C = in_size
65+
else:
66+
N, C, H, W = in_size
67+
68+
# Reference: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
69+
hout = (H + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) // stride[
70+
0
71+
] + 1
72+
wout = (W + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) // stride[
73+
1
74+
] + 1
75+
76+
return torch.Size((in_size[0], out_channels, hout, wout))
77+
78+
5279
# Return the overload packet for the edge op
5380
def get_edge_overload_packet(edge_op: EdgeOpOverload) -> EdgeOpOverloadPacket:
5481
edge_op_namespace, edge_op_name = (
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# Example script for exporting simple models to flatbuffer
8+
9+
import logging
10+
11+
from executorch.backends.cadence.aot.ops_registrations import * # noqa
12+
13+
import torch
14+
import torchvision
15+
16+
from executorch.backends.cadence.aot.export_example import export_model
17+
18+
19+
FORMAT = "[%(levelname)s %(asctime)s %(filename)s:%(lineno)s] %(message)s"
20+
logging.basicConfig(level=logging.INFO, format=FORMAT)
21+
22+
23+
if __name__ == "__main__":
24+
25+
model = torchvision.models.vit_b_16()
26+
example_inputs = (torch.randn(1, 3, 224, 224),)
27+
28+
export_model(model, example_inputs)

0 commit comments

Comments
 (0)