|
1 | 1 | import torch
|
| 2 | +import logging |
| 3 | +from dataclasses import replace, fields |
2 | 4 |
|
| 5 | +from torch_tensorrt.dynamo.backend._settings import CompilationSettings |
3 | 6 | from typing import Any, Union, Sequence, Dict
|
4 | 7 | from torch_tensorrt import _Input, Device
|
5 | 8 |
|
6 | 9 |
|
| 10 | +logger = logging.getLogger(__name__) |
| 11 | + |
| 12 | + |
7 | 13 | def prepare_inputs(
|
8 | 14 | inputs: Union[_Input.Input, torch.Tensor, Sequence, Dict],
|
9 | 15 | device: torch.device = torch.device("cuda"),
|
@@ -66,3 +72,36 @@ def prepare_device(device: Union[Device, torch.device]) -> torch.device:
|
66 | 72 | )
|
67 | 73 |
|
68 | 74 | return device
|
| 75 | + |
| 76 | + |
| 77 | +def parse_dynamo_kwargs(kwargs: Dict) -> CompilationSettings: |
| 78 | + """Parses the kwargs field of a Dynamo backend |
| 79 | +
|
| 80 | + Args: |
| 81 | + kwargs: Keyword arguments dictionary provided to the backend |
| 82 | + Returns: |
| 83 | + CompilationSettings object with relevant kwargs |
| 84 | + """ |
| 85 | + |
| 86 | + # Initialize an empty CompilationSettings object |
| 87 | + settings = CompilationSettings() |
| 88 | + |
| 89 | + # If the user specifies keyword args, overwrite those fields in settings |
| 90 | + # Validate all specified kwargs to ensure they are true fields of the dataclass |
| 91 | + # |
| 92 | + # Note: kwargs provided by torch.compile are wrapped in the "options" key |
| 93 | + if kwargs: |
| 94 | + if "options" in kwargs and len(kwargs) == 1: |
| 95 | + kwargs = kwargs["options"] |
| 96 | + |
| 97 | + valid_attrs = {attr.name for attr in fields(settings)} |
| 98 | + valid_kwargs = {k: v for k, v in kwargs.items() if k in valid_attrs} |
| 99 | + settings = replace(settings, **valid_kwargs) |
| 100 | + |
| 101 | + # Enable debug/verbose mode if requested |
| 102 | + if settings.debug: |
| 103 | + logger.setLevel(logging.DEBUG) |
| 104 | + |
| 105 | + logger.debug(f"Compiling with Settings:\n{settings}") |
| 106 | + |
| 107 | + return settings |
0 commit comments