@@ -27,20 +27,20 @@ class _ShapeMode(Enum):
27
27
STATIC = 0
28
28
DYNAMIC = 1
29
29
30
- shape_mode = None #: (torch_tensorrt.Input._ShapeMode) : Is input statically or dynamically shaped
31
- shape : Union [Tuple , Dict ] = None #: (Tuple or Dict) : Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32
- dtype = (
30
+ shape_mode : Optional [ _ShapeMode ] = None #: Is input statically or dynamically shaped
31
+ shape : Optional [ Union [Tuple [ int , ...], Dict [ str , Tuple [ int , ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32
+ dtype : _enums . dtype = ( # type: ignore[name-defined]
33
33
_enums .dtype .unknown
34
34
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
35
- _explicit_set_dtype = False
36
- format = (
35
+ _explicit_set_dtype : bool = False
36
+ format : _enums . TensorFormat = ( # type: ignore[name-defined]
37
37
_enums .TensorFormat .contiguous
38
38
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
39
39
40
- DOMAIN_OFFSET = 2.0
41
- low_tensor_domain_incl = 0.0
42
- high_tensor_domain_excl = low_tensor_domain_incl + DOMAIN_OFFSET
43
- torch_dtype = torch .float32
40
+ DOMAIN_OFFSET : float = 2.0
41
+ low_tensor_domain_incl : float = 0.0
42
+ high_tensor_domain_excl : float = low_tensor_domain_incl + DOMAIN_OFFSET
43
+ torch_dtype : torch . dtype = torch .float32
44
44
45
45
def __init__ (self , * args : Any , ** kwargs : Any ) -> None :
46
46
"""__init__ Method for torch_tensorrt.Input
@@ -165,15 +165,18 @@ def __str__(self) -> str:
165
165
str (self .tensor_domain [1 ]),
166
166
)
167
167
elif self .shape_mode == Input ._ShapeMode .DYNAMIC :
168
- return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))" .format (
169
- self .shape ["min_shape" ],
170
- self .shape ["opt_shape" ],
171
- self .shape ["max_shape" ],
172
- str (self .dtype ),
173
- str (self .format ),
174
- str (self .tensor_domain [0 ]),
175
- str (self .tensor_domain [1 ]),
176
- )
168
+ if isinstance (self .shape , dict ):
169
+ return "Input(min_shape={}, opt_shape={}, max_shape={}, dtype={}, format={}, domain=[{}, {}))" .format (
170
+ self .shape ["min_shape" ],
171
+ self .shape ["opt_shape" ],
172
+ self .shape ["max_shape" ],
173
+ str (self .dtype ),
174
+ str (self .format ),
175
+ str (self .tensor_domain [0 ]),
176
+ str (self .tensor_domain [1 ]),
177
+ )
178
+ else :
179
+ raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
177
180
else :
178
181
raise RuntimeError ("Unknown input shape mode" )
179
182
@@ -189,7 +192,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
189
192
return False
190
193
191
194
@staticmethod
192
- def _parse_dtype (dtype : Any ) -> _enums .dtype :
195
+ def _parse_dtype (dtype : Any ) -> _enums .dtype : # type: ignore[name-defined]
193
196
if isinstance (dtype , torch .dtype ):
194
197
if dtype == torch .long :
195
198
return _enums .dtype .long
@@ -217,7 +220,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype:
217
220
)
218
221
219
222
@staticmethod
220
- def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype :
223
+ def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype : # type: ignore[name-defined]
221
224
if dtype == _enums .dtype .long :
222
225
return torch .long
223
226
elif dtype == _enums .dtype .int32 :
@@ -233,10 +236,10 @@ def _to_torch_dtype(dtype: _enums.dtype) -> torch.dtype:
233
236
return torch .float32
234
237
235
238
def is_trt_dtype (self ) -> bool :
236
- return self .dtype != _enums .dtype .long
239
+ return bool ( self .dtype != _enums .dtype .long )
237
240
238
241
@staticmethod
239
- def _parse_format (format : Any ) -> _enums .TensorFormat :
242
+ def _parse_format (format : Any ) -> _enums .TensorFormat : # type: ignore[name-defined]
240
243
if isinstance (format , torch .memory_format ):
241
244
if format == torch .contiguous_format :
242
245
return _enums .TensorFormat .contiguous
@@ -256,7 +259,7 @@ def _parse_format(format: Any) -> _enums.TensorFormat:
256
259
)
257
260
258
261
@staticmethod
259
- def _parse_tensor_domain (domain : Optional [Tuple [float , float ]]) -> Tuple :
262
+ def _parse_tensor_domain (domain : Optional [Tuple [float , float ]]) -> Tuple [ float , float ] :
260
263
"""
261
264
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
262
265
@@ -355,7 +358,7 @@ def from_tensors(
355
358
for t in ts
356
359
]
357
360
358
- def example_tensor (self , optimization_profile_field : str = None ) -> torch .Tensor :
361
+ def example_tensor (self , optimization_profile_field : Optional [ str ] = None ) -> Optional [ torch .Tensor ] :
359
362
"""
360
363
Get an example tensor of the shape specified by the Input object
361
364
@@ -365,38 +368,40 @@ def example_tensor(self, optimization_profile_field: str = None) -> torch.Tensor
365
368
Returns:
366
369
A PyTorch Tensor
367
370
"""
368
- if optimization_profile_field is not None :
369
- try :
370
- assert any (
371
- [
372
- optimization_profile_field == field_name
373
- for field_name in ["min_shape" , "opt_shape" , "max_shape" ]
374
- ]
375
- )
376
- except :
371
+ if self .shape_mode == Input ._ShapeMode .STATIC :
372
+ if optimization_profile_field is not None :
377
373
raise ValueError (
378
- "Invalid field name, expected one of min_shape, opt_shape, max_shape "
374
+ "Specified a optimization profile field but the input is static "
379
375
)
376
+ else :
377
+ if isinstance (self .shape , tuple ):
378
+ return torch .rand (self .shape ).to (dtype = self .torch_dtype )
379
+ else :
380
+ RuntimeError (f"Input shape is dynamic but shapes are not provided as sequence (found: { self .shape } )" )
381
+ else :
382
+ if optimization_profile_field is not None :
383
+ try :
384
+ assert any (
385
+ [
386
+ optimization_profile_field == field_name
387
+ for field_name in ["min_shape" , "opt_shape" , "max_shape" ]
388
+ ]
389
+ )
390
+ except :
391
+ raise ValueError (
392
+ "Invalid field name, expected one of min_shape, opt_shape, max_shape"
393
+ )
380
394
381
- if (
382
- optimization_profile_field is not None
383
- and self .shape_mode == Input ._ShapeMode .STATIC
384
- ):
385
- raise ValueError (
386
- "Specified a optimization profile field but the input is static"
387
- )
395
+ if isinstance (self .shape , dict ):
396
+ return torch .rand (self .shape [optimization_profile_field ]).to (
397
+ dtype = self .torch_dtype
398
+ )
399
+ else :
400
+ raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
401
+
402
+ else :
403
+ raise ValueError (
404
+ "Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
405
+ )
388
406
389
- if (
390
- optimization_profile_field is None
391
- and self .shape_mode == Input ._ShapeMode .DYNAMIC
392
- ):
393
- raise ValueError (
394
- "Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
395
- )
396
407
397
- if self .shape_mode == Input ._ShapeMode .STATIC :
398
- return torch .rand (self .shape ).to (dtype = self .torch_dtype )
399
- else :
400
- return torch .rand (self .shape [optimization_profile_field ]).to (
401
- dtype = self .torch_dtype
402
- )
0 commit comments