Got closures to work in async, added bunch of tests

This commit is contained in:
Peter Krull 2024-02-14 01:23:11 +01:00
parent 37f1c9ac27
commit 24a4379832

View File

@ -97,7 +97,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal<M, T, N> {
} }
/// Get a [`Receiver`] for the `MultiSignal`. /// Get a [`Receiver`] for the `MultiSignal`.
pub fn receiver(&'a self) -> Result<Receiver<'a, M, T, N>, Error> { pub fn receiver<'s>(&'a self) -> Result<Receiver<'a, M, T, N>, Error> {
self.mutex.lock(|state| { self.mutex.lock(|state| {
let mut s = state.borrow_mut(); let mut s = state.borrow_mut();
if s.receiver_count < N { if s.receiver_count < N {
@ -142,60 +142,36 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal<M, T, N> {
fn get_id(&self) -> u64 { fn get_id(&self) -> u64 {
self.mutex.lock(|state| state.borrow().current_id) self.mutex.lock(|state| state.borrow().current_id)
} }
/// Poll the `MultiSignal` with an optional context.
fn get_with_context(&self, rcv: &mut Rcv<'a, M, T, N>, cx: Option<&mut Context>) -> Poll<T> {
self.mutex.lock(|state| {
let mut s = state.borrow_mut();
match (s.current_id > rcv.at_id, rcv.predicate) {
(true, None) => {
rcv.at_id = s.current_id;
Poll::Ready(s.data.clone())
}
(true, Some(f)) if f(&s.data) => {
rcv.at_id = s.current_id;
Poll::Ready(s.data.clone())
}
_ => {
if let Some(cx) = cx {
s.wakers.register(cx.waker());
}
Poll::Pending
}
}
})
}
} }
/// A receiver is able to `.await` a changed `MultiSignal` value. /// A receiver is able to `.await` a changed `MultiSignal` value.
pub struct Rcv<'a, M: RawMutex, T: Clone, const N: usize> { pub struct Rcv<'a, M: RawMutex, T: Clone, const N: usize> {
multi_sig: &'a MultiSignal<M, T, N>, multi_sig: &'a MultiSignal<M, T, N>,
predicate: Option<fn(&T) -> bool>,
at_id: u64, at_id: u64,
} }
// f: Option<impl FnMut(&T) -> bool> impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
/// Create a new `Receiver` with a reference the given `MultiSignal`. /// Create a new `Receiver` with a reference the given `MultiSignal`.
fn new(multi_sig: &'a MultiSignal<M, T, N>) -> Self { fn new(multi_sig: &'a MultiSignal<M, T, N>) -> Self {
Self { Self { multi_sig, at_id: 0 }
multi_sig,
predicate: None,
at_id: 0,
}
} }
/// Wait for a change to the value of the corresponding `MultiSignal`. /// Wait for a change to the value of the corresponding `MultiSignal`.
pub fn changed<'s>(&'s mut self) -> ReceiverFuture<'s, 'a, M, T, N> { pub async fn changed(&mut self) -> T {
self.predicate = None; ReceiverWaitFuture { subscriber: self }.await
ReceiverFuture { subscriber: self }
} }
/// Wait for a change to the value of the corresponding `MultiSignal` which matches the predicate `f`. /// Wait for a change to the value of the corresponding `MultiSignal` which matches the predicate `f`.
// TODO: How do we make this work with a FnMut closure? // TODO: How do we make this work with a FnMut closure?
pub fn changed_and<'s>(&'s mut self, f: fn(&T) -> bool) -> ReceiverFuture<'s, 'a, M, T, N> { pub async fn changed_and<F>(&mut self, f: F) -> T
self.predicate = Some(f); where
ReceiverFuture { subscriber: self } F: FnMut(&T) -> bool,
{
ReceiverPredFuture {
subscriber: self,
predicate: f,
}
.await
} }
/// Try to get a changed value of the corresponding `MultiSignal`. /// Try to get a changed value of the corresponding `MultiSignal`.
@ -213,7 +189,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
} }
/// Try to get a changed value of the corresponding `MultiSignal` which matches the predicate `f`. /// Try to get a changed value of the corresponding `MultiSignal` which matches the predicate `f`.
pub fn try_changed_and(&mut self, mut f: impl FnMut(&T) -> bool) -> Option<T> { pub fn try_changed_and<F>(&mut self, mut f: F) -> Option<T>
where
F: FnMut(&T) -> bool,
{
self.multi_sig.mutex.lock(|state| { self.multi_sig.mutex.lock(|state| {
let s = state.borrow(); let s = state.borrow();
match s.current_id > self.at_id && f(&s.data) { match s.current_id > self.at_id && f(&s.data) {
@ -232,7 +211,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
} }
/// Peek the current value of the corresponding `MultiSignal` and check if it satisfies the predicate `f`. /// Peek the current value of the corresponding `MultiSignal` and check if it satisfies the predicate `f`.
pub fn peek_and(&self, f: impl FnMut(&T) -> bool) -> Option<T> { pub fn peek_and<F>(&self, f: F) -> Option<T>
where
F: FnMut(&T) -> bool,
{
self.multi_sig.peek_and(f) self.multi_sig.peek_and(f)
} }
@ -247,7 +229,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> {
/// A `Receiver` is able to `.await` a change to the corresponding [`MultiSignal`] value. /// A `Receiver` is able to `.await` a change to the corresponding [`MultiSignal`] value.
pub struct Receiver<'a, M: RawMutex, T: Clone, const N: usize>(Rcv<'a, M, T, N>); pub struct Receiver<'a, M: RawMutex, T: Clone, const N: usize>(Rcv<'a, M, T, N>);
impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> { impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> {
type Target = Rcv<'a, M, T, N>; type Target = Rcv<'a, M, T, N>;
fn deref(&self) -> &Self::Target { fn deref(&self) -> &Self::Target {
@ -255,7 +237,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N>
} }
} }
impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> { impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> {
fn deref_mut(&mut self) -> &mut Self::Target { fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0 &mut self.0
} }
@ -263,18 +245,280 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T,
/// Future for the `Receiver` wait action /// Future for the `Receiver` wait action
#[must_use = "futures do nothing unless you `.await` or poll them"] #[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReceiverFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> { pub struct ReceiverWaitFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> {
subscriber: &'s mut Rcv<'a, M, T, N>, subscriber: &'s mut Rcv<'a, M, T, N>,
} }
impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverFuture<'s, 'a, M, T, N> { impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverWaitFuture<'s, 'a, M, T, N> {}
impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverWaitFuture<'s, 'a, M, T, N> {
type Output = T; type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.subscriber self.get_with_context(Some(cx))
.multi_sig
.get_with_context(&mut self.subscriber, Some(cx))
} }
} }
impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverFuture<'s, 'a, M, T, N> {} impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> ReceiverWaitFuture<'s, 'a, M, T, N> {
/// Poll the `MultiSignal` with an optional context.
fn get_with_context(&mut self, cx: Option<&mut Context>) -> Poll<T> {
self.subscriber.multi_sig.mutex.lock(|state| {
let mut s = state.borrow_mut();
match s.current_id > self.subscriber.at_id {
true => {
self.subscriber.at_id = s.current_id;
Poll::Ready(s.data.clone())
}
_ => {
if let Some(cx) = cx {
s.wakers.register(cx.waker());
}
Poll::Pending
}
}
})
}
}
/// Future for the `Receiver` wait action, with the ability to filter the value with a predicate.
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct ReceiverPredFuture<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&'a T) -> bool, const N: usize> {
subscriber: &'s mut Rcv<'a, M, T, N>,
predicate: F,
}
impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Unpin for ReceiverPredFuture<'s, 'a, M, T, F, N> {}
impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Future for ReceiverPredFuture<'s, 'a, M, T, F, N>{
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.get_with_context_pred(Some(cx))
}
}
impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> ReceiverPredFuture<'s, 'a, M, T, F, N> {
/// Poll the `MultiSignal` with an optional context.
fn get_with_context_pred(&mut self, cx: Option<&mut Context>) -> Poll<T> {
self.subscriber.multi_sig.mutex.lock(|state| {
let mut s = state.borrow_mut();
match s.current_id > self.subscriber.at_id {
true if (self.predicate)(&s.data) => {
self.subscriber.at_id = s.current_id;
Poll::Ready(s.data.clone())
}
_ => {
if let Some(cx) = cx {
s.wakers.register(cx.waker());
}
Poll::Pending
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::blocking_mutex::raw::CriticalSectionRawMutex;
use futures_executor::block_on;
#[test]
fn multiple_writes() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(10);
// Receive the new value
assert_eq!(rcv0.changed().await, 10);
assert_eq!(rcv1.changed().await, 10);
// No update
assert_eq!(rcv0.try_changed(), None);
assert_eq!(rcv1.try_changed(), None);
SOME_SIGNAL.write(20);
assert_eq!(rcv0.changed().await, 20);
assert_eq!(rcv1.changed().await, 20);
};
block_on(f);
}
#[test]
fn max_receivers() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let _ = SOME_SIGNAL.receiver().unwrap();
let _ = SOME_SIGNAL.receiver().unwrap();
assert!(SOME_SIGNAL.receiver().is_err());
};
block_on(f);
}
// Really weird edge case, but it's possible to have a receiver that never gets a value.
#[test]
fn receive_initial() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
assert_eq!(rcv0.try_changed(), Some(0));
assert_eq!(rcv1.try_changed(), Some(0));
assert_eq!(rcv0.try_changed(), None);
assert_eq!(rcv1.try_changed(), None);
};
block_on(f);
}
#[test]
fn count_ids() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(10);
assert_eq!(rcv0.changed().await, 10);
assert_eq!(rcv1.changed().await, 10);
assert_eq!(rcv0.try_changed(), None);
assert_eq!(rcv1.try_changed(), None);
SOME_SIGNAL.write(20);
SOME_SIGNAL.write(20);
SOME_SIGNAL.write(20);
assert_eq!(rcv0.changed().await, 20);
assert_eq!(rcv1.changed().await, 20);
assert_eq!(rcv0.try_changed(), None);
assert_eq!(rcv1.try_changed(), None);
assert_eq!(SOME_SIGNAL.get_id(), 5);
};
block_on(f);
}
#[test]
fn peek_still_await() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(10);
assert_eq!(rcv0.peek(), 10);
assert_eq!(rcv1.peek(), 10);
assert_eq!(rcv0.changed().await, 10);
assert_eq!(rcv1.changed().await, 10);
};
block_on(f);
}
#[test]
fn predicate() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(20);
assert_eq!(rcv0.changed_and(|x| x > &10).await, 20);
assert_eq!(rcv1.try_changed_and(|x| x > &30), None);
};
block_on(f);
}
#[test]
fn mutable_predicate() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(10);
let mut largest = 0;
let mut predicate = |x: &u8| {
if *x > largest {
largest = *x;
}
true
};
assert_eq!(rcv.changed_and(&mut predicate).await, 10);
SOME_SIGNAL.write(20);
assert_eq!(rcv.changed_and(&mut predicate).await, 20);
SOME_SIGNAL.write(5);
assert_eq!(rcv.changed_and(&mut predicate).await, 5);
assert_eq!(largest, 20)
};
block_on(f);
}
#[test]
fn peek_and() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let mut rcv0 = SOME_SIGNAL.receiver().unwrap();
let mut rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(20);
assert_eq!(rcv0.peek_and(|x| x > &10), Some(20));
assert_eq!(rcv1.peek_and(|x| x > &30), None);
assert_eq!(rcv0.changed().await, 20);
assert_eq!(rcv1.changed().await, 20);
};
block_on(f);
}
#[test]
fn peek_with_static() {
let f = async {
static SOME_SIGNAL: MultiSignal<CriticalSectionRawMutex, u8, 2> = MultiSignal::new(0);
// Obtain Receivers
let rcv0 = SOME_SIGNAL.receiver().unwrap();
let rcv1 = SOME_SIGNAL.receiver().unwrap();
SOME_SIGNAL.write(20);
assert_eq!(rcv0.peek(), 20);
assert_eq!(rcv1.peek(), 20);
assert_eq!(SOME_SIGNAL.peek(), 20);
assert_eq!(SOME_SIGNAL.peek_and(|x| x > &30), None);
};
block_on(f);
}
}