Skip to content

Commit 695c89f

Browse files
Sam Luryefacebook-github-bot
authored andcommitted
Fix errant cuda stream syncs and tensor serializations in remote function impl (#14)
Summary: Pull Request resolved: #14 This diff fixes two issues related to the remote function implementation inside `StreamActor`: 1. There were two calls to `tracing::debug!` that use the function inputs and outputs as arguments. Even if the log level was set such that the message wasn't printed to the console, this still incurred the overhead of serialization (and therefore, when tensors were involved, synchronizing the host with the cuda stream). 2. The code to extract an `RValue` from a `PyObject` worked by attempting to do a conversion in C++, and if that failed, handling the exception via rust and trying the next `RValue` variant. Each time the conversion failed, the code created a `PyValueError` that serialized the `PyObject` as part of its message. Importantly, the case for tensors was only handled after the cases for `ScalarType`, `Layout`, and `MemoryFormat`. So if the `RValue` was a tensor, we would incur 3 C++ exceptions, 3 cuda stream syncs, and 3 tensor serializations. And this would happen for every tensor in the output of a remote function. The fixes in this diff are: 1. Remove the `tracing::debug!` calls. 2. Instead of trying an unchecked type conversion and handling the exception, actually check the type before trying to do the conversion. Reviewed By: colin2328 Differential Revision: D75035191 fbshipit-source-id: 9acf397c0ea9a9af0ffef215f39f2ee5c50b7395
1 parent 2b3f2c6 commit 695c89f

File tree

9 files changed

+68
-8
lines changed

9 files changed

+68
-8
lines changed

monarch_worker/src/stream.rs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -830,9 +830,6 @@ impl StreamActor {
830830
// Execute the borrow.
831831
let _borrow = multiborrow.borrow()?;
832832

833-
tracing::debug!(
834-
"calling python function: {function:?} with args: {py_args:?} and kwargs: {py_kwargs:?}"
835-
);
836833
// Call function.
837834
// Use custom subscriber to route Worker messages to stdout.
838835
let scoped_subscriber = Subscriber::builder().with_writer(std::io::stdout).finish();
@@ -858,7 +855,6 @@ impl StreamActor {
858855
)
859856
.map_err(SerializablePyErr::from_fn(py))
860857
})?;
861-
tracing::debug!("python function: {function:?} result: {result:?}");
862858

863859
// Parse the python result as an `Object`, which should preserve the
864860
// original Python object structure, while providing access to the

torch-sys/src/bridge.cpp

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,18 @@ IValue ivalue_from_arbitrary_py_object(PyObject* unowned) {
408408
return torch::jit::toIValue(obj, inferredType.type());
409409
}
410410

411+
bool py_object_is_ivalue(PyObject* unowned) {
412+
auto obj = py::reinterpret_steal<py::object>(unowned);
413+
auto inferredType = torch::jit::tryToInferType(obj);
414+
if (!inferredType.success()) {
415+
return false;
416+
}
417+
// TODO(agallagher): Arbitrary Python objects -- which we can't and don't want
418+
// to package into an `IValue` -- will be inferred as a `Class` type. Throw
419+
// here so that we'll fallback to parsing into `RValue::PyObject`.
420+
return !inferredType.type()->cast<at::ClassType>();
421+
}
422+
411423
c10::Device device_from_py_object(PyObject* unowned) {
412424
auto obj = py::reinterpret_steal<py::object>(unowned);
413425
if (!THPDevice_Check(obj.ptr())) {
@@ -435,6 +447,11 @@ PyObject* scalar_type_to_py_object(c10::ScalarType scalar_type) {
435447
return Py_NewRef(dtype);
436448
}
437449

450+
bool py_object_is_scalar_type(PyObject* unowned) {
451+
auto obj = py::reinterpret_steal<py::object>(unowned);
452+
return THPDtype_Check(obj.ptr());
453+
}
454+
438455
c10::Layout layout_from_py_object(PyObject* unowned) {
439456
auto obj = py::reinterpret_steal<py::object>(unowned);
440457
if (!THPLayout_Check(obj.ptr())) {
@@ -449,6 +466,11 @@ PyObject* layout_to_py_object(c10::Layout layout) {
449466
return Py_NewRef(thp_layout);
450467
}
451468

469+
bool py_object_is_layout(PyObject* unowned) {
470+
auto obj = py::reinterpret_steal<py::object>(unowned);
471+
return THPLayout_Check(obj.ptr());
472+
}
473+
452474
c10::MemoryFormat memory_format_from_py_object(PyObject* unowned) {
453475
auto obj = py::reinterpret_steal<py::object>(unowned);
454476
if (!THPMemoryFormat_Check(obj.ptr())) {
@@ -463,6 +485,11 @@ PyObject* memory_format_to_py_object(c10::MemoryFormat memory_format) {
463485
return Py_NewRef(thp_memory_format);
464486
}
465487

488+
bool py_object_is_memory_format(PyObject* unowned) {
489+
auto obj = py::reinterpret_steal<py::object>(unowned);
490+
return THPMemoryFormat_Check(obj.ptr());
491+
}
492+
466493
PyObject* tensor_to_py_object(Tensor tensor) {
467494
torch::jit::guardAgainstNamedTensor<Tensor>(tensor);
468495
return py::cast(std::move(tensor)).release().ptr();

torch-sys/src/bridge.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ rust::Vec<int32_t> sizes(const Tensor& tensor);
9292

9393
FFIPyObject arbitrary_ivalue_to_py_object(IValue val);
9494
IValue ivalue_from_arbitrary_py_object(FFIPyObject obj);
95+
bool py_object_is_ivalue(FFIPyObject obj);
9596

9697
inline IValue ivalue_from_py_object_with_type(
9798
FFIPyObject obj,
@@ -111,12 +112,15 @@ FFIPyObject device_to_py_object(c10::Device device);
111112

112113
c10::ScalarType scalar_type_from_py_object(FFIPyObject obj);
113114
FFIPyObject scalar_type_to_py_object(c10::ScalarType scalar_type);
115+
bool py_object_is_scalar_type(FFIPyObject obj);
114116

115117
c10::Layout layout_from_py_object(FFIPyObject obj);
116118
FFIPyObject layout_to_py_object(c10::Layout layout);
119+
bool py_object_is_layout(FFIPyObject obj);
117120

118121
c10::MemoryFormat memory_format_from_py_object(FFIPyObject obj);
119122
FFIPyObject memory_format_to_py_object(c10::MemoryFormat memory_format);
123+
bool py_object_is_memory_format(FFIPyObject obj);
120124

121125
FFIPyObject tensor_to_py_object(Tensor tensor);
122126
Tensor tensor_from_py_object(FFIPyObject obj);

torch-sys/src/bridge.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,10 +217,12 @@ pub(crate) mod ffi {
217217
// Layout
218218
fn layout_from_py_object(obj: FFIPyObject) -> Result<Layout>;
219219
fn layout_to_py_object(layout: Layout) -> FFIPyObject;
220+
fn py_object_is_layout(obj: FFIPyObject) -> bool;
220221

221222
// MemoryFormat
222223
fn memory_format_from_py_object(obj: FFIPyObject) -> Result<MemoryFormat>;
223224
fn memory_format_to_py_object(memory_format: MemoryFormat) -> FFIPyObject;
225+
fn py_object_is_memory_format(obj: FFIPyObject) -> bool;
224226

225227
// Tensor
226228
fn tensor_from_py_object(obj: FFIPyObject) -> Result<Tensor>;
@@ -281,6 +283,7 @@ pub(crate) mod ffi {
281283
// Convert to Python object.
282284
fn scalar_type_from_py_object(obj: FFIPyObject) -> Result<ScalarType>;
283285
fn scalar_type_to_py_object(scalar_type: ScalarType) -> FFIPyObject;
286+
fn py_object_is_scalar_type(obj: FFIPyObject) -> bool;
284287

285288
/// # Safety
286289
/// - **Mutability**:
@@ -330,6 +333,7 @@ pub(crate) mod ffi {
330333
// Interop with Python object.
331334
fn arbitrary_ivalue_to_py_object(val: IValue) -> Result<FFIPyObject>;
332335
fn ivalue_from_arbitrary_py_object(obj: FFIPyObject) -> Result<IValue>;
336+
fn py_object_is_ivalue(obj: FFIPyObject) -> bool;
333337
/// Converts the provided Python object to an `IValue` with the provided
334338
/// type. If the object is not convertible to the provided type, an
335339
/// exception will be thrown.

torch-sys/src/ivalue.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,11 @@ impl IValue {
309309
))
310310
})
311311
}
312+
313+
pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<IValue> {
314+
ffi::py_object_is_ivalue(obj.clone().into())
315+
.then(|| ffi::ivalue_from_arbitrary_py_object(obj.into()).unwrap())
316+
}
312317
}
313318

314319
// impl `From` for all IValue kinds.

torch-sys/src/layout.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ unsafe impl ExternType for Layout {
2222
type Kind = cxx::kind::Trivial;
2323
}
2424

25+
impl Layout {
26+
pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
27+
ffi::py_object_is_layout(obj.clone().into())
28+
.then(|| ffi::layout_from_py_object(obj.into()).unwrap())
29+
}
30+
}
31+
2532
impl FromPyObject<'_> for Layout {
2633
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
2734
ffi::layout_from_py_object(obj.into()).map_err(|e| {

torch-sys/src/memory_format.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,13 @@ unsafe impl ExternType for MemoryFormat {
2222
type Kind = cxx::kind::Trivial;
2323
}
2424

25+
impl MemoryFormat {
26+
pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
27+
ffi::py_object_is_memory_format(obj.clone().into())
28+
.then(|| ffi::memory_format_from_py_object(obj.into()).unwrap())
29+
}
30+
}
31+
2532
impl FromPyObject<'_> for MemoryFormat {
2633
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
2734
ffi::memory_format_from_py_object(obj.into()).map_err(|e| {

torch-sys/src/rvalue.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -167,13 +167,16 @@ impl TryIntoPyObjectUnsafe<PyAny> for &RValue {
167167

168168
impl FromPyObject<'_> for RValue {
169169
fn extract_bound(obj: &pyo3::Bound<'_, pyo3::PyAny>) -> pyo3::PyResult<Self> {
170-
if let Ok(val) = obj.extract::<ScalarType>() {
170+
// It's crucial for correctness to try converting to IValue after we've
171+
// tried the other non-PyObject variants, because the IValue conversion
172+
// will actually succeed when obj is a ScalarType, Layout, or MemoryFormat.
173+
if let Some(val) = ScalarType::from_py_object_or_none(obj) {
171174
Ok(RValue::ScalarType(val))
172-
} else if let Ok(val) = obj.extract::<Layout>() {
175+
} else if let Some(val) = Layout::from_py_object_or_none(obj) {
173176
Ok(RValue::Layout(val))
174-
} else if let Ok(val) = obj.extract::<MemoryFormat>() {
177+
} else if let Some(val) = MemoryFormat::from_py_object_or_none(obj) {
175178
Ok(RValue::MemoryFormat(val))
176-
} else if let Ok(val) = IValue::extract_bound(obj) {
179+
} else if let Some(val) = IValue::from_py_object_or_none(obj) {
177180
Ok(val.into())
178181
} else {
179182
Ok(RValue::PyObject(PickledPyObject::pickle(obj)?))

torch-sys/src/scalar_type.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,13 @@ unsafe impl ExternType for ScalarType {
2323
type Kind = cxx::kind::Trivial;
2424
}
2525

26+
impl ScalarType {
27+
pub(crate) fn from_py_object_or_none(obj: &Bound<'_, PyAny>) -> Option<Self> {
28+
ffi::py_object_is_scalar_type(obj.clone().into())
29+
.then(|| ffi::scalar_type_from_py_object(obj.into()).unwrap())
30+
}
31+
}
32+
2633
impl FromPyObject<'_> for ScalarType {
2734
fn extract_bound(obj: &Bound<'_, PyAny>) -> PyResult<Self> {
2835
ffi::scalar_type_from_py_object(obj.into()).map_err(|e| {

0 commit comments

Comments
 (0)