|
14 | 14 |
|
15 | 15 |
|
16 | 16 | class TestResNet18(unittest.TestCase):
|
17 |
| - def test_fp32_resnet18(self): |
18 |
| - inputs = (torch.ones(1, 3, 224, 224),) |
| 17 | + inputs = (torch.ones(1, 3, 224, 224),) |
| 18 | + dynamic_shapes = ( |
| 19 | + { |
| 20 | + 2: torch.export.Dim("height", min=224, max=455), |
| 21 | + 3: torch.export.Dim("width", min=224, max=455), |
| 22 | + }, |
| 23 | + ) |
| 24 | + |
| 25 | + class DynamicResNet(torch.nn.Module): |
| 26 | + def __init__(self): |
| 27 | + super().__init__() |
| 28 | + self.model = torchvision.models.resnet18() |
| 29 | + |
| 30 | + def forward(self, x): |
| 31 | + x = torch.nn.functional.interpolate( |
| 32 | + x, |
| 33 | + size=(224, 224), |
| 34 | + mode="bilinear", |
| 35 | + align_corners=True, |
| 36 | + antialias=False, |
| 37 | + ) |
| 38 | + return self.model(x) |
| 39 | + |
| 40 | + def _test_exported_resnet(self, tester): |
19 | 41 | (
|
20 |
| - Tester(torchvision.models.resnet18(), inputs) |
21 |
| - .export() |
| 42 | + tester.export() |
22 | 43 | .to_edge()
|
23 | 44 | .partition()
|
| 45 | + .check_not( |
| 46 | + [ |
| 47 | + "executorch_exir_dialects_edge__ops_aten_convolution_default", |
| 48 | + "executorch_exir_dialects_edge__ops_aten_mean_dim", |
| 49 | + ] |
| 50 | + ) |
| 51 | + .check(["torch.ops.higher_order.executorch_call_delegate"]) |
24 | 52 | .to_executorch()
|
25 | 53 | .serialize()
|
26 | 54 | .run_method_and_compare_outputs()
|
27 | 55 | )
|
28 | 56 |
|
| 57 | + def test_fp32_resnet18(self): |
| 58 | + self._test_exported_resnet(Tester(torchvision.models.resnet18(), self.inputs)) |
| 59 | + |
29 | 60 | def test_qs8_resnet18(self):
|
30 |
| - inputs = (torch.ones(1, 3, 224, 224),) |
31 |
| - ( |
32 |
| - Tester(torchvision.models.resnet18(), inputs) |
33 |
| - .quantize(Quantize(calibrate=False)) |
34 |
| - .export() |
35 |
| - .to_edge() |
36 |
| - .partition() |
37 |
| - .to_executorch() |
38 |
| - .serialize() |
39 |
| - .run_method_and_compare_outputs() |
| 61 | + quantized_tester = Tester(torchvision.models.resnet18(), self.inputs).quantize( |
| 62 | + Quantize(calibrate=False) |
| 63 | + ) |
| 64 | + self._test_exported_resnet(quantized_tester) |
| 65 | + |
| 66 | + def test_fp32_resnet18_dynamic(self): |
| 67 | + self._test_exported_resnet( |
| 68 | + Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes) |
| 69 | + ) |
| 70 | + |
| 71 | + def test_qs8_resnet18_dynamic(self): |
| 72 | + self._test_exported_resnet( |
| 73 | + Tester(self.DynamicResNet(), self.inputs, self.dynamic_shapes).quantize( |
| 74 | + Quantize(calibrate=False) |
| 75 | + ) |
40 | 76 | )
|
0 commit comments