diff --git a/embassy-futures/src/join.rs b/embassy-futures/src/join.rs new file mode 100644 index 000000000..39a78ccd3 --- /dev/null +++ b/embassy-futures/src/join.rs @@ -0,0 +1,252 @@ +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; +use core::{fmt, mem}; + +#[derive(Debug)] +enum MaybeDone { + /// A not-yet-completed future + Future(/* #[pin] */ Fut), + /// The output of the completed future + Done(Fut::Output), + /// The empty variant after the result of a [`MaybeDone`] has been + /// taken using the [`take_output`](MaybeDone::take_output) method. + Gone, +} + +impl MaybeDone { + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> bool { + let this = unsafe { self.get_unchecked_mut() }; + match this { + Self::Future(fut) => match unsafe { Pin::new_unchecked(fut) }.poll(cx) { + Poll::Ready(res) => { + *this = Self::Done(res); + true + } + Poll::Pending => false, + }, + _ => true, + } + } + + fn take_output(&mut self) -> Fut::Output { + match &*self { + Self::Done(_) => {} + Self::Future(_) | Self::Gone => panic!("take_output when MaybeDone is not done."), + } + match mem::replace(self, Self::Gone) { + MaybeDone::Done(output) => output, + _ => unreachable!(), + } + } +} + +impl Unpin for MaybeDone {} + +macro_rules! generate { + ($( + $(#[$doc:meta])* + ($Join:ident, <$($Fut:ident),*>), + )*) => ($( + $(#[$doc])* + #[must_use = "futures do nothing unless you `.await` or poll them"] + #[allow(non_snake_case)] + pub struct $Join<$($Fut: Future),*> { + $( + $Fut: MaybeDone<$Fut>, + )* + } + + impl<$($Fut),*> fmt::Debug for $Join<$($Fut),*> + where + $( + $Fut: Future + fmt::Debug, + $Fut::Output: fmt::Debug, + )* + { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct(stringify!($Join)) + $(.field(stringify!($Fut), &self.$Fut))* + .finish() + } + } + + impl<$($Fut: Future),*> $Join<$($Fut),*> { + #[allow(non_snake_case)] + fn new($($Fut: $Fut),*) -> Self { + Self { + $($Fut: MaybeDone::Future($Fut)),* + } + } + } + + impl<$($Fut: Future),*> Future for $Join<$($Fut),*> { + type Output = ($($Fut::Output),*); + + fn poll( + self: Pin<&mut Self>, cx: &mut Context<'_> + ) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let mut all_done = true; + $( + all_done &= unsafe { Pin::new_unchecked(&mut this.$Fut) }.poll(cx); + )* + + if all_done { + Poll::Ready(($(this.$Fut.take_output()), *)) + } else { + Poll::Pending + } + } + } + )*) +} + +generate! { + /// Future for the [`join`](join()) function. + (Join, ), + + /// Future for the [`join3`] function. + (Join3, ), + + /// Future for the [`join4`] function. + (Join4, ), + + /// Future for the [`join5`] function. + (Join5, ), +} + +/// Joins the result of two futures, waiting for them both to complete. +/// +/// This function will return a new future which awaits both futures to +/// complete. The returned future will finish with a tuple of both results. +/// +/// Note that this function consumes the passed futures and returns a +/// wrapped version of it. +/// +/// # Examples +/// +/// ``` +/// # embassy_futures::block_on(async { +/// +/// let a = async { 1 }; +/// let b = async { 2 }; +/// let pair = embassy_futures::join(a, b).await; +/// +/// assert_eq!(pair, (1, 2)); +/// # }); +/// ``` +pub fn join(future1: Fut1, future2: Fut2) -> Join +where + Fut1: Future, + Fut2: Future, +{ + Join::new(future1, future2) +} + +/// Joins the result of three futures, waiting for them all to complete. +/// +/// This function will return a new future which awaits all futures to +/// complete. The returned future will finish with a tuple of all results. +/// +/// Note that this function consumes the passed futures and returns a +/// wrapped version of it. +/// +/// # Examples +/// +/// ``` +/// # embassy_futures::block_on(async { +/// +/// let a = async { 1 }; +/// let b = async { 2 }; +/// let c = async { 3 }; +/// let res = embassy_futures::join3(a, b, c).await; +/// +/// assert_eq!(res, (1, 2, 3)); +/// # }); +/// ``` +pub fn join3(future1: Fut1, future2: Fut2, future3: Fut3) -> Join3 +where + Fut1: Future, + Fut2: Future, + Fut3: Future, +{ + Join3::new(future1, future2, future3) +} + +/// Joins the result of four futures, waiting for them all to complete. +/// +/// This function will return a new future which awaits all futures to +/// complete. The returned future will finish with a tuple of all results. +/// +/// Note that this function consumes the passed futures and returns a +/// wrapped version of it. +/// +/// # Examples +/// +/// ``` +/// # embassy_futures::block_on(async { +/// +/// let a = async { 1 }; +/// let b = async { 2 }; +/// let c = async { 3 }; +/// let d = async { 4 }; +/// let res = embassy_futures::join4(a, b, c, d).await; +/// +/// assert_eq!(res, (1, 2, 3, 4)); +/// # }); +/// ``` +pub fn join4( + future1: Fut1, + future2: Fut2, + future3: Fut3, + future4: Fut4, +) -> Join4 +where + Fut1: Future, + Fut2: Future, + Fut3: Future, + Fut4: Future, +{ + Join4::new(future1, future2, future3, future4) +} + +/// Joins the result of five futures, waiting for them all to complete. +/// +/// This function will return a new future which awaits all futures to +/// complete. The returned future will finish with a tuple of all results. +/// +/// Note that this function consumes the passed futures and returns a +/// wrapped version of it. +/// +/// # Examples +/// +/// ``` +/// # embassy_futures::block_on(async { +/// +/// let a = async { 1 }; +/// let b = async { 2 }; +/// let c = async { 3 }; +/// let d = async { 4 }; +/// let e = async { 5 }; +/// let res = embassy_futures::join5(a, b, c, d, e).await; +/// +/// assert_eq!(res, (1, 2, 3, 4, 5)); +/// # }); +/// ``` +pub fn join5( + future1: Fut1, + future2: Fut2, + future3: Fut3, + future4: Fut4, + future5: Fut5, +) -> Join5 +where + Fut1: Future, + Fut2: Future, + Fut3: Future, + Fut4: Future, + Fut5: Future, +{ + Join5::new(future1, future2, future3, future4, future5) +} diff --git a/embassy-futures/src/lib.rs b/embassy-futures/src/lib.rs index 41e27047d..ea135b3ab 100644 --- a/embassy-futures/src/lib.rs +++ b/embassy-futures/src/lib.rs @@ -6,9 +6,11 @@ pub(crate) mod fmt; mod block_on; +mod join; mod select; mod yield_now; pub use block_on::*; +pub use join::*; pub use select::*; pub use yield_now::*;