From f5ba022257ccd9ddd371f1dcd10c0775cc5a3110 Mon Sep 17 00:00:00 2001 From: alexmoon Date: Wed, 30 Mar 2022 14:17:15 -0400 Subject: [PATCH] Refactor ControlPipe to use the typestate pattern for safety --- embassy-usb/src/control.rs | 121 ++++++++++++++++++++++++++++ embassy-usb/src/lib.rs | 158 +++++++++---------------------------- 2 files changed, 157 insertions(+), 122 deletions(-) diff --git a/embassy-usb/src/control.rs b/embassy-usb/src/control.rs index b5077c732..9f1115ff2 100644 --- a/embassy-usb/src/control.rs +++ b/embassy-usb/src/control.rs @@ -1,5 +1,7 @@ use core::mem; +use crate::descriptor::DescriptorWriter; +use crate::driver::{self, ReadError}; use crate::DEFAULT_ALTERNATE_SETTING; use super::types::*; @@ -191,3 +193,122 @@ pub trait ControlHandler { InResponse::Accepted(&buf[0..2]) } } + +/// Typestate representing a ControlPipe in the DATA IN stage +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct DataInStage { + length: usize, +} + +/// Typestate representing a ControlPipe in the DATA OUT stage +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct DataOutStage { + length: usize, +} + +/// Typestate representing a ControlPipe in the STATUS stage +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct StatusStage {} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum Setup { + DataIn(Request, DataInStage), + DataOut(Request, DataOutStage), +} + +pub(crate) struct ControlPipe { + control: C, +} + +impl ControlPipe { + pub(crate) fn new(control: C) -> Self { + ControlPipe { control } + } + + pub(crate) async fn setup(&mut self) -> Setup { + let req = self.control.setup().await; + match (req.direction, req.length) { + (UsbDirection::Out, n) => Setup::DataOut( + req, + DataOutStage { + length: usize::from(n), + }, + ), + (UsbDirection::In, n) => Setup::DataIn( + req, + DataInStage { + length: usize::from(n), + }, + ), + } + } + + pub(crate) async fn data_out<'a>( + &mut self, + buf: &'a mut [u8], + stage: DataOutStage, + ) -> Result<(&'a [u8], StatusStage), ReadError> { + if stage.length == 0 { + Ok((&[], StatusStage {})) + } else { + let req_length = stage.length; + let max_packet_size = self.control.max_packet_size(); + let mut total = 0; + + for chunk in buf.chunks_mut(max_packet_size) { + let size = self.control.data_out(chunk).await?; + total += size; + if size < max_packet_size || total == req_length { + break; + } + } + + Ok((&buf[0..total], StatusStage {})) + } + } + + pub(crate) async fn accept_in(&mut self, buf: &[u8], stage: DataInStage) { + #[cfg(feature = "defmt")] + debug!("control in accept {:x}", buf); + #[cfg(not(feature = "defmt"))] + debug!("control in accept {:x?}", buf); + + let req_len = stage.length; + let len = buf.len().min(req_len); + let max_packet_size = self.control.max_packet_size(); + let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0; + + let mut chunks = buf[0..len] + .chunks(max_packet_size) + .chain(need_zlp.then(|| -> &[u8] { &[] })); + + while let Some(chunk) = chunks.next() { + self.control.data_in(chunk, chunks.size_hint().0 == 0).await; + } + } + + pub(crate) async fn accept_in_writer( + &mut self, + req: Request, + stage: DataInStage, + f: impl FnOnce(&mut DescriptorWriter), + ) { + let mut buf = [0; 256]; + let mut w = DescriptorWriter::new(&mut buf); + f(&mut w); + let pos = w.position().min(usize::from(req.length)); + self.accept_in(&buf[..pos], stage).await + } + + pub(crate) fn accept(&mut self, _: StatusStage) { + self.control.accept(); + } + + pub(crate) fn reject(&mut self) { + self.control.reject(); + } +} diff --git a/embassy-usb/src/lib.rs b/embassy-usb/src/lib.rs index 77a9c33be..067b5b07f 100644 --- a/embassy-usb/src/lib.rs +++ b/embassy-usb/src/lib.rs @@ -16,7 +16,7 @@ use heapless::Vec; use self::control::*; use self::descriptor::*; -use self::driver::*; +use self::driver::{Bus, Driver, Event}; use self::types::*; use self::util::*; @@ -92,10 +92,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Self { bus: driver, config, - control: ControlPipe { - control, - request: None, - }, + control: ControlPipe::new(control), device_descriptor, config_descriptor, bos_descriptor, @@ -134,57 +131,50 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { Either::Right(req) => { debug!("control request: {:x}", req); - match req.direction { - UsbDirection::In => self.handle_control_in(req).await, - UsbDirection::Out => self.handle_control_out(req).await, + match req { + Setup::DataIn(req, stage) => self.handle_control_in(req, stage).await, + Setup::DataOut(req, stage) => self.handle_control_out(req, stage).await, } } } } } - async fn handle_control_out(&mut self, req: Request) { + async fn handle_control_out(&mut self, req: Request, stage: DataOutStage) { const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16; const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16; - // If the request has a data state, we must read it. - let data = if req.length > 0 { - match self.control.data_out(self.control_buf).await { - Ok(data) => data, - Err(_) => { - warn!("usb: failed to read CONTROL OUT data stage."); - return; - } + let (data, stage) = match self.control.data_out(self.control_buf, stage).await { + Ok(data) => data, + Err(_) => { + warn!("usb: failed to read CONTROL OUT data stage."); + return; } - } else { - &[] }; match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Device) => match (req.request, req.value) { (Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { self.remote_wakeup_enabled = false; - self.control.accept(); + self.control.accept(stage) } (Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => { self.remote_wakeup_enabled = true; - self.control.accept(); + self.control.accept(stage) } (Request::SET_ADDRESS, 1..=127) => { self.pending_address = req.value as u8; - self.control.accept(); + self.control.accept(stage) } (Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => { self.device_state = UsbDeviceState::Configured; - self.control.accept(); + self.control.accept(stage) } (Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state { - UsbDeviceState::Default => { - self.control.accept(); - } + UsbDeviceState::Default => self.control.accept(stage), _ => { self.device_state = UsbDeviceState::Addressed; - self.control.accept(); + self.control.accept(stage) } }, _ => self.control.reject(), @@ -193,12 +183,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { (Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { let ep_addr = ((req.index as u8) & 0x8f).into(); self.bus.set_stalled(ep_addr, true); - self.control.accept(); + self.control.accept(stage) } (Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => { let ep_addr = ((req.index as u8) & 0x8f).into(); self.bus.set_stalled(ep_addr, false); - self.control.accept(); + self.control.accept(stage) } _ => self.control.reject(), }, @@ -218,7 +208,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { _ => handler.control_out(req, data), }; match response { - OutResponse::Accepted => self.control.accept(), + OutResponse::Accepted => self.control.accept(stage), OutResponse::Rejected => self.control.reject(), } } @@ -229,7 +219,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } - async fn handle_control_in(&mut self, req: Request) { + async fn handle_control_in(&mut self, req: Request, stage: DataInStage) { match (req.request_type, req.recipient) { (RequestType::Standard, Recipient::Device) => match req.request { Request::GET_STATUS => { @@ -240,17 +230,15 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.remote_wakeup_enabled { status |= 0x0002; } - self.control.accept_in(&status.to_le_bytes()).await; - } - Request::GET_DESCRIPTOR => { - self.handle_get_descriptor(req).await; + self.control.accept_in(&status.to_le_bytes(), stage).await } + Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, stage).await, Request::GET_CONFIGURATION => { let status = match self.device_state { UsbDeviceState::Configured => CONFIGURATION_VALUE, _ => CONFIGURATION_NONE, }; - self.control.accept_in(&status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes(), stage).await } _ => self.control.reject(), }, @@ -261,7 +249,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { if self.bus.is_stalled(ep_addr) { status |= 0x0001; } - self.control.accept_in(&status.to_le_bytes()).await; + self.control.accept_in(&status.to_le_bytes(), stage).await } _ => self.control.reject(), }, @@ -285,7 +273,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { }; match response { - InResponse::Accepted(data) => self.control.accept_in(data).await, + InResponse::Accepted(data) => self.control.accept_in(data, stage).await, InResponse::Rejected => self.control.reject(), } } @@ -296,17 +284,19 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } - async fn handle_get_descriptor(&mut self, req: Request) { + async fn handle_get_descriptor(&mut self, req: Request, stage: DataInStage) { let (dtype, index) = req.descriptor_type_index(); match dtype { - descriptor_type::BOS => self.control.accept_in(self.bos_descriptor).await, - descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor).await, - descriptor_type::CONFIGURATION => self.control.accept_in(self.config_descriptor).await, + descriptor_type::BOS => self.control.accept_in(self.bos_descriptor, stage).await, + descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor, stage).await, + descriptor_type::CONFIGURATION => { + self.control.accept_in(self.config_descriptor, stage).await + } descriptor_type::STRING => { if index == 0 { self.control - .accept_in_writer(req, |w| { + .accept_in_writer(req, stage, |w| { w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes()); }) .await @@ -324,7 +314,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { }; if let Some(s) = s { - self.control.accept_in_writer(req, |w| w.string(s)).await; + self.control + .accept_in_writer(req, stage, |w| w.string(s)) + .await } else { self.control.reject() } @@ -334,81 +326,3 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> { } } } - -struct ControlPipe { - control: C, - request: Option, -} - -impl ControlPipe { - async fn setup(&mut self) -> Request { - assert!(self.request.is_none()); - let req = self.control.setup().await; - self.request = Some(req); - req - } - - async fn data_out<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], ReadError> { - let req = self.request.unwrap(); - assert_eq!(req.direction, UsbDirection::Out); - assert!(req.length > 0); - let req_length = usize::from(req.length); - - let max_packet_size = self.control.max_packet_size(); - let mut total = 0; - - for chunk in buf.chunks_mut(max_packet_size) { - let size = self.control.data_out(chunk).await?; - total += size; - if size < max_packet_size || total == req_length { - break; - } - } - - Ok(&buf[0..total]) - } - - async fn accept_in(&mut self, buf: &[u8]) -> () { - #[cfg(feature = "defmt")] - debug!("control in accept {:x}", buf); - #[cfg(not(feature = "defmt"))] - debug!("control in accept {:x?}", buf); - let req = unwrap!(self.request); - assert!(req.direction == UsbDirection::In); - - let req_len = usize::from(req.length); - let len = buf.len().min(req_len); - let max_packet_size = self.control.max_packet_size(); - let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0; - - let mut chunks = buf[0..len] - .chunks(max_packet_size) - .chain(need_zlp.then(|| -> &[u8] { &[] })); - - while let Some(chunk) = chunks.next() { - self.control.data_in(chunk, chunks.size_hint().0 == 0).await; - } - - self.request = None; - } - - async fn accept_in_writer(&mut self, req: Request, f: impl FnOnce(&mut DescriptorWriter)) { - let mut buf = [0; 256]; - let mut w = DescriptorWriter::new(&mut buf); - f(&mut w); - let pos = w.position().min(usize::from(req.length)); - self.accept_in(&buf[..pos]).await; - } - - fn accept(&mut self) { - assert!(self.request.is_some()); - self.control.accept(); - self.request = None; - } - - fn reject(&mut self) { - assert!(self.request.is_some()); - self.control.reject(); - self.request = None; - } -}