|
1 | 1 | import json
|
2 | 2 | import os
|
3 | 3 |
|
4 |
| -import custom_models as cm |
5 |
| -import timm |
6 | 4 | import torch
|
7 |
| -import torch.nn as nn |
8 |
| -import torch.nn.functional as F |
9 |
| -import torchvision.models as models |
10 |
| -from transformers import BertConfig, BertModel, BertTokenizer |
11 | 5 |
|
12 | 6 | torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
|
13 | 7 |
|
|
26 | 20 | VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all")
|
27 | 21 |
|
28 | 22 | # Key models selected for benchmarking with their respective paths
|
29 |
| -BENCHMARK_MODELS = { |
30 |
| - "vgg16": { |
31 |
| - "model": models.vgg16(weights=models.VGG16_Weights.DEFAULT), |
32 |
| - "path": ["script", "pytorch"], |
33 |
| - }, |
34 |
| - "resnet50": { |
35 |
| - "model": models.resnet50(weights=None), |
36 |
| - "path": ["script", "pytorch"], |
37 |
| - }, |
38 |
| - "efficientnet_b0": { |
39 |
| - "model": timm.create_model("efficientnet_b0", pretrained=True), |
40 |
| - "path": ["script", "pytorch"], |
41 |
| - }, |
42 |
| - "vit": { |
43 |
| - "model": timm.create_model("vit_base_patch16_224", pretrained=True), |
44 |
| - "path": ["script", "pytorch"], |
45 |
| - }, |
46 |
| - "bert_base_uncased": {"model": cm.BertModule(), "path": "trace"}, |
47 |
| -} |
| 23 | +from utils import BENCHMARK_MODELS |
48 | 24 |
|
49 | 25 |
|
50 | 26 | def get(n, m, manifest):
|
51 | 27 | print("Downloading {}".format(n))
|
52 | 28 | traced_filename = "models/" + n + "_traced.jit.pt"
|
53 | 29 | script_filename = "models/" + n + "_scripted.jit.pt"
|
54 | 30 | pytorch_filename = "models/" + n + "_pytorch.pt"
|
55 |
| - x = torch.ones((1, 3, 300, 300)).cuda() |
56 |
| - if n == "bert_base_uncased": |
57 |
| - traced_model = m["model"] |
58 |
| - torch.jit.save(traced_model, traced_filename) |
| 31 | + |
| 32 | + m["model"] = m["model"].eval().cuda() |
| 33 | + |
| 34 | + # Get all desired model save specifications as list |
| 35 | + paths = [m["path"]] if isinstance(m["path"], str) else m["path"] |
| 36 | + |
| 37 | + # Depending on specified model save specifications, save desired model formats |
| 38 | + if any(path in ("all", "torchscript", "trace") for path in paths): |
| 39 | + # (TorchScript) Traced model |
| 40 | + trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]]) |
| 41 | + torch.jit.save(trace_model, traced_filename) |
59 | 42 | manifest.update({n: [traced_filename]})
|
60 |
| - else: |
61 |
| - m["model"] = m["model"].eval().cuda() |
62 |
| - |
63 |
| - # Get all desired model save specifications as list |
64 |
| - paths = [m["path"]] if isinstance(m["path"], str) else m["path"] |
65 |
| - |
66 |
| - # Depending on specified model save specifications, save desired model formats |
67 |
| - if any(path in ("all", "torchscript", "trace") for path in paths): |
68 |
| - # (TorchScript) Traced model |
69 |
| - trace_model = torch.jit.trace(m["model"], [x]) |
70 |
| - torch.jit.save(trace_model, traced_filename) |
71 |
| - manifest.update({n: [traced_filename]}) |
72 |
| - if any(path in ("all", "torchscript", "script") for path in paths): |
73 |
| - # (TorchScript) Scripted model |
74 |
| - script_model = torch.jit.script(m["model"]) |
75 |
| - torch.jit.save(script_model, script_filename) |
76 |
| - if n in manifest.keys(): |
77 |
| - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
78 |
| - files.append(script_filename) |
79 |
| - manifest.update({n: files}) |
80 |
| - else: |
81 |
| - manifest.update({n: [script_filename]}) |
82 |
| - if any(path in ("all", "pytorch") for path in paths): |
83 |
| - # (PyTorch Module) model |
84 |
| - torch.save(m["model"], pytorch_filename) |
85 |
| - if n in manifest.keys(): |
86 |
| - files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
87 |
| - files.append(script_filename) |
88 |
| - manifest.update({n: files}) |
89 |
| - else: |
90 |
| - manifest.update({n: [script_filename]}) |
| 43 | + if any(path in ("all", "torchscript", "script") for path in paths): |
| 44 | + # (TorchScript) Scripted model |
| 45 | + script_model = torch.jit.script(m["model"]) |
| 46 | + torch.jit.save(script_model, script_filename) |
| 47 | + if n in manifest.keys(): |
| 48 | + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
| 49 | + files.append(script_filename) |
| 50 | + manifest.update({n: files}) |
| 51 | + else: |
| 52 | + manifest.update({n: [script_filename]}) |
| 53 | + if any(path in ("all", "pytorch") for path in paths): |
| 54 | + # (PyTorch Module) model |
| 55 | + torch.save(m["model"], pytorch_filename) |
| 56 | + if n in manifest.keys(): |
| 57 | + files = list(manifest[n]) if type(manifest[n]) != list else manifest[n] |
| 58 | + files.append(script_filename) |
| 59 | + manifest.update({n: files}) |
| 60 | + else: |
| 61 | + manifest.update({n: [script_filename]}) |
| 62 | + |
91 | 63 | return manifest
|
92 | 64 |
|
93 | 65 |
|
|
0 commit comments