Skip to content

feat: Add options kwargs for Torch compile [3 / x] #2005

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 6 additions & 7 deletions py/torch_tensorrt/dynamo/backend/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
partition,
get_submod_inputs,
)
from torch_tensorrt.dynamo.backend.utils import parse_dynamo_kwargs
from torch_tensorrt.dynamo.backend.conversion import convert_module

from torch._dynamo.backends.common import fake_tensor_unsupported
Expand All @@ -25,22 +26,20 @@
@td.register_backend(name="torch_tensorrt")
@fake_tensor_unsupported
def torch_tensorrt_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
):
DEFAULT_BACKEND = aot_torch_tensorrt_aten_backend

return DEFAULT_BACKEND(gm, sample_inputs, settings=settings)
return DEFAULT_BACKEND(gm, sample_inputs, **kwargs)


@td.register_backend(name="aot_torch_tensorrt_aten")
@fake_tensor_unsupported
def aot_torch_tensorrt_aten_backend(
gm: torch.fx.GraphModule,
sample_inputs: Sequence[torch.Tensor],
settings: CompilationSettings = CompilationSettings(),
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor], **kwargs
):
settings = parse_dynamo_kwargs(kwargs)

custom_backend = partial(
_pretraced_backend,
settings=settings,
Expand Down
39 changes: 39 additions & 0 deletions py/torch_tensorrt/dynamo/backend/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
import torch
import logging
from dataclasses import replace, fields

from torch_tensorrt.dynamo.backend._settings import CompilationSettings
from typing import Any, Union, Sequence, Dict
from torch_tensorrt import _Input, Device


logger = logging.getLogger(__name__)


def prepare_inputs(
inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
device: torch.device = torch.device("cuda"),
Expand Down Expand Up @@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
)

return device


def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings:
"""Parses the kwargs field of a Dynamo backend

Args:
kwargs: Keyword arguments dictionary provided to the backend
Returns:
CompilationSettings object with relevant kwargs
"""

# Initialize an empty CompilationSettings object
settings = CompilationSettings()

# If the user specifies keyword args, overwrite those fields in settings
# Validate all specified kwargs to ensure they are true fields of the dataclass
#
# Note: kwargs provided by torch.compile are wrapped in the "options" key
if kwargs:
if "options" in kwargs and len(kwargs) == 1:
kwargs = kwargs["options"]

valid_attrs = {attr.name for attr in fields(settings)}
valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs}
settings = replace(settings, **valid_kwargs)

# Enable debug/verbose mode if requested
if settings.debug:
logger.setLevel(logging.DEBUG)

logger.debug(f"Compiling with Settings:\n{settings}")
Comment on lines +102 to +105
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Debug enablement and settings dump to debug logs moved into kwargs parser.


return settings