Skip to content

feat/fix: Add new models, fix perf scripts #2426

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/docker_builder.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ on:
branches:
- main
- nightly
- release/2.1

# If pushes to main are made in rapid succession,
# cancel existing docker builds and use newer commits
Expand Down
60 changes: 52 additions & 8 deletions tools/perf/benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,52 +6,96 @@ MODELS_DIR="models"
python hub.py

batch_sizes=(1 2 4 8 16 32 64 128 256)
large_model_batch_sizes=(1 2 4 8 16 32 64)

#Benchmark VGG16 model

# Benchmark VGG16 model
echo "Benchmarking VGG16 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vgg16_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vgg16_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vgg16_perf_bs${bs}.txt"
done

# Benchmark AlexNet model
echo "Benchmarking AlexNet model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/alexnet_scripted.jit.pt \
--model_torch ${MODELS_DIR}/alexnet_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 227, 227)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vgg_perf_bs${bs}.txt"
--report "alexnet_perf_bs${bs}.txt"
done

# # Benchmark Resnet50 model
# Benchmark Resnet50 model
echo "Benchmarking Resnet50 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/resnet50_scripted.jit.pt \
--model_torch ${MODELS_DIR}/resnet50_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "rn50_perf_bs${bs}.txt"
--report "resnet50_perf_bs${bs}.txt"
done

# # Benchmark VIT model
# Benchmark VIT model
echo "Benchmarking VIT model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vit_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vit_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vit_perf_bs${bs}.txt"
done

# # Benchmark EfficientNet-B0 model
# Benchmark VIT Large model
echo "Benchmarking VIT Large model"
for bs in ${large_model_batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/vit_large_scripted.jit.pt \
--model_torch ${MODELS_DIR}/vit_large_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--truncate \
--batch_size ${bs} \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "vit_large_perf_bs${bs}.txt"

# Benchmark EfficientNet-B0 model
echo "Benchmarking EfficientNet-B0 model"
for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/efficientnet_b0_scripted.jit.pt \
--model_torch ${MODELS_DIR}/efficientnet_b0_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 3, 224, 224)" \
--batch_size ${bs} \
--truncate \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--report "eff_b0_perf_bs${bs}.txt"
--report "efficientnet_b0_perf_bs${bs}.txt"
done

# Benchmark Stable Diffusion UNet model
echo "Benchmarking SD UNet model"
for bs in ${large_model_batch_sizes[@]}
do
python perf_run.py --model_torch ${MODELS_DIR}/sd_unet_pytorch.pt \
--precision fp32,fp16 --inputs="(${bs}, 4, 128, 128)@fp16;(${bs})@fp16;(${bs}, 1, 768)@fp16" \
--batch_size ${bs} \
--backends torch,dynamo,torch_compile,inductor \
--truncate \
--report "sd_unet_perf_bs${bs}.txt"
done

# Benchmark BERT model
Expand All @@ -60,7 +104,7 @@ for bs in ${batch_sizes[@]}
do
python perf_run.py --model ${MODELS_DIR}/bert_base_uncased_traced.jit.pt \
--model_torch "bert_base_uncased" \
--precision fp32 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
--precision fp32,fp16 --inputs="(${bs}, 128)@int32;(${bs}, 128)@int32" \
--batch_size ${bs} \
--backends torch,ts_trt,dynamo,torch_compile,inductor \
--truncate \
Expand Down
35 changes: 20 additions & 15 deletions tools/perf/custom_models.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
import torch
import torch.nn as nn
from transformers import BertModel, BertTokenizer, BertConfig
import torch.nn.functional as F


def BertModule():
from transformers import BertModel

model_name = "bert-base-uncased"
model = BertModel.from_pretrained(model_name, torchscript=True)
model.eval()
return model


def BertInputs():
from transformers import BertTokenizer

model_name = "bert-base-uncased"
enc = BertTokenizer.from_pretrained(model_name)
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
Expand All @@ -15,16 +23,13 @@ def BertModule():
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
config = BertConfig(
vocab_size_or_config_json_file=32000,
hidden_size=768,
num_hidden_layers=12,
num_attention_heads=12,
intermediate_size=3072,
torchscript=True,
return [tokens_tensor, segments_tensors]


def StableDiffusionUnet():
from diffusers import DiffusionPipeline

pipe = DiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", revision="fp16", torch_dtype=torch.float16
)
model = BertModel(config)
model.eval()
model = BertModel.from_pretrained(model_name, torchscript=True)
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
return traced_model
return pipe.unet
92 changes: 32 additions & 60 deletions tools/perf/hub.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import json
import os

import custom_models as cm
import timm
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
from transformers import BertConfig, BertModel, BertTokenizer

torch.hub._validate_not_a_forked_repo = lambda a, b, c: True

Expand All @@ -26,68 +20,46 @@
VALID_PATHS = ("script", "trace", "torchscript", "pytorch", "all")

# Key models selected for benchmarking with their respective paths
BENCHMARK_MODELS = {
"vgg16": {
"model": models.vgg16(weights=models.VGG16_Weights.DEFAULT),
"path": ["script", "pytorch"],
},
"resnet50": {
"model": models.resnet50(weights=None),
"path": ["script", "pytorch"],
},
"efficientnet_b0": {
"model": timm.create_model("efficientnet_b0", pretrained=True),
"path": ["script", "pytorch"],
},
"vit": {
"model": timm.create_model("vit_base_patch16_224", pretrained=True),
"path": ["script", "pytorch"],
},
"bert_base_uncased": {"model": cm.BertModule(), "path": "trace"},
}
from utils import BENCHMARK_MODELS


def get(n, m, manifest):
print("Downloading {}".format(n))
traced_filename = "models/" + n + "_traced.jit.pt"
script_filename = "models/" + n + "_scripted.jit.pt"
pytorch_filename = "models/" + n + "_pytorch.pt"
x = torch.ones((1, 3, 300, 300)).cuda()
if n == "bert_base_uncased":
traced_model = m["model"]
torch.jit.save(traced_model, traced_filename)

m["model"] = m["model"].eval().cuda()

# Get all desired model save specifications as list
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]

# Depending on specified model save specifications, save desired model formats
if any(path in ("all", "torchscript", "trace") for path in paths):
# (TorchScript) Traced model
trace_model = torch.jit.trace(m["model"], [inp.cuda() for inp in m["inputs"]])
torch.jit.save(trace_model, traced_filename)
manifest.update({n: [traced_filename]})
else:
m["model"] = m["model"].eval().cuda()

# Get all desired model save specifications as list
paths = [m["path"]] if isinstance(m["path"], str) else m["path"]

# Depending on specified model save specifications, save desired model formats
if any(path in ("all", "torchscript", "trace") for path in paths):
# (TorchScript) Traced model
trace_model = torch.jit.trace(m["model"], [x])
torch.jit.save(trace_model, traced_filename)
manifest.update({n: [traced_filename]})
if any(path in ("all", "torchscript", "script") for path in paths):
# (TorchScript) Scripted model
script_model = torch.jit.script(m["model"])
torch.jit.save(script_model, script_filename)
if n in manifest.keys():
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
files.append(script_filename)
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})
if any(path in ("all", "pytorch") for path in paths):
# (PyTorch Module) model
torch.save(m["model"], pytorch_filename)
if n in manifest.keys():
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
files.append(script_filename)
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})
if any(path in ("all", "torchscript", "script") for path in paths):
# (TorchScript) Scripted model
script_model = torch.jit.script(m["model"])
torch.jit.save(script_model, script_filename)
if n in manifest.keys():
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
files.append(script_filename)
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})
if any(path in ("all", "pytorch") for path in paths):
# (PyTorch Module) model
torch.save(m["model"], pytorch_filename)
if n in manifest.keys():
files = list(manifest[n]) if type(manifest[n]) != list else manifest[n]
files.append(script_filename)
manifest.update({n: files})
else:
manifest.update({n: [script_filename]})

return manifest


Expand Down
Loading