time/generic-queue: fix ub in tests.

This commit is contained in:
Dario Nieuwenhuis 2024-05-13 00:35:46 +02:00
parent a8f578751f
commit 1c9bb7c2e1

View File

@ -177,9 +177,10 @@ embassy_time_queue_driver::timer_queue_impl!(static QUEUE: Queue = Queue::new())
#[cfg(test)] #[cfg(test)]
#[cfg(feature = "mock-driver")] #[cfg(feature = "mock-driver")]
mod tests { mod tests {
use core::cell::Cell; use core::sync::atomic::{AtomicBool, Ordering};
use core::task::{RawWaker, RawWakerVTable, Waker}; use core::task::Waker;
use std::rc::Rc; use std::sync::Arc;
use std::task::Wake;
use serial_test::serial; use serial_test::serial;
@ -188,42 +189,26 @@ mod tests {
use crate::{Duration, Instant}; use crate::{Duration, Instant};
struct TestWaker { struct TestWaker {
pub awoken: Rc<Cell<bool>>, pub awoken: AtomicBool,
pub waker: Waker,
} }
impl TestWaker { impl Wake for TestWaker {
fn new() -> Self { fn wake(self: Arc<Self>) {
let flag = Rc::new(Cell::new(false)); self.awoken.store(true, Ordering::Relaxed);
const VTABLE: RawWakerVTable = RawWakerVTable::new(
|data: *const ()| {
unsafe {
Rc::increment_strong_count(data as *const Cell<bool>);
}
RawWaker::new(data as _, &VTABLE)
},
|data: *const ()| unsafe {
let data = data as *const Cell<bool>;
data.as_ref().unwrap().set(true);
Rc::decrement_strong_count(data);
},
|data: *const ()| unsafe {
(data as *const Cell<bool>).as_ref().unwrap().set(true);
},
|data: *const ()| unsafe {
Rc::decrement_strong_count(data);
},
);
let raw = RawWaker::new(Rc::into_raw(flag.clone()) as _, &VTABLE);
Self {
awoken: flag.clone(),
waker: unsafe { Waker::from_raw(raw) },
}
} }
fn wake_by_ref(self: &Arc<Self>) {
self.awoken.store(true, Ordering::Relaxed);
}
}
fn test_waker() -> (Arc<TestWaker>, Waker) {
let arc = Arc::new(TestWaker {
awoken: AtomicBool::new(false),
});
let waker = Waker::from(arc.clone());
(arc, waker)
} }
fn setup() { fn setup() {
@ -249,11 +234,11 @@ mod tests {
assert_eq!(queue_len(), 0); assert_eq!(queue_len(), 0);
let waker = TestWaker::new(); let (flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(1), &waker);
assert!(!waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
} }
@ -262,23 +247,23 @@ mod tests {
fn test_schedule_same() { fn test_schedule_same() {
setup(); setup();
let waker = TestWaker::new(); let (_flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(1), &waker);
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
QUEUE.schedule_wake(Instant::from_secs(1), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(1), &waker);
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(100), &waker);
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
let waker2 = TestWaker::new(); let (_flag2, waker2) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(100), &waker2.waker); QUEUE.schedule_wake(Instant::from_secs(100), &waker2);
assert_eq!(queue_len(), 2); assert_eq!(queue_len(), 2);
} }
@ -288,21 +273,21 @@ mod tests {
fn test_trigger() { fn test_trigger() {
setup(); setup();
let waker = TestWaker::new(); let (flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(100), &waker);
assert!(!waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
MockDriver::get().advance(Duration::from_secs(99)); MockDriver::get().advance(Duration::from_secs(99));
assert!(!waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
MockDriver::get().advance(Duration::from_secs(1)); MockDriver::get().advance(Duration::from_secs(1));
assert!(waker.awoken.get()); assert!(flag.awoken.load(Ordering::Relaxed));
assert_eq!(queue_len(), 0); assert_eq!(queue_len(), 0);
} }
@ -312,18 +297,18 @@ mod tests {
fn test_immediate_trigger() { fn test_immediate_trigger() {
setup(); setup();
let waker = TestWaker::new(); let (flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(100), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(100), &waker);
MockDriver::get().advance(Duration::from_secs(50)); MockDriver::get().advance(Duration::from_secs(50));
let waker2 = TestWaker::new(); let (flag2, waker2) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(40), &waker2.waker); QUEUE.schedule_wake(Instant::from_secs(40), &waker2);
assert!(!waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
assert!(waker2.awoken.get()); assert!(flag2.awoken.load(Ordering::Relaxed));
assert_eq!(queue_len(), 1); assert_eq!(queue_len(), 1);
} }
@ -333,30 +318,31 @@ mod tests {
setup(); setup();
for i in 1..super::QUEUE_SIZE { for i in 1..super::QUEUE_SIZE {
let waker = TestWaker::new(); let (flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(310), &waker.waker); QUEUE.schedule_wake(Instant::from_secs(310), &waker);
assert_eq!(queue_len(), i); assert_eq!(queue_len(), i);
assert!(!waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
} }
let first_waker = TestWaker::new(); let (flag, waker) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(300), &first_waker.waker); QUEUE.schedule_wake(Instant::from_secs(300), &waker);
assert_eq!(queue_len(), super::QUEUE_SIZE); assert_eq!(queue_len(), super::QUEUE_SIZE);
assert!(!first_waker.awoken.get()); assert!(!flag.awoken.load(Ordering::Relaxed));
let second_waker = TestWaker::new(); let (flag2, waker2) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(305), &second_waker.waker); QUEUE.schedule_wake(Instant::from_secs(305), &waker2);
assert_eq!(queue_len(), super::QUEUE_SIZE); assert_eq!(queue_len(), super::QUEUE_SIZE);
assert!(first_waker.awoken.get()); assert!(flag.awoken.load(Ordering::Relaxed));
QUEUE.schedule_wake(Instant::from_secs(320), &TestWaker::new().waker); let (_flag3, waker3) = test_waker();
QUEUE.schedule_wake(Instant::from_secs(320), &waker3);
assert_eq!(queue_len(), super::QUEUE_SIZE); assert_eq!(queue_len(), super::QUEUE_SIZE);
assert!(second_waker.awoken.get()); assert!(flag2.awoken.load(Ordering::Relaxed));
} }
} }