Skip to content

Commit 81ed8e8

Browse files
committed
added the flattening function to util
1 parent cca4374 commit 81ed8e8

File tree

1 file changed

+29
-1
lines changed

1 file changed

+29
-1
lines changed

py/torch_tensorrt/dynamo/utils.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
from dataclasses import fields, replace
55
from enum import Enum
6-
from typing import Any, Callable, Dict, Optional, Sequence, Union
6+
from typing import Any, Callable, Dict, Generator, List, Optional, Sequence, Union
77

88
import numpy as np
99
import torch
@@ -411,3 +411,31 @@ def check_output(
411411
return False
412412

413413
return True
414+
415+
416+
def flatten_dict_value(d: dict[Any, Any]) -> List[Any]:
417+
"""
418+
Flatten the values of a dictionary to a single list.
419+
420+
Args:
421+
d (dict): The dictionary to flatten.
422+
423+
Returns:
424+
list: A list of all values flattened.
425+
"""
426+
427+
def flatten(value: Any) -> Generator[Any, Any, Any]:
428+
if isinstance(value, dict):
429+
for v in value.values():
430+
yield from flatten(v)
431+
elif isinstance(value, list):
432+
for item in value:
433+
yield from flatten(item)
434+
else:
435+
yield value
436+
437+
flat_list: List[Any] = []
438+
for v in d.values():
439+
flat_list.extend(flatten(v))
440+
441+
return flat_list

0 commit comments

Comments
 (0)