Add TransferOptions with timeout to USB Host transfers

This commit is contained in:
Ingmar Jager 2024-09-25 14:31:44 +02:00
parent 2b8923cb50
commit bfda351452
4 changed files with 123 additions and 42 deletions

View File

@ -6,8 +6,10 @@ use core::task::Poll;
use embassy_hal_internal::into_ref;
use embassy_sync::waitqueue::AtomicWaker;
use embassy_time::Timer;
use embassy_usb_driver::host::{ChannelError, ChannelIn, ChannelOut, EndpointDescriptor, USBHostDriverTrait};
use embassy_time::{Duration, Instant, Timer};
use embassy_usb_driver::host::{
ChannelError, ChannelIn, ChannelOut, EndpointDescriptor, TransferOptions, USBHostDriverTrait,
};
use embassy_usb_driver::EndpointType;
use super::{DmPin, DpPin, Instance};
@ -444,9 +446,15 @@ impl<'d, T: Instance> Channel<'d, T, In> {
}
impl<'d, T: Instance> ChannelIn for Channel<'d, T, In> {
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ChannelError> {
async fn read(
&mut self,
buf: &mut [u8],
options: impl Into<Option<TransferOptions>>,
) -> Result<usize, ChannelError> {
let index = self.index;
let options: TransferOptions = options.into().unwrap_or_default();
let regs = T::regs();
let epr = regs.epr(index);
let epr_val = epr.read();
@ -461,9 +469,25 @@ impl<'d, T: Instance> ChannelIn for Channel<'d, T, In> {
let mut count: usize = 0;
let res = poll_fn(|cx| {
let t0 = Instant::now();
poll_fn(|cx| {
EP_IN_WAKERS[index].register(cx.waker());
if let Some(timeout_ms) = options.timeout_ms {
if t0.elapsed() > Duration::from_millis(timeout_ms as u64) {
// Timeout, we need to stop the current transaction.
// stat_rx can only be toggled by writing a 1 to its bits.
// We want to set it to InActive (0b00). So we should write back the current value.
let epr_val = epr.read();
let current_stat_rx = epr_val.stat_rx();
let mut epr_val = invariant(epr_val);
epr_val.set_stat_rx(current_stat_rx);
regs.epr(index).write_value(epr_val);
return Poll::Ready(Err(ChannelError::Timeout));
}
}
let stat = regs.epr(index).read().stat_rx();
match stat {
Stat::DISABLED => {
@ -494,9 +518,7 @@ impl<'d, T: Instance> ChannelIn for Channel<'d, T, In> {
}
}
})
.await;
res
.await
}
}
@ -509,40 +531,56 @@ impl<'d, T: Instance> Channel<'d, T, Out> {
}
impl<'d, T: Instance> ChannelOut for Channel<'d, T, Out> {
async fn write(&mut self, buf: &[u8]) -> Result<(), ChannelError> {
async fn write(&mut self, buf: &[u8], options: impl Into<Option<TransferOptions>>) -> Result<(), ChannelError> {
self.write_data(buf);
let index = self.index;
let options: TransferOptions = options.into().unwrap_or_default();
let regs = T::regs();
let epr = regs.epr(index).read();
let current_stat_tx = epr.stat_tx().to_bits();
let regs = T::regs();
let epr = regs.epr(index);
let epr_val = epr.read();
let current_stat_tx = epr_val.stat_tx().to_bits();
// stat_rx can only be toggled by writing a 1.
// We want to set it to Active (0b11)
let stat_valid = Stat::from_bits(!current_stat_tx & 0x3);
let mut epr = invariant(epr);
epr.set_stat_tx(stat_valid);
regs.epr(index).write_value(epr);
let mut epr_val = invariant(epr_val);
epr_val.set_stat_tx(stat_valid);
epr.write_value(epr_val);
let stat = poll_fn(|cx| {
let t0 = Instant::now();
poll_fn(|cx| {
EP_OUT_WAKERS[index].register(cx.waker());
if let Some(timeout_ms) = options.timeout_ms {
if t0.elapsed() > Duration::from_millis(timeout_ms as u64) {
// Timeout, we need to stop the current transaction.
// stat_tx can only be toggled by writing a 1 to its bits.
// We want to set it to InActive (0b00). So we should write back the current value.
let epr_val = epr.read();
let current_stat_tx = epr_val.stat_tx();
let mut epr_val = invariant(epr_val);
epr_val.set_stat_tx(current_stat_tx);
epr.write_value(epr_val);
return Poll::Ready(Err(ChannelError::Timeout));
}
}
let regs = T::regs();
let stat = regs.epr(index).read().stat_tx();
if matches!(stat, Stat::STALL | Stat::DISABLED) {
Poll::Ready(stat)
} else {
Poll::Pending
let stat = epr.read().stat_tx();
match stat {
Stat::DISABLED => Poll::Ready(Ok(())),
Stat::STALL => Poll::Ready(Err(ChannelError::Stall)),
Stat::NAK | Stat::VALID => Poll::Pending,
}
})
.await;
if stat == Stat::STALL {
return Err(ChannelError::Stall);
}
Ok(())
.await
}
}
@ -651,21 +689,33 @@ impl<'d, T: Instance> USBHostDriverTrait for USBHostDriver<'d, T> {
.await;
}
async fn control_request_out(&mut self, bytes: &[u8]) -> Result<(), ()> {
async fn control_request_out(&mut self, bytes: &[u8], data: &[u8]) -> Result<(), ()> {
let epr0 = T::regs().epr(0);
// setup stage
let mut epr_val = invariant(epr0.read());
epr_val.set_setup(true);
epr0.write_value(epr_val);
self.control_channel_out.write(bytes).await.map_err(|_| ())?;
let options = TransferOptions::default().set_timeout_ms(1000);
self.control_channel_out
.write(bytes, options.clone())
.await
.map_err(|_| ())?;
// TODO data stage
// self.control_channel_out.write(bytes).await.map_err(|_| ())?;
// data stage
if data.len() > 0 {
self.control_channel_out
.write(data, options.clone())
.await
.map_err(|_| ())?;
}
// Status stage
let mut status = [0u8; 0];
self.control_channel_in.read(&mut status).await.map_err(|_| ())?;
self.control_channel_in
.read(&mut status, options)
.await
.map_err(|_| ())?;
Ok(())
}
@ -677,18 +727,25 @@ impl<'d, T: Instance> USBHostDriverTrait for USBHostDriver<'d, T> {
let mut epr_val = invariant(epr0.read());
epr_val.set_setup(true);
epr0.write_value(epr_val);
let options = TransferOptions::default().set_timeout_ms(1000);
self.control_channel_out.write(bytes).await.map_err(|_| ())?;
self.control_channel_out
.write(bytes, options.clone())
.await
.map_err(|_| ())?;
// data stage
let count = self.control_channel_in.read(dest).await.map_err(|_| ())?;
let count = self
.control_channel_in
.read(dest, options.clone())
.await
.map_err(|_| ())?;
// status stage
// Send 0 bytes
let zero = [0u8; 0];
self.control_channel_out.write(&zero).await.map_err(|_| ())?;
self.control_channel_out.write(&zero, options).await.map_err(|_| ())?;
Ok(count)
}

View File

@ -12,6 +12,9 @@ pub enum ChannelError {
/// The device endpoint is stalled.
Stall,
/// The request timed out (no proper response from device)
Timeout,
}
/// USB endpoint descriptor as defined in the USB 2.0 specification.
@ -64,7 +67,7 @@ pub trait USBHostDriverTrait {
async fn wait_for_device_disconnect(&mut self);
/// Issue a control request out (sending data to device).
async fn control_request_out(&mut self, bytes: &[u8]) -> Result<(), ()>;
async fn control_request_out(&mut self, bytes: &[u8], data: &[u8]) -> Result<(), ()>;
/// Issue a control request in (receiving data from device).
async fn control_request_in(&mut self, bytes: &[u8], dest: &mut [u8]) -> Result<usize, ()>;
@ -79,17 +82,38 @@ pub trait USBHostDriverTrait {
fn alloc_channel_out(&mut self, desc: &EndpointDescriptor) -> Result<Self::ChannelOut, ()>;
}
/// USB Transfer Options
#[derive(Default, Clone, Debug)]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct TransferOptions {
/// Optional Timeout in milliseconds.
/// If the device has not Acked the transfer within this time the transfer will return a timeout error.
pub timeout_ms: Option<u32>,
}
impl TransferOptions {
/// Set timeout in milliseconds
pub fn set_timeout_ms(mut self, timeout_ms: u32) -> Self {
self.timeout_ms = Some(timeout_ms);
self
}
}
/// USB Host Channel for an IN Endpoint
pub trait ChannelIn {
/// Attempt to read `buf.len()` bytes from an IN Endpoint.
/// This reads multiple USB packets if `buf.len()` is larger than the maximum packet size.
/// Returns the number of bytes read, which may be be less than `buf.len()` if the device responds with non full packet.
async fn read(&mut self, buf: &mut [u8]) -> Result<usize, ChannelError>;
async fn read(
&mut self,
buf: &mut [u8],
options: impl Into<Option<TransferOptions>>,
) -> Result<usize, ChannelError>;
}
/// USB Host Channel for an OUT Endpoint
pub trait ChannelOut {
/// Write `buf.len()` bytes to an OUT Endpoint.
/// This writes multiple USB packets if `buf.len()` is larger than the maximum packet size.
async fn write(&mut self, buf: &[u8]) -> Result<(), ChannelError>;
async fn write(&mut self, buf: &[u8], options: impl Into<Option<TransferOptions>>) -> Result<(), ChannelError>;
}

View File

@ -477,7 +477,7 @@ impl<D: USBHostDriverTrait> UsbHost<D> {
length: 0,
};
self.driver.control_request_out(packet.as_bytes()).await?;
self.driver.control_request_out(packet.as_bytes(), &[]).await?;
self.device_address = addr;
Ok(())
}
@ -555,7 +555,7 @@ impl<D: USBHostDriverTrait> UsbHost<D> {
length: buf.len() as u16,
};
self.driver.control_request_out(packet.as_bytes()).await
self.driver.control_request_out(packet.as_bytes(), &[]).await
}
/// SET_CONFIGURATION control request.
@ -569,7 +569,7 @@ impl<D: USBHostDriverTrait> UsbHost<D> {
length: 0,
};
self.driver.control_request_out(packet.as_bytes()).await
self.driver.control_request_out(packet.as_bytes(), &[]).await
}
/// Claim/allocate an endpoint. Returns the channel if successful.

View File

@ -242,7 +242,7 @@ mod hid_keyboard {
pub async fn listen(&mut self) {
let mut buffer = [0u8; 8];
if let Ok(_l) = self.channel.read(&mut buffer[..]).await {
if let Ok(_l) = self.channel.read(&mut buffer[..], None).await {
let keycodes = parse_payload(&buffer);
for keycode in keycodes {