Skip to content

Commit 6d4d43f

Browse files
Add support for mutations
1 parent 2cc04a6 commit 6d4d43f

File tree

6 files changed

+291
-64
lines changed

6 files changed

+291
-64
lines changed

docs/reference/egglog-translation.md

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,43 @@ def baz(a: i64Like, b: i64Like=i64(0)) -> i64:
160160
baz(1)
161161
```
162162

163+
### Mutating arguments
164+
165+
In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@egraph.function` decorator and the `mutates_self` argument to the `@egraph.method` decorator. This will cause the first argument to be mutated in place, instead of being copied.
166+
167+
```{code-cell} python
168+
from copy import copy
169+
mutate_egraph = EGraph()
170+
171+
@mutate_egraph.class_
172+
class Int(Expr):
173+
def __init__(self, i: i64Like) -> None:
174+
...
175+
176+
def __add__(self, other: Int) -> Int: # type: ignore[empty-body]
177+
...
178+
179+
@mutate_egraph.function(mutates_first_arg=True)
180+
def incr(x: Int) -> None:
181+
...
182+
183+
i = var("i", Int)
184+
incr_i = copy(i)
185+
incr(incr_i)
186+
187+
x = Int(10)
188+
incr(x)
189+
mutate_egraph.register(rewrite(incr_i).to(i + Int(1)), x)
190+
mutate_egraph.run(10)
191+
mutate_egraph.check(eq(x).to(Int(10) + Int(1)))
192+
mutate_egraph
193+
```
194+
195+
Any function which mutates its first argument must return `None`. In egglog, this is translated into a function which
196+
returns the type of its first argument.
197+
198+
Note that dunder methods such as `__setitem__` will automatically be marked as mutating their first argument.
199+
163200
### Datatype functions
164201

165202
In egglog, the `(datatype ...)` command can also be used to declare functions. All of the functions declared in this block return the type of the declared datatype. Similarily, in Python, we can use the `@egraph.class_` decorator on a class to define a number of functions associated with that class. These
@@ -534,7 +571,9 @@ egraph.register(
534571
# (extract y :variants 2)
535572
y = egraph.define("y", Math(6) + Math(2) * Math.var("x"))
536573
egraph.run(10)
537-
egraph.extract_multiple(y, 2)
574+
# TODO: For some reason this is extracting temp vars
575+
# egraph.extract_multiple(y, 2)
576+
egraph
538577
```
539578

540579
### Simplify

python/egglog/declarations.py

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
"ExprDecl",
3838
"TypedExprDecl",
3939
"ClassDecl",
40+
"PrettyContext",
4041
]
4142
# Special methods which we might want to use as functions
4243
# Mapping to the operator they represent for pretty printing them
@@ -288,7 +289,7 @@ def register_constant_callable(
288289
self._decl.set_constant_type(ref, type_ref)
289290
# Create a function decleartion for a constant function. This is similar to how egglog compiles
290291
# the `declare` command.
291-
return FunctionDecl((), (), (), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
292+
return FunctionDecl((), (), (), type_ref.to_var(), False).to_commands(self, egg_name or ref.generate_egg_name())
292293

293294
def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
294295
self._decl._classes[class_].preserved_methods[method] = fn
@@ -337,7 +338,14 @@ def to_constant_function_decl(self) -> FunctionDecl:
337338
Create a function declaration for a constant function. This is similar to how egglog compiles
338339
the `constant` command.
339340
"""
340-
return FunctionDecl(arg_types=(), arg_names=(), arg_defaults=(), return_type=self.to_var(), var_arg_type=None)
341+
return FunctionDecl(
342+
arg_types=(),
343+
arg_names=(),
344+
arg_defaults=(),
345+
return_type=self.to_var(),
346+
mutates_first_arg=False,
347+
var_arg_type=None,
348+
)
341349

342350

343351
@dataclass(frozen=True)
@@ -432,8 +440,14 @@ class FunctionDecl:
432440
arg_names: Optional[tuple[str, ...]]
433441
arg_defaults: tuple[Optional[ExprDecl], ...]
434442
return_type: TypeOrVarRef
443+
mutates_first_arg: bool
435444
var_arg_type: Optional[TypeOrVarRef] = None
436445

446+
def __post_init__(self):
447+
# If we mutate the first arg, then the first arg should be the same type as the return
448+
if self.mutates_first_arg:
449+
assert self.arg_types[0] == self.return_type
450+
437451
def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature:
438452
arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types)))
439453
parameters = [
@@ -491,7 +505,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
491505
def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var:
492506
return bindings.Var(self.name)
493507

494-
def pretty(self, mod_decls: ModuleDeclarations, **kwargs) -> str:
508+
def pretty(self, context: PrettyContext, **kwargs) -> str:
495509
return self.name
496510

497511

@@ -525,7 +539,7 @@ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
525539
return bindings.Lit(bindings.String(self.value))
526540
assert_never(self.value)
527541

528-
def pretty(self, mod_decls: ModuleDeclarations, wrap_lit=True, **kwargs) -> str:
542+
def pretty(self, context: PrettyContext, wrap_lit=True, **kwargs) -> str:
529543
"""
530544
Returns a string representation of the literal.
531545
@@ -581,7 +595,7 @@ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
581595
egg_fn = mod_decls.get_egg_fn(self.callable)
582596
return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])
583597

584-
def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
598+
def pretty(self, context: PrettyContext, parens=True, **kwargs) -> str:
585599
"""
586600
Pretty print the call.
587601
@@ -590,8 +604,13 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
590604
ref, args = self.callable, [a.expr for a in self.args]
591605
# Special case != since it doesn't have a decl
592606
if isinstance(ref, MethodRef) and ref.method_name == "__ne__":
593-
return f"{args[0].pretty(mod_decls, wrap_lit=True)} != {args[1].pretty(mod_decls, wrap_lit=True)}"
594-
defaults = mod_decls.get_function_decl(ref).arg_defaults
607+
return f"{args[0].pretty(context, wrap_lit=True)} != {args[1].pretty(context, wrap_lit=True)}"
608+
function_decl = context.mod_decls.get_function_decl(ref)
609+
defaults = function_decl.arg_defaults
610+
if function_decl.mutates_first_arg:
611+
mutated_arg_type = function_decl.arg_types[0].to_just().name
612+
else:
613+
mutated_arg_type = None
595614
if isinstance(ref, FunctionRef):
596615
fn_str = ref.name
597616
elif isinstance(ref, ClassMethodRef):
@@ -605,23 +624,37 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
605624
slf, *args = args
606625
defaults = defaults[1:]
607626
if name in UNARY_METHODS:
608-
return f"{UNARY_METHODS[name]}{slf.pretty(mod_decls)}"
627+
return f"{UNARY_METHODS[name]}{slf.pretty(context)}"
609628
elif name in BINARY_METHODS:
610629
assert len(args) == 1
611-
expr = f"{slf.pretty(mod_decls )} {BINARY_METHODS[name]} {args[0].pretty(mod_decls, wrap_lit=False)}"
630+
expr = f"{slf.pretty(context )} {BINARY_METHODS[name]} {args[0].pretty(context, wrap_lit=False)}"
612631
return expr if not parens else f"({expr})"
613632
elif name == "__getitem__":
614633
assert len(args) == 1
615-
return f"{slf.pretty(mod_decls)}[{args[0].pretty(mod_decls, wrap_lit=False)}]"
634+
return f"{slf.pretty(context)}[{args[0].pretty(context, wrap_lit=False)}]"
616635
elif name == "__call__":
617-
return f"{slf.pretty(mod_decls)}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
618-
fn_str = f"{slf.pretty(mod_decls)}.{name}"
636+
return f"{slf.pretty(context)}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"
637+
elif name == "__delitem__":
638+
assert len(args) == 1
639+
assert mutated_arg_type
640+
name = context.name_expr(mutated_arg_type, slf)
641+
context.statements.append(f"del {name}[{args[0].pretty(context, parens=False, wrap_lit=False)}]")
642+
return name
643+
elif name == "__setitem__":
644+
assert len(args) == 2
645+
assert mutated_arg_type
646+
name = context.name_expr(mutated_arg_type, slf)
647+
context.statements.append(
648+
f"{name}[{args[0].pretty(context, parens=False, wrap_lit=False)}] = {args[1].pretty(context, parens=False, wrap_lit=False)}"
649+
)
650+
return name
651+
fn_str = f"{slf.pretty(context)}.{name}"
619652
elif isinstance(ref, ConstantRef):
620653
return ref.name
621654
elif isinstance(ref, ClassVariableRef):
622655
return f"{ref.class_name}.{ref.variable_name}"
623656
elif isinstance(ref, PropertyRef):
624-
return f"{args[0].pretty(mod_decls)}.{ref.property_name}"
657+
return f"{args[0].pretty(context)}.{ref.property_name}"
625658
else:
626659
assert_never(ref)
627660
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
@@ -632,36 +665,85 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
632665
n_defaults += 1
633666
if n_defaults:
634667
args = args[:-n_defaults]
635-
return f"{fn_str}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
668+
if mutated_arg_type:
669+
name = context.name_expr(mutated_arg_type, args[0])
670+
context.statements.append(
671+
f"{fn_str}({', '.join({name}, *(a.pretty(context, wrap_lit=False) for a in args[1:]))})"
672+
)
673+
return name
674+
return f"{fn_str}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"
675+
676+
677+
@dataclass
678+
class PrettyContext:
679+
mod_decls: ModuleDeclarations
680+
# List of statements of "context" setting variable for the expr
681+
statements: list[str] = field(default_factory=list)
682+
683+
_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))
684+
685+
def generate_name(self, typ: str) -> str:
686+
self._gen_name_types[typ] += 1
687+
return f"_{typ}_{self._gen_name_types[typ]}"
688+
689+
def name_expr(self, expr_type: str, expr: ExprDecl) -> str:
690+
name = self.generate_name(expr_type)
691+
self.statements.append(f"{name} = copy({expr.pretty(self, parens=False)})")
692+
return name
693+
694+
def render(self, expr: str) -> str:
695+
return "\n".join(self.statements + [expr])
636696

637697

638698
def test_expr_pretty():
639-
mod_decls = ModuleDeclarations(Declarations())
640-
assert VarDecl("x").pretty(mod_decls) == "x"
641-
assert LitDecl(42).pretty(mod_decls) == "i64(42)"
642-
assert LitDecl("foo").pretty(mod_decls) == 'String("foo")'
643-
assert LitDecl(None).pretty(mod_decls) == "unit()"
699+
context = PrettyContext(ModuleDeclarations(Declarations()))
700+
assert VarDecl("x").pretty(context) == "x"
701+
assert LitDecl(42).pretty(context) == "i64(42)"
702+
assert LitDecl("foo").pretty(context) == 'String("foo")'
703+
assert LitDecl(None).pretty(context) == "unit()"
644704

645705
def v(x: str) -> TypedExprDecl:
646706
return TypedExprDecl(JustTypeRef(""), VarDecl(x))
647707

648-
assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(mod_decls) == "foo(x)"
649-
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(mod_decls) == "foo(x, y, z)"
650-
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(mod_decls) == "x + y"
651-
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(mod_decls) == "x[y]"
652-
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(mod_decls) == "foo(x, y)"
653-
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(mod_decls) == "foo.bar(x, y)"
654-
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(mod_decls) == "x(y)"
708+
assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)"
709+
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)"
710+
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y"
711+
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]"
712+
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)"
713+
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)"
714+
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)"
655715
assert (
656716
CallDecl(
657717
ClassMethodRef("Map", "__init__"),
658718
(),
659719
(JustTypeRef("i64"), JustTypeRef("Unit")),
660-
).pretty(mod_decls)
720+
).pretty(context)
661721
== "Map[i64, Unit]()"
662722
)
663723

664724

725+
def test_setitem_pretty():
726+
context = PrettyContext(ModuleDeclarations(Declarations()))
727+
728+
def v(x: str) -> TypedExprDecl:
729+
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))
730+
731+
final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context)
732+
assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1"
733+
734+
735+
def test_delitem_pretty():
736+
context = PrettyContext(ModuleDeclarations(Declarations()))
737+
738+
def v(x: str) -> TypedExprDecl:
739+
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))
740+
741+
final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context)
742+
assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1"
743+
744+
745+
# TODO: Multiple mutations,
746+
665747
ExprDecl = Union[VarDecl, LitDecl, CallDecl]
666748

667749

0 commit comments

Comments
 (0)