Skip to content

Commit c6f0b07

Browse files
guangy10facebook-github-bot
authored andcommitted
Add deeplab_v3 model to examples (#60)
Summary: Pull Request resolved: #60 Add image segmentation model **deeplabv3_resnet101** to `executorch/examples`. Info about the model: https://pytorch.org/hub/pytorch_vision_deeplabv3_resnet101/ Differential Revision: D48136966 fbshipit-source-id: 42eb6592eeb7dac93cd959037d5a70aeb87ead10
1 parent f4ed04f commit c6f0b07

File tree

6 files changed

+76
-0
lines changed

6 files changed

+76
-0
lines changed

examples/export/test/test_export.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,11 @@ def test_resnet50_export_to_executorch(self):
111111
self._assert_eager_lowered_same_result(
112112
eager_model, example_inputs, self.validate_tensor_allclose
113113
)
114+
115+
def test_dl3_export_to_executorch(self):
116+
eager_model, example_inputs = MODEL_NAME_TO_MODEL["dl3"]()
117+
eager_model = eager_model.eval()
118+
119+
self._assert_eager_lowered_same_result(
120+
eager_model, example_inputs, self.validate_tensor_allclose
121+
)

examples/models/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ python_library(
88
],
99
deps = [
1010
"//caffe2:torch",
11+
"//executorch/examples/models/deeplab_v3:dl3_export",
1112
"//executorch/examples/models/inception_v3:ic3_export",
1213
"//executorch/examples/models/inception_v4:ic4_export",
1314
"//executorch/examples/models/mobilenet_v2:mv2_export",

examples/models/deeplab_v3/TARGETS

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
load("@fbcode_macros//build_defs:python_library.bzl", "python_library")
2+
3+
python_library(
4+
name = "dl3_export",
5+
srcs = [
6+
"__init__.py",
7+
"export.py",
8+
],
9+
base_module = "executorch.examples.models.deeplab_v3",
10+
deps = [
11+
"//caffe2:torch",
12+
"//pytorch/vision:torchvision",
13+
],
14+
)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .export import DeepLabV3ResNet101Model
8+
9+
__all__ = [
10+
DeepLabV3ResNet101Model,
11+
]

examples/models/deeplab_v3/export.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import logging
8+
9+
import torch
10+
from torchvision import models
11+
12+
FORMAT = "[%(filename)s:%(lineno)s] %(message)s"
13+
logging.basicConfig(format=FORMAT)
14+
15+
16+
class DeepLabV3ResNet101Model:
17+
def __init__(self):
18+
pass
19+
20+
@staticmethod
21+
def get_model():
22+
logging.info("loading deeplabv3_resnet101 model")
23+
deeplabv3_resnet101 = models.segmentation.deeplabv3_resnet101(
24+
weights=models.segmentation.deeplabv3.DeepLabV3_ResNet101_Weights.COCO_WITH_VOC_LABELS_V1
25+
)
26+
logging.info("loaded deeplabv3_resnet101 model")
27+
return deeplabv3_resnet101
28+
29+
@staticmethod
30+
def get_example_inputs():
31+
input_shape = (1, 3, 224, 224)
32+
return (torch.randn(input_shape),)

examples/models/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,15 @@ def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
126126
return ResNet50Model.get_model(), ResNet50Model.get_example_inputs()
127127

128128

129+
def gen_deeplab_v3_resnet_101_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
130+
from ..models.deeplab_v3 import DeepLabV3ResNet101Model
131+
132+
return (
133+
DeepLabV3ResNet101Model.get_model(),
134+
DeepLabV3ResNet101Model.get_example_inputs(),
135+
)
136+
137+
129138
MODEL_NAME_TO_MODEL = {
130139
"mul": lambda: (MulModule(), MulModule.get_example_inputs()),
131140
"linear": lambda: (LinearModule(), LinearModule.get_example_inputs()),
@@ -139,4 +148,5 @@ def gen_resnet50_model_and_inputs() -> Tuple[torch.nn.Module, Any]:
139148
"ic4": gen_inception_v4_model_and_inputs,
140149
"resnet18": gen_resnet18_model_and_inputs,
141150
"resnet50": gen_resnet50_model_and_inputs,
151+
"dl3": gen_deeplab_v3_resnet_101_model_and_inputs,
142152
}

0 commit comments

Comments
 (0)