Skip to content

Commit caf6cd4

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Enable remaining tests for rust backend in test_remote_functions (#16)
Summary: Pull Request resolved: #16 As titled. This involved fixing two issues with pipes (allow pipe processes to create tensors from wire values and explicitly flush the pipe buffer during pipe send) and fixing an issue with tensor serialization. Reviewed By: dulinriley Differential Revision: D74854607 fbshipit-source-id: b457a51feb18741a0b4535c405e1ca9c46a8f24e
1 parent 695c89f commit caf6cd4

File tree

5 files changed

+70
-31
lines changed

5 files changed

+70
-31
lines changed

monarch_worker/src/bootstrap.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,12 @@ pub fn bootstrap_pipe() -> Result<(), anyhow::Error> {
106106
// Value of 4 is arbitrary as our side does not need to do buffering.
107107
let mut pipe = StreamPipe::new(std::io::stdin(), std::io::stdout(), 4);
108108
let init: OutOfProcessSetupParams = pipe.recv()?;
109+
// Create a PyPipe that allows unsafe object conversion. This allows the pipe to
110+
// receive tensors, which we know is safe because StreamPipe receives the serialized
111+
// tensors from out-of-process, and they therefore can't be owned by anything except
112+
// the pipe's python code.
109113
run_py_pipe(
110-
PyPipe::new(Box::new(pipe), init.ranks, init.sizes),
114+
PyPipe::new(Box::new(pipe), init.ranks, init.sizes, true),
111115
init.function,
112116
init.args,
113117
init.kwargs,

monarch_worker/src/pipe.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ impl<T: Serialize + DeserializeOwned> Pipe<T> for StreamPipe {
235235
let len = bytes.len();
236236
self.writer.write_all(&len.to_be_bytes())?;
237237
self.writer.write_all(&bytes)?;
238+
self.writer.flush()?;
238239
Ok(())
239240
}
240241

@@ -374,7 +375,7 @@ impl PipeMessageHandler for PipeActor {
374375
// TODO(agallagher): Propagate failures and use a timeout?
375376
tokio::select! {
376377
res = self.handle.wait() => bail!("pipe server exited: {:?}", res),
377-
res = self.pipe.as_mut().unwrap().recv() => res,
378+
res = self.pipe.as_mut().unwrap().recv() => res
378379
}
379380
}
380381
}

monarch_worker/src/py_pipe.rs

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use std::collections::HashMap;
1111
use monarch_messages::worker::ResolvableFunction;
1212
use monarch_types::PyTree;
1313
use monarch_types::TryIntoPyObject;
14+
use monarch_types::TryIntoPyObjectUnsafe;
1415
use pyo3::prelude::*;
1516
use pyo3::types::PyTuple;
1617
use torch_sys::RValue;
@@ -25,15 +26,22 @@ pub struct PyPipe {
2526
ranks: HashMap<String, usize>,
2627
#[pyo3(get)]
2728
sizes: HashMap<String, usize>,
29+
allow_unsafe_obj_conversion: bool,
2830
}
2931

3032
impl PyPipe {
3133
pub fn new(
3234
pipe: Box<dyn Pipe<PyTree<RValue>> + Send>,
3335
ranks: HashMap<String, usize>,
3436
sizes: HashMap<String, usize>,
37+
allow_unsafe_obj_conversion: bool,
3538
) -> Self {
36-
Self { pipe, ranks, sizes }
39+
Self {
40+
pipe,
41+
ranks,
42+
sizes,
43+
allow_unsafe_obj_conversion,
44+
}
3745
}
3846
}
3947

@@ -46,8 +54,14 @@ impl PyPipe {
4654
}
4755

4856
fn recv<'a>(&mut self, py: Python<'a>) -> PyResult<Bound<'a, PyAny>> {
49-
py.allow_threads(move || self.pipe.recv())?
50-
.try_to_object(py)
57+
let val = py.allow_threads(|| self.pipe.recv())?;
58+
if self.allow_unsafe_obj_conversion {
59+
// SAFETY: A caller who initialized this PyPipe with allow_unsafe_obj_conversion=True
60+
// asserts that it is safe to use this unsafe method.
61+
unsafe { val.try_to_object_unsafe(py) }
62+
} else {
63+
val.try_to_object(py)
64+
}
5165
}
5266
}
5367

@@ -122,7 +136,12 @@ mod tests {
122136
async move {
123137
tokio::task::spawn_blocking(move || {
124138
run_py_pipe(
125-
PyPipe::new(Box::new(server), HashMap::new(), HashMap::new()),
139+
PyPipe::new(
140+
Box::new(server),
141+
HashMap::new(),
142+
HashMap::new(),
143+
false, // allow_unsafe_obj_conversion
144+
),
126145
"test_helpers.func".into(),
127146
vec![],
128147
HashMap::new(),

python/tests/test_remote_functions.py

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -518,8 +518,6 @@ def test_remote_function_isend(self, backend_type):
518518
assert local_finished_1.item() == 1.0
519519

520520
def test_distributed_error(self, backend_type):
521-
if backend_type == BackendType.RS:
522-
pytest.skip("FIXME: Rust support for this function")
523521
with self.local_device_mesh(2, 2, backend_type) as _:
524522
x = torch.rand(3, 4).cuda()
525523
y = torch.rand(3, 4).cuda()
@@ -545,8 +543,6 @@ def test_distributed_error(self, backend_type):
545543
fetch_shard(2 * x, gpu=1, host=0).result()
546544

547545
def test_pipe(self, backend_type):
548-
if backend_type == BackendType.RS:
549-
pytest.skip("FIXME: Rust support for this function")
550546
with self.local_device_mesh(2, 2, backend_type):
551547
p = example_echo_add()
552548
for _i in range(10):
@@ -566,8 +562,6 @@ def test_loader(self, backend_type):
566562
assert x.item() == i
567563

568564
def test_loader_blocks_with_small_pipe(self, backend_type):
569-
if backend_type == BackendType.RS:
570-
pytest.skip("FIXME: Rust support for this function")
571565
with self.local_device_mesh(2, 2, backend_type):
572566
iters = 10
573567
p = example_data_loader_small_pipe(iters, (1000, 1000))
@@ -581,8 +575,6 @@ def test_loader_blocks_with_small_pipe(self, backend_type):
581575
assert t[0][0].item() == -1.0
582576

583577
def test_streams_run_parallel(self, backend_type):
584-
if backend_type == BackendType.RS:
585-
pytest.skip("FIXME: Rust support for this function")
586578
with self.local_device_mesh(2, 2, backend_type):
587579
# test that these two streams do in fact run in parallel
588580
# on the worker by having each stream wait on a barrier.
@@ -643,8 +635,6 @@ def test_fetch_preprocess(self, backend_type):
643635
)
644636

645637
def test_cached_remote_function(self, backend_type):
646-
if backend_type == BackendType.RS:
647-
pytest.skip("FIXME: Rust support for this function")
648638
fn = remote("monarch.worker._testing_function.how_many_of_these_do_you_want")
649639
start_hits = remote_module._hit
650640
with self.local_device_mesh(2, 2, backend_type):
@@ -713,9 +703,6 @@ def test_cached_remote_aliases(self, backend_type):
713703
assert outs[2]._fake.storage_offset() == 40
714704

715705
def test_live_function(self, backend_type):
716-
if backend_type == BackendType.RS:
717-
pytest.skip("FIXME: Rust support for this function")
718-
719706
def bar(x, y):
720707
return (
721708
a_function_called_by_a_live_function(x)
@@ -1094,18 +1081,6 @@ def test_nccl_barrier(self, backend_type: BackendType) -> None:
10941081
inspect(t, {"host": host, "gpu": gpu}),
10951082
)
10961083

1097-
def test_nccl_barrier_device_ids(self, backend_type: BackendType) -> None:
1098-
if backend_type == BackendType.PY:
1099-
# pyre-ignore[29]: pytest.skip is callable.
1100-
pytest.skip("FIXME: Python support for this function")
1101-
with self.local_device_mesh(
1102-
self.N_HOSTS, self.N_GPUS, backend_type
1103-
) as device_mesh:
1104-
pg = device_mesh.process_group(("host", "gpu"))
1105-
rank = device_mesh.rank("host") * self.N_GPUS + device_mesh.rank("gpu")
1106-
with pytest.raises(monarch.common.invocation.RemoteException):
1107-
inspect(barrier(device_ids=[rank], group=pg))
1108-
11091084
def test_tensor_dtype_complex(self, backend_type: BackendType) -> None:
11101085
self._test_tensor_dtype_complex(backend_type)
11111086

torch-sys/src/bridge.cpp

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,7 @@ Tensor tensor_from_py_object(PyObject* unowned) {
508508
// TODO: We can do better for IValue serde as we dont need pickle compat here.
509509
const char kIValueStart = '\x01';
510510
const char kTensorsStart = '\x02';
511+
const char kWrappedNumberStart = '\x03';
511512
rust::Vec<uint8_t> serialize_ivalue(const IValue& iv) {
512513
if (iv.isTensor() && !iv.toTensor().defined()) {
513514
// Special case for undefined tensors as pickle doesnt
@@ -529,6 +530,18 @@ rust::Vec<uint8_t> serialize_ivalue(const IValue& iv) {
529530
}
530531
std::copy(
531532
tensors_data.begin(), tensors_data.end(), std::back_inserter(out));
533+
// Tensor serialization doesn't maintain the wrapped number flag, so we
534+
// need to manually serialize it. This is important to maintain because
535+
// it has implications for the output type of torch ops.
536+
out.push_back(kWrappedNumberStart);
537+
for (size_t i = 0; i < tensors.size(); ++i) {
538+
uint8_t offset = i % sizeof(uint8_t);
539+
if (offset == 0) {
540+
out.push_back(0);
541+
}
542+
out.back() |= static_cast<uint8_t>(
543+
tensors.at(i).unsafeGetTensorImpl()->is_wrapped_number() << offset);
544+
}
532545
}
533546
out.push_back(kIValueStart);
534547
out.reserve(out.size() + pickle_data.size());
@@ -565,6 +578,33 @@ IValue deserialize_ivalue(rust::Slice<const uint8_t> buf) {
565578
rust::Slice<const uint8_t> tensor_data(buf.data() + i, tensors_size);
566579
tensors = load<std::vector<Tensor>>(tensor_data);
567580
i += tensors_size;
581+
if (i >= buf.size() || buf.at(i) != kWrappedNumberStart) {
582+
throw std::runtime_error(
583+
"Invalid IValue serialization: missing wrapped number start byte");
584+
}
585+
for (size_t tensor_index = 0; tensor_index < tensors.size();
586+
tensor_index++) {
587+
uint8_t offset = tensor_index % sizeof(uint8_t);
588+
if (offset == 0) {
589+
i++;
590+
}
591+
if (i >= buf.size()) {
592+
throw std::runtime_error(
593+
"Invalid IValue serialization: wrapped number data truncated");
594+
}
595+
bool wrapped_number = (buf.at(i) >> offset) & 0x01;
596+
if (wrapped_number) {
597+
// You would think we could just call
598+
// set_wrapped_number(wrapped_number), but you'd be wrong. Internally,
599+
// set_wrapped_number asserts a 0-dim tensor regardless of whether its
600+
// argument is true or false, so we can only call set_wrapped_number
601+
// safely when wrapped_number == true.
602+
tensors.at(tensor_index)
603+
.unsafeGetTensorImpl()
604+
->set_wrapped_number(true);
605+
}
606+
}
607+
i++;
568608
}
569609
if (i >= buf.size() || buf.at(i++) != kIValueStart) {
570610
throw std::runtime_error(

0 commit comments

Comments
 (0)