diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 0f7d68d8c..163b2c794 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -8,7 +8,7 @@ use embassy::time::{with_timeout, Duration}; use embassy::util::Unborrow; use embassy::waitqueue::AtomicWaker; use embassy_hal_common::unborrow; -use embassy_usb::driver::{self, ReadError, WriteError}; +use embassy_usb::driver::{self, Event, ReadError, WriteError}; use embassy_usb::types::{EndpointAddress, EndpointInfo, EndpointType, UsbDirection}; use futures::future::poll_fn; use futures::Future; @@ -19,7 +19,10 @@ pub use embassy_usb; use crate::interrupt::Interrupt; use crate::pac; -static EP0_WAKER: AtomicWaker = AtomicWaker::new(); +const NEW_AW: AtomicWaker = AtomicWaker::new(); +static BUS_WAKER: AtomicWaker = NEW_AW; +static EP_IN_WAKERS: [AtomicWaker; 9] = [NEW_AW; 9]; +static EP_OUT_WAKERS: [AtomicWaker; 9] = [NEW_AW; 9]; pub struct Driver<'d, T: Instance> { phantom: PhantomData<&'d mut T>, @@ -47,13 +50,48 @@ impl<'d, T: Instance> Driver<'d, T> { fn on_interrupt(_: *mut ()) { let regs = T::regs(); + if regs.events_usbreset.read().bits() != 0 { + regs.intenclr.write(|w| w.usbreset().clear()); + BUS_WAKER.wake(); + } + if regs.events_ep0setup.read().bits() != 0 { regs.intenclr.write(|w| w.ep0setup().clear()); - EP0_WAKER.wake(); + EP_OUT_WAKERS[0].wake(); } + if regs.events_ep0datadone.read().bits() != 0 { regs.intenclr.write(|w| w.ep0datadone().clear()); - EP0_WAKER.wake(); + EP_IN_WAKERS[0].wake(); + } + + // USBEVENT and EPDATA events are weird. They're the "aggregate" + // of individual bits in EVENTCAUSE and EPDATASTATUS. We handle them + // differently than events normally. + // + // They seem to be edge-triggered, not level-triggered: when an + // individual bit goes 0->1, the event fires *just once*. + // Therefore, it's fine to clear just the event, and let main thread + // check the individual bits in EVENTCAUSE and EPDATASTATUS. It + // doesn't cause an infinite irq loop. + if regs.events_usbevent.read().bits() != 0 { + regs.events_usbevent.reset(); + //regs.intenclr.write(|w| w.usbevent().clear()); + BUS_WAKER.wake(); + } + + if regs.events_epdata.read().bits() != 0 { + regs.events_epdata.reset(); + + let r = regs.epdatastatus.read().bits(); + for i in 1..=7 { + if r & (1 << i) != 0 { + EP_IN_WAKERS[i].wake(); + } + if r & (1 << (i + 16)) != 0 { + EP_OUT_WAKERS[i].wake(); + } + } } } @@ -153,6 +191,12 @@ impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { unsafe { NVIC::unmask(pac::Interrupt::USBD) }; + regs.intenset.write(|w| { + w.usbreset().set_bit(); + w.usbevent().set_bit(); + w.epdata().set_bit(); + w + }); // Enable the USB pullup, allowing enumeration. regs.usbpullup.write(|w| w.connect().enabled()); info!("enabled"); @@ -172,6 +216,49 @@ pub struct Bus<'d, T: Instance> { } impl<'d, T: Instance> driver::Bus for Bus<'d, T> { + type PollFuture<'a> + where + Self: 'a, + = impl Future + 'a; + + fn poll<'a>(&'a mut self) -> Self::PollFuture<'a> { + poll_fn(|cx| { + BUS_WAKER.register(cx.waker()); + let regs = T::regs(); + + if regs.events_usbreset.read().bits() != 0 { + regs.events_usbreset.reset(); + regs.intenset.write(|w| w.usbreset().set()); + return Poll::Ready(Event::Reset); + } + + let r = regs.eventcause.read(); + + if r.isooutcrc().bit() { + regs.eventcause.write(|w| w.isooutcrc().set_bit()); + info!("USB event: isooutcrc"); + } + if r.usbwuallowed().bit() { + regs.eventcause.write(|w| w.usbwuallowed().set_bit()); + info!("USB event: usbwuallowed"); + } + if r.suspend().bit() { + regs.eventcause.write(|w| w.suspend().set_bit()); + info!("USB event: suspend"); + } + if r.resume().bit() { + regs.eventcause.write(|w| w.resume().set_bit()); + info!("USB event: resume"); + } + if r.ready().bit() { + regs.eventcause.write(|w| w.ready().set_bit()); + info!("USB event: ready"); + } + + Poll::Pending + }) + } + #[inline] fn reset(&mut self) { let regs = T::regs(); @@ -260,40 +347,95 @@ impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { async move { let regs = T::regs(); + let i = self.info.addr.index(); - if buf.len() == 0 { - regs.tasks_ep0status.write(|w| unsafe { w.bits(1) }); - return Ok(0); - } - - // Wait for SETUP packet - regs.events_ep0setup.reset(); - regs.intenset.write(|w| w.ep0setup().set()); - poll_fn(|cx| { - EP0_WAKER.register(cx.waker()); - if regs.events_ep0setup.read().bits() != 0 { - Poll::Ready(()) - } else { - Poll::Pending + if i == 0 { + if buf.len() == 0 { + regs.tasks_ep0status.write(|w| unsafe { w.bits(1) }); + return Ok(0); } - }) - .await; - info!("got SETUP"); - if buf.len() < 8 { - return Err(ReadError::BufferOverflow); + // Wait for SETUP packet + regs.events_ep0setup.reset(); + regs.intenset.write(|w| w.ep0setup().set()); + poll_fn(|cx| { + EP_OUT_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs.events_ep0setup.read().bits() != 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + info!("got SETUP"); + + if buf.len() < 8 { + return Err(ReadError::BufferOverflow); + } + + buf[0] = regs.bmrequesttype.read().bits() as u8; + buf[1] = regs.brequest.read().brequest().bits(); + buf[2] = regs.wvaluel.read().wvaluel().bits(); + buf[3] = regs.wvalueh.read().wvalueh().bits(); + buf[4] = regs.windexl.read().windexl().bits(); + buf[5] = regs.windexh.read().windexh().bits(); + buf[6] = regs.wlengthl.read().wlengthl().bits(); + buf[7] = regs.wlengthh.read().wlengthh().bits(); + + Ok(8) + } else { + poll_fn(|cx| { + EP_OUT_WAKERS[i].register(cx.waker()); + let regs = T::regs(); + let r = regs.epdatastatus.read().bits(); + if r & (1 << (i + 16)) != 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + // Clear status + regs.epdatastatus + .write(|w| unsafe { w.bits(1 << (i + 16)) }); + + // Check that the packet fits into the buffer + let size = regs.size.epout[i].read().bits(); + if size as usize > buf.len() { + return Err(ReadError::BufferOverflow); + } + + let epout = [ + ®s.epout0, + ®s.epout1, + ®s.epout2, + ®s.epout3, + ®s.epout4, + ®s.epout5, + ®s.epout6, + ®s.epout7, + ]; + epout[i] + .ptr + .write(|w| unsafe { w.bits(buf.as_ptr() as u32) }); + // MAXCNT must match SIZE + epout[i].maxcnt.write(|w| unsafe { w.bits(size) }); + + dma_start(); + regs.events_endepout[i].reset(); + regs.tasks_startepout[i].write(|w| w.tasks_startepout().set_bit()); + while regs.events_endepout[i] + .read() + .events_endepout() + .bit_is_clear() + {} + regs.events_endepout[i].reset(); + dma_end(); + + Ok(size as usize) } - - buf[0] = regs.bmrequesttype.read().bits() as u8; - buf[1] = regs.brequest.read().brequest().bits(); - buf[2] = regs.wvaluel.read().wvaluel().bits(); - buf[3] = regs.wvalueh.read().wvalueh().bits(); - buf[4] = regs.windexl.read().windexl().bits(); - buf[5] = regs.windexh.read().windexh().bits(); - buf[6] = regs.wlengthl.read().wlengthl().bits(); - buf[7] = regs.wlengthh.read().wlengthh().bits(); - - Ok(8) } } } @@ -331,7 +473,8 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { let res = with_timeout( Duration::from_millis(10), poll_fn(|cx| { - EP0_WAKER.register(cx.waker()); + EP_IN_WAKERS[0].register(cx.waker()); + let regs = T::regs(); if regs.events_ep0datadone.read().bits() != 0 { Poll::Ready(()) } else { diff --git a/embassy-usb/Cargo.toml b/embassy-usb/Cargo.toml index dfdc8fbac..5a5a6d7ab 100644 --- a/embassy-usb/Cargo.toml +++ b/embassy-usb/Cargo.toml @@ -1,7 +1,7 @@ [package] name = "embassy-usb" version = "0.1.0" -edition = "2018" +edition = "2021" [dependencies] embassy = { version = "0.1.0", path = "../embassy" } diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index ed4edb576..a7b16efa5 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -2,40 +2,6 @@ use core::future::Future; use super::types::*; -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub struct EndpointAllocError; - -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] - -/// Operation is unsupported by the driver. -pub struct Unsupported; - -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] - -/// Errors returned by [`EndpointIn::write`] -pub enum WriteError { - /// The packet is too long to fit in the - /// transmission buffer. This is generally an error in the class implementation, because the - /// class shouldn't provide more data than the `max_packet_size` it specified when allocating - /// the endpoint. - BufferOverflow, -} - -#[derive(Copy, Clone, Eq, PartialEq, Debug)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] - -/// Errors returned by [`EndpointOut::read`] -pub enum ReadError { - /// The received packet is too long to - /// fit in `buf`. This is generally an error in the class implementation, because the class - /// should use a buffer that is large enough for the `max_packet_size` it specified when - /// allocating the endpoint. - BufferOverflow, -} - /// Driver for a specific USB peripheral. Implement this to add support for a new hardware /// platform. pub trait Driver<'a> { @@ -82,6 +48,12 @@ pub trait Driver<'a> { } pub trait Bus { + type PollFuture<'a>: Future + 'a + where + Self: 'a; + + fn poll<'a>(&'a mut self) -> Self::PollFuture<'a>; + /// Called when the host resets the device. This will be soon called after /// [`poll`](crate::device::UsbDevice::poll) returns [`PollResult::Reset`]. This method should /// reset the state of all endpoints and peripheral flags back to a state suitable for @@ -158,3 +130,50 @@ pub trait EndpointIn: Endpoint { /// Writes a single packet of data to the endpoint. fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a>; } + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +/// Event returned by [`Bus::poll`]. +pub enum Event { + /// The USB reset condition has been detected. + Reset, + + /// A USB suspend request has been detected or, in the case of self-powered devices, the device + /// has been disconnected from the USB bus. + Suspend, + + /// A USB resume request has been detected after being suspended or, in the case of self-powered + /// devices, the device has been connected to the USB bus. + Resume, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct EndpointAllocError; + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +/// Operation is unsupported by the driver. +pub struct Unsupported; + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +/// Errors returned by [`EndpointIn::write`] +pub enum WriteError { + /// The packet is too long to fit in the + /// transmission buffer. This is generally an error in the class implementation, because the + /// class shouldn't provide more data than the `max_packet_size` it specified when allocating + /// the endpoint. + BufferOverflow, +} + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +/// Errors returned by [`EndpointOut::read`] +pub enum ReadError { + /// The received packet is too long to + /// fit in `buf`. This is generally an error in the class implementation, because the class + /// should use a buffer that is large enough for the `max_packet_size` it specified when + /// allocating the endpoint. + BufferOverflow, +} diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 397db96c4..33f3d4712 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -9,11 +9,13 @@ mod control; pub mod descriptor; pub mod driver; pub mod types; +mod util; use self::control::*; use self::descriptor::*; use self::driver::*; use self::types::*; +use self::util::*; pub use self::builder::Config; pub use self::builder::UsbDeviceBuilder; @@ -47,7 +49,7 @@ pub const CONFIGURATION_VALUE: u8 = 1; pub const DEFAULT_ALTERNATE_SETTING: u8 = 0; pub struct UsbDevice<'d, D: Driver<'d>> { - driver: D::Bus, + bus: D::Bus, control_in: D::EndpointIn, control_out: D::EndpointOut, @@ -93,7 +95,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let driver = driver.enable(); Self { - driver, + bus: driver, config, control_in, control_out, @@ -108,20 +110,47 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } pub async fn run(&mut self) { + let mut buf = [0; 8]; + loop { - let mut buf = [0; 8]; - let n = self.control_out.read(&mut buf).await.unwrap(); - assert_eq!(n, 8); - let req = Request::parse(&buf).unwrap(); - info!("setup request: {:x}", req); + let control_fut = self.control_out.read(&mut buf); + let bus_fut = self.bus.poll(); + match select(bus_fut, control_fut).await { + Either::Left(evt) => match evt { + Event::Reset => { + self.bus.reset(); - // Now that we have properly parsed the setup packet, ensure the end-point is no longer in - // a stalled state. - self.control_out.set_stalled(false); + self.device_state = UsbDeviceState::Default; + self.remote_wakeup_enabled = false; + self.pending_address = 0; - match req.direction { - UsbDirection::In => self.handle_control_in(req).await, - UsbDirection::Out => self.handle_control_out(req).await, + // TODO + //self.control.reset(); + //for cls in classes { + // cls.reset(); + //} + } + Event::Resume => {} + Event::Suspend => { + self.bus.suspend(); + self.device_state = UsbDeviceState::Suspend; + } + }, + Either::Right(n) => { + let n = n.unwrap(); + assert_eq!(n, 8); + let req = Request::parse(&buf).unwrap(); + info!("control request: {:x}", req); + + // Now that we have properly parsed the setup packet, ensure the end-point is no longer in + // a stalled state. + self.control_out.set_stalled(false); + + match req.direction { + UsbDirection::In => self.handle_control_in(req).await, + UsbDirection::Out => self.handle_control_out(req).await, + } + } } } } @@ -205,7 +234,8 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } (Recipient::Endpoint, Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { - //self.bus.set_stalled(((req.index as u8) & 0x8f).into(), true); + self.bus + .set_stalled(((req.index as u8) & 0x8f).into(), true); self.control_out_accept(req).await; } @@ -266,7 +296,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { (Recipient::Endpoint, Request::GET_STATUS) => { let ep_addr: EndpointAddress = ((req.index as u8) & 0x8f).into(); let mut status: u16 = 0x0000; - if self.driver.is_stalled(ep_addr) { + if self.bus.is_stalled(ep_addr) { status |= 0x0001; } self.control_in_accept(req, &status.to_le_bytes()).await; diff --git a/embassy-usb/src/util.rs b/embassy-usb/src/util.rs new file mode 100644 index 000000000..18cc875c6 --- /dev/null +++ b/embassy-usb/src/util.rs @@ -0,0 +1,45 @@ +use core::future::Future; +use core::pin::Pin; +use core::task::{Context, Poll}; + +#[derive(Debug, Clone)] +pub enum Either { + Left(A), + Right(B), +} + +pub fn select(a: A, b: B) -> Select +where + A: Future, + B: Future, +{ + Select { a, b } +} + +pub struct Select { + a: A, + b: B, +} + +impl Unpin for Select {} + +impl Future for Select +where + A: Future, + B: Future, +{ + type Output = Either; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let this = unsafe { self.get_unchecked_mut() }; + let a = unsafe { Pin::new_unchecked(&mut this.a) }; + let b = unsafe { Pin::new_unchecked(&mut this.b) }; + match a.poll(cx) { + Poll::Ready(x) => Poll::Ready(Either::Left(x)), + Poll::Pending => match b.poll(cx) { + Poll::Ready(x) => Poll::Ready(Either::Right(x)), + Poll::Pending => Poll::Pending, + }, + } + } +} diff --git a/examples/nrf/src/bin/usb/cdc_acm.rs b/examples/nrf/src/bin/usb/cdc_acm.rs index 345d00389..b7c112ae6 100644 --- a/examples/nrf/src/bin/usb/cdc_acm.rs +++ b/examples/nrf/src/bin/usb/cdc_acm.rs @@ -38,14 +38,15 @@ const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22; /// can be sent if there is no other data to send. This is because USB bulk transactions must be /// terminated with a short packet, even if the bulk endpoint is used for stream-like data. pub struct CdcAcmClass<'d, D: Driver<'d>> { - comm_if: InterfaceNumber, - comm_ep: D::EndpointIn, - data_if: InterfaceNumber, - read_ep: D::EndpointOut, - write_ep: D::EndpointIn, - line_coding: LineCoding, - dtr: bool, - rts: bool, + // TODO not pub + pub comm_if: InterfaceNumber, + pub comm_ep: D::EndpointIn, + pub data_if: InterfaceNumber, + pub read_ep: D::EndpointOut, + pub write_ep: D::EndpointIn, + pub line_coding: LineCoding, + pub dtr: bool, + pub rts: bool, } impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { diff --git a/examples/nrf/src/bin/usb/main.rs b/examples/nrf/src/bin/usb/main.rs index 21ca2ba4f..d175766bb 100644 --- a/examples/nrf/src/bin/usb/main.rs +++ b/examples/nrf/src/bin/usb/main.rs @@ -14,7 +14,9 @@ use embassy_nrf::interrupt; use embassy_nrf::pac; use embassy_nrf::usb::{self, Driver}; use embassy_nrf::Peripherals; +use embassy_usb::driver::EndpointOut; use embassy_usb::{Config, UsbDeviceBuilder}; +use futures::future::{join, select}; use crate::cdc_acm::CdcAcmClass; @@ -49,5 +51,16 @@ async fn main(_spawner: Spawner, p: Peripherals) { let mut class = CdcAcmClass::new(&mut builder, 64); let mut usb = builder.build(); - usb.run().await; + + let fut1 = usb.run(); + let fut2 = async { + let mut buf = [0; 64]; + loop { + let n = class.read_ep.read(&mut buf).await.unwrap(); + let data = &buf[..n]; + info!("data: {:x}", data); + } + }; + + join(fut1, fut2).await; }