diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 163b2c794..874bfe841 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -1,7 +1,8 @@ #![macro_use] use core::marker::PhantomData; -use core::sync::atomic::{compiler_fence, Ordering}; +use core::mem::MaybeUninit; +use core::sync::atomic::{compiler_fence, AtomicU32, Ordering}; use core::task::Poll; use embassy::interrupt::InterruptExt; use embassy::time::{with_timeout, Duration}; @@ -23,6 +24,7 @@ const NEW_AW: AtomicWaker = AtomicWaker::new(); static BUS_WAKER: AtomicWaker = NEW_AW; static EP_IN_WAKERS: [AtomicWaker; 9] = [NEW_AW; 9]; static EP_OUT_WAKERS: [AtomicWaker; 9] = [NEW_AW; 9]; +static READY_ENDPOINTS: AtomicU32 = AtomicU32::new(0); pub struct Driver<'d, T: Instance> { phantom: PhantomData<&'d mut T>, @@ -84,6 +86,8 @@ impl<'d, T: Instance> Driver<'d, T> { regs.events_epdata.reset(); let r = regs.epdatastatus.read().bits(); + regs.epdatastatus.write(|w| unsafe { w.bits(r) }); + READY_ENDPOINTS.fetch_or(r, Ordering::AcqRel); for i in 1..=7 { if r & (1 << i) != 0 { EP_IN_WAKERS[i].wake(); @@ -143,15 +147,12 @@ impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { .alloc_in .allocate(ep_addr, ep_type, max_packet_size, interval)?; let ep_addr = EndpointAddress::from_parts(index, UsbDirection::In); - Ok(Endpoint { - _phantom: PhantomData, - info: EndpointInfo { - addr: ep_addr, - ep_type, - max_packet_size, - interval, - }, - }) + Ok(Endpoint::new(EndpointInfo { + addr: ep_addr, + ep_type, + max_packet_size, + interval, + })) } fn alloc_endpoint_out( @@ -165,15 +166,12 @@ impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { .alloc_out .allocate(ep_addr, ep_type, max_packet_size, interval)?; let ep_addr = EndpointAddress::from_parts(index, UsbDirection::Out); - Ok(Endpoint { - _phantom: PhantomData, - info: EndpointInfo { - addr: ep_addr, - ep_type, - max_packet_size, - interval, - }, - }) + Ok(Endpoint::new(EndpointInfo { + addr: ep_addr, + ep_type, + max_packet_size, + interval, + })) } fn enable(self) -> Self::Bus { @@ -284,7 +282,9 @@ impl<'d, T: Instance> driver::Bus for Bus<'d, T> { } } - //self.busy_in_endpoints = 0; + // IN endpoints (low bits) default to ready. + // OUT endpoints (high bits) default to NOT ready, they become ready when data comes in. + READY_ENDPOINTS.store(0x0000FFFF, Ordering::Release); } #[inline] @@ -324,6 +324,15 @@ pub struct Endpoint<'d, T: Instance, Dir> { info: EndpointInfo, } +impl<'d, T: Instance, Dir> Endpoint<'d, T, Dir> { + fn new(info: EndpointInfo) -> Self { + Self { + info, + _phantom: PhantomData, + } + } +} + impl<'d, T: Instance, Dir> driver::Endpoint for Endpoint<'d, T, Dir> { fn info(&self) -> &EndpointInfo { &self.info @@ -368,7 +377,6 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { } }) .await; - info!("got SETUP"); if buf.len() < 8 { return Err(ReadError::BufferOverflow); @@ -385,10 +393,10 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { Ok(8) } else { + // Wait until ready poll_fn(|cx| { EP_OUT_WAKERS[i].register(cx.waker()); - let regs = T::regs(); - let r = regs.epdatastatus.read().bits(); + let r = READY_ENDPOINTS.load(Ordering::Acquire); if r & (1 << (i + 16)) != 0 { Poll::Ready(()) } else { @@ -397,9 +405,8 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { }) .await; - // Clear status - regs.epdatastatus - .write(|w| unsafe { w.bits(1 << (i + 16)) }); + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); // Check that the packet fits into the buffer let size = regs.size.epout[i].read().bits(); @@ -448,48 +455,83 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { async move { - info!("write: {:x}", buf); - let regs = T::regs(); + let i = self.info.addr.index(); - let ptr = buf.as_ptr() as u32; - let len = buf.len() as u32; - regs.epin0.ptr.write(|w| unsafe { w.bits(ptr) }); - regs.epin0.maxcnt.write(|w| unsafe { w.bits(len) }); - - regs.events_ep0datadone.reset(); - regs.events_endepin[0].reset(); - - dma_start(); - - regs.tasks_startepin[0].write(|w| unsafe { w.bits(1) }); - info!("write: waiting for endepin..."); - while regs.events_endepin[0].read().bits() == 0 {} - - dma_end(); - - info!("write: waiting for ep0datadone..."); - regs.intenset.write(|w| w.ep0datadone().set()); - let res = with_timeout( - Duration::from_millis(10), + // Wait until ready. + if i != 0 { poll_fn(|cx| { - EP_IN_WAKERS[0].register(cx.waker()); - let regs = T::regs(); - if regs.events_ep0datadone.read().bits() != 0 { + EP_IN_WAKERS[i].register(cx.waker()); + let r = READY_ENDPOINTS.load(Ordering::Acquire); + if r & (1 << i) != 0 { Poll::Ready(()) } else { Poll::Pending } - }), - ) - .await; + }) + .await; - if res.is_err() { - // todo wrong error - return Err(driver::WriteError::BufferOverflow); + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel); } - info!("write done"); + if i == 0 { + regs.events_ep0datadone.reset(); + } + + assert!(buf.len() <= 64); + + // EasyDMA can't read FLASH, so we copy through RAM + let mut ram_buf: MaybeUninit<[u8; 64]> = MaybeUninit::uninit(); + let ptr = ram_buf.as_mut_ptr() as *mut u8; + unsafe { core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len()) }; + + let epin = [ + ®s.epin0, + ®s.epin1, + ®s.epin2, + ®s.epin3, + ®s.epin4, + ®s.epin5, + ®s.epin6, + ®s.epin7, + ]; + + // Set the buffer length so the right number of bytes are transmitted. + // Safety: `buf.len()` has been checked to be <= the max buffer length. + unsafe { + epin[i].ptr.write(|w| w.bits(ptr as u32)); + epin[i].maxcnt.write(|w| w.maxcnt().bits(buf.len() as u8)); + } + + regs.events_endepin[i].reset(); + + dma_start(); + regs.tasks_startepin[i].write(|w| unsafe { w.bits(1) }); + while regs.events_endepin[i].read().bits() == 0 {} + dma_end(); + + if i == 0 { + regs.intenset.write(|w| w.ep0datadone().set()); + let res = with_timeout( + Duration::from_millis(10), + poll_fn(|cx| { + EP_IN_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs.events_ep0datadone.read().bits() != 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + }), + ) + .await; + + if res.is_err() { + // todo wrong error + return Err(driver::WriteError::BufferOverflow); + } + } Ok(()) } diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 33f3d4712..95f78804d 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -198,6 +198,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } fn control_reject(&mut self, req: Request) { + info!("control reject"); self.control_out.set_stalled(true); } diff --git a/examples/nrf/Cargo.toml b/examples/nrf/Cargo.toml index fb846b3a9..59e5de026 100644 --- a/examples/nrf/Cargo.toml +++ b/examples/nrf/Cargo.toml @@ -1,6 +1,6 @@ [package] authors = ["Dario Nieuwenhuis "] -edition = "2018" +edition = "2021" name = "embassy-nrf-examples" version = "0.1.0" diff --git a/examples/nrf/src/bin/usb/main.rs b/examples/nrf/src/bin/usb/main.rs index d175766bb..014ad5c6e 100644 --- a/examples/nrf/src/bin/usb/main.rs +++ b/examples/nrf/src/bin/usb/main.rs @@ -10,13 +10,14 @@ mod cdc_acm; use core::mem; use defmt::*; use embassy::executor::Spawner; +use embassy::time::{Duration, Timer}; use embassy_nrf::interrupt; use embassy_nrf::pac; -use embassy_nrf::usb::{self, Driver}; +use embassy_nrf::usb::Driver; use embassy_nrf::Peripherals; -use embassy_usb::driver::EndpointOut; +use embassy_usb::driver::{EndpointIn, EndpointOut}; use embassy_usb::{Config, UsbDeviceBuilder}; -use futures::future::{join, select}; +use futures::future::join3; use crate::cdc_acm::CdcAcmClass; @@ -61,6 +62,15 @@ async fn main(_spawner: Spawner, p: Peripherals) { info!("data: {:x}", data); } }; + let fut3 = async { + loop { + info!("writing..."); + class.write_ep.write(b"Hello World!\r\n").await.unwrap(); + info!("written"); - join(fut1, fut2).await; + Timer::after(Duration::from_secs(1)).await; + } + }; + + join3(fut1, fut2, fut3).await; }