|
| 1 | +""" |
| 2 | +.. _torch_export_sam2: |
| 3 | +
|
| 4 | +Compiling SAM2 using the dynamo backend |
| 5 | +========================================================== |
| 6 | +
|
| 7 | +This example illustrates the state of the art model `Segment Anything Model 2 (SAM2) <https://arxiv.org/pdf/2408.00714>`_ optimized using |
| 8 | +Torch-TensorRT. |
| 9 | +
|
| 10 | +**Segment Anything Model 2** is a foundation model towards solving promptable visual segmentation in images and videos. |
| 11 | +Install the following dependencies before compilation |
| 12 | +
|
| 13 | +.. code-block:: python |
| 14 | +
|
| 15 | + pip install -r requirements.txt |
| 16 | +
|
| 17 | +Certain custom modifications are required to ensure the model is exported successfully. To apply these changes, please install SAM2 using the `following fork <https://github.com/chohk88/sam2/tree/torch-trt>`_ (`Installation instructions <https://github.com/chohk88/sam2/tree/torch-trt?tab=readme-ov-file#installation>`_) |
| 18 | +
|
| 19 | +In the custom SAM2 fork, the following modifications have been applied to remove graph breaks and enhance latency performance, ensuring a more efficient Torch-TRT conversion: |
| 20 | +
|
| 21 | +- **Consistent Data Types:** Preserves input tensor dtypes, removing forced FP32 conversions. |
| 22 | +- **Masked Operations:** Uses mask-based indexing instead of directly selecting data, improving Torch-TRT compatibility. |
| 23 | +- **Safe Initialization:** Initializes tensors conditionally rather than concatenating to empty tensors. |
| 24 | +- **Standard Functions:** Avoids special contexts and custom LayerNorm, relying on built-in PyTorch functions for better stability. |
| 25 | +""" |
| 26 | + |
| 27 | +# %% |
| 28 | +# Import the following libraries |
| 29 | +# ----------------------------- |
| 30 | +import matplotlib |
| 31 | +import matplotlib.pyplot as plt |
| 32 | +import numpy as np |
| 33 | +import pandas as pd |
| 34 | +import torch |
| 35 | +import torch_tensorrt |
| 36 | +from PIL import Image |
| 37 | +from sam2.sam2_image_predictor import SAM2ImagePredictor |
| 38 | +from sam_components import SAM2FullModel |
| 39 | + |
| 40 | +matplotlib.use("Agg") |
| 41 | + |
| 42 | +# %% |
| 43 | +# Define the SAM2 model |
| 44 | +# ----------------------------- |
| 45 | +# Load the ``facebook/sam2-hiera-large`` pretrained model using ``SAM2ImagePredictor`` class. |
| 46 | +# ``SAM2ImagePredictor`` provides utilities to preprocess images, store image features (via ``set_image`` function) |
| 47 | +# and predict the masks (via ``predict`` function) |
| 48 | + |
| 49 | +predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large") |
| 50 | + |
| 51 | +# %% |
| 52 | +# To ensure we export the entire model (image encoder and mask predictor) components successfully, we create a |
| 53 | +# standalone module ``SAM2FullModel`` which uses these utilities from ``SAM2ImagePredictor`` class. |
| 54 | +# ``SAM2FullModel`` performs feature extraction and mask prediction in a single step instead of two step process of |
| 55 | +# ``SAM2ImagePredictor`` (set_image and predict functions) |
| 56 | + |
| 57 | + |
| 58 | +class SAM2FullModel(torch.nn.Module): |
| 59 | + def __init__(self, model): |
| 60 | + super().__init__() |
| 61 | + self.image_encoder = model.forward_image |
| 62 | + self._prepare_backbone_features = model._prepare_backbone_features |
| 63 | + self.directly_add_no_mem_embed = model.directly_add_no_mem_embed |
| 64 | + self.no_mem_embed = model.no_mem_embed |
| 65 | + self._features = None |
| 66 | + |
| 67 | + self.prompt_encoder = model.sam_prompt_encoder |
| 68 | + self.mask_decoder = model.sam_mask_decoder |
| 69 | + |
| 70 | + self._bb_feat_sizes = [(256, 256), (128, 128), (64, 64)] |
| 71 | + |
| 72 | + def forward(self, image, point_coords, point_labels): |
| 73 | + backbone_out = self.image_encoder(image) |
| 74 | + _, vision_feats, _, _ = self._prepare_backbone_features(backbone_out) |
| 75 | + |
| 76 | + if self.directly_add_no_mem_embed: |
| 77 | + vision_feats[-1] = vision_feats[-1] + self.no_mem_embed |
| 78 | + |
| 79 | + feats = [ |
| 80 | + feat.permute(1, 2, 0).view(1, -1, *feat_size) |
| 81 | + for feat, feat_size in zip(vision_feats[::-1], self._bb_feat_sizes[::-1]) |
| 82 | + ][::-1] |
| 83 | + features = {"image_embed": feats[-1], "high_res_feats": feats[:-1]} |
| 84 | + |
| 85 | + high_res_features = [ |
| 86 | + feat_level[-1].unsqueeze(0) for feat_level in features["high_res_feats"] |
| 87 | + ] |
| 88 | + |
| 89 | + sparse_embeddings, dense_embeddings = self.prompt_encoder( |
| 90 | + points=(point_coords, point_labels), boxes=None, masks=None |
| 91 | + ) |
| 92 | + |
| 93 | + low_res_masks, iou_predictions, _, _ = self.mask_decoder( |
| 94 | + image_embeddings=features["image_embed"][-1].unsqueeze(0), |
| 95 | + image_pe=self.prompt_encoder.get_dense_pe(), |
| 96 | + sparse_prompt_embeddings=sparse_embeddings, |
| 97 | + dense_prompt_embeddings=dense_embeddings, |
| 98 | + multimask_output=True, |
| 99 | + repeat_image=point_coords.shape[0] > 1, |
| 100 | + high_res_features=high_res_features, |
| 101 | + ) |
| 102 | + |
| 103 | + out = {"low_res_masks": low_res_masks, "iou_predictions": iou_predictions} |
| 104 | + return out |
| 105 | + |
| 106 | + |
| 107 | +# %% |
| 108 | +# Initialize the SAM2 model with pretrained weights |
| 109 | +# -------------------------------------------------- |
| 110 | +# Initialize the ``SAM2FullModel`` with the pretrained weights. Since we already initialized |
| 111 | +# ``SAM2ImagePredictor``, we can directly use the model from it (``predictor.model``). We cast the model |
| 112 | +# to FP16 precision for faster performance. |
| 113 | +encoder = predictor.model.eval().cuda() |
| 114 | +sam_model = SAM2FullModel(encoder.half()).eval().cuda() |
| 115 | + |
| 116 | +# %% |
| 117 | +# Load a sample image provided in the repository. |
| 118 | +input_image = Image.open("./truck.jpg").convert("RGB") |
| 119 | + |
| 120 | +# %% |
| 121 | +# Load an input image |
| 122 | +# -------------------------------------------------- |
| 123 | +# Here's the input image we are going to use |
| 124 | +# |
| 125 | +# .. image:: ./truck.jpg |
| 126 | +# |
| 127 | +input_image = Image.open("./truck.jpg").convert("RGB") |
| 128 | + |
| 129 | +# %% |
| 130 | +# In addition to the input image, we also provide prompts as inputs which are |
| 131 | +# used to predict the masks. The prompts can be a box, point as well as masks from |
| 132 | +# previous iteration of prediction. We use a point as a prompt in this demo similar to |
| 133 | +# the `original notebook in the SAM2 repository <https://github.com/facebookresearch/sam2/blob/main/notebooks/image_predictor_example.ipynb>`_ |
| 134 | + |
| 135 | +# %% |
| 136 | +# Preprocessing components |
| 137 | +# ------------------------- |
| 138 | +# The following functions implement preprocessing components which apply transformations on the input image |
| 139 | +# and transform given point coordinates. We use the SAM2Transforms available via the SAM2ImagePredictor class. |
| 140 | +# To read more about the transforms, refer to https://github.com/facebookresearch/sam2/blob/main/sam2/utils/transforms.py |
| 141 | + |
| 142 | + |
| 143 | +def preprocess_inputs(image, predictor): |
| 144 | + w, h = image.size |
| 145 | + orig_hw = [(h, w)] |
| 146 | + input_image = predictor._transforms(np.array(image))[None, ...].to("cuda:0") |
| 147 | + |
| 148 | + point_coords = torch.tensor([[500, 375]], dtype=torch.float).to("cuda:0") |
| 149 | + point_labels = torch.tensor([1], dtype=torch.int).to("cuda:0") |
| 150 | + |
| 151 | + point_coords = torch.as_tensor( |
| 152 | + point_coords, dtype=torch.float, device=predictor.device |
| 153 | + ) |
| 154 | + unnorm_coords = predictor._transforms.transform_coords( |
| 155 | + point_coords, normalize=True, orig_hw=orig_hw[0] |
| 156 | + ) |
| 157 | + labels = torch.as_tensor(point_labels, dtype=torch.int, device=predictor.device) |
| 158 | + if len(unnorm_coords.shape) == 2: |
| 159 | + unnorm_coords, labels = unnorm_coords[None, ...], labels[None, ...] |
| 160 | + |
| 161 | + input_image = input_image.half() |
| 162 | + unnorm_coords = unnorm_coords.half() |
| 163 | + |
| 164 | + return (input_image, unnorm_coords, labels) |
| 165 | + |
| 166 | + |
| 167 | +# %% |
| 168 | +# Post Processing components |
| 169 | +# --------------------------- |
| 170 | +# The following functions implement postprocessing components which include plotting and visualizing masks and points. |
| 171 | +# We use the SAM2Transforms to post process these masks and sort them via confidence score. |
| 172 | + |
| 173 | + |
| 174 | +def postprocess_masks(out, predictor, image): |
| 175 | + """Postprocess low-resolution masks and convert them for visualization.""" |
| 176 | + orig_hw = (image.size[1], image.size[0]) # (height, width) |
| 177 | + masks = predictor._transforms.postprocess_masks(out["low_res_masks"], orig_hw) |
| 178 | + masks = (masks > 0.0).squeeze(0).cpu().numpy() |
| 179 | + scores = out["iou_predictions"].squeeze(0).cpu().numpy() |
| 180 | + sorted_indices = np.argsort(scores)[::-1] |
| 181 | + return masks[sorted_indices], scores[sorted_indices] |
| 182 | + |
| 183 | + |
| 184 | +def show_mask(mask, ax, random_color=False, borders=True): |
| 185 | + if random_color: |
| 186 | + color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
| 187 | + else: |
| 188 | + color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) |
| 189 | + h, w = mask.shape[-2:] |
| 190 | + mask = mask.astype(np.uint8) |
| 191 | + mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
| 192 | + if borders: |
| 193 | + import cv2 |
| 194 | + |
| 195 | + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) |
| 196 | + # Try to smooth contours |
| 197 | + contours = [ |
| 198 | + cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours |
| 199 | + ] |
| 200 | + mask_image = cv2.drawContours( |
| 201 | + mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2 |
| 202 | + ) |
| 203 | + ax.imshow(mask_image) |
| 204 | + |
| 205 | + |
| 206 | +def show_points(coords, labels, ax, marker_size=375): |
| 207 | + pos_points = coords[labels == 1] |
| 208 | + neg_points = coords[labels == 0] |
| 209 | + ax.scatter( |
| 210 | + pos_points[:, 0], |
| 211 | + pos_points[:, 1], |
| 212 | + color="green", |
| 213 | + marker="*", |
| 214 | + s=marker_size, |
| 215 | + edgecolor="white", |
| 216 | + linewidth=1.25, |
| 217 | + ) |
| 218 | + ax.scatter( |
| 219 | + neg_points[:, 0], |
| 220 | + neg_points[:, 1], |
| 221 | + color="red", |
| 222 | + marker="*", |
| 223 | + s=marker_size, |
| 224 | + edgecolor="white", |
| 225 | + linewidth=1.25, |
| 226 | + ) |
| 227 | + |
| 228 | + |
| 229 | +def visualize_masks( |
| 230 | + image, masks, scores, point_coords, point_labels, title_prefix="", save=True |
| 231 | +): |
| 232 | + """Visualize and save masks overlaid on the original image.""" |
| 233 | + for i, (mask, score) in enumerate(zip(masks, scores)): |
| 234 | + plt.figure(figsize=(10, 10)) |
| 235 | + plt.imshow(image) |
| 236 | + show_mask(mask, plt.gca()) |
| 237 | + show_points(point_coords, point_labels, plt.gca()) |
| 238 | + plt.title(f"{title_prefix} Mask {i + 1}, Score: {score:.3f}", fontsize=18) |
| 239 | + plt.axis("off") |
| 240 | + plt.savefig(f"{title_prefix}_output_mask_{i + 1}.png") |
| 241 | + plt.close() |
| 242 | + |
| 243 | + |
| 244 | +# %% |
| 245 | +# Preprocess the inputs |
| 246 | +# ---------------------- |
| 247 | +# Preprocess the inputs. In the following snippet, ``torchtrt_inputs`` contains (input_image, unnormalized_coordinates and labels) |
| 248 | +# The unnormalized_coordinates is the representation of the point and the label (= 1 in this demo) represents foreground point. |
| 249 | +torchtrt_inputs = preprocess_inputs(input_image, predictor) |
| 250 | + |
| 251 | +# %% |
| 252 | +# Torch-TensorRT compilation |
| 253 | +# --------------------------- |
| 254 | +# Export the model in non-strict mode and perform Torch-TensorRT compilation in FP16 precision. |
| 255 | +# We enable FP32 matmul accumulation using ``use_fp32_acc=True`` to preserve accuracy with the original Pytorch model. |
| 256 | +exp_program = torch.export.export(sam_model, torchtrt_inputs, strict=False) |
| 257 | +trt_model = torch_tensorrt.dynamo.compile( |
| 258 | + exp_program, |
| 259 | + inputs=torchtrt_inputs, |
| 260 | + min_block_size=1, |
| 261 | + enabled_precisions={torch.float16}, |
| 262 | + use_fp32_acc=True, |
| 263 | +) |
| 264 | +trt_out = trt_model(*torchtrt_inputs) |
| 265 | + |
| 266 | +# %% |
| 267 | +# Output visualization |
| 268 | +# --------------------------- |
| 269 | +# Post process the outputs of Torch-TensorRT and visualize the masks using the post processing |
| 270 | +# components provided above. The outputs should be stored in your current directory. |
| 271 | + |
| 272 | +trt_masks, trt_scores = postprocess_masks(trt_out, predictor, input_image) |
| 273 | +visualize_masks( |
| 274 | + input_image, |
| 275 | + trt_masks, |
| 276 | + trt_scores, |
| 277 | + torch.tensor([[500, 375]]), |
| 278 | + torch.tensor([1]), |
| 279 | + title_prefix="Torch-TRT", |
| 280 | +) |
| 281 | + |
| 282 | +# %% |
| 283 | +# The predicted masks are as shown below |
| 284 | +# .. image:: sam_mask1.png |
| 285 | +# :width: 50% |
| 286 | +# |
| 287 | +# .. image:: sam_mask2.png |
| 288 | +# :width: 50% |
| 289 | +# |
| 290 | +# .. image:: sam_mask3.png |
| 291 | +# :width: 50% |
| 292 | + |
| 293 | +# %% |
| 294 | +# References |
| 295 | +# --------------------------- |
| 296 | +# - `SAM 2: Segment Anything in Images and Videos <https://arxiv.org/pdf/2408.00714>`_ |
| 297 | +# - `SAM 2 Github Repository <https://github.com/facebookresearch/sam2/tree/main>`_ |
0 commit comments