|
3 | 3 | import logging
|
4 | 4 | from dataclasses import fields, replace
|
5 | 5 | from enum import Enum
|
6 |
| -from typing import Any, Callable, Dict, Optional, Sequence, Union |
| 6 | +from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union |
7 | 7 |
|
8 | 8 | import numpy as np
|
9 | 9 | import tensorrt as trt
|
@@ -413,32 +413,29 @@ def check_output(
|
413 | 413 | return True
|
414 | 414 |
|
415 | 415 |
|
416 |
| -def unified_dtype_converter( |
417 |
| - dtype: Union[TRTDataType, torch.dtype, np.dtype], to: Frameworks |
418 |
| -) -> Union[np.dtype, torch.dtype, TRTDataType]: |
| 416 | +def flatten_dict_value(d: dict[Any, Any]) -> List[Any]: |
419 | 417 | """
|
420 |
| - Convert TensorRT, Numpy, or Torch data types to any other of those data types. |
| 418 | + Flatten the values of a dictionary to a single list. |
421 | 419 |
|
422 | 420 | Args:
|
423 |
| - dtype (TRTDataType, torch.dtype, np.dtype): A TensorRT, Numpy, or Torch data type. |
424 |
| - to (Frameworks): The framework to convert the data type to. |
| 421 | + d (dict): The dictionary to flatten. |
425 | 422 |
|
426 | 423 | Returns:
|
427 |
| - The equivalent data type in the requested framework. |
| 424 | + list: A list of all values flattened. |
428 | 425 | """
|
429 |
| - assert to in Frameworks, f"Expected valid Framework for translation, got {to}" |
430 |
| - trt_major_version = int(trt.__version__.split(".")[0]) |
431 |
| - if dtype in (np.int8, torch.int8, trt.int8): |
432 |
| - return DataTypeEquivalence[trt.int8][to] |
433 |
| - elif trt_major_version >= 7 and dtype in (np.bool_, torch.bool, trt.bool): |
434 |
| - return DataTypeEquivalence[trt.bool][to] |
435 |
| - elif dtype in (np.int32, torch.int32, trt.int32): |
436 |
| - return DataTypeEquivalence[trt.int32][to] |
437 |
| - elif dtype in (np.int64, torch.int64, trt.int64): |
438 |
| - return DataTypeEquivalence[trt.int64][to] |
439 |
| - elif dtype in (np.float16, torch.float16, trt.float16): |
440 |
| - return DataTypeEquivalence[trt.float16][to] |
441 |
| - elif dtype in (np.float32, torch.float32, trt.float32): |
442 |
| - return DataTypeEquivalence[trt.float32][to] |
443 |
| - else: |
444 |
| - raise TypeError("%s is not a supported dtype" % dtype) |
| 426 | + |
| 427 | + def flatten(value: Any) -> Generator[Any, Any, Any]: |
| 428 | + if isinstance(value, dict): |
| 429 | + for v in value.values(): |
| 430 | + yield from flatten(v) |
| 431 | + elif isinstance(value, list): |
| 432 | + for item in value: |
| 433 | + yield from flatten(item) |
| 434 | + else: |
| 435 | + yield value |
| 436 | + |
| 437 | + flat_list: List[Any] = [] |
| 438 | + for v in d.values(): |
| 439 | + flat_list.extend(flatten(v)) |
| 440 | + |
| 441 | + return flat_list |
0 commit comments