Skip to content

Commit 65418d2

Browse files
suoguilhermeleobas
authored andcommitted
[sigmoid] fix for FX tracing unflattened modules (pytorch#115708)
Differential Revision: [D52095387](https://our.internmc.facebook.com/intern/diff/D52095387/) Pull Request resolved: pytorch#115708 Approved by: https://github.com/zhxchen17
1 parent b5197f1 commit 65418d2

File tree

2 files changed

+51
-11
lines changed

2 files changed

+51
-11
lines changed

test/test_fx.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,21 @@ def forward(self, *args, **kwargs):
273273
t = T()
274274
self.checkGraphModule(t, (torch.rand(1), torch.rand(1)), {'foo': torch.rand(1)})
275275

276+
def test_varargs_concrete(self):
277+
class T(torch.nn.Module):
278+
def forward(self, *args, **kwargs):
279+
x = args[0] + args[1]
280+
return x
281+
282+
args = (torch.rand(1), torch.rand(1))
283+
284+
t = T()
285+
ref_outs = t(*args)
286+
gm = symbolic_trace(t, concrete_args=(torch.fx.PH, torch.fx.PH))
287+
gm.graph.lint()
288+
test_outs = gm(*args)
289+
self.assertEqual(ref_outs, test_outs)
290+
276291
def test_args_kwargs_no_self(self):
277292
class T(torch.nn.Module):
278293
def forward(*args, **kwargs): # noqa: B902

torch/fx/_symbolic_trace.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,17 @@ def __init__(self, ph_key: Optional[str] = None):
217217
self.ph_key = ph_key
218218

219219

220+
def _transfer_attrs(fr, to):
221+
for attr_name in dir(fr):
222+
attr_val = getattr(fr, attr_name)
223+
if (
224+
not callable(attr_val)
225+
and not attr_name.startswith("__")
226+
and not hasattr(to, attr_name)
227+
):
228+
setattr(to, attr_name, attr_val)
229+
230+
220231
@compatibility(is_backward_compatible=True)
221232
class Tracer(TracerBase):
222233
# Reference: https://github.com/pytorch/pytorch/issues/54354
@@ -597,16 +608,6 @@ def replace_ph(x):
597608
"placeholder", f"{name}_{str(cnt)}", default, {}
598609
)
599610
if isinstance(x, PHBase):
600-
def transfer_attrs(fr, to):
601-
for attr_name in dir(fr):
602-
attr_val = getattr(fr, attr_name)
603-
if (
604-
not callable(attr_val)
605-
and not attr_name.startswith("__")
606-
and not hasattr(to, attr_name)
607-
):
608-
setattr(to, attr_name, attr_val)
609-
610611
if x != PH:
611612
# Transfer attrs in the case where you're using a placeholder other
612613
# than the singleton PH (PH has no attributes to transfer).
@@ -615,7 +616,7 @@ def transfer_attrs(fr, to):
615616
# attributes set by the user) from the placeholder to the
616617
# underlying nodes (the proxy is unwrapped by the user, but
617618
# the metadata should hold).
618-
transfer_attrs(fr=x, to=out.node)
619+
_transfer_attrs(fr=x, to=out.node)
619620

620621
return out
621622
# Union[int, bool] == bool in Python <= 3.6
@@ -657,6 +658,30 @@ def transfer_attrs(fr, to):
657658
type_expr=fn_for_analysis.__annotations__.get(name, None)
658659
)
659660

661+
# This covers the very specific case where we are passing in flat
662+
# concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
663+
# In this case, just take the concrete_args and pass them through.
664+
name_idx = 0
665+
if isinstance(concrete_args, tuple) and \
666+
len(concrete_args) > 0 and \
667+
(co.co_flags & HAS_VARSTUFF) and \
668+
total_args == 1:
669+
for concrete_arg in concrete_args:
670+
out = self.create_proxy("placeholder", f"input_{name_idx}", (), {})
671+
if isinstance(concrete_arg, PHBase):
672+
if concrete_arg != PH:
673+
# Transfer attrs in the case where you're using a placeholder other
674+
# than the singleton PH (PH has no attributes to transfer).
675+
# Proxies were created out of the placeholders.
676+
# Transfer any metadata (put on the placeholders in the form of
677+
# attributes set by the user) from the placeholder to the
678+
# underlying nodes (the proxy is unwrapped by the user, but
679+
# the metadata should hold).
680+
_transfer_attrs(fr=concrete_arg, to=out.node)
681+
args.append(out)
682+
name_idx += 1
683+
return root_fn, args
684+
660685
arg_names = [next(names_iter) for idx in range(skip_arg_idx, total_args)]
661686
if isinstance(concrete_args, tuple):
662687
if len(arg_names) != len(concrete_args):

0 commit comments

Comments
 (0)