Skip to content

Commit 2925c8a

Browse files
committed
chore: refactor
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 6a98dec commit 2925c8a

File tree

4 files changed

+276
-52
lines changed

4 files changed

+276
-52
lines changed

tools/perf/benchmark.sh

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/bin/bash
2+
3+
# Download the Torchscript models
4+
# python hub.py
5+
6+
batch_sizes=(1 2 4 8 16 32 64 128 256)
7+
8+
# # Benchmark VGG16 model
9+
# echo "Benchmarking VGG16 model"
10+
# for bs in 1 2
11+
# do
12+
# python perf_run.py --model models/vgg16_scripted.jit.pt \
13+
# --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
14+
# --batch_size ${bs} \
15+
# --backends torch,torch_tensorrt,tensorrt \
16+
# --report "vgg_perf_bs${bs}.txt"
17+
# done
18+
#
19+
# # Benchmark Resnet50 model
20+
# echo "Benchmarking Resnet50 model"
21+
# for bs in 1 2
22+
# do
23+
# python perf_run.py --model models/resnet50_scripted.jit.pt \
24+
# --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
25+
# --batch_size ${bs} \
26+
# --backends torch,torch_tensorrt,tensorrt \
27+
# --report "rn50_perf_bs${bs}.txt"
28+
# done
29+
#
30+
# # Benchmark VIT model
31+
# echo "Benchmarking VIT model"
32+
# for bs in 1 2
33+
# do
34+
# python perf_run.py --model models/vit_scripted.jit.pt \
35+
# --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
36+
# --batch_size ${bs} \
37+
# --backends torch,torch_tensorrt,tensorrt \
38+
# --report "vit_perf_bs${bs}.txt"
39+
# done
40+
#
41+
# # Benchmark EfficientNet-B0 model
42+
# echo "Benchmarking EfficientNet-B0 model"
43+
# for bs in 1 2
44+
# do
45+
# python perf_run.py --model models/efficientnet_b0_scripted.jit.pt \
46+
# --precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
47+
# --batch_size ${bs} \
48+
# --backends torch,torch_tensorrt,tensorrt \
49+
# --report "eff_b0_perf_bs${bs}.txt"
50+
# done
51+
52+
# Benchmark BERT model
53+
for bs in 1
54+
do
55+
python perf_run.py --model models/bert_base_uncased_traced.jit.pt \
56+
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
57+
--batch_size ${bs} \
58+
--backends torch_tensorrt \
59+
--truncate \
60+
--report "bert_base_perf_bs${bs}.txt"
61+
done

tools/perf/custom_models.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import torch
2+
import torch.nn as nn
3+
from transformers import BertModel, BertTokenizer, BertConfig
4+
import torch.nn.functional as F
5+
6+
def BertModule():
7+
model_name = "bert-base-uncased"
8+
enc = BertTokenizer.from_pretrained(model_name)
9+
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
10+
tokenized_text = enc.tokenize(text)
11+
masked_index = 8
12+
tokenized_text[masked_index] = "[MASK]"
13+
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
14+
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
15+
tokens_tensor = torch.tensor([indexed_tokens])
16+
segments_tensors = torch.tensor([segments_ids])
17+
config = BertConfig(
18+
vocab_size_or_config_json_file=32000,
19+
hidden_size=768,
20+
num_hidden_layers=12,
21+
num_attention_heads=12,
22+
intermediate_size=3072,
23+
torchscript=True,
24+
)
25+
model = BertModel(config)
26+
model.eval()
27+
model = BertModel.from_pretrained(model_name, torchscript=True)
28+
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
29+
return traced_model

tools/perf/hub.py

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.nn.functional as F
4+
import torchvision.models as models
5+
import timm
6+
from transformers import BertModel, BertTokenizer, BertConfig
7+
import os
8+
import json
9+
import custom_models as cm
10+
11+
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
12+
13+
torch_version = torch.__version__
14+
15+
# Detect case of no GPU before deserialization of models on GPU
16+
if not torch.cuda.is_available():
17+
raise Exception("No GPU found. Please check if installed torch version is compatible with CUDA version")
18+
19+
# Downloads all model files again if manifest file is not present
20+
MANIFEST_FILE = 'model_manifest.json'
21+
22+
BENCHMARK_MODELS = {
23+
"vgg16": {
24+
"model": models.vgg16(weights=None),
25+
"path": "script"
26+
},
27+
"resnet50": {
28+
"model": models.resnet50(weights=None),
29+
"path": "script"
30+
},
31+
"efficientnet_b0": {
32+
"model": timm.create_model('efficientnet_b0', pretrained=True),
33+
"path": "script"
34+
},
35+
"vit": {
36+
"model": timm.create_model('vit_base_patch16_224', pretrained=True),
37+
"path": "script"
38+
},
39+
"bert_base_uncased": {
40+
"model": cm.BertModule(),
41+
"path": "trace"
42+
},
43+
}
44+
45+
46+
def get(n, m, manifest):
47+
print("Downloading {}".format(n))
48+
traced_filename = "models/" + n + '_traced.jit.pt'
49+
script_filename = "models/" + n + '_scripted.jit.pt'
50+
x = torch.ones((1, 3, 300, 300)).cuda()
51+
if n == "bert-base-uncased":
52+
traced_model = m["model"]
53+
torch.jit.save(traced_model, traced_filename)
54+
manifest.update({n: [traced_filename]})
55+
else:
56+
m["model"] = m["model"].eval().cuda()
57+
if m["path"] == "both" or m["path"] == "trace":
58+
trace_model = torch.jit.trace(m["model"], [x])
59+
torch.jit.save(trace_model, traced_filename)
60+
manifest.update({n: [traced_filename]})
61+
if m["path"] == "both" or m["path"] == "script":
62+
script_model = torch.jit.script(m["model"])
63+
torch.jit.save(script_model, script_filename)
64+
if n in manifest.keys():
65+
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
66+
files.append(script_filename)
67+
manifest.update({n: files})
68+
else:
69+
manifest.update({n: [script_filename]})
70+
return manifest
71+
72+
73+
def download_models(version_matches, manifest):
74+
# Download all models if torch version is different than model version
75+
if not version_matches:
76+
for n, m in BENCHMARK_MODELS.items():
77+
manifest = get(n, m, manifest)
78+
else:
79+
for n, m in BENCHMARK_MODELS.items():
80+
scripted_filename = "models/" + n + "_scripted.jit.pt"
81+
traced_filename = "models/" + n + "_traced.jit.pt"
82+
# Check if model file exists on disk
83+
if (m["path"] == "both" and os.path.exists(scripted_filename) and os.path.exists(traced_filename)) or \
84+
(m["path"] == "script" and os.path.exists(scripted_filename)) or \
85+
(m["path"] == "trace" and os.path.exists(traced_filename)):
86+
print("Skipping {} ".format(n))
87+
continue
88+
manifest = get(n, m, manifest)
89+
90+
91+
def main():
92+
manifest = None
93+
version_matches = False
94+
manifest_exists = False
95+
96+
# Check if Manifest file exists or is empty
97+
if not os.path.exists(MANIFEST_FILE) or os.stat(MANIFEST_FILE).st_size == 0:
98+
manifest = {"version": torch_version}
99+
100+
# Creating an empty manifest file for overwriting post setup
101+
os.system('touch {}'.format(MANIFEST_FILE))
102+
else:
103+
manifest_exists = True
104+
105+
# Load manifest if already exists
106+
with open(MANIFEST_FILE, 'r') as f:
107+
manifest = json.load(f)
108+
if manifest['version'] == torch_version:
109+
version_matches = True
110+
else:
111+
print("Torch version: {} mismatches \
112+
with manifest's version: {}. Re-downloading \
113+
all models".format(torch_version, manifest['version']))
114+
115+
# Overwrite the manifest version as current torch version
116+
manifest['version'] = torch_version
117+
118+
download_models(version_matches, manifest)
119+
120+
# Write updated manifest file to disk
121+
with open(MANIFEST_FILE, 'r+') as f:
122+
data = f.read()
123+
f.seek(0)
124+
record = json.dumps(manifest)
125+
f.write(record)
126+
f.truncate()
127+
128+
129+
main()

0 commit comments

Comments
 (0)