Skip to content

Add support for mutations #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 10 commits into from
Aug 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
- Added to/from i64 to i64 methods.
- Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f))

- Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35)

## 0.5.1 (2023-07-18)

- Added support for negation on `f64` sort
Expand Down
41 changes: 40 additions & 1 deletion docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,43 @@ def baz(a: i64Like, b: i64Like=i64(0)) -> i64:
baz(1)
```

### Mutating arguments

In order to support Python functions and methods which mutate their arguments, you can pass in the `mutate_first_arg` keyword argument to the `@egraph.function` decorator and the `mutates_self` argument to the `@egraph.method` decorator. This will cause the first argument to be mutated in place, instead of being copied.

```{code-cell} python
from copy import copy
mutate_egraph = EGraph()

@mutate_egraph.class_
class Int(Expr):
def __init__(self, i: i64Like) -> None:
...

def __add__(self, other: Int) -> Int: # type: ignore[empty-body]
...

@mutate_egraph.function(mutates_first_arg=True)
def incr(x: Int) -> None:
...

i = var("i", Int)
incr_i = copy(i)
incr(incr_i)

x = Int(10)
incr(x)
mutate_egraph.register(rewrite(incr_i).to(i + Int(1)), x)
mutate_egraph.run(10)
mutate_egraph.check(eq(x).to(Int(10) + Int(1)))
mutate_egraph
```

Any function which mutates its first argument must return `None`. In egglog, this is translated into a function which
returns the type of its first argument.

Note that dunder methods such as `__setitem__` will automatically be marked as mutating their first argument.

### Datatype functions

In egglog, the `(datatype ...)` command can also be used to declare functions. All of the functions declared in this block return the type of the declared datatype. Similarily, in Python, we can use the `@egraph.class_` decorator on a class to define a number of functions associated with that class. These
Expand Down Expand Up @@ -534,7 +571,9 @@ egraph.register(
# (extract y :variants 2)
y = egraph.define("y", Math(6) + Math(2) * Math.var("x"))
egraph.run(10)
egraph.extract_multiple(y, 2)
# TODO: For some reason this is extracting temp vars
# egraph.extract_multiple(y, 2)
egraph
```

### Simplify
Expand Down
342 changes: 147 additions & 195 deletions docs/tutorials/array-api.ipynb

Large diffs are not rendered by default.

136 changes: 109 additions & 27 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
"ExprDecl",
"TypedExprDecl",
"ClassDecl",
"PrettyContext",
]
# Special methods which we might want to use as functions
# Mapping to the operator they represent for pretty printing them
Expand Down Expand Up @@ -288,7 +289,7 @@ def register_constant_callable(
self._decl.set_constant_type(ref, type_ref)
# Create a function decleartion for a constant function. This is similar to how egglog compiles
# the `declare` command.
return FunctionDecl((), (), (), type_ref.to_var()).to_commands(self, egg_name or ref.generate_egg_name())
return FunctionDecl((), (), (), type_ref.to_var(), False).to_commands(self, egg_name or ref.generate_egg_name())

def register_preserved_method(self, class_: str, method: str, fn: Callable) -> None:
self._decl._classes[class_].preserved_methods[method] = fn
Expand Down Expand Up @@ -337,7 +338,14 @@ def to_constant_function_decl(self) -> FunctionDecl:
Create a function declaration for a constant function. This is similar to how egglog compiles
the `constant` command.
"""
return FunctionDecl(arg_types=(), arg_names=(), arg_defaults=(), return_type=self.to_var(), var_arg_type=None)
return FunctionDecl(
arg_types=(),
arg_names=(),
arg_defaults=(),
return_type=self.to_var(),
mutates_first_arg=False,
var_arg_type=None,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -432,8 +440,14 @@ class FunctionDecl:
arg_names: Optional[tuple[str, ...]]
arg_defaults: tuple[Optional[ExprDecl], ...]
return_type: TypeOrVarRef
mutates_first_arg: bool
var_arg_type: Optional[TypeOrVarRef] = None

def __post_init__(self):
# If we mutate the first arg, then the first arg should be the same type as the return
if self.mutates_first_arg:
assert self.arg_types[0] == self.return_type

def to_signature(self, transform_default: Callable[[TypedExprDecl], object]) -> Signature:
arg_names = self.arg_names or tuple(f"__{i}" for i in range(len(self.arg_types)))
parameters = [
Expand Down Expand Up @@ -491,7 +505,7 @@ def from_egg(cls, var: bindings.Var) -> TypedExprDecl:
def to_egg(self, _decls: ModuleDeclarations) -> bindings.Var:
return bindings.Var(self.name)

def pretty(self, mod_decls: ModuleDeclarations, **kwargs) -> str:
def pretty(self, context: PrettyContext, **kwargs) -> str:
return self.name


Expand Down Expand Up @@ -525,7 +539,7 @@ def to_egg(self, _decls: ModuleDeclarations) -> bindings.Lit:
return bindings.Lit(bindings.String(self.value))
assert_never(self.value)

def pretty(self, mod_decls: ModuleDeclarations, wrap_lit=True, **kwargs) -> str:
def pretty(self, context: PrettyContext, wrap_lit=True, **kwargs) -> str:
"""
Returns a string representation of the literal.

Expand Down Expand Up @@ -581,7 +595,7 @@ def to_egg(self, mod_decls: ModuleDeclarations) -> bindings.Call:
egg_fn = mod_decls.get_egg_fn(self.callable)
return bindings.Call(egg_fn, [a.to_egg(mod_decls) for a in self.args])

def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
def pretty(self, context: PrettyContext, parens=True, **kwargs) -> str:
"""
Pretty print the call.

Expand All @@ -590,8 +604,13 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
ref, args = self.callable, [a.expr for a in self.args]
# Special case != since it doesn't have a decl
if isinstance(ref, MethodRef) and ref.method_name == "__ne__":
return f"{args[0].pretty(mod_decls, wrap_lit=True)} != {args[1].pretty(mod_decls, wrap_lit=True)}"
defaults = mod_decls.get_function_decl(ref).arg_defaults
return f"{args[0].pretty(context, wrap_lit=True)} != {args[1].pretty(context, wrap_lit=True)}"
function_decl = context.mod_decls.get_function_decl(ref)
defaults = function_decl.arg_defaults
if function_decl.mutates_first_arg:
mutated_arg_type = function_decl.arg_types[0].to_just().name
else:
mutated_arg_type = None
if isinstance(ref, FunctionRef):
fn_str = ref.name
elif isinstance(ref, ClassMethodRef):
Expand All @@ -605,23 +624,37 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
slf, *args = args
defaults = defaults[1:]
if name in UNARY_METHODS:
return f"{UNARY_METHODS[name]}{slf.pretty(mod_decls)}"
return f"{UNARY_METHODS[name]}{slf.pretty(context)}"
elif name in BINARY_METHODS:
assert len(args) == 1
expr = f"{slf.pretty(mod_decls )} {BINARY_METHODS[name]} {args[0].pretty(mod_decls, wrap_lit=False)}"
expr = f"{slf.pretty(context )} {BINARY_METHODS[name]} {args[0].pretty(context, wrap_lit=False)}"
return expr if not parens else f"({expr})"
elif name == "__getitem__":
assert len(args) == 1
return f"{slf.pretty(mod_decls)}[{args[0].pretty(mod_decls, wrap_lit=False)}]"
return f"{slf.pretty(context)}[{args[0].pretty(context, wrap_lit=False)}]"
elif name == "__call__":
return f"{slf.pretty(mod_decls)}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
fn_str = f"{slf.pretty(mod_decls)}.{name}"
return f"{slf.pretty(context)}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"
elif name == "__delitem__":
assert len(args) == 1
assert mutated_arg_type
name = context.name_expr(mutated_arg_type, slf)
context.statements.append(f"del {name}[{args[0].pretty(context, parens=False, wrap_lit=False)}]")
return name
elif name == "__setitem__":
assert len(args) == 2
assert mutated_arg_type
name = context.name_expr(mutated_arg_type, slf)
context.statements.append(
f"{name}[{args[0].pretty(context, parens=False, wrap_lit=False)}] = {args[1].pretty(context, parens=False, wrap_lit=False)}"
)
return name
fn_str = f"{slf.pretty(context)}.{name}"
elif isinstance(ref, ConstantRef):
return ref.name
elif isinstance(ref, ClassVariableRef):
return f"{ref.class_name}.{ref.variable_name}"
elif isinstance(ref, PropertyRef):
return f"{args[0].pretty(mod_decls)}.{ref.property_name}"
return f"{args[0].pretty(context)}.{ref.property_name}"
else:
assert_never(ref)
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
Expand All @@ -632,36 +665,85 @@ def pretty(self, mod_decls: ModuleDeclarations, parens=True, **kwargs) -> str:
n_defaults += 1
if n_defaults:
args = args[:-n_defaults]
return f"{fn_str}({', '.join(a.pretty(mod_decls, wrap_lit=False) for a in args)})"
if mutated_arg_type:
name = context.name_expr(mutated_arg_type, args[0])
context.statements.append(
f"{fn_str}({', '.join({name}, *(a.pretty(context, wrap_lit=False) for a in args[1:]))})"
)
return name
return f"{fn_str}({', '.join(a.pretty(context, wrap_lit=False) for a in args)})"


@dataclass
class PrettyContext:
mod_decls: ModuleDeclarations
# List of statements of "context" setting variable for the expr
statements: list[str] = field(default_factory=list)

_gen_name_types: dict[str, int] = field(default_factory=lambda: defaultdict(lambda: 0))

def generate_name(self, typ: str) -> str:
self._gen_name_types[typ] += 1
return f"_{typ}_{self._gen_name_types[typ]}"

def name_expr(self, expr_type: str, expr: ExprDecl) -> str:
name = self.generate_name(expr_type)
self.statements.append(f"{name} = copy({expr.pretty(self, parens=False)})")
return name

def render(self, expr: str) -> str:
return "\n".join(self.statements + [expr])


def test_expr_pretty():
mod_decls = ModuleDeclarations(Declarations())
assert VarDecl("x").pretty(mod_decls) == "x"
assert LitDecl(42).pretty(mod_decls) == "i64(42)"
assert LitDecl("foo").pretty(mod_decls) == 'String("foo")'
assert LitDecl(None).pretty(mod_decls) == "unit()"
context = PrettyContext(ModuleDeclarations(Declarations()))
assert VarDecl("x").pretty(context) == "x"
assert LitDecl(42).pretty(context) == "i64(42)"
assert LitDecl("foo").pretty(context) == 'String("foo")'
assert LitDecl(None).pretty(context) == "unit()"

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef(""), VarDecl(x))

assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(mod_decls) == "foo(x)"
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(mod_decls) == "foo(x, y, z)"
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(mod_decls) == "x + y"
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(mod_decls) == "x[y]"
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(mod_decls) == "foo(x, y)"
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(mod_decls) == "foo.bar(x, y)"
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(mod_decls) == "x(y)"
assert CallDecl(FunctionRef("foo"), (v("x"),)).pretty(context) == "foo(x)"
assert CallDecl(FunctionRef("foo"), (v("x"), v("y"), v("z"))).pretty(context) == "foo(x, y, z)"
assert CallDecl(MethodRef("foo", "__add__"), (v("x"), v("y"))).pretty(context) == "x + y"
assert CallDecl(MethodRef("foo", "__getitem__"), (v("x"), v("y"))).pretty(context) == "x[y]"
assert CallDecl(ClassMethodRef("foo", "__init__"), (v("x"), v("y"))).pretty(context) == "foo(x, y)"
assert CallDecl(ClassMethodRef("foo", "bar"), (v("x"), v("y"))).pretty(context) == "foo.bar(x, y)"
assert CallDecl(MethodRef("foo", "__call__"), (v("x"), v("y"))).pretty(context) == "x(y)"
assert (
CallDecl(
ClassMethodRef("Map", "__init__"),
(),
(JustTypeRef("i64"), JustTypeRef("Unit")),
).pretty(mod_decls)
).pretty(context)
== "Map[i64, Unit]()"
)


def test_setitem_pretty():
context = PrettyContext(ModuleDeclarations(Declarations()))

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

final_expr = CallDecl(MethodRef("foo", "__setitem__"), (v("x"), v("y"), v("z"))).pretty(context)
assert context.render(final_expr) == "_typ_1 = x\n_typ_1[y] = z\n_typ_1"


def test_delitem_pretty():
context = PrettyContext(ModuleDeclarations(Declarations()))

def v(x: str) -> TypedExprDecl:
return TypedExprDecl(JustTypeRef("typ"), VarDecl(x))

final_expr = CallDecl(MethodRef("foo", "__delitem__"), (v("x"), v("y"))).pretty(context)
assert context.render(final_expr) == "_typ_1 = x\ndel _typ_1[y]\n_typ_1"


# TODO: Multiple mutations,

ExprDecl = Union[VarDecl, LitDecl, CallDecl]


Expand Down
Loading