Skip to content

Commit 9221c79

Browse files
peterbell10facebook-github-bot
authored andcommitted
Use pytree.tree_leaves everywhere (#112324)
Summary: This changes all the instances I could find of `tree_flatten(...)[0]` or `x, _ = tree_flatten` to use `tree_leaves`. X-link: pytorch/pytorch#112324 Approved by: https://github.com/lezcano ghstack dependencies: #112327, #112323 Reviewed By: ZainRizvi Differential Revision: D50819663 fbshipit-source-id: 110cbd1295a752fb8b73fbd71009b5823a2cfa86
1 parent f13ea81 commit 9221c79

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

userbenchmark/dynamo/dynamobench/common.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -559,7 +559,7 @@ def _register_dataclass_output_as_pytree(example_outputs) -> None:
559559
# NOTE(angelayi): For huggingface benchmark, some example outputs are
560560
# formatted as a dataclass which pytree cannot consume. So we want
561561
# to register the pytree implementation here
562-
example_outputs_flat, _ = pytree.tree_flatten(example_outputs)
562+
example_outputs_flat = pytree.tree_leaves(example_outputs)
563563
output_dataclass_types = [
564564
type(out) for out in example_outputs_flat if dataclasses.is_dataclass(type(out))
565565
]
@@ -1496,7 +1496,7 @@ def format_pt_outputs(self, pt_outputs):
14961496
if isinstance(pt_outputs, torch.Tensor):
14971497
pt_outputs = (pt_outputs,)
14981498

1499-
pt_outputs, _ = pytree.tree_flatten(pt_outputs)
1499+
pt_outputs = pytree.tree_leaves(pt_outputs)
15001500

15011501
# Hack for huggingface model outputs
15021502
try:
@@ -1511,7 +1511,7 @@ def _to_tuple(x):
15111511
return x
15121512

15131513
pt_outputs = pytree.tree_map(_to_tuple, pt_outputs)
1514-
pt_outputs, _ = pytree.tree_flatten(pt_outputs)
1514+
pt_outputs = pytree.tree_leaves(pt_outputs)
15151515

15161516
return pt_outputs
15171517

@@ -1631,7 +1631,7 @@ def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
16311631
)
16321632

16331633
# Flatten nested tuple of tensors, i.e. past_key_values
1634-
correct_result = pytree.tree_flatten(correct_result)[0]
1634+
correct_result = pytree.tree_leaves(correct_result)
16351635
# Hack to put results from different runs on same device.
16361636
# This is needed for ONNX CPU fallback benchmark, where PyTorch eager is run on GPU.
16371637
# Assuming outputs from a single run are always on same device!
@@ -1640,12 +1640,12 @@ def patch_non_tensor_outputs(cls, correct_result, new_result, fp64_outputs):
16401640
x == devices[0] for x in devices
16411641
), "All tensors must be on same device!"
16421642
device = devices[0]
1643-
new_result = pytree.tree_flatten(new_result)[0]
1643+
new_result = pytree.tree_leaves(new_result)
16441644
new_result = pytree.tree_map(
16451645
lambda x: x.to(device=device) if isinstance(x, torch.Tensor) else x,
16461646
new_result,
16471647
)
1648-
fp64_outputs = pytree.tree_flatten(fp64_outputs)[0]
1648+
fp64_outputs = pytree.tree_leaves(fp64_outputs)
16491649

16501650
return correct_result, new_result, fp64_outputs
16511651

0 commit comments

Comments
 (0)