From bdc6e0481c42d20d5cca19dfc8ec56306e47296e Mon Sep 17 00:00:00 2001 From: alexmoon Date: Fri, 25 Mar 2022 16:46:14 -0400 Subject: [PATCH] Add support for USB classes handling control requests. --- embassy-nrf/src/usb.rs | 432 ++++++++++++++++++---------- embassy-usb/src/builder.rs | 14 +- embassy-usb/src/class.rs | 190 ++++++++++++ embassy-usb/src/control.rs | 17 +- embassy-usb/src/driver.rs | 42 +++ embassy-usb/src/lib.rs | 169 +++++------ examples/nrf/src/bin/usb/cdc_acm.rs | 127 +++++++- examples/nrf/src/bin/usb/main.rs | 3 +- 8 files changed, 703 insertions(+), 291 deletions(-) create mode 100644 embassy-usb/src/class.rs diff --git a/embassy-nrf/src/usb.rs b/embassy-nrf/src/usb.rs index 5df5053ac..203f08fc0 100644 --- a/embassy-nrf/src/usb.rs +++ b/embassy-nrf/src/usb.rs @@ -9,6 +9,7 @@ use embassy::time::{with_timeout, Duration}; use embassy::util::Unborrow; use embassy::waitqueue::AtomicWaker; use embassy_hal_common::unborrow; +use embassy_usb::control::Request; use embassy_usb::driver::{self, Event, ReadError, WriteError}; use embassy_usb::types::{EndpointAddress, EndpointInfo, EndpointType, UsbDirection}; use futures::future::poll_fn; @@ -134,6 +135,7 @@ impl<'d, T: Instance> Driver<'d, T> { impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { type EndpointOut = Endpoint<'d, T, Out>; type EndpointIn = Endpoint<'d, T, In>; + type ControlPipe = ControlPipe<'d, T>; type Bus = Bus<'d, T>; fn alloc_endpoint_in( @@ -174,6 +176,19 @@ impl<'d, T: Instance> driver::Driver<'d> for Driver<'d, T> { })) } + fn alloc_control_pipe( + &mut self, + max_packet_size: u16, + ) -> Result { + self.alloc_endpoint_out(Some(0x00.into()), EndpointType::Control, max_packet_size, 0)?; + self.alloc_endpoint_in(Some(0x80.into()), EndpointType::Control, max_packet_size, 0)?; + Ok(ControlPipe { + _phantom: PhantomData, + max_packet_size, + request: None, + }) + } + fn enable(self) -> Self::Bus { let regs = T::regs(); @@ -344,99 +359,110 @@ impl<'d, T: Instance, Dir> driver::Endpoint for Endpoint<'d, T, Dir> { } } +unsafe fn read_dma(i: usize, buf: &mut [u8]) -> Result { + let regs = T::regs(); + + // Check that the packet fits into the buffer + let size = regs.size.epout[0].read().bits() as usize; + if size > buf.len() { + return Err(ReadError::BufferOverflow); + } + + if i == 0 { + regs.events_ep0datadone.reset(); + } + + let epout = [ + ®s.epout0, + ®s.epout1, + ®s.epout2, + ®s.epout3, + ®s.epout4, + ®s.epout5, + ®s.epout6, + ®s.epout7, + ]; + epout[i].ptr.write(|w| w.bits(buf.as_ptr() as u32)); + // MAXCNT must match SIZE + epout[i].maxcnt.write(|w| w.bits(size as u32)); + + dma_start(); + regs.events_endepout[i].reset(); + regs.tasks_startepout[i].write(|w| w.tasks_startepout().set_bit()); + while regs.events_endepout[i] + .read() + .events_endepout() + .bit_is_clear() + {} + regs.events_endepout[i].reset(); + dma_end(); + + regs.size.epout[i].reset(); + + Ok(size) +} + +unsafe fn write_dma(i: usize, buf: &[u8]) -> Result<(), WriteError> { + let regs = T::regs(); + if buf.len() > 64 { + return Err(WriteError::BufferOverflow); + } + + // EasyDMA can't read FLASH, so we copy through RAM + let mut ram_buf: MaybeUninit<[u8; 64]> = MaybeUninit::uninit(); + let ptr = ram_buf.as_mut_ptr() as *mut u8; + core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len()); + + let epin = [ + ®s.epin0, + ®s.epin1, + ®s.epin2, + ®s.epin3, + ®s.epin4, + ®s.epin5, + ®s.epin6, + ®s.epin7, + ]; + + // Set the buffer length so the right number of bytes are transmitted. + // Safety: `buf.len()` has been checked to be <= the max buffer length. + epin[i].ptr.write(|w| w.bits(ptr as u32)); + epin[i].maxcnt.write(|w| w.maxcnt().bits(buf.len() as u8)); + + regs.events_endepin[i].reset(); + + dma_start(); + regs.tasks_startepin[i].write(|w| w.bits(1)); + while regs.events_endepin[i].read().bits() == 0 {} + dma_end(); + + Ok(()) +} + impl<'d, T: Instance> driver::EndpointOut for Endpoint<'d, T, Out> { type ReadFuture<'a> = impl Future> + 'a where Self: 'a; fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a> { async move { - let regs = T::regs(); let i = self.info.addr.index(); + assert!(i != 0); - if i == 0 { - if buf.len() == 0 { - regs.tasks_ep0status.write(|w| unsafe { w.bits(1) }); - return Ok(0); + // Wait until ready + poll_fn(|cx| { + EP_OUT_WAKERS[i].register(cx.waker()); + let r = READY_ENDPOINTS.load(Ordering::Acquire); + if r & (1 << (i + 16)) != 0 { + Poll::Ready(()) + } else { + Poll::Pending } + }) + .await; - // Wait for SETUP packet - regs.events_ep0setup.reset(); - regs.intenset.write(|w| w.ep0setup().set()); - poll_fn(|cx| { - EP_OUT_WAKERS[0].register(cx.waker()); - let regs = T::regs(); - if regs.events_ep0setup.read().bits() != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); - if buf.len() < 8 { - return Err(ReadError::BufferOverflow); - } - - buf[0] = regs.bmrequesttype.read().bits() as u8; - buf[1] = regs.brequest.read().brequest().bits(); - buf[2] = regs.wvaluel.read().wvaluel().bits(); - buf[3] = regs.wvalueh.read().wvalueh().bits(); - buf[4] = regs.windexl.read().windexl().bits(); - buf[5] = regs.windexh.read().windexh().bits(); - buf[6] = regs.wlengthl.read().wlengthl().bits(); - buf[7] = regs.wlengthh.read().wlengthh().bits(); - - Ok(8) - } else { - // Wait until ready - poll_fn(|cx| { - EP_OUT_WAKERS[i].register(cx.waker()); - let r = READY_ENDPOINTS.load(Ordering::Acquire); - if r & (1 << (i + 16)) != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << (i + 16)), Ordering::AcqRel); - - // Check that the packet fits into the buffer - let size = regs.size.epout[i].read().bits(); - if size as usize > buf.len() { - return Err(ReadError::BufferOverflow); - } - - let epout = [ - ®s.epout0, - ®s.epout1, - ®s.epout2, - ®s.epout3, - ®s.epout4, - ®s.epout5, - ®s.epout6, - ®s.epout7, - ]; - epout[i] - .ptr - .write(|w| unsafe { w.bits(buf.as_ptr() as u32) }); - // MAXCNT must match SIZE - epout[i].maxcnt.write(|w| unsafe { w.bits(size) }); - - dma_start(); - regs.events_endepout[i].reset(); - regs.tasks_startepout[i].write(|w| w.tasks_startepout().set_bit()); - while regs.events_endepout[i] - .read() - .events_endepout() - .bit_is_clear() - {} - regs.events_endepout[i].reset(); - dma_end(); - - Ok(size as usize) - } + unsafe { read_dma::(i, buf) } } } } @@ -446,87 +472,181 @@ impl<'d, T: Instance> driver::EndpointIn for Endpoint<'d, T, In> { fn write<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteFuture<'a> { async move { - let regs = T::regs(); let i = self.info.addr.index(); + assert!(i != 0); // Wait until ready. - if i != 0 { - poll_fn(|cx| { - EP_IN_WAKERS[i].register(cx.waker()); - let r = READY_ENDPOINTS.load(Ordering::Acquire); - if r & (1 << i) != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }) - .await; - - // Mark as not ready - READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel); - } - - if i == 0 { - regs.events_ep0datadone.reset(); - } - - assert!(buf.len() <= 64); - - // EasyDMA can't read FLASH, so we copy through RAM - let mut ram_buf: MaybeUninit<[u8; 64]> = MaybeUninit::uninit(); - let ptr = ram_buf.as_mut_ptr() as *mut u8; - unsafe { core::ptr::copy_nonoverlapping(buf.as_ptr(), ptr, buf.len()) }; - - let epin = [ - ®s.epin0, - ®s.epin1, - ®s.epin2, - ®s.epin3, - ®s.epin4, - ®s.epin5, - ®s.epin6, - ®s.epin7, - ]; - - // Set the buffer length so the right number of bytes are transmitted. - // Safety: `buf.len()` has been checked to be <= the max buffer length. - unsafe { - epin[i].ptr.write(|w| w.bits(ptr as u32)); - epin[i].maxcnt.write(|w| w.maxcnt().bits(buf.len() as u8)); - } - - regs.events_endepin[i].reset(); - - dma_start(); - regs.tasks_startepin[i].write(|w| unsafe { w.bits(1) }); - while regs.events_endepin[i].read().bits() == 0 {} - dma_end(); - - if i == 0 { - regs.intenset.write(|w| w.ep0datadone().set()); - let res = with_timeout( - Duration::from_millis(10), - poll_fn(|cx| { - EP_IN_WAKERS[0].register(cx.waker()); - let regs = T::regs(); - if regs.events_ep0datadone.read().bits() != 0 { - Poll::Ready(()) - } else { - Poll::Pending - } - }), - ) - .await; - - if res.is_err() { - // todo wrong error - return Err(driver::WriteError::BufferOverflow); + poll_fn(|cx| { + EP_IN_WAKERS[i].register(cx.waker()); + let r = READY_ENDPOINTS.load(Ordering::Acquire); + if r & (1 << i) != 0 { + Poll::Ready(()) + } else { + Poll::Pending } + }) + .await; + + // Mark as not ready + READY_ENDPOINTS.fetch_and(!(1 << i), Ordering::AcqRel); + + unsafe { write_dma::(i, buf) } + } + } +} + +pub struct ControlPipe<'d, T: Instance> { + _phantom: PhantomData<&'d mut T>, + max_packet_size: u16, + request: Option, +} + +impl<'d, T: Instance> ControlPipe<'d, T> { + async fn write(&mut self, buf: &[u8], last_chunk: bool) { + let regs = T::regs(); + regs.events_ep0datadone.reset(); + unsafe { + write_dma::(0, buf).unwrap(); + } + + regs.shorts + .modify(|_, w| w.ep0datadone_ep0status().bit(last_chunk)); + + regs.intenset.write(|w| w.ep0datadone().set()); + let res = with_timeout( + Duration::from_millis(10), + poll_fn(|cx| { + EP_IN_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs.events_ep0datadone.read().bits() != 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + }), + ) + .await; + + if res.is_err() { + error!("ControlPipe::write timed out."); + } + } +} + +impl<'d, T: Instance> driver::ControlPipe for ControlPipe<'d, T> { + type SetupFuture<'a> = impl Future + 'a where Self: 'a; + type DataOutFuture<'a> = impl Future> + 'a where Self: 'a; + type AcceptInFuture<'a> = impl Future + 'a where Self: 'a; + + fn setup<'a>(&'a mut self) -> Self::SetupFuture<'a> { + async move { + assert!(self.request.is_none()); + + let regs = T::regs(); + + // Wait for SETUP packet + regs.intenset.write(|w| w.ep0setup().set()); + poll_fn(|cx| { + EP_OUT_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs.events_ep0setup.read().bits() != 0 { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + // Reset shorts + regs.shorts + .modify(|_, w| w.ep0datadone_ep0status().clear_bit()); + regs.events_ep0setup.reset(); + + let mut buf = [0; 8]; + buf[0] = regs.bmrequesttype.read().bits() as u8; + buf[1] = regs.brequest.read().brequest().bits(); + buf[2] = regs.wvaluel.read().wvaluel().bits(); + buf[3] = regs.wvalueh.read().wvalueh().bits(); + buf[4] = regs.windexl.read().windexl().bits(); + buf[5] = regs.windexh.read().windexh().bits(); + buf[6] = regs.wlengthl.read().wlengthl().bits(); + buf[7] = regs.wlengthh.read().wlengthh().bits(); + + let req = Request::parse(&buf); + + if req.direction == UsbDirection::Out { + regs.tasks_ep0rcvout + .write(|w| w.tasks_ep0rcvout().set_bit()); } - Ok(()) + self.request = Some(req); + req } } + + fn data_out<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::DataOutFuture<'a> { + async move { + let req = self.request.unwrap(); + assert_eq!(req.direction, UsbDirection::Out); + assert!(req.length > 0); + assert!(buf.len() >= usize::from(req.length)); + + let regs = T::regs(); + + // Wait until ready + regs.intenset.write(|w| w.ep0datadone().set()); + poll_fn(|cx| { + EP_OUT_WAKERS[0].register(cx.waker()); + let regs = T::regs(); + if regs + .events_ep0datadone + .read() + .events_ep0datadone() + .bit_is_set() + { + Poll::Ready(()) + } else { + Poll::Pending + } + }) + .await; + + unsafe { read_dma::(0, buf) } + } + } + + fn accept(&mut self) { + let regs = T::regs(); + regs.tasks_ep0status + .write(|w| w.tasks_ep0status().bit(true)); + self.request = None; + } + + fn accept_in<'a>(&'a mut self, buf: &'a [u8]) -> Self::AcceptInFuture<'a> { + async move { + info!("control accept {=[u8]:x}", buf); + let req = self.request.unwrap(); + assert_eq!(req.direction, UsbDirection::In); + + let req_len = usize::from(req.length); + let len = buf.len().min(req_len); + let need_zlp = len != req_len && (len % usize::from(self.max_packet_size)) == 0; + let mut chunks = buf[0..len] + .chunks(usize::from(self.max_packet_size)) + .chain(need_zlp.then(|| -> &[u8] { &[] })); + while let Some(chunk) = chunks.next() { + self.write(chunk, chunks.size_hint().0 == 0).await; + } + + self.request = None; + } + } + + fn reject(&mut self) { + let regs = T::regs(); + regs.tasks_ep0stall.write(|w| w.tasks_ep0stall().bit(true)); + self.request = None; + } } fn dma_start() { diff --git a/embassy-usb/src/builder.rs b/embassy-usb/src/builder.rs index e92cc8ef2..f0f94b932 100644 --- a/embassy-usb/src/builder.rs +++ b/embassy-usb/src/builder.rs @@ -1,3 +1,4 @@ +use super::class::UsbClass; use super::descriptor::{BosWriter, DescriptorWriter}; use super::driver::{Driver, EndpointAllocError}; use super::types::*; @@ -174,7 +175,10 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { } /// Creates the [`UsbDevice`] instance with the configuration in this builder. - pub fn build(mut self) -> UsbDevice<'d, D> { + /// + /// If a device has mutliple [`UsbClass`]es, they can be provided as a tuple list: + /// `(class1, (class2, (class3, ()))`. + pub fn build>(mut self, classes: C) -> UsbDevice<'d, D, C> { self.config_descriptor.end_configuration(); self.bos_descriptor.end_bos(); @@ -184,6 +188,7 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { self.device_descriptor.into_buf(), self.config_descriptor.into_buf(), self.bos_descriptor.writer.into_buf(), + classes, ) } @@ -268,9 +273,10 @@ impl<'d, D: Driver<'d>> UsbDeviceBuilder<'d, D> { /// Panics if endpoint allocation fails, because running out of endpoints or memory is not /// feasibly recoverable. #[inline] - pub fn alloc_control_endpoint_out(&mut self, max_packet_size: u16) -> D::EndpointOut { - self.alloc_endpoint_out(None, EndpointType::Control, max_packet_size, 0) - .expect("alloc_ep failed") + pub fn alloc_control_pipe(&mut self, max_packet_size: u16) -> D::ControlPipe { + self.bus + .alloc_control_pipe(max_packet_size) + .expect("alloc_control_pipe failed") } /// Allocates a bulk in endpoint. diff --git a/embassy-usb/src/class.rs b/embassy-usb/src/class.rs new file mode 100644 index 000000000..97bf7aba1 --- /dev/null +++ b/embassy-usb/src/class.rs @@ -0,0 +1,190 @@ +use core::future::Future; + +use crate::control::Request; +use crate::driver::{ControlPipe, Driver}; + +#[derive(Copy, Clone, Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RequestStatus { + Unhandled, + Accepted, + Rejected, +} + +impl Default for RequestStatus { + fn default() -> Self { + RequestStatus::Unhandled + } +} + +/// A trait for implementing USB classes. +/// +/// All methods are optional callbacks that will be called by +/// [`UsbDevice::run()`](crate::UsbDevice::run) +pub trait UsbClass<'d, D: Driver<'d>> { + type ControlOutFuture<'a>: Future + 'a + where + Self: 'a, + 'd: 'a, + D: 'a; + + type ControlInFuture<'a>: Future + 'a + where + Self: 'a, + 'd: 'a, + D: 'a; + + /// Called after a USB reset after the bus reset sequence is complete. + fn reset(&mut self) {} + + /// Called when a control request is received with direction HostToDevice. + /// + /// All requests are passed to classes in turn, which can choose to accept, ignore or report an + /// error. Classes can even choose to override standard requests, but doing that is rarely + /// necessary. + /// + /// When implementing your own class, you should ignore any requests that are not meant for your + /// class so that any other classes in the composite device can process them. + /// + /// # Arguments + /// + /// * `req` - The request from the SETUP packet. + /// * `data` - The data from the request. + fn control_out<'a>(&'a mut self, req: Request, data: &'a [u8]) -> Self::ControlOutFuture<'a> + where + 'd: 'a, + D: 'a; + + /// Called when a control request is received with direction DeviceToHost. + /// + /// All requests are passed to classes in turn, which can choose to accept, ignore or report an + /// error. Classes can even choose to override standard requests, but doing that is rarely + /// necessary. + /// + /// See [`ControlIn`] for how to respond to the transfer. + /// + /// When implementing your own class, you should ignore any requests that are not meant for your + /// class so that any other classes in the composite device can process them. + /// + /// # Arguments + /// + /// * `req` - The request from the SETUP packet. + /// * `control` - The control pipe. + fn control_in<'a>( + &'a mut self, + req: Request, + control: ControlIn<'a, 'd, D>, + ) -> Self::ControlInFuture<'a> + where + 'd: 'a; +} + +impl<'d, D: Driver<'d>> UsbClass<'d, D> for () { + type ControlOutFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + type ControlInFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + + fn control_out<'a>(&'a mut self, _req: Request, _data: &'a [u8]) -> Self::ControlOutFuture<'a> + where + 'd: 'a, + D: 'a, + { + async move { RequestStatus::default() } + } + + fn control_in<'a>( + &'a mut self, + _req: Request, + control: ControlIn<'a, 'd, D>, + ) -> Self::ControlInFuture<'a> + where + 'd: 'a, + D: 'a, + { + async move { control.ignore() } + } +} + +impl<'d, D: Driver<'d>, Head, Tail> UsbClass<'d, D> for (Head, Tail) +where + Head: UsbClass<'d, D>, + Tail: UsbClass<'d, D>, +{ + type ControlOutFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + type ControlInFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + + fn control_out<'a>(&'a mut self, req: Request, data: &'a [u8]) -> Self::ControlOutFuture<'a> + where + 'd: 'a, + D: 'a, + { + async move { + match self.0.control_out(req, data).await { + RequestStatus::Unhandled => self.1.control_out(req, data).await, + status => status, + } + } + } + + fn control_in<'a>( + &'a mut self, + req: Request, + control: ControlIn<'a, 'd, D>, + ) -> Self::ControlInFuture<'a> + where + 'd: 'a, + { + async move { + match self + .0 + .control_in(req, ControlIn::new(control.control)) + .await + { + ControlInRequestStatus(RequestStatus::Unhandled) => { + self.1.control_in(req, control).await + } + status => status, + } + } + } +} + +/// Handle for a control IN transfer. When implementing a class, use the methods of this object to +/// response to the transfer with either data or an error (STALL condition). To ignore the request +/// and pass it on to the next class, call [`Self::ignore()`]. +pub struct ControlIn<'a, 'd: 'a, D: Driver<'d>> { + control: &'a mut D::ControlPipe, +} + +#[derive(Eq, PartialEq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ControlInRequestStatus(pub(crate) RequestStatus); + +impl ControlInRequestStatus { + pub fn status(self) -> RequestStatus { + self.0 + } +} + +impl<'a, 'd: 'a, D: Driver<'d>> ControlIn<'a, 'd, D> { + pub(crate) fn new(control: &'a mut D::ControlPipe) -> Self { + ControlIn { control } + } + + /// Ignores the request and leaves it unhandled. + pub fn ignore(self) -> ControlInRequestStatus { + ControlInRequestStatus(RequestStatus::Unhandled) + } + + /// Accepts the transfer with the supplied buffer. + pub async fn accept(self, data: &[u8]) -> ControlInRequestStatus { + self.control.accept_in(data).await; + + ControlInRequestStatus(RequestStatus::Accepted) + } + + /// Rejects the transfer by stalling the pipe. + pub fn reject(self) -> ControlInRequestStatus { + self.control.reject(); + ControlInRequestStatus(RequestStatus::Rejected) + } +} diff --git a/embassy-usb/src/control.rs b/embassy-usb/src/control.rs index f1148ac76..77bc10aa4 100644 --- a/embassy-usb/src/control.rs +++ b/embassy-usb/src/control.rs @@ -2,12 +2,6 @@ use core::mem; use super::types::*; -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -#[cfg_attr(feature = "defmt", derive(defmt::Format))] -pub enum ParseError { - InvalidLength, -} - /// Control request type. #[repr(u8)] #[derive(Copy, Clone, Eq, PartialEq, Debug)] @@ -104,15 +98,12 @@ impl Request { /// Standard USB feature Device Remote Wakeup for Set/Clear Feature pub const FEATURE_DEVICE_REMOTE_WAKEUP: u16 = 1; - pub(crate) fn parse(buf: &[u8]) -> Result { - if buf.len() != 8 { - return Err(ParseError::InvalidLength); - } - + /// Parses a USB control request from a byte array. + pub fn parse(buf: &[u8; 8]) -> Request { let rt = buf[0]; let recipient = rt & 0b11111; - Ok(Request { + Request { direction: rt.into(), request_type: unsafe { mem::transmute((rt >> 5) & 0b11) }, recipient: if recipient <= 3 { @@ -124,7 +115,7 @@ impl Request { value: (buf[2] as u16) | ((buf[3] as u16) << 8), index: (buf[4] as u16) | ((buf[5] as u16) << 8), length: (buf[6] as u16) | ((buf[7] as u16) << 8), - }) + } } /// Gets the descriptor type and index from the value field of a GET_DESCRIPTOR request. diff --git a/embassy-usb/src/driver.rs b/embassy-usb/src/driver.rs index a7b16efa5..1c6ba1f52 100644 --- a/embassy-usb/src/driver.rs +++ b/embassy-usb/src/driver.rs @@ -1,5 +1,7 @@ use core::future::Future; +use crate::control::Request; + use super::types::*; /// Driver for a specific USB peripheral. Implement this to add support for a new hardware @@ -7,6 +9,7 @@ use super::types::*; pub trait Driver<'a> { type EndpointOut: EndpointOut + 'a; type EndpointIn: EndpointIn + 'a; + type ControlPipe: ControlPipe + 'a; type Bus: Bus + 'a; /// Allocates an endpoint and specified endpoint parameters. This method is called by the device @@ -36,6 +39,11 @@ pub trait Driver<'a> { interval: u8, ) -> Result; + fn alloc_control_pipe( + &mut self, + max_packet_size: u16, + ) -> Result; + /// Enables and initializes the USB peripheral. Soon after enabling the device will be reset, so /// there is no need to perform a USB reset in this method. fn enable(self) -> Self::Bus; @@ -122,6 +130,40 @@ pub trait EndpointOut: Endpoint { fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::ReadFuture<'a>; } +pub trait ControlPipe { + type SetupFuture<'a>: Future + 'a + where + Self: 'a; + type DataOutFuture<'a>: Future> + 'a + where + Self: 'a; + type AcceptInFuture<'a>: Future + 'a + where + Self: 'a; + + /// Reads a single setup packet from the endpoint. + fn setup<'a>(&'a mut self) -> Self::SetupFuture<'a>; + + /// Reads the data packet of a control write sequence. + /// + /// Must be called after `setup()` for requests with `direction` of `Out` + /// and `length` greater than zero. + /// + /// `buf.len()` must be greater than or equal to the request's `length`. + fn data_out<'a>(&'a mut self, buf: &'a mut [u8]) -> Self::DataOutFuture<'a>; + + /// Accepts a control request. + fn accept(&mut self); + + /// Accepts a control read request with `data`. + /// + /// `data.len()` must be less than or equal to the request's `length`. + fn accept_in<'a>(&'a mut self, data: &'a [u8]) -> Self::AcceptInFuture<'a>; + + /// Rejects a control request. + fn reject(&mut self); +} + pub trait EndpointIn: Endpoint { type WriteFuture<'a>: Future> + 'a where diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 95f78804d..4082868fb 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -1,16 +1,19 @@ #![no_std] #![feature(generic_associated_types)] +#![feature(type_alias_impl_trait)] // This mod MUST go first, so that the others see its macros. pub(crate) mod fmt; mod builder; -mod control; +pub mod class; +pub mod control; pub mod descriptor; pub mod driver; pub mod types; mod util; +use self::class::{RequestStatus, UsbClass}; use self::control::*; use self::descriptor::*; use self::driver::*; @@ -48,10 +51,9 @@ pub const CONFIGURATION_VALUE: u8 = 1; /// The default value for bAlternateSetting for all interfaces. pub const DEFAULT_ALTERNATE_SETTING: u8 = 0; -pub struct UsbDevice<'d, D: Driver<'d>> { +pub struct UsbDevice<'d, D: Driver<'d>, C: UsbClass<'d, D>> { bus: D::Bus, - control_in: D::EndpointIn, - control_out: D::EndpointOut, + control: D::ControlPipe, config: Config<'d>, device_descriptor: &'d [u8], @@ -62,32 +64,21 @@ pub struct UsbDevice<'d, D: Driver<'d>> { remote_wakeup_enabled: bool, self_powered: bool, pending_address: u8, + + classes: C, } -impl<'d, D: Driver<'d>> UsbDevice<'d, D> { +impl<'d, D: Driver<'d>, C: UsbClass<'d, D>> UsbDevice<'d, D, C> { pub(crate) fn build( mut driver: D, config: Config<'d>, device_descriptor: &'d [u8], config_descriptor: &'d [u8], bos_descriptor: &'d [u8], + classes: C, ) -> Self { - let control_out = driver - .alloc_endpoint_out( - Some(0x00.into()), - EndpointType::Control, - config.max_packet_size_0 as u16, - 0, - ) - .expect("failed to alloc control endpoint"); - - let control_in = driver - .alloc_endpoint_in( - Some(0x80.into()), - EndpointType::Control, - config.max_packet_size_0 as u16, - 0, - ) + let control = driver + .alloc_control_pipe(config.max_packet_size_0 as u16) .expect("failed to alloc control endpoint"); // Enable the USB bus. @@ -97,8 +88,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Self { bus: driver, config, - control_in, - control_out, + control, device_descriptor, config_descriptor, bos_descriptor, @@ -106,14 +96,13 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { remote_wakeup_enabled: false, self_powered: false, pending_address: 0, + classes, } } pub async fn run(&mut self) { - let mut buf = [0; 8]; - loop { - let control_fut = self.control_out.read(&mut buf); + let control_fut = self.control.setup(); let bus_fut = self.bus.poll(); match select(bus_fut, control_fut).await { Either::Left(evt) => match evt { @@ -124,11 +113,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { self.remote_wakeup_enabled = false; self.pending_address = 0; - // TODO - //self.control.reset(); - //for cls in classes { - // cls.reset(); - //} + self.classes.reset(); } Event::Resume => {} Event::Suspend => { @@ -136,16 +121,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { self.device_state = UsbDeviceState::Suspend; } }, - Either::Right(n) => { - let n = n.unwrap(); - assert_eq!(n, 8); - let req = Request::parse(&buf).unwrap(); + Either::Right(req) => { info!("control request: {:x}", req); - // Now that we have properly parsed the setup packet, ensure the end-point is no longer in - // a stalled state. - self.control_out.set_stalled(false); - match req.direction { UsbDirection::In => self.handle_control_in(req).await, UsbDirection::Out => self.handle_control_out(req).await, @@ -155,36 +133,6 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } - async fn write_chunked(&mut self, data: &[u8]) -> Result<(), driver::WriteError> { - for c in data.chunks(8) { - self.control_in.write(c).await?; - } - if data.len() % 8 == 0 { - self.control_in.write(&[]).await?; - } - Ok(()) - } - - async fn control_out_accept(&mut self, req: Request) { - info!("control out accept"); - // status phase - // todo: cleanup - self.control_out.read(&mut []).await.unwrap(); - } - - async fn control_in_accept(&mut self, req: Request, data: &[u8]) { - info!("control accept {:x}", data); - - let len = data.len().min(req.length as _); - if let Err(e) = self.write_chunked(&data[..len]).await { - info!("write_chunked failed: {:?}", e); - } - - // status phase - // todo: cleanup - self.control_out.read(&mut []).await.unwrap(); - } - async fn control_in_accept_writer( &mut self, req: Request, @@ -193,17 +141,26 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let mut buf = [0; 256]; let mut w = DescriptorWriter::new(&mut buf); f(&mut w); - let pos = w.position(); - self.control_in_accept(req, &buf[..pos]).await; - } - - fn control_reject(&mut self, req: Request) { - info!("control reject"); - self.control_out.set_stalled(true); + let pos = w.position().min(usize::from(req.length)); + self.control.accept_in(&buf[..pos]).await; } async fn handle_control_out(&mut self, req: Request) { - // TODO actually read the data if there's an OUT data phase. + { + let mut buf = [0; 128]; + let data = if req.length > 0 { + let size = self.control.data_out(&mut buf).await.unwrap(); + &buf[0..size] + } else { + &[] + }; + + match self.classes.control_out(req, data).await { + RequestStatus::Accepted => return self.control.accept(), + RequestStatus::Rejected => return self.control.reject(), + RequestStatus::Unhandled => (), + } + } const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16; const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16; @@ -217,12 +174,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Request::FEATURE_DEVICE_REMOTE_WAKEUP, ) => { self.remote_wakeup_enabled = false; - self.control_out_accept(req).await; + self.control.accept(); } (Recipient::Endpoint, Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { //self.bus.set_stalled(((req.index as u8) & 0x8f).into(), false); - self.control_out_accept(req).await; + self.control.accept(); } ( @@ -231,51 +188,61 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Request::FEATURE_DEVICE_REMOTE_WAKEUP, ) => { self.remote_wakeup_enabled = true; - self.control_out_accept(req).await; + self.control.accept(); } (Recipient::Endpoint, Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { self.bus .set_stalled(((req.index as u8) & 0x8f).into(), true); - self.control_out_accept(req).await; + self.control.accept(); } (Recipient::Device, Request::SET_ADDRESS, 1..=127) => { self.pending_address = req.value as u8; // on NRF the hardware auto-handles SET_ADDRESS. - self.control_out_accept(req).await; + self.control.accept(); } (Recipient::Device, Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => { self.device_state = UsbDeviceState::Configured; - self.control_out_accept(req).await; + self.control.accept(); } (Recipient::Device, Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => { match self.device_state { UsbDeviceState::Default => { - self.control_out_accept(req).await; + self.control.accept(); } _ => { self.device_state = UsbDeviceState::Addressed; - self.control_out_accept(req).await; + self.control.accept(); } } } (Recipient::Interface, Request::SET_INTERFACE, DEFAULT_ALTERNATE_SETTING_U16) => { // TODO: do something when alternate settings are implemented - self.control_out_accept(req).await; + self.control.accept(); } - _ => self.control_reject(req), + _ => self.control.reject(), }, - _ => self.control_reject(req), + _ => self.control.reject(), } } async fn handle_control_in(&mut self, req: Request) { + match self + .classes + .control_in(req, class::ControlIn::new(&mut self.control)) + .await + .status() + { + RequestStatus::Accepted | RequestStatus::Rejected => return, + RequestStatus::Unhandled => (), + } + match req.request_type { RequestType::Standard => match (req.recipient, req.request) { (Recipient::Device, Request::GET_STATUS) => { @@ -286,12 +253,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.remote_wakeup_enabled { status |= 0x0002; } - self.control_in_accept(req, &status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes()).await; } (Recipient::Interface, Request::GET_STATUS) => { let status: u16 = 0x0000; - self.control_in_accept(req, &status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes()).await; } (Recipient::Endpoint, Request::GET_STATUS) => { @@ -300,7 +267,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.bus.is_stalled(ep_addr) { status |= 0x0001; } - self.control_in_accept(req, &status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes()).await; } (Recipient::Device, Request::GET_DESCRIPTOR) => { @@ -312,17 +279,17 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { UsbDeviceState::Configured => CONFIGURATION_VALUE, _ => CONFIGURATION_NONE, }; - self.control_in_accept(req, &status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes()).await; } (Recipient::Interface, Request::GET_INTERFACE) => { // TODO: change when alternate settings are implemented let status = DEFAULT_ALTERNATE_SETTING; - self.control_in_accept(req, &status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes()).await; } - _ => self.control_reject(req), + _ => self.control.reject(), }, - _ => self.control_reject(req), + _ => self.control.reject(), } } @@ -331,11 +298,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let config = self.config.clone(); match dtype { - descriptor_type::BOS => self.control_in_accept(req, self.bos_descriptor).await, - descriptor_type::DEVICE => self.control_in_accept(req, self.device_descriptor).await, - descriptor_type::CONFIGURATION => { - self.control_in_accept(req, self.config_descriptor).await - } + descriptor_type::BOS => self.control.accept_in(self.bos_descriptor).await, + descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor).await, + descriptor_type::CONFIGURATION => self.control.accept_in(self.config_descriptor).await, descriptor_type::STRING => { if index == 0 { self.control_in_accept_writer(req, |w| { @@ -363,11 +328,11 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { self.control_in_accept_writer(req, |w| w.string(s).unwrap()) .await; } else { - self.control_reject(req) + self.control.reject() } } } - _ => self.control_reject(req), + _ => self.control.reject(), } } } diff --git a/examples/nrf/src/bin/usb/cdc_acm.rs b/examples/nrf/src/bin/usb/cdc_acm.rs index b7c112ae6..eebf89221 100644 --- a/examples/nrf/src/bin/usb/cdc_acm.rs +++ b/examples/nrf/src/bin/usb/cdc_acm.rs @@ -1,5 +1,8 @@ -use core::convert::TryInto; +use core::future::Future; use core::mem; +use defmt::info; +use embassy_usb::class::{ControlInRequestStatus, RequestStatus, UsbClass}; +use embassy_usb::control::{self, Request}; use embassy_usb::driver::{Endpoint, EndpointIn, EndpointOut, ReadError, WriteError}; use embassy_usb::{driver::Driver, types::*, UsbDeviceBuilder}; @@ -39,16 +42,107 @@ const REQ_SET_CONTROL_LINE_STATE: u8 = 0x22; /// terminated with a short packet, even if the bulk endpoint is used for stream-like data. pub struct CdcAcmClass<'d, D: Driver<'d>> { // TODO not pub - pub comm_if: InterfaceNumber, pub comm_ep: D::EndpointIn, pub data_if: InterfaceNumber, pub read_ep: D::EndpointOut, pub write_ep: D::EndpointIn, + pub control: CdcAcmControl, +} + +pub struct CdcAcmControl { + pub comm_if: InterfaceNumber, pub line_coding: LineCoding, pub dtr: bool, pub rts: bool, } +impl<'d, D: Driver<'d>> UsbClass<'d, D> for CdcAcmControl { + type ControlOutFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + type ControlInFuture<'a> = impl Future + 'a where Self: 'a, 'd: 'a, D: 'a; + + fn reset(&mut self) { + self.line_coding = LineCoding::default(); + self.dtr = false; + self.rts = false; + } + + fn control_out<'a>( + &'a mut self, + req: control::Request, + data: &'a [u8], + ) -> Self::ControlOutFuture<'a> + where + 'd: 'a, + D: 'a, + { + async move { + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return RequestStatus::Unhandled; + } + + match req.request { + REQ_SEND_ENCAPSULATED_COMMAND => { + // We don't actually support encapsulated commands but pretend we do for standards + // compatibility. + RequestStatus::Accepted + } + REQ_SET_LINE_CODING if data.len() >= 7 => { + self.line_coding.data_rate = u32::from_le_bytes(data[0..4].try_into().unwrap()); + self.line_coding.stop_bits = data[4].into(); + self.line_coding.parity_type = data[5].into(); + self.line_coding.data_bits = data[6]; + info!("Set line coding to: {:?}", self.line_coding); + + RequestStatus::Accepted + } + REQ_SET_CONTROL_LINE_STATE => { + self.dtr = (req.value & 0x0001) != 0; + self.rts = (req.value & 0x0002) != 0; + info!("Set dtr {}, rts {}", self.dtr, self.rts); + + RequestStatus::Accepted + } + _ => RequestStatus::Rejected, + } + } + } + + fn control_in<'a>( + &'a mut self, + req: Request, + control: embassy_usb::class::ControlIn<'a, 'd, D>, + ) -> Self::ControlInFuture<'a> + where + 'd: 'a, + { + async move { + if !(req.request_type == control::RequestType::Class + && req.recipient == control::Recipient::Interface + && req.index == u8::from(self.comm_if) as u16) + { + return control.ignore(); + } + + match req.request { + // REQ_GET_ENCAPSULATED_COMMAND is not really supported - it will be rejected below. + REQ_GET_LINE_CODING if req.length == 7 => { + info!("Sending line coding"); + let mut data = [0; 7]; + data[0..4].copy_from_slice(&self.line_coding.data_rate.to_le_bytes()); + data[4] = self.line_coding.stop_bits as u8; + data[5] = self.line_coding.parity_type as u8; + data[6] = self.line_coding.data_bits; + control.accept(&data).await + } + _ => control.reject(), + } + } + } +} + impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { /// Creates a new CdcAcmClass with the provided UsbBus and max_packet_size in bytes. For /// full-speed devices, max_packet_size has to be one of 8, 16, 32 or 64. @@ -133,19 +227,21 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { builder.config_descriptor.endpoint(read_ep.info()).unwrap(); CdcAcmClass { - comm_if, comm_ep, data_if, read_ep, write_ep, - line_coding: LineCoding { - stop_bits: StopBits::One, - data_bits: 8, - parity_type: ParityType::None, - data_rate: 8_000, + control: CdcAcmControl { + comm_if, + dtr: false, + rts: false, + line_coding: LineCoding { + stop_bits: StopBits::One, + data_bits: 8, + parity_type: ParityType::None, + data_rate: 8_000, + }, }, - dtr: false, - rts: false, } } @@ -158,17 +254,17 @@ impl<'d, D: Driver<'d>> CdcAcmClass<'d, D> { /// Gets the current line coding. The line coding contains information that's mainly relevant /// for USB to UART serial port emulators, and can be ignored if not relevant. pub fn line_coding(&self) -> &LineCoding { - &self.line_coding + &self.control.line_coding } /// Gets the DTR (data terminal ready) state pub fn dtr(&self) -> bool { - self.dtr + self.control.dtr } /// Gets the RTS (request to send) state pub fn rts(&self) -> bool { - self.rts + self.control.rts } /// Writes a single packet into the IN endpoint. @@ -270,7 +366,7 @@ impl UsbClass for CdcAcmClass<'_, B> { */ /// Number of stop bits for LineCoding -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, defmt::Format)] pub enum StopBits { /// 1 stop bit One = 0, @@ -293,7 +389,7 @@ impl From for StopBits { } /// Parity for LineCoding -#[derive(Copy, Clone, PartialEq, Eq)] +#[derive(Copy, Clone, PartialEq, Eq, defmt::Format)] pub enum ParityType { None = 0, Odd = 1, @@ -316,6 +412,7 @@ impl From for ParityType { /// /// This is provided by the host for specifying the standard UART parameters such as baud rate. Can /// be ignored if you don't plan to interface with a physical UART. +#[derive(defmt::Format)] pub struct LineCoding { stop_bits: StopBits, data_bits: u8, diff --git a/examples/nrf/src/bin/usb/main.rs b/examples/nrf/src/bin/usb/main.rs index ecbdc3461..71285579c 100644 --- a/examples/nrf/src/bin/usb/main.rs +++ b/examples/nrf/src/bin/usb/main.rs @@ -1,5 +1,6 @@ #![no_std] #![no_main] +#![feature(generic_associated_types)] #![feature(type_alias_impl_trait)] #[path = "../../example_common.rs"] @@ -58,7 +59,7 @@ async fn main(_spawner: Spawner, p: Peripherals) { let mut class = CdcAcmClass::new(&mut builder, 64); // Build the builder. - let mut usb = builder.build(); + let mut usb = builder.build(class.control); // Run the USB device. let fut1 = usb.run();