-
Notifications
You must be signed in to change notification settings - Fork 364
feat: Implement Input class support for FX backend. #1763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,7 +4,6 @@ | |
import torch | ||
|
||
from torch_tensorrt import _enums | ||
from torch_tensorrt import _C | ||
|
||
|
||
class Input(object): | ||
|
@@ -41,6 +40,7 @@ class _ShapeMode(Enum): | |
DOMAIN_OFFSET = 2.0 | ||
low_tensor_domain_incl = 0.0 | ||
high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET | ||
torch_dtype = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we derive torch_dtype from self.dtype? |
||
|
||
def __init__(self, *args, **kwargs): | ||
"""__init__ Method for torch_tensorrt.Input | ||
|
@@ -138,6 +138,9 @@ def __init__(self, *args, **kwargs): | |
) | ||
|
||
if "dtype" in kwargs: | ||
if isinstance(kwargs["dtype"], torch.dtype): | ||
self.torch_dtype = kwargs["dtype"] | ||
|
||
self.dtype = Input._parse_dtype(kwargs["dtype"]) | ||
self._explicit_set_dtype = True | ||
|
||
|
@@ -173,59 +176,6 @@ def __str__(self) -> str: | |
else: | ||
raise RuntimeError("Unknown input shape mode") | ||
|
||
def _to_internal(self) -> _C.Input: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why was this taken out? |
||
internal_in = _C.Input() | ||
if self.shape_mode == Input._ShapeMode.DYNAMIC: | ||
if not Input._supported_input_size_type(self.shape["min_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["min_shape"])) | ||
+ " for min_shape" | ||
) | ||
else: | ||
internal_in.min = self.shape["min_shape"] | ||
|
||
if not Input._supported_input_size_type(self.shape["opt_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["opt_shape"])) | ||
+ " for opt_shape" | ||
) | ||
else: | ||
internal_in.opt = self.shape["opt_shape"] | ||
|
||
if not Input._supported_input_size_type(self.shape["max_shape"]): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape["max_shape"])) | ||
+ " for max_shape" | ||
) | ||
else: | ||
internal_in.max = self.shape["max_shape"] | ||
internal_in.input_is_dynamic = True | ||
else: | ||
if not Input._supported_input_size_type(self.shape): | ||
raise TypeError( | ||
"Input shape specifications for inputs are required to be a List, tuple or torch.Size, found type: " | ||
+ str(type(self.shape)) | ||
+ " for shape" | ||
) | ||
else: | ||
internal_in.opt = self.shape | ||
internal_in.input_is_dynamic = False | ||
|
||
if self.dtype != _enums.dtype.unknown: | ||
self._explicit_set_dtype = True | ||
else: | ||
self._explicit_set_dtype = False | ||
|
||
internal_in.dtype = Input._parse_dtype(self.dtype) | ||
internal_in._explicit_set_dtype = self._explicit_set_dtype | ||
internal_in.format = Input._parse_format(self.format) | ||
|
||
internal_in.tensor_domain = Input._parse_tensor_domain(self.tensor_domain) | ||
return internal_in | ||
|
||
@staticmethod | ||
def _supported_input_size_type(input_size: Any) -> bool: | ||
if isinstance(input_size, torch.Size): | ||
|
@@ -304,6 +254,7 @@ def _parse_tensor_domain(domain: Optional[Tuple[float, float]]) -> Tuple: | |
Input.low_tensor_domain_incl, | ||
Input.high_tensor_domain_excl, | ||
) | ||
|
||
elif len(domain) == 2: | ||
domain_lo, domain_hi = domain | ||
|
||
|
@@ -416,8 +367,10 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor | |
) | ||
|
||
if self.shape_mode == Input._ShapeMode.STATIC: | ||
return torch.randn(self.shape).to(dtype=self.dtype) | ||
return torch.randn(self.shape).to( | ||
dtype=self.dtype if not self.torch_dtype else self.torch_dtype | ||
) | ||
else: | ||
return torch.randn(self.shape[optimization_profile_field]).to( | ||
dtype=self.dtype | ||
dtype=self.dtype if not self.torch_dtype else self.torch_dtype | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -153,7 +153,6 @@ def validate_conversion(self): | |
|
||
def run( | ||
self, | ||
max_batch_size=64, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am afraid we can not do this change. We have to maintain backward compatibility on the API, otherwise, it will break our internal product. |
||
max_workspace_size=1 << 25, | ||
lower_precision=LowerPrecision.FP16, | ||
sparse_weights=False, | ||
|
@@ -167,7 +166,6 @@ def run( | |
""" | ||
Build TensorRT engine with some configs. | ||
Args: | ||
max_batch_size: set accordingly for maximum batch size you will use. | ||
max_workspace_size: set to the maximum size we can afford for temporary buffer | ||
lower_precision: the precision model layers are running on (TensorRT will choose the best perforamnce precision). | ||
sparse_weights: allow the builder to examine weights and use optimized functions when weights have suitable sparsity | ||
|
@@ -207,7 +205,6 @@ def run( | |
) | ||
build_engine_start_time = datetime.now() | ||
|
||
self.builder.max_batch_size = max_batch_size | ||
builder_config = self.builder.create_builder_config() | ||
builder_config.max_workspace_size = max_workspace_size | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these comments for?