@@ -217,6 +217,17 @@ def __init__(self, ph_key: Optional[str] = None):
217
217
self .ph_key = ph_key
218
218
219
219
220
+ def _transfer_attrs (fr , to ):
221
+ for attr_name in dir (fr ):
222
+ attr_val = getattr (fr , attr_name )
223
+ if (
224
+ not callable (attr_val )
225
+ and not attr_name .startswith ("__" )
226
+ and not hasattr (to , attr_name )
227
+ ):
228
+ setattr (to , attr_name , attr_val )
229
+
230
+
220
231
@compatibility (is_backward_compatible = True )
221
232
class Tracer (TracerBase ):
222
233
# Reference: https://github.com/pytorch/pytorch/issues/54354
@@ -597,16 +608,6 @@ def replace_ph(x):
597
608
"placeholder" , f"{ name } _{ str (cnt )} " , default , {}
598
609
)
599
610
if isinstance (x , PHBase ):
600
- def transfer_attrs (fr , to ):
601
- for attr_name in dir (fr ):
602
- attr_val = getattr (fr , attr_name )
603
- if (
604
- not callable (attr_val )
605
- and not attr_name .startswith ("__" )
606
- and not hasattr (to , attr_name )
607
- ):
608
- setattr (to , attr_name , attr_val )
609
-
610
611
if x != PH :
611
612
# Transfer attrs in the case where you're using a placeholder other
612
613
# than the singleton PH (PH has no attributes to transfer).
@@ -615,7 +616,7 @@ def transfer_attrs(fr, to):
615
616
# attributes set by the user) from the placeholder to the
616
617
# underlying nodes (the proxy is unwrapped by the user, but
617
618
# the metadata should hold).
618
- transfer_attrs (fr = x , to = out .node )
619
+ _transfer_attrs (fr = x , to = out .node )
619
620
620
621
return out
621
622
# Union[int, bool] == bool in Python <= 3.6
@@ -657,6 +658,30 @@ def transfer_attrs(fr, to):
657
658
type_expr = fn_for_analysis .__annotations__ .get (name , None )
658
659
)
659
660
661
+ # This covers the very specific case where we are passing in flat
662
+ # concrete_args as a tuple, but our traced fn takes (*args, **kwargs).
663
+ # In this case, just take the concrete_args and pass them through.
664
+ name_idx = 0
665
+ if isinstance (concrete_args , tuple ) and \
666
+ len (concrete_args ) > 0 and \
667
+ (co .co_flags & HAS_VARSTUFF ) and \
668
+ total_args == 1 :
669
+ for concrete_arg in concrete_args :
670
+ out = self .create_proxy ("placeholder" , f"input_{ name_idx } " , (), {})
671
+ if isinstance (concrete_arg , PHBase ):
672
+ if concrete_arg != PH :
673
+ # Transfer attrs in the case where you're using a placeholder other
674
+ # than the singleton PH (PH has no attributes to transfer).
675
+ # Proxies were created out of the placeholders.
676
+ # Transfer any metadata (put on the placeholders in the form of
677
+ # attributes set by the user) from the placeholder to the
678
+ # underlying nodes (the proxy is unwrapped by the user, but
679
+ # the metadata should hold).
680
+ _transfer_attrs (fr = concrete_arg , to = out .node )
681
+ args .append (out )
682
+ name_idx += 1
683
+ return root_fn , args
684
+
660
685
arg_names = [next (names_iter ) for idx in range (skip_arg_idx , total_args )]
661
686
if isinstance (concrete_args , tuple ):
662
687
if len (arg_names ) != len (concrete_args ):
0 commit comments