Skip to content

Commit 7672cce

Browse files
committed
feat: Add options kwargs for Torch compile
- Add ability to pass `options` dictionary to `kwargs` in `torch_tensorrt_backend`, for compatibility with updated torch compile API - The `options` dictionary is automatically parsed for specified fields and overwrites those fields in the `settings` object - Refactor code so that registered Dynamo backends accept keyword-args, while internal-only backends accept settings objects
1 parent 3751e32 commit 7672cce

File tree

2 files changed

+45
-7
lines changed

2 files changed

+45
-7
lines changed

py/torch_tensorrt/dynamo/backend/backends.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
partition,
1313
get_submod_inputs,
1414
)
15+
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
1516
from torch_tensorrt.dynamo.backend.conversion import convert_module
1617

1718
from torch._dynamo.backends.common import fake_tensor_unsupported
@@ -25,22 +26,20 @@
2526
@td.register_backend(name="torch_tensorrt")
2627
@fake_tensor_unsupported
2728
def torch_tensorrt_backend(
28-
gm: torch.fx.GraphModule,
29-
sample_inputs: Sequence[torch.Tensor],
30-
settings: CompilationSettings = CompilationSettings(),
29+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
3130
):
3231
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend
3332

34-
return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
33+
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)
3534

3635

3736
@td.register_backend(name="aot_torch_tensorrt_aten")
3837
@fake_tensor_unsupported
3938
def aot_torch_tensorrt_aten_backend(
40-
gm: torch.fx.GraphModule,
41-
sample_inputs: Sequence[torch.Tensor],
42-
settings: CompilationSettings = CompilationSettings(),
39+
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
4340
):
41+
settings = parse_dynamo_kwargs(kwargs)
42+
4443
custom_backend = partial(
4544
_pretraced_backend,
4645
settings=settings,

py/torch_tensorrt/dynamo/backend/utils.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
import torch
2+
import logging
3+
from dataclasses import replace, fields
24

5+
from torch_tensorrt.dynamo.backend._settings import CompilationSettings
36
from typing import Any, Union, Sequence, Dict
47
from torch_tensorrt import _Input, Device
58

69

10+
logger = logging.getLogger(__name__)
11+
12+
713
def prepare_inputs(
814
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
915
device: torch.device = torch.device("cuda"),
@@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
6672
)
6773

6874
return device
75+
76+
77+
def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
78+
"""Parses the kwargs field of a Dynamo backend
79+
80+
Args:
81+
kwargs: Keyword arguments dictionary provided to the backend
82+
Returns:
83+
CompilationSettings object with relevant kwargs
84+
"""
85+
86+
# Initialize an empty CompilationSettings object
87+
settings = CompilationSettings()
88+
89+
# If the user specifies keyword args, overwrite those fields in settings
90+
# Validate all specified kwargs to ensure they are true fields of the dataclass
91+
#
92+
# Note: kwargs provided by torch.compile are wrapped in the "options" key
93+
if kwargs:
94+
if "options" in kwargs and len(kwargs) == 1:
95+
kwargs = kwargs["options"]
96+
97+
valid_attrs = {attr.name for attr in fields(settings)}
98+
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
99+
settings = replace(settings, **valid_kwargs)
100+
101+
# Enable debug/verbose mode if requested
102+
if settings.debug:
103+
logger.setLevel(logging.DEBUG)
104+
105+
logger.debug(f"Compiling with Settings:\n{settings}")
106+
107+
return settings

0 commit comments

Comments
 (0)