@@ -375,6 +375,11 @@ def __init__(
375
375
self .kwargs = kwargs
376
376
self .input_types = [inp .type for inp in inputs ]
377
377
self .output_types = [out .type for out in outputs ]
378
+
379
+ self .lop_overrides = lop_overrides
380
+ self .grad_overrides = grad_overrides
381
+ self .rop_overrides = rop_overrides
382
+
378
383
if lop_overrides != "default" :
379
384
if grad_overrides != "default" :
380
385
raise ValueError (
@@ -732,19 +737,71 @@ def R_op(self, inputs, eval_points):
732
737
]
733
738
return ret_l
734
739
740
+ def __call__ (self , * inputs , ** kwargs ):
741
+ # The user interface doesn't expect the shared variable inputs of the
742
+ # inner-graph, but, since `Op.make_node` does (and `Op.__call__`
743
+ # dispatches to `Op.make_node`), we need to compensate here
744
+ num_expected_inps = len (self .inner_inputs ) - len (self .shared_inputs )
745
+
746
+ if len (inputs ) == num_expected_inps :
747
+ actual_inputs = inputs + tuple (self .shared_inputs )
748
+ return super ().__call__ (* actual_inputs , ** kwargs )
749
+ elif len (inputs ) == len (self .inner_inputs ):
750
+ return super ().__call__ (* inputs , ** kwargs )
751
+ else :
752
+ raise ValueError (f"Expected at least { num_expected_inps } input(s)" )
753
+
735
754
def make_node (self , * inputs ):
755
+ # The `inputs` received here should correspond to the inputs in the
756
+ # `Apply` nodes we produce below
757
+ if len (inputs ) != len (self .inner_inputs ):
758
+ raise ValueError (f"Expected { len (self .inner_inputs )} input(s)" )
759
+
736
760
num_expected_inps = len (self .inner_inputs ) - len (self .shared_inputs )
737
- if len (inputs ) != num_expected_inps :
738
- raise ValueError (
739
- f"Expected { int (num_expected_inps )} inputs, got { len (inputs )} "
740
- )
741
- inputs = [
742
- inp_t .filter_variable (inp ) for inp , inp_t in zip (inputs , self .input_types )
761
+ non_shared_inputs = inputs [:num_expected_inps ]
762
+
763
+ non_shared_inputs = [
764
+ inp_t .filter_variable (inp )
765
+ for inp , inp_t in zip (non_shared_inputs , self .input_types )
743
766
]
767
+
768
+ shared_inputs = inputs [num_expected_inps :]
769
+ local_shared_inputs = self .inner_inputs [num_expected_inps :]
770
+
771
+ inner_and_input_shareds = list (zip (local_shared_inputs , shared_inputs ))
772
+
773
+ if not all (inp_s == inn_s for inn_s , inp_s in inner_and_input_shareds ):
774
+ # The shared variables are not equal to the original shared
775
+ # variables, so we construct a new `Op` that uses the new shared
776
+ # variables instead
777
+ replace = {
778
+ old_inp : new_inp for old_inp , new_inp in zip (self .inner_inputs , inputs )
779
+ }
780
+ replace .update (inner_and_input_shareds )
781
+
782
+ # If the new shared variables are inconsistent with the inner-graph,
783
+ # such errors should arise in this step
784
+ new_outputs = clone_replace (
785
+ self .inner_outputs , replace = replace , share_inputs = True
786
+ )
787
+
788
+ new_op = type (self )(
789
+ inputs = non_shared_inputs ,
790
+ outputs = new_outputs ,
791
+ inline = self .is_inline ,
792
+ lop_overrides = self .lop_overrides ,
793
+ grad_overrides = self .grad_overrides ,
794
+ rop_overrides = self .rop_overrides ,
795
+ connection_pattern = self ._connection_pattern ,
796
+ name = self .name ,
797
+ )
798
+ else :
799
+ new_op = self
800
+
744
801
apply_node = Apply (
745
- self ,
746
- list (inputs ) + self .shared_inputs ,
747
- [type () for type in self .output_types ],
802
+ new_op ,
803
+ list (non_shared_inputs ) + new_op .shared_inputs ,
804
+ [type () for type in new_op .output_types ],
748
805
)
749
806
return apply_node
750
807
0 commit comments