1
1
from enum import Enum
2
- from typing import List , Dict , Any , Tuple , Optional , Union
2
+ from typing import List , Dict , Any , Tuple , Optional , Union , Sequence
3
3
4
4
import torch
5
5
@@ -27,13 +27,17 @@ class _ShapeMode(Enum):
27
27
STATIC = 0
28
28
DYNAMIC = 1
29
29
30
- shape_mode : Optional [_ShapeMode ] = None #: Is input statically or dynamically shaped
31
- shape : Optional [Union [Tuple [int , ...], Dict [str , Tuple [int , ...]]]] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
32
- dtype : _enums .dtype = ( # type: ignore[name-defined]
30
+ shape_mode : Optional [
31
+ _ShapeMode
32
+ ] = None #: Is input statically or dynamically shaped
33
+ shape : Optional [
34
+ Tuple [int , ...] | Dict [str , Tuple [int , ...]]
35
+ ] = None #: Either a single Tuple or a dict of tuples defining the input shape. Static shaped inputs will have a single tuple. Dynamic inputs will have a dict of the form ``{ "min_shape": Tuple, "opt_shape": Tuple, "max_shape": Tuple }``
36
+ dtype : _enums .dtype = ( # type: ignore[name-defined]
33
37
_enums .dtype .unknown
34
38
) #: The expected data type of the input tensor (default: torch_tensorrt.dtype.float32)
35
39
_explicit_set_dtype : bool = False
36
- format : _enums .TensorFormat = ( # type: ignore[name-defined]
40
+ format : _enums .TensorFormat = ( # type: ignore[name-defined]
37
41
_enums .TensorFormat .contiguous
38
42
) #: The expected format of the input tensor (default: torch_tensorrt.TensorFormat.NCHW)
39
43
@@ -176,7 +180,9 @@ def __str__(self) -> str:
176
180
str (self .tensor_domain [1 ]),
177
181
)
178
182
else :
179
- raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
183
+ raise RuntimeError (
184
+ f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )"
185
+ )
180
186
else :
181
187
raise RuntimeError ("Unknown input shape mode" )
182
188
@@ -192,7 +198,7 @@ def _supported_input_size_type(input_size: Any) -> bool:
192
198
return False
193
199
194
200
@staticmethod
195
- def _parse_dtype (dtype : Any ) -> _enums .dtype : # type: ignore[name-defined]
201
+ def _parse_dtype (dtype : Any ) -> _enums .dtype : # type: ignore[name-defined]
196
202
if isinstance (dtype , torch .dtype ):
197
203
if dtype == torch .long :
198
204
return _enums .dtype .long
@@ -220,7 +226,7 @@ def _parse_dtype(dtype: Any) -> _enums.dtype: # type: ignore[name-defined]
220
226
)
221
227
222
228
@staticmethod
223
- def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype : # type: ignore[name-defined]
229
+ def _to_torch_dtype (dtype : _enums .dtype ) -> torch .dtype : # type: ignore[name-defined]
224
230
if dtype == _enums .dtype .long :
225
231
return torch .long
226
232
elif dtype == _enums .dtype .int32 :
@@ -239,7 +245,7 @@ def is_trt_dtype(self) -> bool:
239
245
return bool (self .dtype != _enums .dtype .long )
240
246
241
247
@staticmethod
242
- def _parse_format (format : Any ) -> _enums .TensorFormat : # type: ignore[name-defined]
248
+ def _parse_format (format : Any ) -> _enums .TensorFormat : # type: ignore[name-defined]
243
249
if isinstance (format , torch .memory_format ):
244
250
if format == torch .contiguous_format :
245
251
return _enums .TensorFormat .contiguous
@@ -259,7 +265,9 @@ def _parse_format(format: Any) -> _enums.TensorFormat: # type: ignore[name-defin
259
265
)
260
266
261
267
@staticmethod
262
- def _parse_tensor_domain (domain : Optional [Tuple [float , float ]]) -> Tuple [float , float ]:
268
+ def _parse_tensor_domain (
269
+ domain : Optional [Tuple [float , float ]]
270
+ ) -> Tuple [float , float ]:
263
271
"""
264
272
Produce a tuple of integers which specifies a tensor domain in the interval format: [lo, hi)
265
273
@@ -338,7 +346,7 @@ def from_tensor(
338
346
339
347
@classmethod
340
348
def from_tensors (
341
- cls , ts : torch .Tensor , disable_memory_format_check : bool = False
349
+ cls , ts : Sequence [ torch .Tensor ] , disable_memory_format_check : bool = False
342
350
) -> List ["Input" ]:
343
351
"""
344
352
Produce a list of Inputs which contain
@@ -358,7 +366,9 @@ def from_tensors(
358
366
for t in ts
359
367
]
360
368
361
- def example_tensor (self , optimization_profile_field : Optional [str ] = None ) -> Optional [torch .Tensor ]:
369
+ def example_tensor (
370
+ self , optimization_profile_field : Optional [str ] = None
371
+ ) -> torch .Tensor :
362
372
"""
363
373
Get an example tensor of the shape specified by the Input object
364
374
@@ -377,7 +387,9 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
377
387
if isinstance (self .shape , tuple ):
378
388
return torch .rand (self .shape ).to (dtype = self .torch_dtype )
379
389
else :
380
- RuntimeError (f"Input shape is dynamic but shapes are not provided as sequence (found: { self .shape } )" )
390
+ RuntimeError (
391
+ f"Input shape is dynamic but shapes are not provided as sequence (found: { self .shape } )"
392
+ )
381
393
else :
382
394
if optimization_profile_field is not None :
383
395
try :
@@ -397,11 +409,12 @@ def example_tensor(self, optimization_profile_field: Optional[str] = None) -> Op
397
409
dtype = self .torch_dtype
398
410
)
399
411
else :
400
- raise RuntimeError (f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )" )
412
+ raise RuntimeError (
413
+ f"Input shape is dynamic but shapes are not provided as dictionary (found: { self .shape } )"
414
+ )
401
415
402
416
else :
403
417
raise ValueError (
404
418
"Requested an example tensor from a dynamic shaped input but did not specific which profile field to use."
405
419
)
406
- return None
407
-
420
+ raise
0 commit comments