|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | +# |
| 4 | +# This source code is licensed under the BSD-style license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | + |
| 7 | +import unittest |
| 8 | + |
| 9 | +from dataclasses import dataclass |
| 10 | +from typing import List, Optional, Tuple |
| 11 | + |
| 12 | +import numpy as np |
| 13 | +import PIL |
| 14 | +import torch |
| 15 | + |
| 16 | +from parameterized import parameterized |
| 17 | +from PIL import Image |
| 18 | + |
| 19 | +from torchtune.models.clip.inference._transforms import ( |
| 20 | + _CLIPImageTransform, |
| 21 | + CLIPImageTransform, |
| 22 | +) |
| 23 | + |
| 24 | +from torchtune.modules.transforms import ( |
| 25 | + find_supported_resolutions, |
| 26 | + get_canvas_best_fit, |
| 27 | + get_inscribed_size, |
| 28 | +) |
| 29 | +from torchvision.transforms.v2 import functional as F |
| 30 | + |
| 31 | +from .export_preprocess_lib import export_preprocess |
| 32 | + |
| 33 | + |
| 34 | +@dataclass |
| 35 | +class PreprocessConfig: |
| 36 | + image_mean: Optional[List[float]] = None |
| 37 | + image_std: Optional[List[float]] = None |
| 38 | + resize_to_max_canvas: bool = True |
| 39 | + resample: str = "bilinear" |
| 40 | + antialias: bool = False |
| 41 | + tile_size: int = 224 |
| 42 | + max_num_tiles: int = 4 |
| 43 | + possible_resolutions = None |
| 44 | + |
| 45 | + |
| 46 | +class TestImageTransform(unittest.TestCase): |
| 47 | + """ |
| 48 | + This unittest checks that the exported image transform model produces the |
| 49 | + same output as the reference model. |
| 50 | +
|
| 51 | + Reference model: CLIPImageTransform |
| 52 | + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115 |
| 53 | + Eager and exported models: _CLIPImageTransform |
| 54 | + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L26 |
| 55 | + """ |
| 56 | + |
| 57 | + def setUp(self): |
| 58 | + np.random.seed(0) |
| 59 | + |
| 60 | + def prepare_inputs( |
| 61 | + self, image: Image.Image, config: PreprocessConfig |
| 62 | + ) -> Tuple[torch.Tensor]: |
| 63 | + """ |
| 64 | + Prepare inputs for eager and exported models: |
| 65 | + - Convert PIL image to tensor. |
| 66 | + - Calculate the best resolution; a canvas with height and width divisible by tile_size. |
| 67 | + - Calculate the inscribed size; the size of the image inscribed within best_resolution, |
| 68 | + without distortion. |
| 69 | +
|
| 70 | + These calculations are done by the reference model inside __init__ and __call__ |
| 71 | + https://github.com/pytorch/torchtune/blob/main/torchtune/models/clip/inference/_transforms.py#L115 |
| 72 | + """ |
| 73 | + image_tensor = F.to_dtype( |
| 74 | + F.grayscale_to_rgb_image(F.to_image(image)), scale=True |
| 75 | + ) |
| 76 | + |
| 77 | + # Calculate possible resolutions. |
| 78 | + possible_resolutions = config.possible_resolutions |
| 79 | + if possible_resolutions is None: |
| 80 | + possible_resolutions = find_supported_resolutions( |
| 81 | + max_num_tiles=config.max_num_tiles, tile_size=config.tile_size |
| 82 | + ) |
| 83 | + possible_resolutions = torch.tensor(possible_resolutions).reshape(-1, 2) |
| 84 | + |
| 85 | + # Limit resizing. |
| 86 | + max_size = None if config.resize_to_max_canvas else config.tile_size |
| 87 | + |
| 88 | + # Find the best canvas to fit the image without distortion. |
| 89 | + best_resolution = get_canvas_best_fit( |
| 90 | + image=image_tensor, |
| 91 | + possible_resolutions=possible_resolutions, |
| 92 | + resize_to_max_canvas=config.resize_to_max_canvas, |
| 93 | + ) |
| 94 | + best_resolution = torch.tensor(best_resolution) |
| 95 | + |
| 96 | + # Find the dimensions of the image, such that it is inscribed within best_resolution |
| 97 | + # without distortion. |
| 98 | + inscribed_size = get_inscribed_size( |
| 99 | + image_tensor.shape[-2:], best_resolution, max_size |
| 100 | + ) |
| 101 | + inscribed_size = torch.tensor(inscribed_size) |
| 102 | + |
| 103 | + return image_tensor, inscribed_size, best_resolution |
| 104 | + |
| 105 | + # This test setup mirrors the one in torchtune: |
| 106 | + # https://github.com/pytorch/torchtune/blob/main/tests/torchtune/models/clip/test_clip_image_transform.py |
| 107 | + # The values are slightly different, as torchtune uses antialias=True, |
| 108 | + # and this test uses antialias=False, which is exportable (has a portable kernel). |
| 109 | + @parameterized.expand( |
| 110 | + [ |
| 111 | + ( |
| 112 | + (100, 400, 3), # image_size |
| 113 | + torch.Size([2, 3, 224, 224]), # expected shape |
| 114 | + False, # resize_to_max_canvas |
| 115 | + [0.2230, 0.1763], # expected_tile_means |
| 116 | + [1.0, 1.0], # expected_tile_max |
| 117 | + [0.0, 0.0], # expected_tile_min |
| 118 | + [1, 2], # expected_aspect_ratio |
| 119 | + ), |
| 120 | + ( |
| 121 | + (1000, 300, 3), # image_size |
| 122 | + torch.Size([4, 3, 224, 224]), # expected shape |
| 123 | + True, # resize_to_max_canvas |
| 124 | + [0.5005, 0.4992, 0.5004, 0.1651], # expected_tile_means |
| 125 | + [0.9976, 0.9940, 0.9936, 0.9906], # expected_tile_max |
| 126 | + [0.0037, 0.0047, 0.0039, 0.0], # expected_tile_min |
| 127 | + [4, 1], # expected_aspect_ratio |
| 128 | + ), |
| 129 | + ( |
| 130 | + (200, 200, 3), # image_size |
| 131 | + torch.Size([4, 3, 224, 224]), # expected shape |
| 132 | + True, # resize_to_max_canvas |
| 133 | + [0.5012, 0.5020, 0.5010, 0.4991], # expected_tile_means |
| 134 | + [0.9921, 0.9925, 0.9969, 0.9908], # expected_tile_max |
| 135 | + [0.0056, 0.0069, 0.0059, 0.0032], # expected_tile_min |
| 136 | + [2, 2], # expected_aspect_ratio |
| 137 | + ), |
| 138 | + ( |
| 139 | + (600, 200, 3), # image_size |
| 140 | + torch.Size([3, 3, 224, 224]), # expected shape |
| 141 | + False, # resize_to_max_canvas |
| 142 | + [0.4472, 0.4468, 0.3031], # expected_tile_means |
| 143 | + [1.0, 1.0, 1.0], # expected_tile_max |
| 144 | + [0.0, 0.0, 0.0], # expected_tile_min |
| 145 | + [3, 1], # expected_aspect_ratio |
| 146 | + ), |
| 147 | + ] |
| 148 | + ) |
| 149 | + def test_preprocess( |
| 150 | + self, |
| 151 | + image_size: Tuple[int], |
| 152 | + expected_shape: torch.Size, |
| 153 | + resize_to_max_canvas: bool, |
| 154 | + expected_tile_means: List[float], |
| 155 | + expected_tile_max: List[float], |
| 156 | + expected_tile_min: List[float], |
| 157 | + expected_ar: List[int], |
| 158 | + ) -> None: |
| 159 | + config = PreprocessConfig(resize_to_max_canvas=resize_to_max_canvas) |
| 160 | + |
| 161 | + reference_model = CLIPImageTransform( |
| 162 | + image_mean=config.image_mean, |
| 163 | + image_std=config.image_std, |
| 164 | + resize_to_max_canvas=config.resize_to_max_canvas, |
| 165 | + resample=config.resample, |
| 166 | + antialias=config.antialias, |
| 167 | + tile_size=config.tile_size, |
| 168 | + max_num_tiles=config.max_num_tiles, |
| 169 | + possible_resolutions=None, |
| 170 | + ) |
| 171 | + |
| 172 | + eager_model = _CLIPImageTransform( |
| 173 | + image_mean=config.image_mean, |
| 174 | + image_std=config.image_std, |
| 175 | + resample=config.resample, |
| 176 | + antialias=config.antialias, |
| 177 | + tile_size=config.tile_size, |
| 178 | + max_num_tiles=config.max_num_tiles, |
| 179 | + ) |
| 180 | + |
| 181 | + exported_model = export_preprocess( |
| 182 | + image_mean=config.image_mean, |
| 183 | + image_std=config.image_std, |
| 184 | + resample=config.resample, |
| 185 | + antialias=config.antialias, |
| 186 | + tile_size=config.tile_size, |
| 187 | + max_num_tiles=config.max_num_tiles, |
| 188 | + ) |
| 189 | + |
| 190 | + # Prepare image input. |
| 191 | + image = ( |
| 192 | + np.random.randint(0, 256, np.prod(image_size)) |
| 193 | + .reshape(image_size) |
| 194 | + .astype(np.uint8) |
| 195 | + ) |
| 196 | + image = PIL.Image.fromarray(image) |
| 197 | + |
| 198 | + # Run reference model. |
| 199 | + reference_output = reference_model(image=image) |
| 200 | + reference_image = reference_output["image"] |
| 201 | + reference_ar = reference_output["aspect_ratio"].tolist() |
| 202 | + |
| 203 | + # Check output shape and aspect ratio matches expected values. |
| 204 | + self.assertEqual(reference_image.shape, expected_shape) |
| 205 | + self.assertEqual(reference_ar, expected_ar) |
| 206 | + |
| 207 | + # Check pixel values within expected range [0, 1] |
| 208 | + self.assertTrue(0 <= reference_image.min() <= reference_image.max() <= 1) |
| 209 | + |
| 210 | + # Check mean, max, and min values of the tiles match expected values. |
| 211 | + for i, tile in enumerate(reference_image): |
| 212 | + self.assertAlmostEqual( |
| 213 | + tile.mean().item(), expected_tile_means[i], delta=1e-4 |
| 214 | + ) |
| 215 | + self.assertAlmostEqual(tile.max().item(), expected_tile_max[i], delta=1e-4) |
| 216 | + self.assertAlmostEqual(tile.min().item(), expected_tile_min[i], delta=1e-4) |
| 217 | + |
| 218 | + # Check num tiles matches the product of the aspect ratio. |
| 219 | + expected_num_tiles = reference_ar[0] * reference_ar[1] |
| 220 | + self.assertEqual(expected_num_tiles, reference_image.shape[0]) |
| 221 | + |
| 222 | + # Pre-work for eager and exported models. The reference model performs these |
| 223 | + # calculations and passes the result to _CLIPImageTransform, the exportable model. |
| 224 | + image_tensor, inscribed_size, best_resolution = self.prepare_inputs( |
| 225 | + image=image, config=config |
| 226 | + ) |
| 227 | + |
| 228 | + # Run eager and exported models. |
| 229 | + eager_image, eager_ar = eager_model( |
| 230 | + image_tensor, inscribed_size, best_resolution |
| 231 | + ) |
| 232 | + eager_ar = eager_ar.tolist() |
| 233 | + |
| 234 | + exported_image, exported_ar = exported_model.module()( |
| 235 | + image_tensor, inscribed_size, best_resolution |
| 236 | + ) |
| 237 | + exported_ar = exported_ar.tolist() |
| 238 | + |
| 239 | + # Check eager and exported models match reference model. |
| 240 | + self.assertTrue(torch.allclose(reference_image, eager_image)) |
| 241 | + self.assertTrue(torch.allclose(reference_image, exported_image)) |
| 242 | + |
| 243 | + self.assertTrue(reference_ar, eager_ar) |
| 244 | + self.assertTrue(reference_ar, exported_ar) |
0 commit comments