Skip to content

rust: require mutable references when initialising sync primitives. #392

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 27, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion drivers/android/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ impl Context {
let ctx = Arc::get_mut(&mut ctx_ref).unwrap();

// SAFETY: `manager` is also pinned when `ctx` is.
let manager = unsafe { Pin::new_unchecked(&ctx.manager) };
let manager = unsafe { Pin::new_unchecked(&mut ctx.manager) };
kernel::mutex_init!(manager, "Context::manager");

// SAFETY: `ctx_ref` is pinned behind the `Arc` reference.
Expand Down
4 changes: 2 additions & 2 deletions drivers/android/node.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ impl NodeDeath {
}
}

pub(crate) fn init(self: Pin<&Self>) {
pub(crate) fn init(self: Pin<&mut Self>) {
// SAFETY: `inner` is pinned when `self` is.
let inner = unsafe { self.map_unchecked(|s| &s.inner) };
let inner = unsafe { self.map_unchecked_mut(|n| &mut n.inner) };
kernel::spinlock_init!(inner, "NodeDeath::inner");
}

Expand Down
21 changes: 13 additions & 8 deletions drivers/android/process.rs
Original file line number Diff line number Diff line change
Expand Up @@ -300,12 +300,12 @@ impl Process {
// SAFETY: `node_refs` is initialised in the call to `mutex_init` below.
node_refs: unsafe { Mutex::new(ProcessNodeRefs::new()) },
},
|process| {
// SAFETY: `inner` is pinned behind the `Ref` reference.
let pinned = unsafe { Pin::new_unchecked(&process.inner) };
|mut process| {
// SAFETY: `inner` is pinned when `Process` is.
let pinned = unsafe { process.as_mut().map_unchecked_mut(|p| &mut p.inner) };
kernel::mutex_init!(pinned, "Process::inner");
// SAFETY: `node_refs` is pinned behind the `Ref` reference.
let pinned = unsafe { Pin::new_unchecked(&process.node_refs) };
// SAFETY: `node_refs` is pinned when `Process` is.
let pinned = unsafe { process.as_mut().map_unchecked_mut(|p| &mut p.node_refs) };
kernel::mutex_init!(pinned, "Process::node_refs");
},
)
Expand Down Expand Up @@ -720,10 +720,15 @@ impl Process {
}

// SAFETY: `init` is called below.
let death = death
let mut death = death
.commit(unsafe { NodeDeath::new(info.node_ref.node.clone(), self.clone(), cookie) });
// SAFETY: `death` is pinned behind the `Arc` reference.
unsafe { Pin::new_unchecked(death.as_ref()) }.init();

{
let mutable = Arc::get_mut(&mut death).ok_or(Error::EINVAL)?;
// SAFETY: `mutable` is pinned behind the `Arc` reference.
unsafe { Pin::new_unchecked(mutable) }.init();
}

info.death = Some(death.clone());

// Register the death notification.
Expand Down
25 changes: 11 additions & 14 deletions drivers/android/thread.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,12 @@ impl Thread {
})?;
let thread = Arc::get_mut(&mut arc).unwrap();
// SAFETY: `inner` is pinned behind the `Arc` reference.
let inner = unsafe { Pin::new_unchecked(&thread.inner) };
let inner = unsafe { Pin::new_unchecked(&mut thread.inner) };
kernel::spinlock_init!(inner, "Thread::inner");
kernel::condvar_init!(thread.pinned_condvar(), "Thread::work_condvar");

// SAFETY: `work_condvar` is pinned behind the `Arc` reference.
let condvar = unsafe { Pin::new_unchecked(&mut thread.work_condvar) };
kernel::condvar_init!(condvar, "Thread::work_condvar");
{
let mut inner = arc.inner.lock();
inner.set_reply_work(reply_work);
Expand All @@ -258,10 +261,6 @@ impl Thread {
Ok(arc)
}

fn pinned_condvar(&self) -> Pin<&CondVar> {
unsafe { Pin::new_unchecked(&self.work_condvar) }
}

pub(crate) fn set_current_transaction(&self, transaction: Arc<Transaction>) {
self.inner.lock().current_transaction = Some(transaction);
}
Expand All @@ -276,15 +275,14 @@ impl Thread {
}

// Loop waiting only on the local queue (i.e., not registering with the process queue).
let cv = self.pinned_condvar();
let mut inner = self.inner.lock();
loop {
if let Some(work) = inner.pop_work() {
return Ok(work);
}

inner.looper_flags |= LOOPER_WAITING;
let signal_pending = cv.wait(&mut inner);
let signal_pending = self.work_condvar.wait(&mut inner);
inner.looper_flags &= !LOOPER_WAITING;

if signal_pending {
Expand Down Expand Up @@ -321,15 +319,14 @@ impl Thread {
Either::Right(reg) => reg,
};

let cv = self.pinned_condvar();
let mut inner = self.inner.lock();
loop {
if let Some(work) = inner.pop_work() {
return Ok(work);
}

inner.looper_flags |= LOOPER_WAITING;
let signal_pending = cv.wait(&mut inner);
let signal_pending = self.work_condvar.wait(&mut inner);
inner.looper_flags &= !LOOPER_WAITING;

if signal_pending {
Expand All @@ -352,7 +349,7 @@ impl Thread {
}
inner.push_work(work);
}
self.pinned_condvar().notify_one();
self.work_condvar.notify_one();
Ok(())
}

Expand Down Expand Up @@ -535,7 +532,7 @@ impl Thread {
}

// Notify the thread now that we've released the inner lock.
self.pinned_condvar().notify_one();
self.work_condvar.notify_one();
false
}

Expand Down Expand Up @@ -779,7 +776,7 @@ impl Thread {

// Now that the lock is no longer held, notify the waiters if we have to.
if notify {
self.pinned_condvar().notify_one();
self.work_condvar.notify_one();
}
}

Expand All @@ -802,7 +799,7 @@ impl Thread {
// Remove epoll items if polling was ever used on the thread.
let poller = self.inner.lock().looper_flags & LOOPER_POLL != 0;
if poller {
self.pinned_condvar().free_waiters();
self.work_condvar.free_waiters();

unsafe { bindings::synchronize_rcu() };
}
Expand Down
4 changes: 2 additions & 2 deletions drivers/android/transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ impl Transaction {
let mut_tr = Arc::get_mut(&mut tr).ok_or(Error::EINVAL)?;

// SAFETY: `inner` is pinned behind `Arc`.
let pinned = unsafe { Pin::new_unchecked(&mut_tr.inner) };
let pinned = unsafe { Pin::new_unchecked(&mut mut_tr.inner) };
kernel::spinlock_init!(pinned, "Transaction::inner");
Ok(tr)
}
Expand Down Expand Up @@ -112,7 +112,7 @@ impl Transaction {
let mut_tr = Arc::get_mut(&mut tr).ok_or(Error::EINVAL)?;

// SAFETY: `inner` is pinned behind `Arc`.
let pinned = unsafe { Pin::new_unchecked(&mut_tr.inner) };
let pinned = unsafe { Pin::new_unchecked(&mut mut_tr.inner) };
kernel::spinlock_init!(pinned, "Transaction::inner");
Ok(tr)
}
Expand Down
2 changes: 1 addition & 1 deletion rust/kernel/sync/condvar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ impl CondVar {
}

impl NeedsLockClass for CondVar {
unsafe fn init(self: Pin<&Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe fn init(self: Pin<&mut Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe { bindings::__init_waitqueue_head(self.wait_list.get(), name.as_char_ptr(), key) };
}
}
8 changes: 5 additions & 3 deletions rust/kernel/sync/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,11 @@
//! # use kernel::prelude::*;
//! # use kernel::mutex_init;
//! # use kernel::sync::Mutex;
//! # use alloc::boxed::Box;
//! # use core::pin::Pin;
//! // SAFETY: `init` is called below.
//! let data = alloc::sync::Arc::pin(unsafe { Mutex::new(0) });
//! mutex_init!(data.as_ref(), "test::data");
//! let mut data = Pin::from(Box::new(unsafe { Mutex::new(0) }));
//! mutex_init!(data.as_mut(), "test::data");
//! *data.lock() = 10;
//! pr_info!("{}\n", *data.lock());
//! ```
Expand Down Expand Up @@ -73,7 +75,7 @@ pub trait NeedsLockClass {
/// # Safety
///
/// `key` must point to a valid memory location as it will be used by the kernel.
unsafe fn init(self: Pin<&Self>, name: &'static CStr, key: *mut bindings::lock_class_key);
unsafe fn init(self: Pin<&mut Self>, name: &'static CStr, key: *mut bindings::lock_class_key);
}

/// Determines if a signal is pending on the current process.
Expand Down
2 changes: 1 addition & 1 deletion rust/kernel/sync/mutex.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ impl<T: ?Sized> Mutex<T> {
}

impl<T: ?Sized> NeedsLockClass for Mutex<T> {
unsafe fn init(self: Pin<&Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe fn init(self: Pin<&mut Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe { bindings::__mutex_init(self.mutex.get(), name.as_char_ptr(), key) };
}
}
Expand Down
2 changes: 1 addition & 1 deletion rust/kernel/sync/spinlock.rs
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ impl<T: ?Sized> SpinLock<T> {
}

impl<T: ?Sized> NeedsLockClass for SpinLock<T> {
unsafe fn init(self: Pin<&Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe fn init(self: Pin<&mut Self>, name: &'static CStr, key: *mut bindings::lock_class_key) {
unsafe { rust_helper_spin_lock_init(self.spin_lock.get(), name.as_char_ptr(), key) };
}
}
Expand Down
10 changes: 5 additions & 5 deletions samples/rust/rust_miscdev.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ impl SharedState {
// SAFETY: `mutex_init!` is called below.
inner: unsafe { Mutex::new(SharedStateInner { token_count: 0 }) },
},
|state| {
|mut state| {
// SAFETY: `state_changed` is pinned when `state` is.
let state_changed = unsafe { Pin::new_unchecked(&state.state_changed) };
kernel::condvar_init!(state_changed, "SharedState::state_changed");
let pinned = unsafe { state.as_mut().map_unchecked_mut(|s| &mut s.state_changed) };
kernel::condvar_init!(pinned, "SharedState::state_changed");
// SAFETY: `inner` is pinned when `state` is.
let inner = unsafe { Pin::new_unchecked(&state.inner) };
kernel::mutex_init!(inner, "SharedState::inner");
let pinned = unsafe { state.as_mut().map_unchecked_mut(|s| &mut s.inner) };
kernel::mutex_init!(pinned, "SharedState::inner");
},
)?))
}
Expand Down
6 changes: 3 additions & 3 deletions samples/rust/rust_semaphore.rs
Original file line number Diff line number Diff line change
Expand Up @@ -131,13 +131,13 @@ impl KernelModule for RustSemaphore {
})
},
},
|sema| {
|mut sema| {
// SAFETY: `changed` is pinned when `sema` is.
let pinned = unsafe { Pin::new_unchecked(&sema.changed) };
let pinned = unsafe { sema.as_mut().map_unchecked_mut(|s| &mut s.changed) };
condvar_init!(pinned, "Semaphore::changed");

// SAFETY: `inner` is pinned when `sema` is.
let pinned = unsafe { Pin::new_unchecked(&sema.inner) };
let pinned = unsafe { sema.as_mut().map_unchecked_mut(|s| &mut s.inner) };
mutex_init!(pinned, "Semaphore::inner");
},
)?;
Expand Down
17 changes: 9 additions & 8 deletions samples/rust/rust_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,15 @@ impl KernelModule for RustSync {
// Test mutexes.
{
// SAFETY: `init` is called below.
let data = Pin::from(Box::try_new(unsafe { Mutex::new(0) })?);
mutex_init!(data.as_ref(), "RustSync::init::data1");
let mut data = Pin::from(Box::try_new(unsafe { Mutex::new(0) })?);
mutex_init!(data.as_mut(), "RustSync::init::data1");
*data.lock() = 10;
pr_info!("Value: {}\n", *data.lock());

// SAFETY: `init` is called below.
let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_ref(), "RustSync::init::cv1");
let mut cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_mut(), "RustSync::init::cv1");

{
let mut guard = data.lock();
while *guard != 10 {
Expand All @@ -52,14 +53,14 @@ impl KernelModule for RustSync {
// Test spinlocks.
{
// SAFETY: `init` is called below.
let data = Pin::from(Box::try_new(unsafe { SpinLock::new(0) })?);
spinlock_init!(data.as_ref(), "RustSync::init::data2");
let mut data = Pin::from(Box::try_new(unsafe { SpinLock::new(0) })?);
spinlock_init!(data.as_mut(), "RustSync::init::data2");
*data.lock() = 10;
pr_info!("Value: {}\n", *data.lock());

// SAFETY: `init` is called below.
let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_ref(), "RustSync::init::cv2");
let mut cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?);
condvar_init!(cv.as_mut(), "RustSync::init::cv2");
{
let mut guard = data.lock();
while *guard != 10 {
Expand Down