diff --git a/embassy-time/src/queue_generic.rs b/embassy-time/src/queue_generic.rs index cf7a986d5..4882afd3e 100644 --- a/embassy-time/src/queue_generic.rs +++ b/embassy-time/src/queue_generic.rs @@ -177,9 +177,10 @@ embassy_time_queue_driver::timer_queue_impl!(static QUEUE: Queue = Queue::new()) #[cfg(test)] #[cfg(feature = "mock-driver")] mod tests { - use core::cell::Cell; - use core::task::{RawWaker, RawWakerVTable, Waker}; - use std::rc::Rc; + use core::sync::atomic::{AtomicBool, Ordering}; + use core::task::Waker; + use std::sync::Arc; + use std::task::Wake; use serial_test::serial; @@ -188,42 +189,26 @@ mod tests { use crate::{Duration, Instant}; struct TestWaker { - pub awoken: Rc>, - pub waker: Waker, + pub awoken: AtomicBool, } - impl TestWaker { - fn new() -> Self { - let flag = Rc::new(Cell::new(false)); - - const VTABLE: RawWakerVTable = RawWakerVTable::new( - |data: *const ()| { - unsafe { - Rc::increment_strong_count(data as *const Cell); - } - - RawWaker::new(data as _, &VTABLE) - }, - |data: *const ()| unsafe { - let data = data as *const Cell; - data.as_ref().unwrap().set(true); - Rc::decrement_strong_count(data); - }, - |data: *const ()| unsafe { - (data as *const Cell).as_ref().unwrap().set(true); - }, - |data: *const ()| unsafe { - Rc::decrement_strong_count(data); - }, - ); - - let raw = RawWaker::new(Rc::into_raw(flag.clone()) as _, &VTABLE); - - Self { - awoken: flag.clone(), - waker: unsafe { Waker::from_raw(raw) }, - } + impl Wake for TestWaker { + fn wake(self: Arc) { + self.awoken.store(true, Ordering::Relaxed); } + + fn wake_by_ref(self: &Arc) { + self.awoken.store(true, Ordering::Relaxed); + } + } + + fn test_waker() -> (Arc, Waker) { + let arc = Arc::new(TestWaker { + awoken: AtomicBool::new(false), + }); + let waker = Waker::from(arc.clone()); + + (arc, waker) } fn setup() { @@ -249,11 +234,11 @@ mod tests { assert_eq!(queue_len(), 0); - let waker = TestWaker::new(); + let (flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(1), &waker); - assert!(!waker.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); assert_eq!(queue_len(), 1); } @@ -262,23 +247,23 @@ mod tests { fn test_schedule_same() { setup(); - let waker = TestWaker::new(); + let (_flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(1), &waker); assert_eq!(queue_len(), 1); - QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(1), &waker); assert_eq!(queue_len(), 1); - QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(100), &waker); assert_eq!(queue_len(), 1); - let waker2 = TestWaker::new(); + let (_flag2, waker2) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(100), &waker2.waker); + QUEUE.schedule_wake(Instant::from_secs(100), &waker2); assert_eq!(queue_len(), 2); } @@ -288,21 +273,21 @@ mod tests { fn test_trigger() { setup(); - let waker = TestWaker::new(); + let (flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(100), &waker); - assert!(!waker.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); MockDriver::get().advance(Duration::from_secs(99)); - assert!(!waker.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); assert_eq!(queue_len(), 1); MockDriver::get().advance(Duration::from_secs(1)); - assert!(waker.awoken.get()); + assert!(flag.awoken.load(Ordering::Relaxed)); assert_eq!(queue_len(), 0); } @@ -312,18 +297,18 @@ mod tests { fn test_immediate_trigger() { setup(); - let waker = TestWaker::new(); + let (flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(100), &waker); MockDriver::get().advance(Duration::from_secs(50)); - let waker2 = TestWaker::new(); + let (flag2, waker2) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(40), &waker2.waker); + QUEUE.schedule_wake(Instant::from_secs(40), &waker2); - assert!(!waker.awoken.get()); - assert!(waker2.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); + assert!(flag2.awoken.load(Ordering::Relaxed)); assert_eq!(queue_len(), 1); } @@ -333,30 +318,31 @@ mod tests { setup(); for i in 1..super::QUEUE_SIZE { - let waker = TestWaker::new(); + let (flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(310), &waker.waker); + QUEUE.schedule_wake(Instant::from_secs(310), &waker); assert_eq!(queue_len(), i); - assert!(!waker.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); } - let first_waker = TestWaker::new(); + let (flag, waker) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(300), &first_waker.waker); + QUEUE.schedule_wake(Instant::from_secs(300), &waker); assert_eq!(queue_len(), super::QUEUE_SIZE); - assert!(!first_waker.awoken.get()); + assert!(!flag.awoken.load(Ordering::Relaxed)); - let second_waker = TestWaker::new(); + let (flag2, waker2) = test_waker(); - QUEUE.schedule_wake(Instant::from_secs(305), &second_waker.waker); + QUEUE.schedule_wake(Instant::from_secs(305), &waker2); assert_eq!(queue_len(), super::QUEUE_SIZE); - assert!(first_waker.awoken.get()); + assert!(flag.awoken.load(Ordering::Relaxed)); - QUEUE.schedule_wake(Instant::from_secs(320), &TestWaker::new().waker); + let (_flag3, waker3) = test_waker(); + QUEUE.schedule_wake(Instant::from_secs(320), &waker3); assert_eq!(queue_len(), super::QUEUE_SIZE); - assert!(second_waker.awoken.get()); + assert!(flag2.awoken.load(Ordering::Relaxed)); } }