Skip to content

Commit a11d9f6

Browse files
_rust_bindings: expose Selection to python
Summary: this diff provides a Python binding for the `Selection` type. includes a class-method constructor `Selection.from_string(…)`. integrates into `monarch_extension` and adds a basic test. Reviewed By: dulinriley Differential Revision: D74971945 fbshipit-source-id: f63d6bdd9a9db23cef8f099f72a139bf046d079f
1 parent b21941e commit a11d9f6

File tree

11 files changed

+110
-0
lines changed

11 files changed

+110
-0
lines changed

monarch_extension/src/lib.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,11 @@ pub fn mod_init(module: &Bound<'_, PyModule>) -> PyResult<()> {
4747
"monarch_hyperactor.shape",
4848
)?)?;
4949

50+
monarch_hyperactor::selection::register_python_bindings(&get_or_add_new_module(
51+
module,
52+
"monarch_hyperactor.selection",
53+
)?)?;
54+
5055
client::register_python_bindings(&get_or_add_new_module(module, "monarch_extension.client")?)?;
5156
worker::register_python_bindings(&get_or_add_new_module(module, "monarch_extension.worker")?)?;
5257

monarch_hyperactor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ hyperactor_multiprocess = { version = "0.0.0", path = "../hyperactor_multiproces
1818
hyperactor_telemetry = { version = "0.0.0", path = "../hyperactor_telemetry" }
1919
monarch_types = { version = "0.0.0", path = "../monarch_types" }
2020
ndslice = { version = "0.0.0", path = "../ndslice" }
21+
nom = "7.1"
2122
pyo3 = { version = "0.22.6", features = ["anyhow"] }
2223
pyo3-async-runtimes = { git = "https://github.com/PyO3/pyo3-async-runtimes", rev = "f6bb9b471a5b7765dd770af36e83f26802459621", features = ["attributes", "tokio-runtime"] }
2324
serde = { version = "1.0.185", features = ["derive", "rc"] }

monarch_hyperactor/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ pub mod ndslice;
99
pub mod proc;
1010
pub mod proc_mesh;
1111
pub mod runtime;
12+
pub mod selection;
1213
pub mod shape;
1314

1415
use pyo3::Bound;
@@ -48,5 +49,7 @@ pub fn register_python_bindings(hyperactor_mod: &Bound<'_, PyModule>) -> PyResul
4849

4950
hyperactor_mod.add_class::<shape::PyShape>()?;
5051

52+
hyperactor_mod.add_class::<selection::PySelection>()?;
53+
5154
Ok(())
5255
}

monarch_hyperactor/src/selection.rs

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
use ndslice::selection::Selection;
2+
use pyo3::PyResult;
3+
use pyo3::prelude::*;
4+
use pyo3::types::PyType;
5+
6+
#[pyclass(name = "Selection", module = "monarch._monarch.selection", frozen)]
7+
pub struct PySelection {
8+
inner: Selection,
9+
}
10+
11+
#[pymethods]
12+
impl PySelection {
13+
#[getter]
14+
fn __repr__(&self) -> String {
15+
format!("{:?}", self.inner)
16+
}
17+
}
18+
19+
impl From<Selection> for PySelection {
20+
fn from(inner: Selection) -> Self {
21+
Self { inner }
22+
}
23+
}
24+
25+
#[pymethods]
26+
impl PySelection {
27+
#[classmethod]
28+
#[pyo3(name = "from_string")]
29+
pub fn parse(_cls: Bound<'_, PyType>, input: &str) -> PyResult<Self> {
30+
// TODO: Make this a utility in ndslice.
31+
use ndslice::selection::parse::expression;
32+
use nom::combinator::all_consuming;
33+
34+
let input: String = input.chars().filter(|c| !c.is_whitespace()).collect();
35+
let (_, selection) = all_consuming(expression)(&input).map_err(|err| {
36+
pyo3::exceptions::PyValueError::new_err(format!("parse error: {err}"))
37+
})?;
38+
39+
Ok(PySelection::from(selection))
40+
}
41+
}
42+
43+
pub fn register_python_bindings(module: &Bound<'_, PyModule>) -> PyResult<()> {
44+
module.add_class::<PySelection>()?;
45+
Ok(())
46+
}

python/monarch/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from monarch.common.opaque_ref import OpaqueRef
4040
from monarch.common.pipe import create_pipe, Pipe, remote_generator
4141
from monarch.common.remote import remote
42+
from monarch.common.selection import Selection
4243
from monarch.common.shape import NDSlice, Shape
4344
from monarch.common.stream import get_active_stream, Stream
4445
from monarch.common.tensor import reduce, reduce_, Tensor

python/monarch/_monarch/hyperactor/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@
2828
AllocSpec,
2929
)
3030

31+
from monarch._rust_bindings.monarch_hyperactor.selection import ( # @manual=//monarch/monarch_extension:monarch_extension
32+
Selection,
33+
)
34+
3135
from monarch._rust_bindings.monarch_hyperactor.shape import ( # @manual=//monarch/monarch_extension:monarch_extension
3236
Shape,
3337
)
@@ -72,5 +76,6 @@ async def handle_cast(
7276
"PythonActorMesh",
7377
"ProcessAllocatorBase",
7478
"Shape",
79+
"Selection",
7580
"LocalAllocatorBase",
7681
]
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from monarch._rust_bindings.monarch_hyperactor.selection import ( # @manual=//monarch/monarch_extension:monarch_extension
2+
Selection,
3+
)
4+
5+
__all__ = [
6+
"Selection",
7+
]

python/monarch/_rust_bindings/hyperactor.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Dict, final, List, Optional, Protocol, Type
44

55
from monarch._rust_bindings.hyperactor_extension import Alloc, AllocSpec
6+
from monarch._rust_bindings.monarch_hyperactor.selection import Selection
67

78
from monarch._rust_bindings.monarch_hyperactor.shape import Shape
89

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from typing import final
2+
3+
@final
4+
class Selection:
5+
"""Opaque representation of a selection expression used to represent
6+
constraints over multidimensional shapes.
7+
8+
Construct via from_string()` and use with mesh APIs to filter,
9+
evaluate, or route over structured topologies.
10+
"""
11+
def __repr__(self) -> str: ...
12+
@classmethod
13+
def from_string(cls, s: str) -> Selection:
14+
"""Parse a selection expression from a string.
15+
16+
Accepts a compact string syntax such as `"(*, 0:4)"` or `"0 & (1 | 2)"`,
17+
and returns a structured Selection object.
18+
19+
Raises:
20+
ValueError: if the input string is not a valid selection expression.
21+
"""
22+
...

python/monarch/common/selection.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from monarch._monarch.selection import Selection
2+
3+
__all__ = ["Selection"]

python/tests/_monarch/test_ndslice.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import random
77
from unittest import TestCase
88

9+
from monarch._monarch.selection import Selection
10+
911
from monarch._monarch.shape import Shape, Slice
1012

1113

@@ -183,3 +185,17 @@ def test_shape_repr(self) -> None:
183185
repr(shape),
184186
'Shape { labels: ["label0", "label1"], slice: Slice { offset: 0, sizes: [2, 3], strides: [3, 1] } }',
185187
)
188+
189+
190+
class TestSelection(TestCase):
191+
def test_parse_repr(self) -> None:
192+
sel = Selection.from_string("(*, 1:3) & (0, *)")
193+
self.assertIsInstance(sel, Selection)
194+
self.assertEqual(
195+
repr(sel),
196+
"Intersection(All(Range(Range(1, Some(3), 1), True)), Range(Range(0, Some(1), 1), All(True)))",
197+
)
198+
199+
def test_parse_invalid(self) -> None:
200+
with self.assertRaises(ValueError):
201+
Selection.from_string("this is not valid selection syntax")

0 commit comments

Comments
 (0)