From 7ed462a6575cba95e8f07d2d9516d5e7b33d7196 Mon Sep 17 00:00:00 2001 From: Dario Nieuwenhuis Date: Mon, 9 May 2022 02:11:02 +0200 Subject: [PATCH] usb: simplify control in/out handlng, calling response from a single place. --- embassy-usb-ncm/src/lib.rs | 14 +- embassy-usb/src/control.rs | 29 +--- embassy-usb/src/descriptor.rs | 1 + embassy-usb/src/lib.rs | 270 +++++++++++++++++++--------------- 4 files changed, 158 insertions(+), 156 deletions(-) diff --git a/embassy-usb-ncm/src/lib.rs b/embassy-usb-ncm/src/lib.rs index 3a5abee8d..71d0691d4 100644 --- a/embassy-usb-ncm/src/lib.rs +++ b/embassy-usb-ncm/src/lib.rs @@ -133,6 +133,7 @@ impl Default for ControlShared { struct CommControl<'a> { mac_addr_string: StringIndex, shared: &'a ControlShared, + mac_addr_str: [u8; 12], } impl<'d> ControlHandler for CommControl<'d> { @@ -178,24 +179,20 @@ impl<'d> ControlHandler for CommControl<'d> { } } - fn get_string<'a>( - &'a mut self, - index: StringIndex, - _lang_id: u16, - buf: &'a mut [u8], - ) -> Option<&'a str> { + fn get_string(&mut self, index: StringIndex, _lang_id: u16) -> Option<&str> { if index == self.mac_addr_string { let mac_addr = self.shared.mac_addr.get(); + let s = &mut self.mac_addr_str; for i in 0..12 { let n = (mac_addr[i / 2] >> ((1 - i % 2) * 4)) & 0xF; - buf[i] = match n { + s[i] = match n { 0x0..=0x9 => b'0' + n, 0xA..=0xF => b'A' + n - 0xA, _ => unreachable!(), } } - Some(unsafe { core::str::from_utf8_unchecked(&buf[..12]) }) + Some(unsafe { core::str::from_utf8_unchecked(s) }) } else { warn!("unknown string index requested"); None @@ -244,6 +241,7 @@ impl<'d, D: Driver<'d>> CdcNcmClass<'d, D> { iface.handler(state.comm_control.write(CommControl { mac_addr_string, shared: &control_shared, + mac_addr_str: [0; 12], })); let comm_if = iface.interface_number(); let mut alt = iface.alt_setting(USB_CLASS_CDC, CDC_SUBCLASS_NCM, CDC_PROTOCOL_NONE); diff --git a/embassy-usb/src/control.rs b/embassy-usb/src/control.rs index ff42f9d78..4fc65b6a5 100644 --- a/embassy-usb/src/control.rs +++ b/embassy-usb/src/control.rs @@ -1,9 +1,7 @@ use core::mem; -use crate::descriptor::DescriptorWriter; -use crate::driver::{self, EndpointError}; - use super::types::*; +use crate::driver::{self, EndpointError}; /// Control request type. #[repr(u8)] @@ -191,16 +189,8 @@ pub trait ControlHandler { } /// Called when a GET_DESCRIPTOR STRING control request is received. - /// - /// Write the response string somewhere (usually to `buf`, but you may use another buffer - /// owned by yourself, or a static buffer), then return it. - fn get_string<'a>( - &'a mut self, - index: StringIndex, - lang_id: u16, - buf: &'a mut [u8], - ) -> Option<&'a str> { - let _ = (index, lang_id, buf); + fn get_string(&mut self, index: StringIndex, lang_id: u16) -> Option<&str> { + let _ = (index, lang_id); None } } @@ -316,19 +306,6 @@ impl ControlPipe { } } - pub(crate) async fn accept_in_writer( - &mut self, - req: Request, - stage: DataInStage, - f: impl FnOnce(&mut DescriptorWriter), - ) { - let mut buf = [0; 256]; - let mut w = DescriptorWriter::new(&mut buf); - f(&mut w); - let pos = w.position().min(usize::from(req.length)); - self.accept_in(&buf[..pos], stage).await - } - pub(crate) fn accept(&mut self, _: StatusStage) { trace!(" control accept"); self.control.accept(); diff --git a/embassy-usb/src/descriptor.rs b/embassy-usb/src/descriptor.rs index dce326780..7f23fd921 100644 --- a/embassy-usb/src/descriptor.rs +++ b/embassy-usb/src/descriptor.rs @@ -244,6 +244,7 @@ impl<'a> DescriptorWriter<'a> { } /// Writes a string descriptor. + #[allow(unused)] pub(crate) fn string(&mut self, string: &str) { let mut pos = self.position; diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 3bfedc048..690305885 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -100,15 +100,19 @@ struct Interface<'d> { } pub struct UsbDevice<'d, D: Driver<'d>> { + control_buf: &'d mut [u8], + control: ControlPipe, + inner: Inner<'d, D>, +} + +struct Inner<'d, D: Driver<'d>> { bus: D::Bus, handler: Option<&'d dyn DeviceStateHandler>, - control: ControlPipe, config: Config<'d>, device_descriptor: &'d [u8], config_descriptor: &'d [u8], bos_descriptor: &'d [u8], - control_buf: &'d mut [u8], device_state: UsbDeviceState, suspended: bool, @@ -139,20 +143,23 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { let bus = driver.into_bus(); Self { - bus, - config, - handler, - control: ControlPipe::new(control), - device_descriptor, - config_descriptor, - bos_descriptor, control_buf, - device_state: UsbDeviceState::Disabled, - suspended: false, - remote_wakeup_enabled: false, - self_powered: false, - pending_address: 0, - interfaces, + control: ControlPipe::new(control), + inner: Inner { + bus, + config, + handler, + device_descriptor, + config_descriptor, + bos_descriptor, + + device_state: UsbDeviceState::Disabled, + suspended: false, + remote_wakeup_enabled: false, + self_powered: false, + pending_address: 0, + interfaces, + }, } } @@ -176,28 +183,60 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { /// before calling any other `UsbDevice` methods to fully reset the /// peripheral. pub async fn run_until_suspend(&mut self) -> () { - if self.device_state == UsbDeviceState::Disabled { - self.bus.enable().await; - self.device_state = UsbDeviceState::Default; + if self.inner.device_state == UsbDeviceState::Disabled { + self.inner.bus.enable().await; + self.inner.device_state = UsbDeviceState::Default; - if let Some(h) = &self.handler { + if let Some(h) = &self.inner.handler { h.enabled(true); } } loop { let control_fut = self.control.setup(); - let bus_fut = self.bus.poll(); + let bus_fut = self.inner.bus.poll(); match select(bus_fut, control_fut).await { Either::First(evt) => { - self.handle_bus_event(evt); - if self.suspended { + self.inner.handle_bus_event(evt); + if self.inner.suspended { return; } } Either::Second(req) => match req { - Setup::DataIn(req, stage) => self.handle_control_in(req, stage).await, - Setup::DataOut(req, stage) => self.handle_control_out(req, stage).await, + Setup::DataIn(req, mut stage) => { + // If we don't have an address yet, respond with max 1 packet. + // The host doesn't know our EP0 max packet size yet, and might assume + // a full-length packet is a short packet, thinking we're done sending data. + // See https://github.com/hathach/tinyusb/issues/184 + const DEVICE_DESCRIPTOR_LEN: u8 = 18; + if self.inner.pending_address == 0 + && self.inner.config.max_packet_size_0 < DEVICE_DESCRIPTOR_LEN + && (self.inner.config.max_packet_size_0 as usize) < stage.length + { + trace!("received control req while not addressed: capping response to 1 packet."); + stage.length = self.inner.config.max_packet_size_0 as _; + } + + match self.inner.handle_control_in(req, &mut self.control_buf) { + InResponse::Accepted(data) => self.control.accept_in(data, stage).await, + InResponse::Rejected => self.control.reject(), + } + } + Setup::DataOut(req, stage) => { + let (data, stage) = + match self.control.data_out(self.control_buf, stage).await { + Ok(data) => data, + Err(_) => { + warn!("usb: failed to read CONTROL OUT data stage."); + return; + } + }; + + match self.inner.handle_control_out(req, data) { + OutResponse::Accepted => self.control.accept(stage), + OutResponse::Rejected => self.control.reject(), + } + } }, } } @@ -205,13 +244,13 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { /// Disables the USB peripheral. pub async fn disable(&mut self) { - if self.device_state != UsbDeviceState::Disabled { - self.bus.disable().await; - self.device_state = UsbDeviceState::Disabled; - self.suspended = false; - self.remote_wakeup_enabled = false; + if self.inner.device_state != UsbDeviceState::Disabled { + self.inner.bus.disable().await; + self.inner.device_state = UsbDeviceState::Disabled; + self.inner.suspended = false; + self.inner.remote_wakeup_enabled = false; - if let Some(h) = &self.handler { + if let Some(h) = &self.inner.handler { h.enabled(false); } } @@ -221,9 +260,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { /// /// This future is cancel-safe. pub async fn wait_resume(&mut self) { - while self.suspended { - let evt = self.bus.poll().await; - self.handle_bus_event(evt); + while self.inner.suspended { + let evt = self.inner.bus.poll().await; + self.inner.handle_bus_event(evt); } } @@ -236,11 +275,11 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { /// After dropping the future, [`UsbDevice::disable()`] should be called /// before calling any other `UsbDevice` methods to fully reset the peripheral. pub async fn remote_wakeup(&mut self) -> Result<(), RemoteWakeupError> { - if self.suspended && self.remote_wakeup_enabled { - self.bus.remote_wakeup().await?; - self.suspended = false; + if self.inner.suspended && self.inner.remote_wakeup_enabled { + self.inner.bus.remote_wakeup().await?; + self.inner.suspended = false; - if let Some(h) = &self.handler { + if let Some(h) = &self.inner.handler { h.suspended(false); } @@ -249,7 +288,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Err(RemoteWakeupError::InvalidState) } } +} +impl<'d, D: Driver<'d>> Inner<'d, D> { fn handle_bus_event(&mut self, evt: Event) { match evt { Event::Reset => { @@ -288,18 +329,10 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } - async fn handle_control_out(&mut self, req: Request, stage: DataOutStage) { + fn handle_control_out(&mut self, req: Request, data: &[u8]) -> OutResponse { const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16; const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16; - let (data, stage) = match self.control.data_out(self.control_buf, stage).await { - Ok(data) => data, - Err(_) => { - warn!("usb: failed to read CONTROL OUT data stage."); - return; - } - }; - match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Device) => match (req.request, req.value) { (Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { @@ -307,14 +340,14 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if let Some(h) = &self.handler { h.remote_wakeup_enabled(false); } - self.control.accept(stage) + OutResponse::Accepted } (Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { self.remote_wakeup_enabled = true; if let Some(h) = &self.handler { h.remote_wakeup_enabled(true); } - self.control.accept(stage) + OutResponse::Accepted } (Request::SET_ADDRESS, addr @ 1..=127) => { self.pending_address = addr as u8; @@ -323,7 +356,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if let Some(h) = &self.handler { h.addressed(self.pending_address); } - self.control.accept(stage) + OutResponse::Accepted } (Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => { debug!("SET_CONFIGURATION: configured"); @@ -344,10 +377,10 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { h.configured(true); } - self.control.accept(stage) + OutResponse::Accepted } (Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state { - UsbDeviceState::Default => self.control.accept(stage), + UsbDeviceState::Default => OutResponse::Accepted, _ => { debug!("SET_CONFIGURATION: unconfigured"); self.device_state = UsbDeviceState::Addressed; @@ -363,15 +396,15 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { h.configured(false); } - self.control.accept(stage) + OutResponse::Accepted } }, - _ => self.control.reject(), + _ => OutResponse::Rejected, }, (RequestType::Standard, Recipient::Interface) => { let iface = match self.interfaces.get_mut(req.index as usize) { Some(iface) => iface, - None => return self.control.reject(), + None => return OutResponse::Rejected, }; match req.request { @@ -380,7 +413,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if new_altsetting >= iface.num_alt_settings { warn!("SET_INTERFACE: trying to select alt setting out of range."); - return self.control.reject(); + return OutResponse::Rejected; } iface.current_alt_setting = new_altsetting; @@ -402,55 +435,39 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if let Some(handler) = &mut iface.handler { handler.set_alternate_setting(new_altsetting); } - self.control.accept(stage) + OutResponse::Accepted } - _ => self.control.reject(), + _ => OutResponse::Rejected, } } (RequestType::Standard, Recipient::Endpoint) => match (req.request, req.value) { (Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { let ep_addr = ((req.index as u8) & 0x8f).into(); self.bus.endpoint_set_stalled(ep_addr, true); - self.control.accept(stage) + OutResponse::Accepted } (Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { let ep_addr = ((req.index as u8) & 0x8f).into(); self.bus.endpoint_set_stalled(ep_addr, false); - self.control.accept(stage) + OutResponse::Accepted } - _ => self.control.reject(), + _ => OutResponse::Rejected, }, (RequestType::Class, Recipient::Interface) => { let iface = match self.interfaces.get_mut(req.index as usize) { Some(iface) => iface, - None => return self.control.reject(), + None => return OutResponse::Rejected, }; match &mut iface.handler { - Some(handler) => match handler.control_out(req, data) { - OutResponse::Accepted => self.control.accept(stage), - OutResponse::Rejected => self.control.reject(), - }, - None => self.control.reject(), + Some(handler) => handler.control_out(req, data), + None => OutResponse::Rejected, } } - _ => self.control.reject(), + _ => OutResponse::Rejected, } } - async fn handle_control_in(&mut self, req: Request, mut stage: DataInStage) { - // If we don't have an address yet, respond with max 1 packet. - // The host doesn't know our EP0 max packet size yet, and might assume - // a full-length packet is a short packet, thinking we're done sending data. - // See https://github.com/hathach/tinyusb/issues/184 - const DEVICE_DESCRIPTOR_LEN: u8 = 18; - if self.pending_address == 0 - && self.config.max_packet_size_0 < DEVICE_DESCRIPTOR_LEN - && (self.config.max_packet_size_0 as usize) < stage.length - { - trace!("received control req while not addressed: capping response to 1 packet."); - stage.length = self.config.max_packet_size_0 as _; - } - + fn handle_control_in<'a>(&'a mut self, req: Request, buf: &'a mut [u8]) -> InResponse<'a> { match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Device) => match req.request { Request::GET_STATUS => { @@ -461,42 +478,41 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.remote_wakeup_enabled { status |= 0x0002; } - self.control.accept_in(&status.to_le_bytes(), stage).await + buf[..2].copy_from_slice(&status.to_le_bytes()); + InResponse::Accepted(&buf[..2]) } - Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, stage).await, + Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, buf), Request::GET_CONFIGURATION => { let status = match self.device_state { UsbDeviceState::Configured => CONFIGURATION_VALUE, _ => CONFIGURATION_NONE, }; - self.control.accept_in(&status.to_le_bytes(), stage).await + buf[0] = status; + InResponse::Accepted(&buf[..1]) } - _ => self.control.reject(), + _ => InResponse::Rejected, }, (RequestType::Standard, Recipient::Interface) => { let iface = match self.interfaces.get_mut(req.index as usize) { Some(iface) => iface, - None => return self.control.reject(), + None => return InResponse::Rejected, }; match req.request { Request::GET_STATUS => { let status: u16 = 0; - self.control.accept_in(&status.to_le_bytes(), stage).await + buf[..2].copy_from_slice(&status.to_le_bytes()); + InResponse::Accepted(&buf[..2]) } Request::GET_INTERFACE => { - self.control - .accept_in(&[iface.current_alt_setting], stage) - .await; + buf[0] = iface.current_alt_setting; + InResponse::Accepted(&buf[..1]) } Request::GET_DESCRIPTOR => match &mut iface.handler { - Some(handler) => match handler.get_descriptor(req, self.control_buf) { - InResponse::Accepted(data) => self.control.accept_in(data, stage).await, - InResponse::Rejected => self.control.reject(), - }, - None => self.control.reject(), + Some(handler) => handler.get_descriptor(req, buf), + None => InResponse::Rejected, }, - _ => self.control.reject(), + _ => InResponse::Rejected, } } (RequestType::Standard, Recipient::Endpoint) => match req.request { @@ -506,44 +522,40 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.bus.endpoint_is_stalled(ep_addr) { status |= 0x0001; } - self.control.accept_in(&status.to_le_bytes(), stage).await + buf[..2].copy_from_slice(&status.to_le_bytes()); + InResponse::Accepted(&buf[..2]) } - _ => self.control.reject(), + _ => InResponse::Rejected, }, (RequestType::Class, Recipient::Interface) => { let iface = match self.interfaces.get_mut(req.index as usize) { Some(iface) => iface, - None => return self.control.reject(), + None => return InResponse::Rejected, }; match &mut iface.handler { - Some(handler) => match handler.control_in(req, self.control_buf) { - InResponse::Accepted(data) => self.control.accept_in(data, stage).await, - InResponse::Rejected => self.control.reject(), - }, - None => self.control.reject(), + Some(handler) => handler.control_in(req, buf), + None => InResponse::Rejected, } } - _ => self.control.reject(), + _ => InResponse::Rejected, } } - async fn handle_get_descriptor(&mut self, req: Request, stage: DataInStage) { + fn handle_get_descriptor<'a>(&'a mut self, req: Request, buf: &'a mut [u8]) -> InResponse<'a> { let (dtype, index) = req.descriptor_type_index(); match dtype { - descriptor_type::BOS => self.control.accept_in(self.bos_descriptor, stage).await, - descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor, stage).await, - descriptor_type::CONFIGURATION => { - self.control.accept_in(self.config_descriptor, stage).await - } + descriptor_type::BOS => InResponse::Accepted(self.bos_descriptor), + descriptor_type::DEVICE => InResponse::Accepted(self.device_descriptor), + descriptor_type::CONFIGURATION => InResponse::Accepted(self.config_descriptor), descriptor_type::STRING => { if index == 0 { - self.control - .accept_in_writer(req, stage, |w| { - w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes()); - }) - .await + buf[0] = 4; // len + buf[1] = descriptor_type::STRING; + buf[2] = lang_id::ENGLISH_US as u8; + buf[3] = (lang_id::ENGLISH_US >> 8) as u8; + InResponse::Accepted(&buf[..4]) } else { let s = match index { STRING_INDEX_MANUFACTURER => self.config.manufacturer, @@ -565,7 +577,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if let Some(handler) = &mut iface.handler { let index = StringIndex::new(index); let lang_id = req.index; - handler.get_string(index, lang_id, self.control_buf) + handler.get_string(index, lang_id) } else { warn!("String requested to an interface with no handler."); None @@ -578,15 +590,29 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { }; if let Some(s) = s { - self.control - .accept_in_writer(req, stage, |w| w.string(s)) - .await + if buf.len() < 2 { + panic!("control buffer too small"); + } + + buf[1] = descriptor_type::STRING; + let mut pos = 2; + for c in s.encode_utf16() { + if pos >= buf.len() { + panic!("control buffer too small"); + } + + buf[pos..pos + 2].copy_from_slice(&c.to_le_bytes()); + pos += 2; + } + + buf[0] = pos as u8; + InResponse::Accepted(&buf[..pos]) } else { - self.control.reject() + InResponse::Rejected } } } - _ => self.control.reject(), + _ => InResponse::Rejected, } } }