1
1
from enum import Enum
2
- from typing import List , Dict , Any , Tuple , Optional , Union
2
+ from typing import List , Dict , Any , Tuple , Optional , Union , Sequence
3
3
4
4
import torch
5
5
@@ -27,13 +27,17 @@ class _ShapeMode(Enum):
27
27
STATIC = 0
28
28
DYNAMIC = 1
29
29
30
- shape_mode : Optional [_ShapeMode ] = None #: Is input statically or dynamically shaped
31
- shape : Optional [Union [Tuple [int , ...], Dict [str , Tuple [int , ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32
- dtype : _enums .dtype = ( # type: ignore[name-defined]
30
+ shape_mode : Optional [
31
+ _ShapeMode
32
+ ] = None #: Is input statically or dynamically shaped
33
+ shape : Optional [
34
+ Tuple [int , ...] | Dict [str , Tuple [int , ...]]
35
+ ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36
+ dtype : _enums .dtype = ( # type: ignore[name-defined]
33
37
_enums .dtype .unknown
34
38
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
35
39
_explicit_set_dtype : bool = False
36
- format : _enums .TensorFormat = ( # type: ignore[name-defined]
40
+ format : _enums .TensorFormat = ( # type: ignore[name-defined]
37
41
_enums .TensorFormat .contiguous
38
42
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
39
43
@@ -187,7 +191,9 @@ def __str__(self) -> str:
187
191
str (self .tensor_domain [1 ]),
188
192
)
189
193
else :
190
- raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
194
+ raise RuntimeError (
195
+ f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )"
196
+ )
191
197
else :
192
198
raise RuntimeError ("Unknown input shape mode" )
193
199
@@ -203,7 +209,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
203
209
return False
204
210
205
211
@staticmethod
206
- def _parse_dtype (dtype : Any ) -> _enums .dtype : # type: ignore[name-defined]
212
+ def _parse_dtype (dtype : Any ) -> _enums .dtype : # type: ignore[name-defined]
207
213
if isinstance (dtype , torch .dtype ):
208
214
if dtype == torch .long :
209
215
return _enums .dtype .long
@@ -231,7 +237,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
231
237
)
232
238
233
239
@staticmethod
234
- def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype : # type: ignore[name-defined]
240
+ def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype : # type: ignore[name-defined]
235
241
if dtype == _enums .dtype .long :
236
242
return torch .long
237
243
elif dtype == _enums .dtype .int32 :
@@ -250,7 +256,7 @@ def is_trt_dtype(self) -> bool:
250
256
return bool (self .dtype != _enums .dtype .long )
251
257
252
258
@staticmethod
253
- def _parse_format (format : Any ) -> _enums .TensorFormat : # type: ignore[name-defined]
259
+ def _parse_format (format : Any ) -> _enums .TensorFormat : # type: ignore[name-defined]
254
260
if isinstance (format , torch .memory_format ):
255
261
if format == torch .contiguous_format :
256
262
return _enums .TensorFormat .contiguous
@@ -270,7 +276,9 @@ def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defin
270
276
)
271
277
272
278
@staticmethod
273
- def _parse_tensor_domain (domain : Optional [Tuple [float , float ]]) -> Tuple [float , float ]:
279
+ def _parse_tensor_domain (
280
+ domain : Optional [Tuple [float , float ]]
281
+ ) -> Tuple [float , float ]:
274
282
"""
275
283
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
276
284
@@ -349,7 +357,7 @@ def from_tensor(
349
357
350
358
@classmethod
351
359
def from_tensors (
352
- cls , ts : torch .Tensor , disable_memory_format_check : bool = False
360
+ cls , ts : Sequence [ torch .Tensor ] , disable_memory_format_check : bool = False
353
361
) -> List ["Input" ]:
354
362
"""
355
363
Produce a list of Inputs which contain
@@ -369,7 +377,9 @@ def from_tensors(
369
377
for t in ts
370
378
]
371
379
372
- def example_tensor (self , optimization_profile_field : Optional [str ] = None ) -> Optional [torch .Tensor ]:
380
+ def example_tensor (
381
+ self , optimization_profile_field : Optional [str ] = None
382
+ ) -> torch .Tensor :
373
383
"""
374
384
Get an example tensor of the shape specified by the Input object
375
385
@@ -388,7 +398,9 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
388
398
if isinstance (self .shape , tuple ):
389
399
return torch .rand (self .shape ).to (dtype = self .torch_dtype )
390
400
else :
391
- RuntimeError (f"Input shape is dynamic but shapes are not provided as sequence (found: { self .shape } )" )
401
+ RuntimeError (
402
+ f"Input shape is dynamic but shapes are not provided as sequence (found: { self .shape } )"
403
+ )
392
404
else :
393
405
if optimization_profile_field is not None :
394
406
try :
@@ -408,11 +420,12 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
408
420
dtype = self .torch_dtype
409
421
)
410
422
else :
411
- raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
423
+ raise RuntimeError (
424
+ f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )"
425
+ )
412
426
413
427
else :
414
428
raise ValueError (
415
429
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
416
430
)
417
- return None
418
-
431
+ raise
0 commit comments