37
37
"ExprDecl" ,
38
38
"TypedExprDecl" ,
39
39
"ClassDecl" ,
40
+ "PrettyContext" ,
40
41
]
41
42
# Special methods which we might want to use as functions
42
43
# Mapping to the operator they represent for pretty printing them
@@ -288,7 +289,7 @@ def register_constant_callable(
288
289
self ._decl .set_constant_type (ref , type_ref )
289
290
# Create a function decleartion for a constant function. This is similar to how egglog compiles
290
291
# the `declare` command.
291
- return FunctionDecl ((), (), (), type_ref .to_var ()).to_commands (self , egg_name or ref .generate_egg_name ())
292
+ return FunctionDecl ((), (), (), type_ref .to_var (), False ).to_commands (self , egg_name or ref .generate_egg_name ())
292
293
293
294
def register_preserved_method (self , class_ : str , method : str , fn : Callable ) -> None :
294
295
self ._decl ._classes [class_ ].preserved_methods [method ] = fn
@@ -337,7 +338,14 @@ def to_constant_function_decl(self) -> FunctionDecl:
337
338
Create a function declaration for a constant function. This is similar to how egglog compiles
338
339
the `constant` command.
339
340
"""
340
- return FunctionDecl (arg_types = (), arg_names = (), arg_defaults = (), return_type = self .to_var (), var_arg_type = None )
341
+ return FunctionDecl (
342
+ arg_types = (),
343
+ arg_names = (),
344
+ arg_defaults = (),
345
+ return_type = self .to_var (),
346
+ mutates_first_arg = False ,
347
+ var_arg_type = None ,
348
+ )
341
349
342
350
343
351
@dataclass (frozen = True )
@@ -432,8 +440,14 @@ class FunctionDecl:
432
440
arg_names : Optional [tuple [str , ...]]
433
441
arg_defaults : tuple [Optional [ExprDecl ], ...]
434
442
return_type : TypeOrVarRef
443
+ mutates_first_arg : bool
435
444
var_arg_type : Optional [TypeOrVarRef ] = None
436
445
446
+ def __post_init__ (self ):
447
+ # If we mutate the first arg, then the first arg should be the same type as the return
448
+ if self .mutates_first_arg :
449
+ assert self .arg_types [0 ] == self .return_type
450
+
437
451
def to_signature (self , transform_default : Callable [[TypedExprDecl ], object ]) -> Signature :
438
452
arg_names = self .arg_names or tuple (f"__{ i } " for i in range (len (self .arg_types )))
439
453
parameters = [
@@ -491,7 +505,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
491
505
def to_egg (self , _decls : ModuleDeclarations ) -> bindings .Var :
492
506
return bindings .Var (self .name )
493
507
494
- def pretty (self , mod_decls : ModuleDeclarations , ** kwargs ) -> str :
508
+ def pretty (self , context : PrettyContext , ** kwargs ) -> str :
495
509
return self .name
496
510
497
511
@@ -525,7 +539,7 @@ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
525
539
return bindings .Lit (bindings .String (self .value ))
526
540
assert_never (self .value )
527
541
528
- def pretty (self , mod_decls : ModuleDeclarations , wrap_lit = True , ** kwargs ) -> str :
542
+ def pretty (self , context : PrettyContext , wrap_lit = True , ** kwargs ) -> str :
529
543
"""
530
544
Returns a string representation of the literal.
531
545
@@ -581,7 +595,7 @@ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
581
595
egg_fn = mod_decls .get_egg_fn (self .callable )
582
596
return bindings .Call (egg_fn , [a .to_egg (mod_decls ) for a in self .args ])
583
597
584
- def pretty (self , mod_decls : ModuleDeclarations , parens = True , ** kwargs ) -> str :
598
+ def pretty (self , context : PrettyContext , parens = True , ** kwargs ) -> str :
585
599
"""
586
600
Pretty print the call.
587
601
@@ -590,8 +604,13 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
590
604
ref , args = self .callable , [a .expr for a in self .args ]
591
605
# Special case != since it doesn't have a decl
592
606
if isinstance (ref , MethodRef ) and ref .method_name == "__ne__" :
593
- return f"{ args [0 ].pretty (mod_decls , wrap_lit = True )} != { args [1 ].pretty (mod_decls , wrap_lit = True )} "
594
- defaults = mod_decls .get_function_decl (ref ).arg_defaults
607
+ return f"{ args [0 ].pretty (context , wrap_lit = True )} != { args [1 ].pretty (context , wrap_lit = True )} "
608
+ function_decl = context .mod_decls .get_function_decl (ref )
609
+ defaults = function_decl .arg_defaults
610
+ if function_decl .mutates_first_arg :
611
+ mutated_arg_type = function_decl .arg_types [0 ].to_just ().name
612
+ else :
613
+ mutated_arg_type = None
595
614
if isinstance (ref , FunctionRef ):
596
615
fn_str = ref .name
597
616
elif isinstance (ref , ClassMethodRef ):
@@ -605,23 +624,37 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
605
624
slf , * args = args
606
625
defaults = defaults [1 :]
607
626
if name in UNARY_METHODS :
608
- return f"{ UNARY_METHODS [name ]} { slf .pretty (mod_decls )} "
627
+ return f"{ UNARY_METHODS [name ]} { slf .pretty (context )} "
609
628
elif name in BINARY_METHODS :
610
629
assert len (args ) == 1
611
- expr = f"{ slf .pretty (mod_decls )} { BINARY_METHODS [name ]} { args [0 ].pretty (mod_decls , wrap_lit = False )} "
630
+ expr = f"{ slf .pretty (context )} { BINARY_METHODS [name ]} { args [0 ].pretty (context , wrap_lit = False )} "
612
631
return expr if not parens else f"({ expr } )"
613
632
elif name == "__getitem__" :
614
633
assert len (args ) == 1
615
- return f"{ slf .pretty (mod_decls )} [{ args [0 ].pretty (mod_decls , wrap_lit = False )} ]"
634
+ return f"{ slf .pretty (context )} [{ args [0 ].pretty (context , wrap_lit = False )} ]"
616
635
elif name == "__call__" :
617
- return f"{ slf .pretty (mod_decls )} ({ ', ' .join (a .pretty (mod_decls , wrap_lit = False ) for a in args )} )"
618
- fn_str = f"{ slf .pretty (mod_decls )} .{ name } "
636
+ return f"{ slf .pretty (context )} ({ ', ' .join (a .pretty (context , wrap_lit = False ) for a in args )} )"
637
+ elif name == "__delitem__" :
638
+ assert len (args ) == 1
639
+ assert mutated_arg_type
640
+ name = context .name_expr (mutated_arg_type , slf )
641
+ context .statements .append (f"del { name } [{ args [0 ].pretty (context , parens = False , wrap_lit = False )} ]" )
642
+ return name
643
+ elif name == "__setitem__" :
644
+ assert len (args ) == 2
645
+ assert mutated_arg_type
646
+ name = context .name_expr (mutated_arg_type , slf )
647
+ context .statements .append (
648
+ f"{ name } [{ args [0 ].pretty (context , parens = False , wrap_lit = False )} ] = { args [1 ].pretty (context , parens = False , wrap_lit = False )} "
649
+ )
650
+ return name
651
+ fn_str = f"{ slf .pretty (context )} .{ name } "
619
652
elif isinstance (ref , ConstantRef ):
620
653
return ref .name
621
654
elif isinstance (ref , ClassVariableRef ):
622
655
return f"{ ref .class_name } .{ ref .variable_name } "
623
656
elif isinstance (ref , PropertyRef ):
624
- return f"{ args [0 ].pretty (mod_decls )} .{ ref .property_name } "
657
+ return f"{ args [0 ].pretty (context )} .{ ref .property_name } "
625
658
else :
626
659
assert_never (ref )
627
660
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
@@ -632,36 +665,85 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
632
665
n_defaults += 1
633
666
if n_defaults :
634
667
args = args [:- n_defaults ]
635
- return f"{ fn_str } ({ ', ' .join (a .pretty (mod_decls , wrap_lit = False ) for a in args )} )"
668
+ if mutated_arg_type :
669
+ name = context .name_expr (mutated_arg_type , args [0 ])
670
+ context .statements .append (
671
+ f"{ fn_str } ({ ', ' .join ({name }, * (a .pretty (context , wrap_lit = False ) for a in args [1 :]))} )"
672
+ )
673
+ return name
674
+ return f"{ fn_str } ({ ', ' .join (a .pretty (context , wrap_lit = False ) for a in args )} )"
675
+
676
+
677
+ @dataclass
678
+ class PrettyContext :
679
+ mod_decls : ModuleDeclarations
680
+ # List of statements of "context" setting variable for the expr
681
+ statements : list [str ] = field (default_factory = list )
682
+
683
+ _gen_name_types : dict [str , int ] = field (default_factory = lambda : defaultdict (lambda : 0 ))
684
+
685
+ def generate_name (self , typ : str ) -> str :
686
+ self ._gen_name_types [typ ] += 1
687
+ return f"_{ typ } _{ self ._gen_name_types [typ ]} "
688
+
689
+ def name_expr (self , expr_type : str , expr : ExprDecl ) -> str :
690
+ name = self .generate_name (expr_type )
691
+ self .statements .append (f"{ name } = copy({ expr .pretty (self , parens = False )} )" )
692
+ return name
693
+
694
+ def render (self , expr : str ) -> str :
695
+ return "\n " .join (self .statements + [expr ])
636
696
637
697
638
698
def test_expr_pretty ():
639
- mod_decls = ModuleDeclarations (Declarations ())
640
- assert VarDecl ("x" ).pretty (mod_decls ) == "x"
641
- assert LitDecl (42 ).pretty (mod_decls ) == "i64(42)"
642
- assert LitDecl ("foo" ).pretty (mod_decls ) == 'String("foo")'
643
- assert LitDecl (None ).pretty (mod_decls ) == "unit()"
699
+ context = PrettyContext ( ModuleDeclarations (Declarations () ))
700
+ assert VarDecl ("x" ).pretty (context ) == "x"
701
+ assert LitDecl (42 ).pretty (context ) == "i64(42)"
702
+ assert LitDecl ("foo" ).pretty (context ) == 'String("foo")'
703
+ assert LitDecl (None ).pretty (context ) == "unit()"
644
704
645
705
def v (x : str ) -> TypedExprDecl :
646
706
return TypedExprDecl (JustTypeRef ("" ), VarDecl (x ))
647
707
648
- assert CallDecl (FunctionRef ("foo" ), (v ("x" ),)).pretty (mod_decls ) == "foo(x)"
649
- assert CallDecl (FunctionRef ("foo" ), (v ("x" ), v ("y" ), v ("z" ))).pretty (mod_decls ) == "foo(x, y, z)"
650
- assert CallDecl (MethodRef ("foo" , "__add__" ), (v ("x" ), v ("y" ))).pretty (mod_decls ) == "x + y"
651
- assert CallDecl (MethodRef ("foo" , "__getitem__" ), (v ("x" ), v ("y" ))).pretty (mod_decls ) == "x[y]"
652
- assert CallDecl (ClassMethodRef ("foo" , "__init__" ), (v ("x" ), v ("y" ))).pretty (mod_decls ) == "foo(x, y)"
653
- assert CallDecl (ClassMethodRef ("foo" , "bar" ), (v ("x" ), v ("y" ))).pretty (mod_decls ) == "foo.bar(x, y)"
654
- assert CallDecl (MethodRef ("foo" , "__call__" ), (v ("x" ), v ("y" ))).pretty (mod_decls ) == "x(y)"
708
+ assert CallDecl (FunctionRef ("foo" ), (v ("x" ),)).pretty (context ) == "foo(x)"
709
+ assert CallDecl (FunctionRef ("foo" ), (v ("x" ), v ("y" ), v ("z" ))).pretty (context ) == "foo(x, y, z)"
710
+ assert CallDecl (MethodRef ("foo" , "__add__" ), (v ("x" ), v ("y" ))).pretty (context ) == "x + y"
711
+ assert CallDecl (MethodRef ("foo" , "__getitem__" ), (v ("x" ), v ("y" ))).pretty (context ) == "x[y]"
712
+ assert CallDecl (ClassMethodRef ("foo" , "__init__" ), (v ("x" ), v ("y" ))).pretty (context ) == "foo(x, y)"
713
+ assert CallDecl (ClassMethodRef ("foo" , "bar" ), (v ("x" ), v ("y" ))).pretty (context ) == "foo.bar(x, y)"
714
+ assert CallDecl (MethodRef ("foo" , "__call__" ), (v ("x" ), v ("y" ))).pretty (context ) == "x(y)"
655
715
assert (
656
716
CallDecl (
657
717
ClassMethodRef ("Map" , "__init__" ),
658
718
(),
659
719
(JustTypeRef ("i64" ), JustTypeRef ("Unit" )),
660
- ).pretty (mod_decls )
720
+ ).pretty (context )
661
721
== "Map[i64, Unit]()"
662
722
)
663
723
664
724
725
+ def test_setitem_pretty ():
726
+ context = PrettyContext (ModuleDeclarations (Declarations ()))
727
+
728
+ def v (x : str ) -> TypedExprDecl :
729
+ return TypedExprDecl (JustTypeRef ("typ" ), VarDecl (x ))
730
+
731
+ final_expr = CallDecl (MethodRef ("foo" , "__setitem__" ), (v ("x" ), v ("y" ), v ("z" ))).pretty (context )
732
+ assert context .render (final_expr ) == "_typ_1 = x\n _typ_1[y] = z\n _typ_1"
733
+
734
+
735
+ def test_delitem_pretty ():
736
+ context = PrettyContext (ModuleDeclarations (Declarations ()))
737
+
738
+ def v (x : str ) -> TypedExprDecl :
739
+ return TypedExprDecl (JustTypeRef ("typ" ), VarDecl (x ))
740
+
741
+ final_expr = CallDecl (MethodRef ("foo" , "__delitem__" ), (v ("x" ), v ("y" ))).pretty (context )
742
+ assert context .render (final_expr ) == "_typ_1 = x\n del _typ_1[y]\n _typ_1"
743
+
744
+
745
+ # TODO: Multiple mutations,
746
+
665
747
ExprDecl = Union [VarDecl , LitDecl , CallDecl ]
666
748
667
749
0 commit comments