diff --git a/drivers/char/rust_example.rs b/drivers/char/rust_example.rs index e3582f7373ab49..5079b801f2cc01 100644 --- a/drivers/char/rust_example.rs +++ b/drivers/char/rust_example.rs @@ -10,10 +10,10 @@ use alloc::boxed::Box; use core::pin::Pin; use kernel::prelude::*; use kernel::{ - chrdev, cstr, + chrdev, condvar_init, cstr, file_operations::FileOperations, miscdev, mutex_init, spinlock_init, - sync::{Mutex, SpinLock}, + sync::{CondVar, Mutex, SpinLock}, }; module! { @@ -86,6 +86,20 @@ impl KernelModule for RustExample { mutex_init!(data.as_ref(), "RustExample::init::data1"); *data.lock() = 10; println!("Value: {}", *data.lock()); + + // SAFETY: `init` is called below. + let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?); + condvar_init!(cv.as_ref(), "RustExample::init::cv1"); + { + let guard = data.lock(); + #[allow(clippy::while_immutable_condition)] + while *guard != 10 { + cv.wait(&guard); + } + } + cv.notify_one(); + cv.notify_all(); + cv.free_waiters(); } // Test spinlocks. @@ -95,13 +109,27 @@ impl KernelModule for RustExample { spinlock_init!(data.as_ref(), "RustExample::init::data2"); *data.lock() = 10; println!("Value: {}", *data.lock()); + + // SAFETY: `init` is called below. + let cv = Pin::from(Box::try_new(unsafe { CondVar::new() })?); + condvar_init!(cv.as_ref(), "RustExample::init::cv2"); + { + let guard = data.lock(); + #[allow(clippy::while_immutable_condition)] + while *guard != 10 { + cv.wait(&guard); + } + } + cv.notify_one(); + cv.notify_all(); + cv.free_waiters(); } // Including this large variable on the stack will trigger // stack probing on the supported archs. // This will verify that stack probing does not lead to // any errors if we need to link `__rust_probestack`. - let x: [u64; 1028] = core::hint::black_box([5; 1028]); + let x: [u64; 514] = core::hint::black_box([5; 514]); println!("Large array has length: {}", x.len()); let mut chrdev_reg = diff --git a/rust/helpers.c b/rust/helpers.c index fcf1dbbd6cdda4..51c18ca60c9f28 100644 --- a/rust/helpers.c +++ b/rust/helpers.c @@ -3,6 +3,7 @@ #include #include #include +#include void rust_helper_BUG(void) { @@ -47,6 +48,18 @@ void rust_helper_spin_unlock(spinlock_t *lock) } EXPORT_SYMBOL(rust_helper_spin_unlock); +void rust_helper_init_wait(struct wait_queue_entry *wq_entry) +{ + init_wait(wq_entry); +} +EXPORT_SYMBOL(rust_helper_init_wait); + +int rust_helper_signal_pending(void) +{ + return signal_pending(current); +} +EXPORT_SYMBOL(rust_helper_signal_pending); + // See https://github.com/rust-lang/rust-bindgen/issues/1671 static_assert(__builtin_types_compatible_p(size_t, uintptr_t), "size_t must match uintptr_t, what architecture is this??"); diff --git a/rust/kernel/bindings_helper.h b/rust/kernel/bindings_helper.h index ba1652004d5224..39b0cea37d6166 100644 --- a/rust/kernel/bindings_helper.h +++ b/rust/kernel/bindings_helper.h @@ -9,6 +9,7 @@ #include #include #include +#include // `bindgen` gets confused at certain things const gfp_t BINDINGS_GFP_KERNEL = GFP_KERNEL; diff --git a/rust/kernel/sync/condvar.rs b/rust/kernel/sync/condvar.rs new file mode 100644 index 00000000000000..6e670854336ea5 --- /dev/null +++ b/rust/kernel/sync/condvar.rs @@ -0,0 +1,137 @@ +// SPDX-License-Identifier: GPL-2.0 + +//! A condition variable. +//! +//! This module allows Rust code to use the kernel's [`struct wait_queue_head`] as a condition +//! variable. + +use super::{Guard, Lock, NeedsLockClass}; +use crate::{bindings, c_types, CStr}; +use core::{cell::UnsafeCell, marker::PhantomPinned, mem::MaybeUninit, pin::Pin}; + +extern "C" { + fn rust_helper_init_wait(wq: *mut bindings::wait_queue_entry); + fn rust_helper_signal_pending() -> c_types::c_int; +} + +/// Safely initialises a [`CondVar`] with the given name, generating a new lock class. +#[macro_export] +macro_rules! condvar_init { + ($condvar:expr, $name:literal) => { + $crate::init_with_lockdep!($condvar, $name) + }; +} + +// TODO: `bindgen` is not generating this constant. Figure out why. +const POLLFREE: u32 = 0x4000; + +/// Exposes the kernel's [`struct wait_queue_head`] as a condition variable. It allows the caller to +/// atomically release the given lock and go to sleep. It reacquires the lock when it wakes up. And +/// it wakes up when notified by another thread (via [`CondVar::notify_one`] or +/// [`CondVar::notify_all`]) or because the thread received a signal. +/// +/// [`struct wait_queue_head`]: ../../../include/linux/wait.h +pub struct CondVar { + pub(crate) wait_list: UnsafeCell, + + /// A condvar needs to be pinned because it contains a [`struct list_head`] that is + /// self-referential, so it cannot be safely moved once it is initialised. + _pin: PhantomPinned, +} + +// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on any thread. +unsafe impl Send for CondVar {} + +// SAFETY: `CondVar` only uses a `struct wait_queue_head`, which is safe to use on multiple threads +// concurrently. +unsafe impl Sync for CondVar {} + +impl CondVar { + /// Constructs a new conditional variable. + /// + /// # Safety + /// + /// The caller must call `CondVar::init` before using the conditional variable. + pub unsafe fn new() -> Self { + Self { + wait_list: UnsafeCell::new(bindings::wait_queue_head::default()), + _pin: PhantomPinned, + } + } + + /// Atomically releases the given lock (whose ownership is proven by the guard) and puts the + /// thread to sleep. It wakes up when notified by [`CondVar::notify_one`] or + /// [`CondVar::notify_all`], or when the thread receives a signal. + /// + /// Returns whether there is a signal pending. + pub fn wait(&self, g: &Guard) -> bool { + let l = g.lock; + let mut wait = MaybeUninit::::uninit(); + + // SAFETY: `wait` points to valid memory. + unsafe { rust_helper_init_wait(wait.as_mut_ptr()) }; + + // SAFETY: Both `wait` and `wait_list` point to valid memory. + unsafe { + bindings::prepare_to_wait_exclusive( + self.wait_list.get(), + wait.as_mut_ptr(), + bindings::TASK_INTERRUPTIBLE as _, + ); + } + + // SAFETY: The guard is evidence that the caller owns the lock. + unsafe { l.unlock() }; + + // SAFETY: No arguments, switches to another thread. + unsafe { bindings::schedule() }; + + l.lock_noguard(); + + // SAFETY: Both `wait` and `wait_list` point to valid memory. + unsafe { bindings::finish_wait(self.wait_list.get(), wait.as_mut_ptr()) }; + + // SAFETY: No arguments, just checks `current` for pending signals. + unsafe { rust_helper_signal_pending() != 0 } + } + + /// Calls the kernel function to notify the appropriate number of threads with the given flags. + fn notify(&self, count: i32, flags: u32) { + // SAFETY: `wait_list` points to valid memory. + unsafe { + bindings::__wake_up( + self.wait_list.get(), + bindings::TASK_NORMAL, + count, + flags as _, + ) + }; + } + + /// Wakes a single waiter up, if any. This is not 'sticky' in the sense that if no thread is + /// waiting, the notification is lost completely (as opposed to automatically waking up the + /// next waiter). + pub fn notify_one(&self) { + self.notify(1, 0); + } + + /// Wakes all waiters up, if any. This is not 'sticky' in the sense that if no thread is + /// waiting, the notification is lost completely (as opposed to automatically waking up the + /// next waiter). + pub fn notify_all(&self) { + self.notify(0, 0); + } + + /// Wakes all waiters up. If they were added by `epoll`, they are also removed from the list of + /// waiters. This is useful when cleaning up a condition variable that may be waited on by + /// threads that use `epoll`. + pub fn free_waiters(&self) { + self.notify(1, bindings::POLLHUP | POLLFREE); + } +} + +impl NeedsLockClass for CondVar { + unsafe fn init(self: Pin<&Self>, name: CStr<'static>, key: *mut bindings::lock_class_key) { + bindings::__init_waitqueue_head(self.wait_list.get(), name.as_ptr() as _, key); + } +} diff --git a/rust/kernel/sync/mod.rs b/rust/kernel/sync/mod.rs index 7fb61af0270778..9f533ed025ea87 100644 --- a/rust/kernel/sync/mod.rs +++ b/rust/kernel/sync/mod.rs @@ -20,10 +20,12 @@ use crate::{bindings, CStr}; use core::pin::Pin; +mod condvar; mod guard; mod mutex; mod spinlock; +pub use condvar::CondVar; pub use guard::{Guard, Lock}; pub use mutex::Mutex; pub use spinlock::SpinLock;