diff --git a/embassy-sync/src/mutex.rs b/embassy-sync/src/mutex.rs index 72459d660..b48a408c4 100644 --- a/embassy-sync/src/mutex.rs +++ b/embassy-sync/src/mutex.rs @@ -3,6 +3,7 @@ //! This module provides a mutex that can be used to synchronize data between asynchronous tasks. use core::cell::{RefCell, UnsafeCell}; use core::future::poll_fn; +use core::mem; use core::ops::{Deref, DerefMut}; use core::task::Poll; @@ -134,6 +135,7 @@ where /// successfully locked the mutex, and grants access to the contents. /// /// Dropping it unlocks the mutex. +#[clippy::has_significant_drop] pub struct MutexGuard<'a, M, T> where M: RawMutex, @@ -142,6 +144,25 @@ where mutex: &'a Mutex, } +impl<'a, M, T> MutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + /// Returns a locked view over a portion of the locked data. + pub fn map(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> { + let mutex = this.mutex; + let value = fun(unsafe { &mut *this.mutex.inner.get() }); + // Don't run the `drop` method for MutexGuard. The ownership of the underlying + // locked state is being moved to the returned MappedMutexGuard. + mem::forget(this); + MappedMutexGuard { + state: &mutex.state, + value, + } + } +} + impl<'a, M, T> Drop for MutexGuard<'a, M, T> where M: RawMutex, @@ -180,3 +201,115 @@ where unsafe { &mut *(self.mutex.inner.get()) } } } + +/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`] or +/// [`MappedMutexGuard::map`]. +/// +/// This can be used to hold a subfield of the protected data. +#[clippy::has_significant_drop] +pub struct MappedMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + state: &'a BlockingMutex>, + value: *mut T, +} + +impl<'a, M, T> MappedMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + /// Returns a locked view over a portion of the locked data. + pub fn map(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> { + let state = this.state; + let value = fun(unsafe { &mut *this.value }); + // Don't run the `drop` method for MutexGuard. The ownership of the underlying + // locked state is being moved to the returned MappedMutexGuard. + mem::forget(this); + MappedMutexGuard { state, value } + } +} + +impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + type Target = T; + fn deref(&self) -> &Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &*self.value } + } +} + +impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn deref_mut(&mut self) -> &mut Self::Target { + // Safety: the MutexGuard represents exclusive access to the contents + // of the mutex, so it's OK to get it. + unsafe { &mut *self.value } + } +} + +impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T> +where + M: RawMutex, + T: ?Sized, +{ + fn drop(&mut self) { + self.state.lock(|s| { + let mut s = unwrap!(s.try_borrow_mut()); + s.locked = false; + s.waker.wake(); + }) + } +} + +unsafe impl Send for MappedMutexGuard<'_, M, T> +where + M: RawMutex + Sync, + T: Send + ?Sized, +{ +} + +unsafe impl Sync for MappedMutexGuard<'_, M, T> +where + M: RawMutex + Sync, + T: Sync + ?Sized, +{ +} + +#[cfg(test)] +mod tests { + use crate::blocking_mutex::raw::NoopRawMutex; + use crate::mutex::{Mutex, MutexGuard}; + + #[futures_test::test] + async fn mapped_guard_releases_lock_when_dropped() { + let mutex: Mutex = Mutex::new([0, 1]); + + { + let guard = mutex.lock().await; + assert_eq!(*guard, [0, 1]); + let mut mapped = MutexGuard::map(guard, |this| &mut this[1]); + assert_eq!(*mapped, 1); + *mapped = 2; + } + + { + let guard = mutex.lock().await; + assert_eq!(*guard, [0, 2]); + let mut mapped = MutexGuard::map(guard, |this| &mut this[1]); + assert_eq!(*mapped, 2); + *mapped = 3; + } + + assert_eq!(*mutex.lock().await, [0, 3]); + } +}