Skip to content

Commit 49d2e68

Browse files
angelayifacebook-github-bot
authored andcommitted
Update how we input kwargs (#314)
Summary: Pull Request resolved: #314 Previously, the code for passing inputs to exported program was: ``` if kwargs: return (args, kwargs) else: return args ``` However, this causes some inconsistency where if the original input contains args and kwargs, the treespec would be a tuple containing a tuple of arguments, and a dictionary of keyword arguments. But if the original input only contained args, the treespec would just be a tuple of arguments. So I updated the code to just always keep the kwargs around cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng X-link: pytorch/pytorch#109160 Reviewed By: zhxchen17 Differential Revision: D49218534 Pulled By: angelayi fbshipit-source-id: dbc481aaa93df349dc3171a1d66a306ec4dd2370
1 parent b78576e commit 49d2e68

File tree

6 files changed

+7
-7
lines changed

6 files changed

+7
-7
lines changed

exir/capture/_capture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def capture( # noqa: C901
163163
# input spec, but due to some limitations in pytree implementation, it doesn't
164164
# recognize the make_fx graph with torchdynamo input spec. We workaround it
165165
# by getting the input spec directly from user argument.
166-
in_spec = pytree.tree_flatten(args)[1]
166+
in_spec = pytree.tree_flatten((args, {}))[1]
167167

168168
if config.enable_functionalization and not config.enable_aot:
169169
args = copy.deepcopy(args)

exir/emit/test/test_emit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,8 @@ def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
176176
)
177177

178178
self.assertEqual(
179-
program.execution_plan[0].container_meta_type.encoded_inp_str, "T1#1($)"
179+
program.execution_plan[0].container_meta_type.encoded_inp_str,
180+
"T2#1#0(T1#1($),D0())",
180181
)
181182

182183
def test_buffers_with_perfect_alignment(self) -> None:

exir/program/_program.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,9 @@ def __init__(
7171

7272
def __call__(self, *args: Any, **kwargs: Any) -> Any:
7373
import torch._export.error as error
74-
from torch._export import combine_args_kwargs
7574

7675
if self.call_spec.in_spec is not None:
77-
user_args = combine_args_kwargs(args, kwargs)
76+
user_args = args
7877
try:
7978
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
8079
except Exception:

exir/tests/fixtures/basic_sin_max.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{
77
"name": "forward",
88
"container_meta_type": {
9-
"encoded_inp_str": "T1#1($)",
9+
"encoded_inp_str": "T2#1#0(T1#1($),D0())",
1010
"encoded_out_str": "$"
1111
},
1212
"values": [

exir/tests/fixtures/composite_delegate.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{
77
"name": "forward",
88
"container_meta_type": {
9-
"encoded_inp_str": "T3#1#1#1($,$,$)",
9+
"encoded_inp_str": "T2#3#0(T3#1#1#1($,$,$),D0())",
1010
"encoded_out_str": "$"
1111
},
1212
"values": [

exir/verification/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def run(self, *raw_args: torch.Tensor) -> PyTree:
369369
"""
370370

371371
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
372-
args, pytree = ex_pytree.tree_flatten(raw_args)
372+
args, pytree = ex_pytree.tree_flatten((raw_args, {}))
373373

374374
if pytree.to_str() != self.container_metatype.encoded_inp_str:
375375
raise TypeError(

0 commit comments

Comments
 (0)