From 24a4379832d387754d407b77ff7aac5e55401eb3 Mon Sep 17 00:00:00 2001 From: Peter Krull Date: Wed, 14 Feb 2024 01:23:11 +0100 Subject: [PATCH] Got closures to work in async, added bunch of tests --- embassy-sync/src/multi_signal.rs | 340 ++++++++++++++++++++++++++----- 1 file changed, 292 insertions(+), 48 deletions(-) diff --git a/embassy-sync/src/multi_signal.rs b/embassy-sync/src/multi_signal.rs index 5f724c76b..1481dc8f8 100644 --- a/embassy-sync/src/multi_signal.rs +++ b/embassy-sync/src/multi_signal.rs @@ -97,7 +97,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal { } /// Get a [`Receiver`] for the `MultiSignal`. - pub fn receiver(&'a self) -> Result, Error> { + pub fn receiver<'s>(&'a self) -> Result, Error> { self.mutex.lock(|state| { let mut s = state.borrow_mut(); if s.receiver_count < N { @@ -142,60 +142,36 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> MultiSignal { fn get_id(&self) -> u64 { self.mutex.lock(|state| state.borrow().current_id) } - - /// Poll the `MultiSignal` with an optional context. - fn get_with_context(&self, rcv: &mut Rcv<'a, M, T, N>, cx: Option<&mut Context>) -> Poll { - self.mutex.lock(|state| { - let mut s = state.borrow_mut(); - match (s.current_id > rcv.at_id, rcv.predicate) { - (true, None) => { - rcv.at_id = s.current_id; - Poll::Ready(s.data.clone()) - } - (true, Some(f)) if f(&s.data) => { - rcv.at_id = s.current_id; - Poll::Ready(s.data.clone()) - } - _ => { - if let Some(cx) = cx { - s.wakers.register(cx.waker()); - } - Poll::Pending - } - } - }) - } } /// A receiver is able to `.await` a changed `MultiSignal` value. pub struct Rcv<'a, M: RawMutex, T: Clone, const N: usize> { multi_sig: &'a MultiSignal, - predicate: Option bool>, at_id: u64, } -// f: Option bool> -impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> { +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> { /// Create a new `Receiver` with a reference the given `MultiSignal`. fn new(multi_sig: &'a MultiSignal) -> Self { - Self { - multi_sig, - predicate: None, - at_id: 0, - } + Self { multi_sig, at_id: 0 } } /// Wait for a change to the value of the corresponding `MultiSignal`. - pub fn changed<'s>(&'s mut self) -> ReceiverFuture<'s, 'a, M, T, N> { - self.predicate = None; - ReceiverFuture { subscriber: self } + pub async fn changed(&mut self) -> T { + ReceiverWaitFuture { subscriber: self }.await } /// Wait for a change to the value of the corresponding `MultiSignal` which matches the predicate `f`. // TODO: How do we make this work with a FnMut closure? - pub fn changed_and<'s>(&'s mut self, f: fn(&T) -> bool) -> ReceiverFuture<'s, 'a, M, T, N> { - self.predicate = Some(f); - ReceiverFuture { subscriber: self } + pub async fn changed_and(&mut self, f: F) -> T + where + F: FnMut(&T) -> bool, + { + ReceiverPredFuture { + subscriber: self, + predicate: f, + } + .await } /// Try to get a changed value of the corresponding `MultiSignal`. @@ -213,7 +189,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> { } /// Try to get a changed value of the corresponding `MultiSignal` which matches the predicate `f`. - pub fn try_changed_and(&mut self, mut f: impl FnMut(&T) -> bool) -> Option { + pub fn try_changed_and(&mut self, mut f: F) -> Option + where + F: FnMut(&T) -> bool, + { self.multi_sig.mutex.lock(|state| { let s = state.borrow(); match s.current_id > self.at_id && f(&s.data) { @@ -232,7 +211,10 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> { } /// Peek the current value of the corresponding `MultiSignal` and check if it satisfies the predicate `f`. - pub fn peek_and(&self, f: impl FnMut(&T) -> bool) -> Option { + pub fn peek_and(&self, f: F) -> Option + where + F: FnMut(&T) -> bool, + { self.multi_sig.peek_and(f) } @@ -247,7 +229,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Rcv<'a, M, T, N> { /// A `Receiver` is able to `.await` a change to the corresponding [`MultiSignal`] value. pub struct Receiver<'a, M: RawMutex, T: Clone, const N: usize>(Rcv<'a, M, T, N>); -impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> { +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> { type Target = Rcv<'a, M, T, N>; fn deref(&self) -> &Self::Target { @@ -255,7 +237,7 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> Deref for Receiver<'a, M, T, N> } } -impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> { +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, N> { fn deref_mut(&mut self) -> &mut Self::Target { &mut self.0 } @@ -263,18 +245,280 @@ impl<'a, M: RawMutex, T: Clone, const N: usize> DerefMut for Receiver<'a, M, T, /// Future for the `Receiver` wait action #[must_use = "futures do nothing unless you `.await` or poll them"] -pub struct ReceiverFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> { +pub struct ReceiverWaitFuture<'s, 'a, M: RawMutex, T: Clone, const N: usize> { subscriber: &'s mut Rcv<'a, M, T, N>, } -impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverFuture<'s, 'a, M, T, N> { +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverWaitFuture<'s, 'a, M, T, N> {} +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Future for ReceiverWaitFuture<'s, 'a, M, T, N> { type Output = T; fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.subscriber - .multi_sig - .get_with_context(&mut self.subscriber, Some(cx)) + self.get_with_context(Some(cx)) } } -impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> Unpin for ReceiverFuture<'s, 'a, M, T, N> {} +impl<'s, 'a, M: RawMutex, T: Clone, const N: usize> ReceiverWaitFuture<'s, 'a, M, T, N> { + /// Poll the `MultiSignal` with an optional context. + fn get_with_context(&mut self, cx: Option<&mut Context>) -> Poll { + self.subscriber.multi_sig.mutex.lock(|state| { + let mut s = state.borrow_mut(); + match s.current_id > self.subscriber.at_id { + true => { + self.subscriber.at_id = s.current_id; + Poll::Ready(s.data.clone()) + } + _ => { + if let Some(cx) = cx { + s.wakers.register(cx.waker()); + } + Poll::Pending + } + } + }) + } +} + +/// Future for the `Receiver` wait action, with the ability to filter the value with a predicate. +#[must_use = "futures do nothing unless you `.await` or poll them"] +pub struct ReceiverPredFuture<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&'a T) -> bool, const N: usize> { + subscriber: &'s mut Rcv<'a, M, T, N>, + predicate: F, +} + +impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Unpin for ReceiverPredFuture<'s, 'a, M, T, F, N> {} +impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> Future for ReceiverPredFuture<'s, 'a, M, T, F, N>{ + type Output = T; + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.get_with_context_pred(Some(cx)) + } +} + +impl<'s, 'a, M: RawMutex, T: Clone, F: FnMut(&T) -> bool, const N: usize> ReceiverPredFuture<'s, 'a, M, T, F, N> { + /// Poll the `MultiSignal` with an optional context. + fn get_with_context_pred(&mut self, cx: Option<&mut Context>) -> Poll { + self.subscriber.multi_sig.mutex.lock(|state| { + let mut s = state.borrow_mut(); + match s.current_id > self.subscriber.at_id { + true if (self.predicate)(&s.data) => { + self.subscriber.at_id = s.current_id; + Poll::Ready(s.data.clone()) + } + _ => { + if let Some(cx) = cx { + s.wakers.register(cx.waker()); + } + Poll::Pending + } + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::blocking_mutex::raw::CriticalSectionRawMutex; + use futures_executor::block_on; + + #[test] + fn multiple_writes() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(10); + + // Receive the new value + assert_eq!(rcv0.changed().await, 10); + assert_eq!(rcv1.changed().await, 10); + + // No update + assert_eq!(rcv0.try_changed(), None); + assert_eq!(rcv1.try_changed(), None); + + SOME_SIGNAL.write(20); + + assert_eq!(rcv0.changed().await, 20); + assert_eq!(rcv1.changed().await, 20); + }; + block_on(f); + } + + #[test] + fn max_receivers() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let _ = SOME_SIGNAL.receiver().unwrap(); + let _ = SOME_SIGNAL.receiver().unwrap(); + assert!(SOME_SIGNAL.receiver().is_err()); + }; + block_on(f); + } + + // Really weird edge case, but it's possible to have a receiver that never gets a value. + #[test] + fn receive_initial() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + assert_eq!(rcv0.try_changed(), Some(0)); + assert_eq!(rcv1.try_changed(), Some(0)); + + assert_eq!(rcv0.try_changed(), None); + assert_eq!(rcv1.try_changed(), None); + }; + block_on(f); + } + + #[test] + fn count_ids() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(10); + + assert_eq!(rcv0.changed().await, 10); + assert_eq!(rcv1.changed().await, 10); + + assert_eq!(rcv0.try_changed(), None); + assert_eq!(rcv1.try_changed(), None); + + SOME_SIGNAL.write(20); + SOME_SIGNAL.write(20); + SOME_SIGNAL.write(20); + + assert_eq!(rcv0.changed().await, 20); + assert_eq!(rcv1.changed().await, 20); + + assert_eq!(rcv0.try_changed(), None); + assert_eq!(rcv1.try_changed(), None); + + assert_eq!(SOME_SIGNAL.get_id(), 5); + }; + block_on(f); + } + + #[test] + fn peek_still_await() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(10); + + assert_eq!(rcv0.peek(), 10); + assert_eq!(rcv1.peek(), 10); + + assert_eq!(rcv0.changed().await, 10); + assert_eq!(rcv1.changed().await, 10); + }; + block_on(f); + } + + #[test] + fn predicate() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(20); + + assert_eq!(rcv0.changed_and(|x| x > &10).await, 20); + assert_eq!(rcv1.try_changed_and(|x| x > &30), None); + }; + block_on(f); + } + + #[test] + fn mutable_predicate() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(10); + + let mut largest = 0; + let mut predicate = |x: &u8| { + if *x > largest { + largest = *x; + } + true + }; + + assert_eq!(rcv.changed_and(&mut predicate).await, 10); + + SOME_SIGNAL.write(20); + + assert_eq!(rcv.changed_and(&mut predicate).await, 20); + + SOME_SIGNAL.write(5); + + assert_eq!(rcv.changed_and(&mut predicate).await, 5); + + assert_eq!(largest, 20) + }; + block_on(f); + } + + #[test] + fn peek_and() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let mut rcv0 = SOME_SIGNAL.receiver().unwrap(); + let mut rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(20); + + assert_eq!(rcv0.peek_and(|x| x > &10), Some(20)); + assert_eq!(rcv1.peek_and(|x| x > &30), None); + + assert_eq!(rcv0.changed().await, 20); + assert_eq!(rcv1.changed().await, 20); + }; + block_on(f); + } + + #[test] + fn peek_with_static() { + let f = async { + static SOME_SIGNAL: MultiSignal = MultiSignal::new(0); + + // Obtain Receivers + let rcv0 = SOME_SIGNAL.receiver().unwrap(); + let rcv1 = SOME_SIGNAL.receiver().unwrap(); + + SOME_SIGNAL.write(20); + + assert_eq!(rcv0.peek(), 20); + assert_eq!(rcv1.peek(), 20); + assert_eq!(SOME_SIGNAL.peek(), 20); + assert_eq!(SOME_SIGNAL.peek_and(|x| x > &30), None); + }; + block_on(f); + } +}