Skip to content

Commit 1681439

Browse files
committed
feat/fix: Add new models, fix perf scripts
- Add new key models to benchmarking scripts - Add fixes and improvements to existing benchmarking code
1 parent 82b402d commit 1681439

File tree

6 files changed

+179
-109
lines changed

6 files changed

+179
-109
lines changed

tools/perf/benchmark.sh

Lines changed: 54 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,52 +6,96 @@ MODELS_DIR="models"
66
python hub.py
77

88
batch_sizes=(1 2 4 8 16 32 64 128 256)
9+
large_model_batch_sizes=(1 2 4 8 16 32 64)
910

10-
#Benchmark VGG16 model
11+
12+
# Benchmark VGG16 model
1113
echo "Benchmarking VGG16 model"
1214
for bs in ${batch_sizes[@]}
1315
do
1416
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
1517
--model_torch ${MODELS_DIR}/vgg16_pytorch.pt \
16-
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
18+
--precision fp16 --inputs="(${bs}, 3, 224, 224)" \
19+
--batch_size ${bs} \
20+
--truncate \
21+
--backends torch,ts_trt,dynamo,torch_compile,inductor \
22+
--report "vgg16_perf_bs${bs}.txt"
23+
done
24+
25+
Benchmark AlexNet model
26+
echo "Benchmarking AlexNet model"
27+
for bs in ${batch_sizes[@]}
28+
do
29+
python perf_run.py --model ${MODELS_DIR}/alexnet_scripted.jit.pt \
30+
--model_torch ${MODELS_DIR}/alexnet_pytorch.pt \
31+
--precision fp16 --inputs="(${bs}, 3, 227, 227)" \
1732
--batch_size ${bs} \
33+
--truncate \
1834
--backends torch,ts_trt,dynamo,torch_compile,inductor \
19-
--report "vgg_perf_bs${bs}.txt"
35+
--report "alexnet_perf_bs${bs}.txt"
2036
done
2137

22-
# # Benchmark Resnet50 model
38+
Benchmark Resnet50 model
2339
echo "Benchmarking Resnet50 model"
2440
for bs in ${batch_sizes[@]}
2541
do
2642
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
2743
--model_torch ${MODELS_DIR}/resnet50_pytorch.pt \
28-
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
44+
--precision fp16 --inputs="(${bs}, 3, 224, 224)" \
2945
--batch_size ${bs} \
46+
--truncate \
3047
--backends torch,ts_trt,dynamo,torch_compile,inductor \
31-
--report "rn50_perf_bs${bs}.txt"
48+
--report "resnet50_perf_bs${bs}.txt"
3249
done
3350

3451
# # Benchmark VIT model
3552
echo "Benchmarking VIT model"
3653
for bs in ${batch_sizes[@]}
3754
do
3855
python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \
39-
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
56+
--model_torch ${MODELS_DIR}/vit_pytorch.pt \
57+
--precision fp16 --inputs="(${bs}, 3, 224, 224)" \
4058
--batch_size ${bs} \
59+
--truncate \
4160
--backends torch,ts_trt,dynamo,torch_compile,inductor \
4261
--report "vit_perf_bs${bs}.txt"
4362
done
4463

64+
# Benchmark VIT Large model
65+
echo "Benchmarking VIT Large model"
66+
for bs in ${large_model_batch_sizes[@]}
67+
do
68+
python perf_run.py --model ${MODELS_DIR}/vit_large_scripted.jit.pt \
69+
--model_torch ${MODELS_DIR}/vit_large_pytorch.pt \
70+
--precision fp16 --inputs="(${bs}, 3, 224, 224)" \
71+
--truncate \
72+
--batch_size ${bs} \
73+
--backends torch,ts_trt,dynamo,torch_compile,inductor \
74+
--report "vit_large_perf_bs${bs}.txt"
75+
4576
# # Benchmark EfficientNet-B0 model
4677
echo "Benchmarking EfficientNet-B0 model"
4778
for bs in ${batch_sizes[@]}
4879
do
4980
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
5081
--model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \
51-
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
82+
--precision fp16 --inputs="(${bs}, 3, 224, 224)" \
5283
--batch_size ${bs} \
84+
--truncate \
5385
--backends torch,ts_trt,dynamo,torch_compile,inductor \
54-
--report "eff_b0_perf_bs${bs}.txt"
86+
--report "efficientnet_b0_perf_bs${bs}.txt"
87+
done
88+
89+
# Benchmark Stable Diffusion UNet model
90+
echo "Benchmarking SD UNet model"
91+
for bs in ${large_model_batch_sizes[@]}
92+
do
93+
python perf_run.py --model_torch ${MODELS_DIR}/sd_unet_pytorch.pt \
94+
--precision fp16 --inputs="(${bs}, 4, 128, 128)@fp16;(${bs})@fp16;(${bs}, 1, 768)@fp16" \
95+
--batch_size ${bs} \
96+
--backends torch_compile \
97+
--truncate \
98+
--report "sd_unet_perf_bs1.txt"
5599
done
56100

57101
# Benchmark BERT model
@@ -60,7 +104,7 @@ for bs in ${batch_sizes[@]}
60104
do
61105
python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \
62106
--model_torch "bert_base_uncased" \
63-
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
107+
--precision fp16 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
64108
--batch_size ${bs} \
65109
--backends torch,ts_trt,dynamo,torch_compile,inductor \
66110
--truncate \

tools/perf/custom_models.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,18 @@
11
import torch
2-
import torch.nn as nn
3-
from transformers import BertModel, BertTokenizer, BertConfig
4-
import torch.nn.functional as F
52

63

74
def BertModule():
5+
from transformers import BertModel
6+
7+
model_name = "bert-base-uncased"
8+
model = BertModel.from_pretrained(model_name, torchscript=True)
9+
model.eval()
10+
return model
11+
12+
13+
def BertInputs():
14+
from transformers import BertTokenizer
15+
816
model_name = "bert-base-uncased"
917
enc = BertTokenizer.from_pretrained(model_name)
1018
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
@@ -15,16 +23,13 @@ def BertModule():
1523
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
1624
tokens_tensor = torch.tensor([indexed_tokens])
1725
segments_tensors = torch.tensor([segments_ids])
18-
config = BertConfig(
19-
vocab_size_or_config_json_file=32000,
20-
hidden_size=768,
21-
num_hidden_layers=12,
22-
num_attention_heads=12,
23-
intermediate_size=3072,
24-
torchscript=True,
26+
return [tokens_tensor, segments_tensors]
27+
28+
29+
def StableDiffusionUnet():
30+
from diffusers import DiffusionPipeline
31+
32+
pipe = DiffusionPipeline.from_pretrained(
33+
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
2534
)
26-
model = BertModel(config)
27-
model.eval()
28-
model = BertModel.from_pretrained(model_name, torchscript=True)
29-
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
30-
return traced_model
35+
return pipe.unet

tools/perf/hub.py

Lines changed: 32 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,7 @@
11
import json
22
import os
33

4-
import custom_models as cm
5-
import timm
64
import torch
7-
import torch.nn as nn
8-
import torch.nn.functional as F
9-
import torchvision.models as models
10-
from transformers import BertConfig, BertModel, BertTokenizer
115

126
torch.hub._validate_not_a_forked_repo = lambda a, b, c: True
137

@@ -26,68 +20,46 @@
2620
VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all")
2721

2822
# Key models selected for benchmarking with their respective paths
29-
BENCHMARK_MODELS = {
30-
"vgg16": {
31-
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
32-
"path": ["script", "pytorch"],
33-
},
34-
"resnet50": {
35-
"model": models.resnet50(weights=None),
36-
"path": ["script", "pytorch"],
37-
},
38-
"efficientnet_b0": {
39-
"model": timm.create_model("efficientnet_b0", pretrained=True),
40-
"path": ["script", "pytorch"],
41-
},
42-
"vit": {
43-
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
44-
"path": ["script", "pytorch"],
45-
},
46-
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
47-
}
23+
from utils import BENCHMARK_MODELS
4824

4925

5026
def get(n, m, manifest):
5127
print("Downloading {}".format(n))
5228
traced_filename = "models/" + n + "_traced.jit.pt"
5329
script_filename = "models/" + n + "_scripted.jit.pt"
5430
pytorch_filename = "models/" + n + "_pytorch.pt"
55-
x = torch.ones((1, 3, 300, 300)).cuda()
56-
if n == "bert_base_uncased":
57-
traced_model = m["model"]
58-
torch.jit.save(traced_model, traced_filename)
31+
32+
m["model"] = m["model"].eval().cuda()
33+
34+
# Get all desired model save specifications as list
35+
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]
36+
37+
# Depending on specified model save specifications, save desired model formats
38+
if any(path in ("all", "torchscript", "trace") for path in paths):
39+
# (TorchScript) Traced model
40+
trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]])
41+
torch.jit.save(trace_model, traced_filename)
5942
manifest.update({n: [traced_filename]})
60-
else:
61-
m["model"] = m["model"].eval().cuda()
62-
63-
# Get all desired model save specifications as list
64-
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]
65-
66-
# Depending on specified model save specifications, save desired model formats
67-
if any(path in ("all", "torchscript", "trace") for path in paths):
68-
# (TorchScript) Traced model
69-
trace_model = torch.jit.trace(m["model"], [x])
70-
torch.jit.save(trace_model, traced_filename)
71-
manifest.update({n: [traced_filename]})
72-
if any(path in ("all", "torchscript", "script") for path in paths):
73-
# (TorchScript) Scripted model
74-
script_model = torch.jit.script(m["model"])
75-
torch.jit.save(script_model, script_filename)
76-
if n in manifest.keys():
77-
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
78-
files.append(script_filename)
79-
manifest.update({n: files})
80-
else:
81-
manifest.update({n: [script_filename]})
82-
if any(path in ("all", "pytorch") for path in paths):
83-
# (PyTorch Module) model
84-
torch.save(m["model"], pytorch_filename)
85-
if n in manifest.keys():
86-
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
87-
files.append(script_filename)
88-
manifest.update({n: files})
89-
else:
90-
manifest.update({n: [script_filename]})
43+
if any(path in ("all", "torchscript", "script") for path in paths):
44+
# (TorchScript) Scripted model
45+
script_model = torch.jit.script(m["model"])
46+
torch.jit.save(script_model, script_filename)
47+
if n in manifest.keys():
48+
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
49+
files.append(script_filename)
50+
manifest.update({n: files})
51+
else:
52+
manifest.update({n: [script_filename]})
53+
if any(path in ("all", "pytorch") for path in paths):
54+
# (PyTorch Module) model
55+
torch.save(m["model"], pytorch_filename)
56+
if n in manifest.keys():
57+
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
58+
files.append(script_filename)
59+
manifest.update({n: files})
60+
else:
61+
manifest.update({n: [script_filename]})
62+
9163
return manifest
9264

9365

0 commit comments

Comments
 (0)