Skip to content

Commit 3dded9d

Browse files
peri044Chengzhe Xuchohk88
authored
feat: Add SAM2 to our model zoo (#3318)
Co-authored-by: Chengzhe Xu <[email protected]> Co-authored-by: Hoonkyung Cho <[email protected]>
1 parent c27fee0 commit 3dded9d

File tree

8 files changed

+306
-2
lines changed

8 files changed

+306
-2
lines changed

docsrc/index.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ Model Zoo
136136
* :ref:`torch_compile_stable_diffusion`
137137
* :ref:`torch_export_gpt2`
138138
* :ref:`torch_export_llama2`
139+
* :ref:`torch_export_sam2`
139140
* :ref:`notebooks`
140141

141142
.. toctree::
@@ -150,6 +151,7 @@ Model Zoo
150151
tutorials/_rendered_examples/distributed_inference/data_parallel_stable_diffusion
151152
tutorials/_rendered_examples/dynamo/torch_export_gpt2
152153
tutorials/_rendered_examples/dynamo/torch_export_llama2
154+
tutorials/_rendered_examples/dynamo/torch_export_sam2
153155
tutorials/notebooks
154156

155157
Python API Documentation
Loading
Loading
Loading
Loading

examples/dynamo/README.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,4 +18,5 @@ Model Zoo
1818
* :ref:`torch_compile_transformer`: Compiling a Transformer model using ``torch.compile``
1919
* :ref:`torch_compile_stable_diffusion`: Compiling a Stable Diffusion model using ``torch.compile``
2020
* :ref:`_torch_export_gpt2`: Compiling a GPT2 model using AOT workflow (`ir=dynamo`)
21-
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
21+
* :ref:`_torch_export_llama2`: Compiling a Llama2 model using AOT workflow (`ir=dynamo`)
22+
* :ref:`_torch_export_sam2`: Compiling SAM2 model using AOT workflow (`ir=dynamo`)

examples/dynamo/requirements.txt

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
11
cupy==13.1.0
22
triton==2.3.0
33
diffusers==0.30.3
4-
transformers==4.44.2
4+
transformers==4.44.2
5+
matplotlib
6+
pandas
7+
huggingface_hub
8+
opencv-python

examples/dynamo/torch_export_sam2.py

Lines changed: 297 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,297 @@
1+
"""
2+
.. _torch_export_sam2:
3+
4+
Compiling SAM2 using the dynamo backend
5+
==========================================================
6+
7+
This example illustrates the state of the art model `Segment Anything Model 2 (SAM2) <https://arxiv.org/pdf/2408.00714>`_ optimized using
8+
Torch-TensorRT.
9+
10+
**Segment Anything Model 2** is a foundation model towards solving promptable visual segmentation in images and videos.
11+
Install the following dependencies before compilation
12+
13+
.. code-block:: python
14+
15+
pip install -r requirements.txt
16+
17+
Certain custom modifications are required to ensure the model is exported successfully. To apply these changes, please install SAM2 using the `following fork <https://github.com/chohk88/sam2/tree/torch-trt>`_ (`Installation instructions <https://github.com/chohk88/sam2/tree/torch-trt?tab=readme-ov-file#installation>`_)
18+
19+
In the custom SAM2 fork, the following modifications have been applied to remove graph breaks and enhance latency performance, ensuring a more efficient Torch-TRT conversion:
20+
21+
- **Consistent Data Types:** Preserves input tensor dtypes, removing forced FP32 conversions.
22+
- **Masked Operations:** Uses mask-based indexing instead of directly selecting data, improving Torch-TRT compatibility.
23+
- **Safe Initialization:** Initializes tensors conditionally rather than concatenating to empty tensors.
24+
- **Standard Functions:** Avoids special contexts and custom LayerNorm, relying on built-in PyTorch functions for better stability.
25+
"""
26+
27+
# %%
28+
# Import the following libraries
29+
# -----------------------------
30+
import matplotlib
31+
import matplotlib.pyplot as plt
32+
import numpy as np
33+
import pandas as pd
34+
import torch
35+
import torch_tensorrt
36+
from PIL import Image
37+
from sam2.sam2_image_predictor import SAM2ImagePredictor
38+
from sam_components import SAM2FullModel
39+
40+
matplotlib.use("Agg")
41+
42+
# %%
43+
# Define the SAM2 model
44+
# -----------------------------
45+
# Load the ``facebook/sam2-hiera-large`` pretrained model using ``SAM2ImagePredictor`` class.
46+
# ``SAM2ImagePredictor`` provides utilities to preprocess images, store image features (via ``set_image`` function)
47+
# and predict the masks (via ``predict`` function)
48+
49+
predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
50+
51+
# %%
52+
# To ensure we export the entire model (image encoder and mask predictor) components successfully, we create a
53+
# standalone module ``SAM2FullModel`` which uses these utilities from ``SAM2ImagePredictor`` class.
54+
# ``SAM2FullModel`` performs feature extraction and mask prediction in a single step instead of two step process of
55+
# ``SAM2ImagePredictor`` (set_image and predict functions)
56+
57+
58+
class SAM2FullModel(torch.nn.Module):
59+
def __init__(self, model):
60+
super().__init__()
61+
self.image_encoder = model.forward_image
62+
self._prepare_backbone_features = model._prepare_backbone_features
63+
self.directly_add_no_mem_embed = model.directly_add_no_mem_embed
64+
self.no_mem_embed = model.no_mem_embed
65+
self._features = None
66+
67+
self.prompt_encoder = model.sam_prompt_encoder
68+
self.mask_decoder = model.sam_mask_decoder
69+
70+
self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)]
71+
72+
def forward(self, image, point_coords, point_labels):
73+
backbone_out = self.image_encoder(image)
74+
_, vision_feats, _, _ = self._prepare_backbone_features(backbone_out)
75+
76+
if self.directly_add_no_mem_embed:
77+
vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
78+
79+
feats = [
80+
feat.permute(1, 2, 0).view(1, -1, *feat_size)
81+
for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1])
82+
][::-1]
83+
features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]}
84+
85+
high_res_features = [
86+
feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"]
87+
]
88+
89+
sparse_embeddings, dense_embeddings = self.prompt_encoder(
90+
points=(point_coords, point_labels), boxes=None, masks=None
91+
)
92+
93+
low_res_masks, iou_predictions, _, _ = self.mask_decoder(
94+
image_embeddings=features["image_embed"][-1].unsqueeze(0),
95+
image_pe=self.prompt_encoder.get_dense_pe(),
96+
sparse_prompt_embeddings=sparse_embeddings,
97+
dense_prompt_embeddings=dense_embeddings,
98+
multimask_output=True,
99+
repeat_image=point_coords.shape[0] > 1,
100+
high_res_features=high_res_features,
101+
)
102+
103+
out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions}
104+
return out
105+
106+
107+
# %%
108+
# Initialize the SAM2 model with pretrained weights
109+
# --------------------------------------------------
110+
# Initialize the ``SAM2FullModel`` with the pretrained weights. Since we already initialized
111+
# ``SAM2ImagePredictor``, we can directly use the model from it (``predictor.model``). We cast the model
112+
# to FP16 precision for faster performance.
113+
encoder = predictor.model.eval().cuda()
114+
sam_model = SAM2FullModel(encoder.half()).eval().cuda()
115+
116+
# %%
117+
# Load a sample image provided in the repository.
118+
input_image = Image.open("./truck.jpg").convert("RGB")
119+
120+
# %%
121+
# Load an input image
122+
# --------------------------------------------------
123+
# Here's the input image we are going to use
124+
#
125+
# .. image:: ./truck.jpg
126+
#
127+
input_image = Image.open("./truck.jpg").convert("RGB")
128+
129+
# %%
130+
# In addition to the input image, we also provide prompts as inputs which are
131+
# used to predict the masks. The prompts can be a box, point as well as masks from
132+
# previous iteration of prediction. We use a point as a prompt in this demo similar to
133+
# the `original notebook in the SAM2 repository <https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb>`_
134+
135+
# %%
136+
# Preprocessing components
137+
# -------------------------
138+
# The following functions implement preprocessing components which apply transformations on the input image
139+
# and transform given point coordinates. We use the SAM2Transforms available via the SAM2ImagePredictor class.
140+
# To read more about the transforms, refer to https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py
141+
142+
143+
def preprocess_inputs(image, predictor):
144+
w, h = image.size
145+
orig_hw = [(h, w)]
146+
input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0")
147+
148+
point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0")
149+
point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0")
150+
151+
point_coords = torch.as_tensor(
152+
point_coords, dtype=torch.float, device=predictor.device
153+
)
154+
unnorm_coords = predictor._transforms.transform_coords(
155+
point_coords, normalize=True, orig_hw=orig_hw[0]
156+
)
157+
labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device)
158+
if len(unnorm_coords.shape) == 2:
159+
unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...]
160+
161+
input_image = input_image.half()
162+
unnorm_coords = unnorm_coords.half()
163+
164+
return (input_image, unnorm_coords, labels)
165+
166+
167+
# %%
168+
# Post Processing components
169+
# ---------------------------
170+
# The following functions implement postprocessing components which include plotting and visualizing masks and points.
171+
# We use the SAM2Transforms to post process these masks and sort them via confidence score.
172+
173+
174+
def postprocess_masks(out, predictor, image):
175+
"""Postprocess low-resolution masks and convert them for visualization."""
176+
orig_hw = (image.size[1], image.size[0]) # (height, width)
177+
masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw)
178+
masks = (masks > 0.0).squeeze(0).cpu().numpy()
179+
scores = out["iou_predictions"].squeeze(0).cpu().numpy()
180+
sorted_indices = np.argsort(scores)[::-1]
181+
return masks[sorted_indices], scores[sorted_indices]
182+
183+
184+
def show_mask(mask, ax, random_color=False, borders=True):
185+
if random_color:
186+
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
187+
else:
188+
color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
189+
h, w = mask.shape[-2:]
190+
mask = mask.astype(np.uint8)
191+
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
192+
if borders:
193+
import cv2
194+
195+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
196+
# Try to smooth contours
197+
contours = [
198+
cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours
199+
]
200+
mask_image = cv2.drawContours(
201+
mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2
202+
)
203+
ax.imshow(mask_image)
204+
205+
206+
def show_points(coords, labels, ax, marker_size=375):
207+
pos_points = coords[labels == 1]
208+
neg_points = coords[labels == 0]
209+
ax.scatter(
210+
pos_points[:, 0],
211+
pos_points[:, 1],
212+
color="green",
213+
marker="*",
214+
s=marker_size,
215+
edgecolor="white",
216+
linewidth=1.25,
217+
)
218+
ax.scatter(
219+
neg_points[:, 0],
220+
neg_points[:, 1],
221+
color="red",
222+
marker="*",
223+
s=marker_size,
224+
edgecolor="white",
225+
linewidth=1.25,
226+
)
227+
228+
229+
def visualize_masks(
230+
image, masks, scores, point_coords, point_labels, title_prefix="", save=True
231+
):
232+
"""Visualize and save masks overlaid on the original image."""
233+
for i, (mask, score) in enumerate(zip(masks, scores)):
234+
plt.figure(figsize=(10, 10))
235+
plt.imshow(image)
236+
show_mask(mask, plt.gca())
237+
show_points(point_coords, point_labels, plt.gca())
238+
plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18)
239+
plt.axis("off")
240+
plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png")
241+
plt.close()
242+
243+
244+
# %%
245+
# Preprocess the inputs
246+
# ----------------------
247+
# Preprocess the inputs. In the following snippet, ``torchtrt_inputs`` contains (input_image, unnormalized_coordinates and labels)
248+
# The unnormalized_coordinates is the representation of the point and the label (= 1 in this demo) represents foreground point.
249+
torchtrt_inputs = preprocess_inputs(input_image, predictor)
250+
251+
# %%
252+
# Torch-TensorRT compilation
253+
# ---------------------------
254+
# Export the model in non-strict mode and perform Torch-TensorRT compilation in FP16 precision.
255+
# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model.
256+
exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False)
257+
trt_model = torch_tensorrt.dynamo.compile(
258+
exp_program,
259+
inputs=torchtrt_inputs,
260+
min_block_size=1,
261+
enabled_precisions={torch.float16},
262+
use_fp32_acc=True,
263+
)
264+
trt_out = trt_model(*torchtrt_inputs)
265+
266+
# %%
267+
# Output visualization
268+
# ---------------------------
269+
# Post process the outputs of Torch-TensorRT and visualize the masks using the post processing
270+
# components provided above. The outputs should be stored in your current directory.
271+
272+
trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image)
273+
visualize_masks(
274+
input_image,
275+
trt_masks,
276+
trt_scores,
277+
torch.tensor([[500, 375]]),
278+
torch.tensor([1]),
279+
title_prefix="Torch-TRT",
280+
)
281+
282+
# %%
283+
# The predicted masks are as shown below
284+
# .. image:: sam_mask1.png
285+
# :width: 50%
286+
#
287+
# .. image:: sam_mask2.png
288+
# :width: 50%
289+
#
290+
# .. image:: sam_mask3.png
291+
# :width: 50%
292+
293+
# %%
294+
# References
295+
# ---------------------------
296+
# - `SAM 2: Segment Anything in Images and Videos <https://arxiv.org/pdf/2408.00714>`_
297+
# - `SAM 2 Github Repository <https://github.com/facebookresearch/sam2/tree/main>`_

0 commit comments

Comments
 (0)