Skip to content

Commit 5e9032b

Browse files
Merge pull request #38 from metadsl/next-sklearn
Make conversions transitive and make getitem more comprehensive
2 parents 2379731 + 2c0a918 commit 5e9032b

File tree

6 files changed

+227
-43
lines changed

6 files changed

+227
-43
lines changed

docs/changelog.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
2121
- Upgraded `egg-smol` dependency ([changes](https://github.com/saulshanabrook/egg-smol/compare/353c4387640019bd2066991ee0488dc6d5c54168...2ac80cb1162c61baef295d8e6d00351bfe84883f))
2222

2323
- Add support for functions which mutates their args, like `__setitem__` [#35](https://github.com/metadsl/egglog-python/pull/35)
24+
- Makes conversions transitive [#38](https://github.com/metadsl/egglog-python/pull/38)
2425

2526
## 0.5.1 (2023-07-18)
2627

docs/reference/egglog-translation.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,8 @@ Math(2) + 30 + "x"
255255
Math(2) + Math(i64(30)) + Math.var(String("x"))
256256
```
257257

258+
Regstering a conversion from A to B will also register all transitively reachable conversions from A to B.
259+
258260
### Declarations
259261

260262
In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function:

docs/tutorials/array-api.ipynb

Lines changed: 14 additions & 30 deletions
Large diffs are not rendered by default.

python/egglog/exp/array_api.py

Lines changed: 97 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# mypy: disable-error-code=empty-body
2-
31
from __future__ import annotations
42

53
import itertools
@@ -13,6 +11,9 @@
1311
# Pretend that exprs are numbers b/c scikit learn does isinstance checks
1412
from egglog.runtime import RuntimeExpr
1513

14+
# mypy: disable-error-code=empty-body
15+
16+
1617
numbers.Integral.register(RuntimeExpr)
1718

1819
egraph = EGraph()
@@ -111,7 +112,6 @@ def isdtype(dtype: DType, kind: IsDtypeKind) -> Bool:
111112
...
112113

113114

114-
converter(np.dtype, IsDtypeKind, lambda x: IsDtypeKind.dtype(convert(x, DType)))
115115
converter(DType, IsDtypeKind, lambda x: IsDtypeKind.dtype(x))
116116
converter(str, IsDtypeKind, lambda x: IsDtypeKind.string(x))
117117
converter(
@@ -286,23 +286,108 @@ def _tuple_int(ti: TupleInt, ti2: TupleInt, i: Int, i2: Int, k: i64):
286286
]
287287

288288

289-
# HANDLED_FUNCTIONS = {}
289+
@egraph.class_
290+
class OptionalInt(Expr):
291+
none: ClassVar[OptionalInt]
292+
293+
@classmethod
294+
def some(cls, value: Int) -> OptionalInt:
295+
...
296+
297+
298+
converter(type(None), OptionalInt, lambda x: OptionalInt.none)
299+
converter(Int, OptionalInt, OptionalInt.some)
290300

291301

292302
@egraph.class_
293-
class IndexKey(Expr):
303+
class Slice(Expr):
304+
def __init__(self, start: OptionalInt, stop: OptionalInt, step: OptionalInt) -> None:
305+
...
306+
307+
308+
converter(
309+
slice,
310+
Slice,
311+
lambda x: Slice(convert(x.start, OptionalInt), convert(x.stop, OptionalInt), convert(x.step, OptionalInt)),
312+
)
313+
314+
315+
@egraph.class_
316+
class MultiAxisIndexKeyItem(Expr):
317+
ELLIPSIS: ClassVar[MultiAxisIndexKeyItem]
318+
NONE: ClassVar[MultiAxisIndexKeyItem]
319+
294320
@classmethod
295-
def tuple_int(cls, ti: TupleInt) -> IndexKey:
321+
def int(cls, i: Int) -> MultiAxisIndexKeyItem:
296322
...
297323

324+
@classmethod
325+
def slice(cls, slice: Slice) -> MultiAxisIndexKeyItem:
326+
...
327+
328+
329+
converter(type(...), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.ELLIPSIS)
330+
converter(type(None), MultiAxisIndexKeyItem, lambda x: MultiAxisIndexKeyItem.NONE)
331+
converter(Int, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.int)
332+
converter(Slice, MultiAxisIndexKeyItem, MultiAxisIndexKeyItem.slice)
333+
334+
335+
@egraph.class_
336+
class MultiAxisIndexKey(Expr):
337+
def __init__(self, item: MultiAxisIndexKeyItem) -> None:
338+
...
339+
340+
EMPTY: ClassVar[MultiAxisIndexKey]
341+
342+
def __add__(self, other: MultiAxisIndexKey) -> MultiAxisIndexKey:
343+
...
344+
345+
346+
converter(
347+
tuple,
348+
MultiAxisIndexKey,
349+
lambda x: MultiAxisIndexKey(convert(x[0], MultiAxisIndexKeyItem)) + convert(x[1:], MultiAxisIndexKey)
350+
if x
351+
else MultiAxisIndexKey.EMPTY,
352+
)
353+
354+
355+
@egraph.class_
356+
class IndexKey(Expr):
357+
"""
358+
A key for indexing into an array
359+
360+
https://data-apis.org/array-api/2022.12/API_specification/indexing.html
361+
362+
It is equivalent to the following type signature:
363+
364+
Union[int, slice, ellipsis, Tuple[Union[int, slice, ellipsis, None], ...], array]
365+
"""
366+
367+
ELLIPSIS: ClassVar[IndexKey]
368+
298369
@classmethod
299370
def int(cls, i: Int) -> IndexKey:
300371
...
301372

373+
@classmethod
374+
def slice(cls, slice: Slice) -> IndexKey:
375+
...
376+
377+
# Disabled until we support late binding
378+
# @classmethod
379+
# def boolean_array(cls, b: NDArray) -> IndexKey:
380+
# ...
381+
382+
@classmethod
383+
def multi_axis(cls, key: MultiAxisIndexKey) -> IndexKey:
384+
...
385+
302386

303-
converter(tuple, IndexKey, lambda x: IndexKey.tuple_int(convert(x, TupleInt)))
304-
converter(int, IndexKey, lambda x: IndexKey.int(Int(x)))
305-
converter(Int, IndexKey, lambda x: IndexKey.int(x))
387+
converter(type(...), IndexKey, lambda x: IndexKey.ELLIPSIS)
388+
converter(Int, IndexKey, IndexKey.int)
389+
converter(Slice, IndexKey, IndexKey.slice)
390+
converter(MultiAxisIndexKey, IndexKey, IndexKey.multi_axis)
306391

307392

308393
@egraph.class_
@@ -400,8 +485,8 @@ def ndarray_index(x: NDArray) -> IndexKey:
400485
converter(NDArray, IndexKey, ndarray_index)
401486

402487

403-
converter(float, NDArray, lambda x: NDArray.scalar_float(Float(x)))
404-
converter(int, NDArray, lambda x: NDArray.scalar_int(Int(x)))
488+
converter(Float, NDArray, NDArray.scalar_float)
489+
converter(Int, NDArray, NDArray.scalar_int)
405490

406491

407492
@egraph.register
@@ -478,7 +563,6 @@ def some(cls, value: Bool) -> OptionalBool:
478563

479564
converter(type(None), OptionalBool, lambda x: OptionalBool.none)
480565
converter(Bool, OptionalBool, lambda x: OptionalBool.some(x))
481-
converter(bool, OptionalBool, lambda x: OptionalBool.some(convert(x, Bool)))
482566

483567

484568
@egraph.class_
@@ -518,6 +602,7 @@ def some(cls, value: TupleInt) -> OptionalTupleInt:
518602

519603
converter(type(None), OptionalTupleInt, lambda x: OptionalTupleInt.none)
520604
converter(TupleInt, OptionalTupleInt, lambda x: OptionalTupleInt.some(x))
605+
# TODO: Don't allow ints to be converted to OptionalTupleInt, and have another type that also unions ints
521606
converter(int, OptionalTupleInt, lambda x: OptionalTupleInt.some(TupleInt(Int(x))))
522607

523608

python/egglog/runtime.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,41 @@ def converter(from_type: Type[T], to_type: Type[V], fn: Callable[[T], V]) -> Non
7070
to_type_name = process_tp(to_type)
7171
if not isinstance(to_type_name, JustTypeRef):
7272
raise TypeError(f"Expected return type to be a egglog type, got {to_type_name}")
73-
CONVERSIONS[(process_tp(from_type), to_type_name)] = fn
73+
_register_converter(process_tp(from_type), to_type_name, fn)
74+
75+
76+
def _register_converter(a: Type | JustTypeRef, b: JustTypeRef, a_b: Callable) -> None:
77+
"""
78+
Registers a converter from some type to an egglog type, if not already registered.
79+
80+
Also adds transitive converters, i.e. if registering A->B and there is already B->C, then A->C will be registered.
81+
Also, if registering A->B and there is already D->A, then D->B will be registered.
82+
"""
83+
if a == b or (a, b) in CONVERSIONS:
84+
return
85+
CONVERSIONS[(a, b)] = a_b
86+
for (c, d), c_d in list(CONVERSIONS.items()):
87+
if b == c:
88+
_register_converter(a, d, _ComposedConverter(a_b, c_d))
89+
if a == d:
90+
_register_converter(c, b, _ComposedConverter(c_d, a_b))
91+
92+
93+
@dataclass
94+
class _ComposedConverter:
95+
"""
96+
A converter which is composed of multiple converters.
97+
98+
_ComposeConverter(a_b, b_c) is equivalent to lambda x: b_c(a_b(x))
99+
100+
We use the dataclass instead of the lambda to make it easier to debug.
101+
"""
102+
103+
a_b: Callable
104+
b_c: Callable
105+
106+
def __call__(self, x: object) -> object:
107+
return self.b_c(self.a_b(x))
74108

75109

76110
def convert(source: object, target: type[V]) -> V:

python/tests/test_convert.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,17 @@
1+
import copy
2+
3+
import egglog.runtime
4+
import pytest
15
from egglog import *
26

37

8+
@pytest.fixture(autouse=True)
9+
def reset_conversions():
10+
old_conversions = copy.copy(egglog.runtime.CONVERSIONS)
11+
yield
12+
egglog.runtime.CONVERSIONS = old_conversions
13+
14+
415
def test_conversion_custom_metaclass():
516
class MyMeta(type):
617
pass
@@ -33,3 +44,70 @@ def __init__(self):
3344
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
3445

3546
assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr())
47+
48+
49+
def test_conversion_transitive_forward():
50+
egraph = EGraph()
51+
52+
class MyType:
53+
pass
54+
55+
@egraph.class_
56+
class MyTypeExpr(Expr):
57+
def __init__(self):
58+
...
59+
60+
@egraph.class_
61+
class MyTypeExpr2(Expr):
62+
def __init__(self):
63+
...
64+
65+
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
66+
converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
67+
68+
assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())
69+
70+
71+
def test_conversion_transitive_backward():
72+
egraph = EGraph()
73+
74+
class MyType:
75+
pass
76+
77+
@egraph.class_
78+
class MyTypeExpr(Expr):
79+
def __init__(self):
80+
...
81+
82+
@egraph.class_
83+
class MyTypeExpr2(Expr):
84+
def __init__(self):
85+
...
86+
87+
converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
88+
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
89+
assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())
90+
91+
92+
def test_conversion_transitive_cycle():
93+
egraph = EGraph()
94+
95+
class MyType:
96+
pass
97+
98+
@egraph.class_
99+
class MyTypeExpr(Expr):
100+
def __init__(self):
101+
...
102+
103+
@egraph.class_
104+
class MyTypeExpr2(Expr):
105+
def __init__(self):
106+
...
107+
108+
converter(MyType, MyTypeExpr, lambda x: MyTypeExpr())
109+
converter(MyTypeExpr, MyTypeExpr2, lambda x: MyTypeExpr2())
110+
converter(MyTypeExpr2, MyTypeExpr, lambda x: MyTypeExpr())
111+
112+
assert expr_parts(convert(MyType(), MyTypeExpr2)) == expr_parts(MyTypeExpr2())
113+
assert expr_parts(convert(MyType(), MyTypeExpr)) == expr_parts(MyTypeExpr())

0 commit comments

Comments
 (0)