From 89bad07e817dec482d385f765da5be3b6d2d0e4c Mon Sep 17 00:00:00 2001 From: Nathan Perry Date: Fri, 20 Sep 2024 01:52:59 -0400 Subject: [PATCH] embassy_sync: `Sink` adapter for `pubsub::Pub` Corresponding to the `Stream` impl for `pubsub::Sub`. Notable difference is that we need a separate adapter type to store the pending item, i.e. we can't `impl Sink for Pub` directly. Instead a method `Pub::sink(&self)` is exposed, which constructs a `PubSink`. --- embassy-sync/Cargo.toml | 3 +- embassy-sync/src/pubsub/mod.rs | 26 +++++++++++ embassy-sync/src/pubsub/publisher.rs | 67 ++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+), 1 deletion(-) diff --git a/embassy-sync/Cargo.toml b/embassy-sync/Cargo.toml index 7b7d2bf8e..2f049b6bc 100644 --- a/embassy-sync/Cargo.toml +++ b/embassy-sync/Cargo.toml @@ -27,6 +27,7 @@ turbowakers = [] defmt = { version = "0.3", optional = true } log = { version = "0.4.14", optional = true } +futures-sink = { version = "0.3", default-features = false, features = [] } futures-util = { version = "0.3.17", default-features = false } critical-section = "1.1" heapless = "0.8" @@ -37,7 +38,7 @@ embedded-io-async = { version = "0.6.1" } futures-executor = { version = "0.3.17", features = [ "thread-pool" ] } futures-test = "0.3.17" futures-timer = "3.0.2" -futures-util = { version = "0.3.17", features = [ "channel" ] } +futures-util = { version = "0.3.17", features = [ "channel", "sink" ] } # Enable critical-section implementation for std, for tests critical-section = { version = "1.1", features = ["std"] } diff --git a/embassy-sync/src/pubsub/mod.rs b/embassy-sync/src/pubsub/mod.rs index 812302e2b..1c0f68bd0 100644 --- a/embassy-sync/src/pubsub/mod.rs +++ b/embassy-sync/src/pubsub/mod.rs @@ -755,4 +755,30 @@ mod tests { assert_eq!(1, sub0.try_next_message_pure().unwrap().0); assert_eq!(0, sub1.try_next_message_pure().unwrap().0); } + + #[futures_test::test] + async fn publisher_sink() { + use futures_util::{SinkExt, StreamExt}; + + let channel = PubSubChannel::::new(); + + let mut sub = channel.subscriber().unwrap(); + + let publ = channel.publisher().unwrap(); + let mut sink = publ.sink(); + + sink.send(0).await.unwrap(); + assert_eq!(0, sub.try_next_message_pure().unwrap()); + + sink.send(1).await.unwrap(); + assert_eq!(1, sub.try_next_message_pure().unwrap()); + + sink.send_all(&mut futures_util::stream::iter(0..4).map(Ok)) + .await + .unwrap(); + assert_eq!(0, sub.try_next_message_pure().unwrap()); + assert_eq!(1, sub.try_next_message_pure().unwrap()); + assert_eq!(2, sub.try_next_message_pure().unwrap()); + assert_eq!(3, sub.try_next_message_pure().unwrap()); + } } diff --git a/embassy-sync/src/pubsub/publisher.rs b/embassy-sync/src/pubsub/publisher.rs index e66b3b1db..7a1ab66de 100644 --- a/embassy-sync/src/pubsub/publisher.rs +++ b/embassy-sync/src/pubsub/publisher.rs @@ -74,6 +74,12 @@ impl<'a, PSB: PubSubBehavior + ?Sized, T: Clone> Pub<'a, PSB, T> { pub fn is_full(&self) -> bool { self.channel.is_full() } + + /// Create a [`futures::Sink`] adapter for this publisher. + #[inline] + pub const fn sink(&self) -> PubSink<'a, '_, PSB, T> { + PubSink { publ: self, fut: None } + } } impl<'a, PSB: PubSubBehavior + ?Sized, T: Clone> Drop for Pub<'a, PSB, T> { @@ -221,6 +227,67 @@ impl<'a, M: RawMutex, T: Clone, const CAP: usize, const SUBS: usize, const PUBS: } } +#[must_use = "Sinks do nothing unless polled"] +/// [`futures_sink::Sink`] adapter for [`Pub`]. +pub struct PubSink<'a, 'p, PSB, T> +where + T: Clone, + PSB: PubSubBehavior + ?Sized, +{ + publ: &'p Pub<'a, PSB, T>, + fut: Option>, +} + +impl<'a, 'p, PSB, T> PubSink<'a, 'p, PSB, T> +where + PSB: PubSubBehavior + ?Sized, + T: Clone, +{ + /// Try to make progress on the pending future if we have one. + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> { + let Some(mut fut) = self.fut.take() else { + return Poll::Ready(()); + }; + + if Pin::new(&mut fut).poll(cx).is_pending() { + self.fut = Some(fut); + return Poll::Pending; + } + + Poll::Ready(()) + } +} + +impl<'a, 'p, PSB, T> futures_sink::Sink for PubSink<'a, 'p, PSB, T> +where + PSB: PubSubBehavior + ?Sized, + T: Clone, +{ + type Error = core::convert::Infallible; + + #[inline] + fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll(cx).map(Ok) + } + + #[inline] + fn start_send(mut self: Pin<&mut Self>, item: T) -> Result<(), Self::Error> { + self.fut = Some(self.publ.publish(item)); + + Ok(()) + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll(cx).map(Ok) + } + + #[inline] + fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + self.poll(cx).map(Ok) + } +} + /// Future for the publisher wait action #[must_use = "futures do nothing unless you `.await` or poll them"] pub struct PublisherWaitFuture<'s, 'a, PSB: PubSubBehavior + ?Sized, T: Clone> {