Add map method

This commit is contained in:
Caio 2024-04-14 19:35:59 -03:00
parent 953e1cd6f6
commit 86706bdc14

View File

@ -3,6 +3,7 @@
//! This module provides a mutex that can be used to synchronize data between asynchronous tasks. //! This module provides a mutex that can be used to synchronize data between asynchronous tasks.
use core::cell::{RefCell, UnsafeCell}; use core::cell::{RefCell, UnsafeCell};
use core::future::poll_fn; use core::future::poll_fn;
use core::mem;
use core::ops::{Deref, DerefMut}; use core::ops::{Deref, DerefMut};
use core::task::Poll; use core::task::Poll;
@ -134,6 +135,7 @@ where
/// successfully locked the mutex, and grants access to the contents. /// successfully locked the mutex, and grants access to the contents.
/// ///
/// Dropping it unlocks the mutex. /// Dropping it unlocks the mutex.
#[clippy::has_significant_drop]
pub struct MutexGuard<'a, M, T> pub struct MutexGuard<'a, M, T>
where where
M: RawMutex, M: RawMutex,
@ -142,6 +144,25 @@ where
mutex: &'a Mutex<M, T>, mutex: &'a Mutex<M, T>,
} }
impl<'a, M, T> MutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
/// Returns a locked view over a portion of the locked data.
pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
let mutex = this.mutex;
let value = fun(unsafe { &mut *this.mutex.inner.get() });
// Don't run the `drop` method for MutexGuard. The ownership of the underlying
// locked state is being moved to the returned MappedMutexGuard.
mem::forget(this);
MappedMutexGuard {
state: &mutex.state,
value,
}
}
}
impl<'a, M, T> Drop for MutexGuard<'a, M, T> impl<'a, M, T> Drop for MutexGuard<'a, M, T>
where where
M: RawMutex, M: RawMutex,
@ -180,3 +201,115 @@ where
unsafe { &mut *(self.mutex.inner.get()) } unsafe { &mut *(self.mutex.inner.get()) }
} }
} }
/// A handle to a held `Mutex` that has had a function applied to it via [`MutexGuard::map`] or
/// [`MappedMutexGuard::map`].
///
/// This can be used to hold a subfield of the protected data.
#[clippy::has_significant_drop]
pub struct MappedMutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
state: &'a BlockingMutex<M, RefCell<State>>,
value: *mut T,
}
impl<'a, M, T> MappedMutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
/// Returns a locked view over a portion of the locked data.
pub fn map<U>(this: Self, fun: impl FnOnce(&mut T) -> &mut U) -> MappedMutexGuard<'a, M, U> {
let state = this.state;
let value = fun(unsafe { &mut *this.value });
// Don't run the `drop` method for MutexGuard. The ownership of the underlying
// locked state is being moved to the returned MappedMutexGuard.
mem::forget(this);
MappedMutexGuard { state, value }
}
}
impl<'a, M, T> Deref for MappedMutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
type Target = T;
fn deref(&self) -> &Self::Target {
// Safety: the MutexGuard represents exclusive access to the contents
// of the mutex, so it's OK to get it.
unsafe { &*self.value }
}
}
impl<'a, M, T> DerefMut for MappedMutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn deref_mut(&mut self) -> &mut Self::Target {
// Safety: the MutexGuard represents exclusive access to the contents
// of the mutex, so it's OK to get it.
unsafe { &mut *self.value }
}
}
impl<'a, M, T> Drop for MappedMutexGuard<'a, M, T>
where
M: RawMutex,
T: ?Sized,
{
fn drop(&mut self) {
self.state.lock(|s| {
let mut s = unwrap!(s.try_borrow_mut());
s.locked = false;
s.waker.wake();
})
}
}
unsafe impl<M, T> Send for MappedMutexGuard<'_, M, T>
where
M: RawMutex + Sync,
T: Send + ?Sized,
{
}
unsafe impl<M, T> Sync for MappedMutexGuard<'_, M, T>
where
M: RawMutex + Sync,
T: Sync + ?Sized,
{
}
#[cfg(test)]
mod tests {
use crate::blocking_mutex::raw::NoopRawMutex;
use crate::mutex::{Mutex, MutexGuard};
#[futures_test::test]
async fn mapped_guard_releases_lock_when_dropped() {
let mutex: Mutex<NoopRawMutex, [i32; 2]> = Mutex::new([0, 1]);
{
let guard = mutex.lock().await;
assert_eq!(*guard, [0, 1]);
let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
assert_eq!(*mapped, 1);
*mapped = 2;
}
{
let guard = mutex.lock().await;
assert_eq!(*guard, [0, 2]);
let mut mapped = MutexGuard::map(guard, |this| &mut this[1]);
assert_eq!(*mapped, 2);
*mapped = 3;
}
assert_eq!(*mutex.lock().await, [0, 3]);
}
}