Skip to content

Commit ae7f616

Browse files
committed
fix: Segfault fix for Benchmarks
- Segfault fix for benchmarking on Docker container with CUDNN 8.8 - Likely due to Torch 2.1.0 based on CUDNN 8.9
1 parent bcb13c7 commit ae7f616

File tree

2 files changed

+2
-9
lines changed

2 files changed

+2
-9
lines changed

tools/perf/perf_run.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import time
88
import timeit
99
import warnings
10+
from functools import wraps
1011

1112
import numpy as np
1213
import pandas as pd
1314
import tensorrt as trt
1415

1516
# Importing supported Backends
1617
import torch
17-
import torch.backends.cudnn as cudnn
1818
from utils import (
1919
BENCHMARK_MODELS,
2020
parse_backends,
@@ -30,6 +30,7 @@
3030

3131

3232
def run_with_try_except(func):
33+
@wraps(func)
3334
def wrapper_func(*args, **kwargs):
3435
try:
3536
return func(*args, **kwargs)
@@ -527,7 +528,6 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None):
527528
)
528529
args = arg_parser.parse_args()
529530

530-
cudnn.benchmark = True
531531
# Create random input tensor of certain size
532532
torch.manual_seed(12345)
533533
model_name = "Model"
@@ -542,9 +542,6 @@ def recordStats(backend, timings, precision, batch_size=1, compile_time_s=None):
542542
if os.path.exists(model_name):
543543
print("Loading user provided torchscript model: ", model_name)
544544
model = torch.jit.load(model_name).cuda().eval()
545-
elif model_name in BENCHMARK_MODELS:
546-
print("Loading torchscript model from BENCHMARK_MODELS for: ", model_name)
547-
model = BENCHMARK_MODELS[model_name]["model"].eval().cuda()
548545

549546
# Load PyTorch Model, if provided
550547
if len(model_name_torch) > 0 and os.path.exists(model_name_torch):

tools/perf/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
1-
from typing import Optional, Sequence, Union
2-
31
import custom_models as cm
42
import timm
53
import torch
64
import torchvision.models as models
75

8-
import torch_tensorrt
9-
106
BENCHMARK_MODEL_NAMES = {
117
"vgg16",
128
"alexnet",

0 commit comments

Comments
 (0)