Skip to content

Commit 31c1cfd

Browse files
committed
fix: Add temporary workaround for precisions
- torch compile precisions are currently not being reflected due to recent API changes. This update honors specified precisions
1 parent 61e716e commit 31c1cfd

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
from typing import Any, Callable, Dict, Optional, Sequence
66

77
import torch
8+
import torch_tensorrt
89
from torch_tensorrt._Device import Device
910
from torch_tensorrt._Input import Input
1011
from torch_tensorrt.dynamo import CompilationSettings
12+
from torch_tensorrt.dynamo._defaults import PRECISION
1113

1214
from packaging import version
1315

@@ -161,6 +163,28 @@ def parse_dynamo_kwargs(kwargs: Any) -> CompilationSettings:
161163
if settings.debug:
162164
logger.setLevel(logging.DEBUG)
163165

166+
# TODO: Remove once Dynamo precisions refactoring is complete
167+
if "enabled_precisions" in kwargs:
168+
enabled_precisions = kwargs["enabled_precisions"]
169+
170+
if (
171+
torch.float16 in enabled_precisions
172+
or torch_tensorrt.dtype.half in enabled_precisions
173+
):
174+
settings.precision = torch.float16
175+
elif (
176+
torch.float32 in enabled_precisions
177+
or torch_tensorrt.dtype.float in enabled_precisions
178+
):
179+
settings.precision = torch.float32
180+
elif len(enabled_precisions) == 0:
181+
logger.info(f"No precision specified, defaulting to {PRECISION}")
182+
settings.precision = PRECISION
183+
else:
184+
raise ValueError(
185+
f"Precision {enabled_precisions} not supported in the Dynamo Path"
186+
)
187+
164188
# Parse input runtime specification
165189
settings.use_python_runtime = use_python_runtime_parser(settings.use_python_runtime)
166190

0 commit comments

Comments
 (0)