Skip to content

Commit 7a9570f

Browse files
committed
added the flattening function to util
1 parent c78491e commit 7a9570f

File tree

1 file changed

+21
-24
lines changed

1 file changed

+21
-24
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 21 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from dataclasses import fields, replace
55
from enum import Enum
6-
from typing import Any, Callable, Dict, Optional, Sequence, Union
6+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
77

88
import numpy as np
99
import tensorrt as trt
@@ -413,32 +413,29 @@ def check_output(
413413
return True
414414

415415

416-
def unified_dtype_converter(
417-
dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks
418-
) -> Union[np.dtype, torch.dtype, TRTDataType]:
416+
def flatten_dict_value(d: dict[Any, Any]) -> List[Any]:
419417
"""
420-
Convert TensorRT, Numpy, or Torch data types to any other of those data types.
418+
Flatten the values of a dictionary to a single list.
421419
422420
Args:
423-
dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type.
424-
to (Frameworks): The framework to convert the data type to.
421+
d (dict): The dictionary to flatten.
425422
426423
Returns:
427-
The equivalent data type in the requested framework.
424+
list: A list of all values flattened.
428425
"""
429-
assert to in Frameworks, f"Expected valid Framework for translation, got {to}"
430-
trt_major_version = int(trt.__version__.split(".")[0])
431-
if dtype in (np.int8, torch.int8, trt.int8):
432-
return DataTypeEquivalence[trt.int8][to]
433-
elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool):
434-
return DataTypeEquivalence[trt.bool][to]
435-
elif dtype in (np.int32, torch.int32, trt.int32):
436-
return DataTypeEquivalence[trt.int32][to]
437-
elif dtype in (np.int64, torch.int64, trt.int64):
438-
return DataTypeEquivalence[trt.int64][to]
439-
elif dtype in (np.float16, torch.float16, trt.float16):
440-
return DataTypeEquivalence[trt.float16][to]
441-
elif dtype in (np.float32, torch.float32, trt.float32):
442-
return DataTypeEquivalence[trt.float32][to]
443-
else:
444-
raise TypeError("%s is not a supported dtype" % dtype)
426+
427+
def flatten(value: Any) -> Generator[Any, Any, Any]:
428+
if isinstance(value, dict):
429+
for v in value.values():
430+
yield from flatten(v)
431+
elif isinstance(value, list):
432+
for item in value:
433+
yield from flatten(item)
434+
else:
435+
yield value
436+
437+
flat_list: List[Any] = []
438+
for v in d.values():
439+
flat_list.extend(flatten(v))
440+
441+
return flat_list

0 commit comments

Comments
 (0)