Skip to content

Commit 2c8dc34

Browse files
angelayifacebook-github-bot
authored andcommitted
Update how we input kwargs (#314)
Summary: Previously, the code for passing inputs to exported program was: ``` if kwargs: return (args, kwargs) else: return args ``` However, this causes some inconsistency where if the original input contains args and kwargs, the treespec would be a tuple containing a tuple of arguments, and a dictionary of keyword arguments. But if the original input only contained args, the treespec would just be a tuple of arguments. So I updated the code to just always keep the kwargs around cc voznesenskym penguinwu EikanWang jgong5 Guobing-Chen XiaobingSuper zhuhaozhe blzheng Xia-Weiwen wenzhe-nrv jiayisunx chenyang78 aakhundov kadeng X-link: pytorch/pytorch#109160 Reviewed By: zhxchen17 Differential Revision: D49218534 Pulled By: angelayi
1 parent e6e1898 commit 2c8dc34

File tree

6 files changed

+7
-6
lines changed

6 files changed

+7
-6
lines changed

exir/capture/_capture.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,7 @@ def capture( # noqa: C901
163163
# input spec, but due to some limitations in pytree implementation, it doesn't
164164
# recognize the make_fx graph with torchdynamo input spec. We workaround it
165165
# by getting the input spec directly from user argument.
166-
in_spec = pytree.tree_flatten(args)[1]
166+
in_spec = pytree.tree_flatten((args, {}))[1]
167167

168168
if config.enable_functionalization and not config.enable_aot:
169169
args = copy.deepcopy(args)

exir/emit/test/test_emit.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -173,7 +173,8 @@ def f(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
173173
)
174174

175175
self.assertEqual(
176-
program.execution_plan[0].container_meta_type.encoded_inp_str, "T1#1($)"
176+
program.execution_plan[0].container_meta_type.encoded_inp_str,
177+
"T2#1#0(T1#1($),D0())",
177178
)
178179

179180
def test_buffers_with_perfect_alignment(self) -> None:

exir/program/_program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
7575
from torch._export import combine_args_kwargs
7676

7777
if self.call_spec.in_spec is not None:
78-
user_args = combine_args_kwargs(args, kwargs)
78+
user_args = args
7979
try:
8080
args = fx_pytree.tree_flatten_spec(user_args, self.call_spec.in_spec) # type: ignore[assignment]
8181
except Exception:

exir/tests/fixtures/basic_sin_max.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{
77
"name": "forward",
88
"container_meta_type": {
9-
"encoded_inp_str": "T1#1($)",
9+
"encoded_inp_str": "T2#1#0(T1#1($),D0())",
1010
"encoded_out_str": "$"
1111
},
1212
"values": [

exir/tests/fixtures/composite_delegate.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
{
77
"name": "forward",
88
"container_meta_type": {
9-
"encoded_inp_str": "T3#1#1#1($,$,$)",
9+
"encoded_inp_str": "T2#3#0(T3#1#1#1($,$,$),D0())",
1010
"encoded_out_str": "$"
1111
},
1212
"values": [

exir/verification/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def run(self, *raw_args: torch.Tensor) -> PyTree:
369369
"""
370370

371371
# pyre-fixme[16]: Module `pytree` has no attribute `tree_flatten`.
372-
args, pytree = ex_pytree.tree_flatten(raw_args)
372+
args, pytree = ex_pytree.tree_flatten((raw_args, {}))
373373

374374
if pytree.to_str() != self.container_metatype.encoded_inp_str:
375375
raise TypeError(

0 commit comments

Comments
 (0)