mirror of
https://github.com/embassy-rs/embassy.git
synced 2024-11-21 22:32:29 +00:00
Refactor ControlPipe to use the typestate pattern for safety
This commit is contained in:
parent
77e0aca03b
commit
f5ba022257
@ -1,5 +1,7 @@
|
||||
use core::mem;
|
||||
|
||||
use crate::descriptor::DescriptorWriter;
|
||||
use crate::driver::{self, ReadError};
|
||||
use crate::DEFAULT_ALTERNATE_SETTING;
|
||||
|
||||
use super::types::*;
|
||||
@ -191,3 +193,122 @@ pub trait ControlHandler {
|
||||
InResponse::Accepted(&buf[0..2])
|
||||
}
|
||||
}
|
||||
|
||||
/// Typestate representing a ControlPipe in the DATA IN stage
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
||||
pub(crate) struct DataInStage {
|
||||
length: usize,
|
||||
}
|
||||
|
||||
/// Typestate representing a ControlPipe in the DATA OUT stage
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
||||
pub(crate) struct DataOutStage {
|
||||
length: usize,
|
||||
}
|
||||
|
||||
/// Typestate representing a ControlPipe in the STATUS stage
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
||||
pub(crate) struct StatusStage {}
|
||||
|
||||
#[derive(Debug)]
|
||||
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
|
||||
pub(crate) enum Setup {
|
||||
DataIn(Request, DataInStage),
|
||||
DataOut(Request, DataOutStage),
|
||||
}
|
||||
|
||||
pub(crate) struct ControlPipe<C: driver::ControlPipe> {
|
||||
control: C,
|
||||
}
|
||||
|
||||
impl<C: driver::ControlPipe> ControlPipe<C> {
|
||||
pub(crate) fn new(control: C) -> Self {
|
||||
ControlPipe { control }
|
||||
}
|
||||
|
||||
pub(crate) async fn setup(&mut self) -> Setup {
|
||||
let req = self.control.setup().await;
|
||||
match (req.direction, req.length) {
|
||||
(UsbDirection::Out, n) => Setup::DataOut(
|
||||
req,
|
||||
DataOutStage {
|
||||
length: usize::from(n),
|
||||
},
|
||||
),
|
||||
(UsbDirection::In, n) => Setup::DataIn(
|
||||
req,
|
||||
DataInStage {
|
||||
length: usize::from(n),
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn data_out<'a>(
|
||||
&mut self,
|
||||
buf: &'a mut [u8],
|
||||
stage: DataOutStage,
|
||||
) -> Result<(&'a [u8], StatusStage), ReadError> {
|
||||
if stage.length == 0 {
|
||||
Ok((&[], StatusStage {}))
|
||||
} else {
|
||||
let req_length = stage.length;
|
||||
let max_packet_size = self.control.max_packet_size();
|
||||
let mut total = 0;
|
||||
|
||||
for chunk in buf.chunks_mut(max_packet_size) {
|
||||
let size = self.control.data_out(chunk).await?;
|
||||
total += size;
|
||||
if size < max_packet_size || total == req_length {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok((&buf[0..total], StatusStage {}))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn accept_in(&mut self, buf: &[u8], stage: DataInStage) {
|
||||
#[cfg(feature = "defmt")]
|
||||
debug!("control in accept {:x}", buf);
|
||||
#[cfg(not(feature = "defmt"))]
|
||||
debug!("control in accept {:x?}", buf);
|
||||
|
||||
let req_len = stage.length;
|
||||
let len = buf.len().min(req_len);
|
||||
let max_packet_size = self.control.max_packet_size();
|
||||
let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
|
||||
|
||||
let mut chunks = buf[0..len]
|
||||
.chunks(max_packet_size)
|
||||
.chain(need_zlp.then(|| -> &[u8] { &[] }));
|
||||
|
||||
while let Some(chunk) = chunks.next() {
|
||||
self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn accept_in_writer(
|
||||
&mut self,
|
||||
req: Request,
|
||||
stage: DataInStage,
|
||||
f: impl FnOnce(&mut DescriptorWriter),
|
||||
) {
|
||||
let mut buf = [0; 256];
|
||||
let mut w = DescriptorWriter::new(&mut buf);
|
||||
f(&mut w);
|
||||
let pos = w.position().min(usize::from(req.length));
|
||||
self.accept_in(&buf[..pos], stage).await
|
||||
}
|
||||
|
||||
pub(crate) fn accept(&mut self, _: StatusStage) {
|
||||
self.control.accept();
|
||||
}
|
||||
|
||||
pub(crate) fn reject(&mut self) {
|
||||
self.control.reject();
|
||||
}
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ use heapless::Vec;
|
||||
|
||||
use self::control::*;
|
||||
use self::descriptor::*;
|
||||
use self::driver::*;
|
||||
use self::driver::{Bus, Driver, Event};
|
||||
use self::types::*;
|
||||
use self::util::*;
|
||||
|
||||
@ -92,10 +92,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
Self {
|
||||
bus: driver,
|
||||
config,
|
||||
control: ControlPipe {
|
||||
control,
|
||||
request: None,
|
||||
},
|
||||
control: ControlPipe::new(control),
|
||||
device_descriptor,
|
||||
config_descriptor,
|
||||
bos_descriptor,
|
||||
@ -134,57 +131,50 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
Either::Right(req) => {
|
||||
debug!("control request: {:x}", req);
|
||||
|
||||
match req.direction {
|
||||
UsbDirection::In => self.handle_control_in(req).await,
|
||||
UsbDirection::Out => self.handle_control_out(req).await,
|
||||
match req {
|
||||
Setup::DataIn(req, stage) => self.handle_control_in(req, stage).await,
|
||||
Setup::DataOut(req, stage) => self.handle_control_out(req, stage).await,
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_control_out(&mut self, req: Request) {
|
||||
async fn handle_control_out(&mut self, req: Request, stage: DataOutStage) {
|
||||
const CONFIGURATION_NONE_U16: u16 = CONFIGURATION_NONE as u16;
|
||||
const CONFIGURATION_VALUE_U16: u16 = CONFIGURATION_VALUE as u16;
|
||||
|
||||
// If the request has a data state, we must read it.
|
||||
let data = if req.length > 0 {
|
||||
match self.control.data_out(self.control_buf).await {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
warn!("usb: failed to read CONTROL OUT data stage.");
|
||||
return;
|
||||
}
|
||||
let (data, stage) = match self.control.data_out(self.control_buf, stage).await {
|
||||
Ok(data) => data,
|
||||
Err(_) => {
|
||||
warn!("usb: failed to read CONTROL OUT data stage.");
|
||||
return;
|
||||
}
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
|
||||
match (req.request_type, req.recipient) {
|
||||
(RequestType::Standard, Recipient::Device) => match (req.request, req.value) {
|
||||
(Request::CLEAR_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
|
||||
self.remote_wakeup_enabled = false;
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
(Request::SET_FEATURE, Request::FEATURE_DEVICE_REMOTE_WAKEUP) => {
|
||||
self.remote_wakeup_enabled = true;
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
(Request::SET_ADDRESS, 1..=127) => {
|
||||
self.pending_address = req.value as u8;
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
(Request::SET_CONFIGURATION, CONFIGURATION_VALUE_U16) => {
|
||||
self.device_state = UsbDeviceState::Configured;
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
(Request::SET_CONFIGURATION, CONFIGURATION_NONE_U16) => match self.device_state {
|
||||
UsbDeviceState::Default => {
|
||||
self.control.accept();
|
||||
}
|
||||
UsbDeviceState::Default => self.control.accept(stage),
|
||||
_ => {
|
||||
self.device_state = UsbDeviceState::Addressed;
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
},
|
||||
_ => self.control.reject(),
|
||||
@ -193,12 +183,12 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
(Request::SET_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
|
||||
let ep_addr = ((req.index as u8) & 0x8f).into();
|
||||
self.bus.set_stalled(ep_addr, true);
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
(Request::CLEAR_FEATURE, Request::FEATURE_ENDPOINT_HALT) => {
|
||||
let ep_addr = ((req.index as u8) & 0x8f).into();
|
||||
self.bus.set_stalled(ep_addr, false);
|
||||
self.control.accept();
|
||||
self.control.accept(stage)
|
||||
}
|
||||
_ => self.control.reject(),
|
||||
},
|
||||
@ -218,7 +208,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
_ => handler.control_out(req, data),
|
||||
};
|
||||
match response {
|
||||
OutResponse::Accepted => self.control.accept(),
|
||||
OutResponse::Accepted => self.control.accept(stage),
|
||||
OutResponse::Rejected => self.control.reject(),
|
||||
}
|
||||
}
|
||||
@ -229,7 +219,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_control_in(&mut self, req: Request) {
|
||||
async fn handle_control_in(&mut self, req: Request, stage: DataInStage) {
|
||||
match (req.request_type, req.recipient) {
|
||||
(RequestType::Standard, Recipient::Device) => match req.request {
|
||||
Request::GET_STATUS => {
|
||||
@ -240,17 +230,15 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
if self.remote_wakeup_enabled {
|
||||
status |= 0x0002;
|
||||
}
|
||||
self.control.accept_in(&status.to_le_bytes()).await;
|
||||
}
|
||||
Request::GET_DESCRIPTOR => {
|
||||
self.handle_get_descriptor(req).await;
|
||||
self.control.accept_in(&status.to_le_bytes(), stage).await
|
||||
}
|
||||
Request::GET_DESCRIPTOR => self.handle_get_descriptor(req, stage).await,
|
||||
Request::GET_CONFIGURATION => {
|
||||
let status = match self.device_state {
|
||||
UsbDeviceState::Configured => CONFIGURATION_VALUE,
|
||||
_ => CONFIGURATION_NONE,
|
||||
};
|
||||
self.control.accept_in(&status.to_le_bytes()).await;
|
||||
self.control.accept_in(&status.to_le_bytes(), stage).await
|
||||
}
|
||||
_ => self.control.reject(),
|
||||
},
|
||||
@ -261,7 +249,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
if self.bus.is_stalled(ep_addr) {
|
||||
status |= 0x0001;
|
||||
}
|
||||
self.control.accept_in(&status.to_le_bytes()).await;
|
||||
self.control.accept_in(&status.to_le_bytes(), stage).await
|
||||
}
|
||||
_ => self.control.reject(),
|
||||
},
|
||||
@ -285,7 +273,7 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
};
|
||||
|
||||
match response {
|
||||
InResponse::Accepted(data) => self.control.accept_in(data).await,
|
||||
InResponse::Accepted(data) => self.control.accept_in(data, stage).await,
|
||||
InResponse::Rejected => self.control.reject(),
|
||||
}
|
||||
}
|
||||
@ -296,17 +284,19 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
}
|
||||
}
|
||||
|
||||
async fn handle_get_descriptor(&mut self, req: Request) {
|
||||
async fn handle_get_descriptor(&mut self, req: Request, stage: DataInStage) {
|
||||
let (dtype, index) = req.descriptor_type_index();
|
||||
|
||||
match dtype {
|
||||
descriptor_type::BOS => self.control.accept_in(self.bos_descriptor).await,
|
||||
descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor).await,
|
||||
descriptor_type::CONFIGURATION => self.control.accept_in(self.config_descriptor).await,
|
||||
descriptor_type::BOS => self.control.accept_in(self.bos_descriptor, stage).await,
|
||||
descriptor_type::DEVICE => self.control.accept_in(self.device_descriptor, stage).await,
|
||||
descriptor_type::CONFIGURATION => {
|
||||
self.control.accept_in(self.config_descriptor, stage).await
|
||||
}
|
||||
descriptor_type::STRING => {
|
||||
if index == 0 {
|
||||
self.control
|
||||
.accept_in_writer(req, |w| {
|
||||
.accept_in_writer(req, stage, |w| {
|
||||
w.write(descriptor_type::STRING, &lang_id::ENGLISH_US.to_le_bytes());
|
||||
})
|
||||
.await
|
||||
@ -324,7 +314,9 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
};
|
||||
|
||||
if let Some(s) = s {
|
||||
self.control.accept_in_writer(req, |w| w.string(s)).await;
|
||||
self.control
|
||||
.accept_in_writer(req, stage, |w| w.string(s))
|
||||
.await
|
||||
} else {
|
||||
self.control.reject()
|
||||
}
|
||||
@ -334,81 +326,3 @@ impl<'d, D: Driver<'d>> UsbDevice<'d, D> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct ControlPipe<C: driver::ControlPipe> {
|
||||
control: C,
|
||||
request: Option<Request>,
|
||||
}
|
||||
|
||||
impl<C: driver::ControlPipe> ControlPipe<C> {
|
||||
async fn setup(&mut self) -> Request {
|
||||
assert!(self.request.is_none());
|
||||
let req = self.control.setup().await;
|
||||
self.request = Some(req);
|
||||
req
|
||||
}
|
||||
|
||||
async fn data_out<'a>(&mut self, buf: &'a mut [u8]) -> Result<&'a [u8], ReadError> {
|
||||
let req = self.request.unwrap();
|
||||
assert_eq!(req.direction, UsbDirection::Out);
|
||||
assert!(req.length > 0);
|
||||
let req_length = usize::from(req.length);
|
||||
|
||||
let max_packet_size = self.control.max_packet_size();
|
||||
let mut total = 0;
|
||||
|
||||
for chunk in buf.chunks_mut(max_packet_size) {
|
||||
let size = self.control.data_out(chunk).await?;
|
||||
total += size;
|
||||
if size < max_packet_size || total == req_length {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(&buf[0..total])
|
||||
}
|
||||
|
||||
async fn accept_in(&mut self, buf: &[u8]) -> () {
|
||||
#[cfg(feature = "defmt")]
|
||||
debug!("control in accept {:x}", buf);
|
||||
#[cfg(not(feature = "defmt"))]
|
||||
debug!("control in accept {:x?}", buf);
|
||||
let req = unwrap!(self.request);
|
||||
assert!(req.direction == UsbDirection::In);
|
||||
|
||||
let req_len = usize::from(req.length);
|
||||
let len = buf.len().min(req_len);
|
||||
let max_packet_size = self.control.max_packet_size();
|
||||
let need_zlp = len != req_len && (len % usize::from(max_packet_size)) == 0;
|
||||
|
||||
let mut chunks = buf[0..len]
|
||||
.chunks(max_packet_size)
|
||||
.chain(need_zlp.then(|| -> &[u8] { &[] }));
|
||||
|
||||
while let Some(chunk) = chunks.next() {
|
||||
self.control.data_in(chunk, chunks.size_hint().0 == 0).await;
|
||||
}
|
||||
|
||||
self.request = None;
|
||||
}
|
||||
|
||||
async fn accept_in_writer(&mut self, req: Request, f: impl FnOnce(&mut DescriptorWriter)) {
|
||||
let mut buf = [0; 256];
|
||||
let mut w = DescriptorWriter::new(&mut buf);
|
||||
f(&mut w);
|
||||
let pos = w.position().min(usize::from(req.length));
|
||||
self.accept_in(&buf[..pos]).await;
|
||||
}
|
||||
|
||||
fn accept(&mut self) {
|
||||
assert!(self.request.is_some());
|
||||
self.control.accept();
|
||||
self.request = None;
|
||||
}
|
||||
|
||||
fn reject(&mut self) {
|
||||
assert!(self.request.is_some());
|
||||
self.control.reject();
|
||||
self.request = None;
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user