diff --git a/Cargo.toml b/Cargo.toml index dadfb5c5a..6e3237448 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,7 +14,7 @@ firmware-logs = [] embassy-time = { version = "0.1.0" } embassy-sync = { version = "0.1.0" } embassy-futures = { version = "0.1.0" } -embassy-net = { version = "0.1.0" } +embassy-net-driver-channel = { version = "0.1.0" } atomic-polyfill = "0.1.5" defmt = { version = "0.3", optional = true } diff --git a/examples/rpi-pico-w/Cargo.toml b/examples/rpi-pico-w/Cargo.toml index b817289e5..fa1cad8c7 100644 --- a/examples/rpi-pico-w/Cargo.toml +++ b/examples/rpi-pico-w/Cargo.toml @@ -9,7 +9,7 @@ cyw43 = { path = "../../", features = ["defmt", "firmware-logs"]} embassy-executor = { version = "0.1.0", features = ["defmt", "integrated-timers"] } embassy-time = { version = "0.1.0", features = ["defmt", "defmt-timestamp-uptime"] } embassy-rp = { version = "0.1.0", features = ["defmt", "unstable-traits", "nightly", "unstable-pac", "time-driver"] } -embassy-net = { version = "0.1.0", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "pool-16", "unstable-traits", "nightly"] } +embassy-net = { version = "0.1.0", features = ["defmt", "tcp", "dhcpv4", "medium-ethernet", "unstable-traits", "nightly"] } atomic-polyfill = "0.1.5" static_cell = "1.0" @@ -28,12 +28,14 @@ heapless = "0.7.15" [patch.crates-io] -embassy-executor = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } -embassy-time = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } -embassy-futures = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } -embassy-sync = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } -embassy-rp = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } -embassy-net = { git = "https://github.com/embassy-rs/embassy", rev = "645fb66a5122bdc8180e0e65d076ca103431a426" } +embassy-executor = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-time = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-futures = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-sync = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-rp = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-net = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-net-driver = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } +embassy-net-driver-channel = { git = "https://github.com/embassy-rs/embassy", rev = "771806be790a2758f1314d6460defe7c2f0d3e99" } [profile.dev] debug = 2 @@ -43,7 +45,7 @@ overflow-checks = true [profile.release] codegen-units = 1 -debug = 2 +debug = 1 debug-assertions = false incremental = false lto = 'fat' diff --git a/examples/rpi-pico-w/src/main.rs b/examples/rpi-pico-w/src/main.rs index a19f38591..fd58e46df 100644 --- a/examples/rpi-pico-w/src/main.rs +++ b/examples/rpi-pico-w/src/main.rs @@ -34,7 +34,7 @@ async fn wifi_task( } #[embassy_executor::task] -async fn net_task(stack: &'static Stack>) -> ! { +async fn net_task(stack: &'static Stack>) -> ! { stack.run().await } @@ -66,11 +66,11 @@ async fn main(spawner: Spawner) { let spi = ExclusiveDevice::new(bus, cs); let state = singleton!(cyw43::State::new()); - let (mut control, runner) = cyw43::new(state, pwr, spi, fw).await; + let (net_device, mut control, runner) = cyw43::new(state, pwr, spi, fw).await; spawner.spawn(wifi_task(runner)).unwrap(); - let net_device = control.init(clm).await; + control.init(clm).await; //control.join_open(env!("WIFI_NETWORK")).await; control.join_wpa2(env!("WIFI_NETWORK"), env!("WIFI_PASSWORD")).await; diff --git a/src/lib.rs b/src/lib.rs index fa73b32e0..25e6f8f16 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,14 +15,10 @@ mod structs; use core::cell::Cell; use core::cmp::{max, min}; use core::slice; -use core::sync::atomic::Ordering; -use core::task::Waker; -use atomic_polyfill::AtomicBool; +use ch::driver::LinkState; use embassy_futures::yield_now; -use embassy_net::{PacketBoxExt, PacketBuf}; -use embassy_sync::blocking_mutex::raw::NoopRawMutex; -use embassy_sync::channel::Channel; +use embassy_net_driver_channel as ch; use embassy_time::{block_for, Duration, Timer}; use embedded_hal_1::digital::OutputPin; use embedded_hal_async::spi::{SpiBusRead, SpiBusWrite, SpiDevice}; @@ -32,6 +28,8 @@ use crate::consts::*; use crate::events::Event; use crate::structs::*; +const MTU: usize = 1514; + #[derive(Clone, Copy)] pub enum IoctlType { Get = 0, @@ -128,30 +126,25 @@ enum IoctlState { pub struct State { ioctl_state: Cell, - - tx_channel: Channel, - rx_channel: Channel, - link_up: AtomicBool, + ch: ch::State, } impl State { pub fn new() -> Self { Self { ioctl_state: Cell::new(IoctlState::Idle), - - tx_channel: Channel::new(), - rx_channel: Channel::new(), - link_up: AtomicBool::new(true), // TODO set up/down as we join/deassociate + ch: ch::State::new(), } } } pub struct Control<'a> { - state: &'a State, + state_ch: ch::StateRunner<'a>, + ioctl_state: &'a Cell, } impl<'a> Control<'a> { - pub async fn init(&mut self, clm: &[u8]) -> NetDevice<'a> { + pub async fn init(&mut self, clm: &[u8]) { const CHUNK_SIZE: usize = 1024; info!("Downloading CLM..."); @@ -258,12 +251,10 @@ impl<'a> Control<'a> { Timer::after(Duration::from_millis(100)).await; - info!("INIT DONE"); + self.state_ch.set_ethernet_address(mac_addr); + self.state_ch.set_link_state(LinkState::Up); // TODO do on join/leave - NetDevice { - state: self.state, - mac_addr, - } + info!("INIT DONE"); } pub async fn join_open(&mut self, ssid: &str) { @@ -381,75 +372,30 @@ impl<'a> Control<'a> { async fn ioctl(&mut self, kind: IoctlType, cmd: u32, iface: u32, buf: &mut [u8]) -> usize { // TODO cancel ioctl on future drop. - while !matches!(self.state.ioctl_state.get(), IoctlState::Idle) { + while !matches!(self.ioctl_state.get(), IoctlState::Idle) { yield_now().await; } - self.state - .ioctl_state - .set(IoctlState::Pending { kind, cmd, iface, buf }); + self.ioctl_state.set(IoctlState::Pending { kind, cmd, iface, buf }); let resp_len = loop { - if let IoctlState::Done { resp_len } = self.state.ioctl_state.get() { + if let IoctlState::Done { resp_len } = self.ioctl_state.get() { break resp_len; } yield_now().await; }; - self.state.ioctl_state.set(IoctlState::Idle); + self.ioctl_state.set(IoctlState::Idle); resp_len } } -pub struct NetDevice<'a> { - state: &'a State, - mac_addr: [u8; 6], -} - -impl<'a> embassy_net::Device for NetDevice<'a> { - fn register_waker(&mut self, waker: &Waker) { - // loopy loopy wakey wakey - waker.wake_by_ref() - } - - fn link_state(&mut self) -> embassy_net::LinkState { - match self.state.link_up.load(Ordering::Relaxed) { - true => embassy_net::LinkState::Up, - false => embassy_net::LinkState::Down, - } - } - - fn capabilities(&self) -> embassy_net::DeviceCapabilities { - let mut caps = embassy_net::DeviceCapabilities::default(); - caps.max_transmission_unit = 1514; // 1500 IP + 14 ethernet header - caps.medium = embassy_net::Medium::Ethernet; - caps - } - - fn is_transmit_ready(&mut self) -> bool { - true - } - - fn transmit(&mut self, pkt: PacketBuf) { - if self.state.tx_channel.try_send(pkt).is_err() { - warn!("TX failed") - } - } - - fn receive(&mut self) -> Option { - self.state.rx_channel.try_recv().ok() - } - - fn ethernet_address(&self) -> [u8; 6] { - self.mac_addr - } -} - pub struct Runner<'a, PWR, SPI> { + ch: ch::Runner<'a, MTU>, bus: Bus, - state: &'a State, + ioctl_state: &'a Cell, ioctl_id: u16, sdpcm_seq: u8, sdpcm_seq_max: u8, @@ -466,21 +412,27 @@ struct LogState { buf_count: usize, } +pub type NetDriver<'a> = ch::Device<'a, MTU>; + pub async fn new<'a, PWR, SPI>( - state: &'a State, + state: &'a mut State, pwr: PWR, spi: SPI, firmware: &[u8], -) -> (Control<'a>, Runner<'a, PWR, SPI>) +) -> (NetDriver<'a>, Control<'a>, Runner<'a, PWR, SPI>) where PWR: OutputPin, SPI: SpiDevice, SPI::Bus: SpiBusRead + SpiBusWrite, { + let (ch_runner, device) = ch::new(&mut state.ch, [0; 6]); + let state_ch = ch_runner.state_runner(); + let mut runner = Runner { + ch: ch_runner, bus: Bus::new(pwr, spi), - state, + ioctl_state: &state.ioctl_state, ioctl_id: 0, sdpcm_seq: 0, sdpcm_seq_max: 1, @@ -496,7 +448,14 @@ where runner.init(firmware).await; - (Control { state }, runner) + ( + device, + Control { + state_ch, + ioctl_state: &state.ioctl_state, + }, + runner, + ) } impl<'a, PWR, SPI> Runner<'a, PWR, SPI> @@ -662,15 +621,55 @@ where if !self.has_credit() { warn!("TX stalled"); } else { - if let IoctlState::Pending { kind, cmd, iface, buf } = self.state.ioctl_state.get() { + if let IoctlState::Pending { kind, cmd, iface, buf } = self.ioctl_state.get() { self.send_ioctl(kind, cmd, iface, unsafe { &*buf }).await; - self.state.ioctl_state.set(IoctlState::Sent { buf }); + self.ioctl_state.set(IoctlState::Sent { buf }); } if !self.has_credit() { warn!("TX stalled"); } else { - if let Ok(p) = self.state.tx_channel.try_recv() { - self.send_packet(&p).await; + if let Some(packet) = self.ch.try_tx_buf() { + trace!("tx pkt {:02x}", &packet[..packet.len().min(48)]); + + let mut buf = [0; 512]; + let buf8 = slice8_mut(&mut buf); + + let total_len = SdpcmHeader::SIZE + BcdHeader::SIZE + packet.len(); + + let seq = self.sdpcm_seq; + self.sdpcm_seq = self.sdpcm_seq.wrapping_add(1); + + let sdpcm_header = SdpcmHeader { + len: total_len as u16, // TODO does this len need to be rounded up to u32? + len_inv: !total_len as u16, + sequence: seq, + channel_and_flags: CHANNEL_TYPE_DATA, + next_length: 0, + header_length: SdpcmHeader::SIZE as _, + wireless_flow_control: 0, + bus_data_credit: 0, + reserved: [0, 0], + }; + + let bcd_header = BcdHeader { + flags: BDC_VERSION << BDC_VERSION_SHIFT, + priority: 0, + flags2: 0, + data_offset: 0, + }; + trace!("tx {:?}", sdpcm_header); + trace!(" {:?}", bcd_header); + + buf8[0..SdpcmHeader::SIZE].copy_from_slice(&sdpcm_header.to_bytes()); + buf8[SdpcmHeader::SIZE..][..BcdHeader::SIZE].copy_from_slice(&bcd_header.to_bytes()); + buf8[SdpcmHeader::SIZE + BcdHeader::SIZE..][..packet.len()].copy_from_slice(packet); + + let total_len = (total_len + 3) & !3; // round up to 4byte + + trace!(" {:02x}", &buf8[..total_len.min(48)]); + + self.bus.wlan_write(&buf[..(total_len / 4)]).await; + self.ch.tx_done(); } } } @@ -686,7 +685,6 @@ where if status & STATUS_F2_PKT_AVAILABLE != 0 { let len = (status & STATUS_F2_PKT_LEN_MASK) >> STATUS_F2_PKT_LEN_SHIFT; - self.bus.wlan_read(&mut buf[..(len as usize + 3) / 4]).await; trace!("rx {:02x}", &slice8_mut(&mut buf)[..(len as usize).min(48)]); self.rx(&slice8_mut(&mut buf)[..len as usize]); @@ -698,49 +696,6 @@ where } } - async fn send_packet(&mut self, packet: &[u8]) { - trace!("tx pkt {:02x}", &packet[..packet.len().min(48)]); - - let mut buf = [0; 512]; - let buf8 = slice8_mut(&mut buf); - - let total_len = SdpcmHeader::SIZE + BcdHeader::SIZE + packet.len(); - - let seq = self.sdpcm_seq; - self.sdpcm_seq = self.sdpcm_seq.wrapping_add(1); - - let sdpcm_header = SdpcmHeader { - len: total_len as u16, // TODO does this len need to be rounded up to u32? - len_inv: !total_len as u16, - sequence: seq, - channel_and_flags: CHANNEL_TYPE_DATA, - next_length: 0, - header_length: SdpcmHeader::SIZE as _, - wireless_flow_control: 0, - bus_data_credit: 0, - reserved: [0, 0], - }; - - let bcd_header = BcdHeader { - flags: BDC_VERSION << BDC_VERSION_SHIFT, - priority: 0, - flags2: 0, - data_offset: 0, - }; - trace!("tx {:?}", sdpcm_header); - trace!(" {:?}", bcd_header); - - buf8[0..SdpcmHeader::SIZE].copy_from_slice(&sdpcm_header.to_bytes()); - buf8[SdpcmHeader::SIZE..][..BcdHeader::SIZE].copy_from_slice(&bcd_header.to_bytes()); - buf8[SdpcmHeader::SIZE + BcdHeader::SIZE..][..packet.len()].copy_from_slice(packet); - - let total_len = (total_len + 3) & !3; // round up to 4byte - - trace!(" {:02x}", &buf8[..total_len.min(48)]); - - self.bus.wlan_write(&buf[..(total_len / 4)]).await; - } - fn rx(&mut self, packet: &[u8]) { if packet.len() < SdpcmHeader::SIZE { warn!("packet too short, len={}", packet.len()); @@ -775,7 +730,7 @@ where let cdc_header = CdcHeader::from_bytes(payload[..CdcHeader::SIZE].try_into().unwrap()); trace!(" {:?}", cdc_header); - if let IoctlState::Sent { buf } = self.state.ioctl_state.get() { + if let IoctlState::Sent { buf } = self.ioctl_state.get() { if cdc_header.id == self.ioctl_id { if cdc_header.status != 0 { // TODO: propagate error instead @@ -786,7 +741,7 @@ where info!("IOCTL Response: {:02x}", &payload[CdcHeader::SIZE..][..resp_len]); (unsafe { &mut *buf }[..resp_len]).copy_from_slice(&payload[CdcHeader::SIZE..][..resp_len]); - self.state.ioctl_state.set(IoctlState::Done { resp_len }); + self.ioctl_state.set(IoctlState::Done { resp_len }); } } } @@ -859,11 +814,12 @@ where let packet = &payload[packet_start..]; trace!("rx pkt {:02x}", &packet[..(packet.len() as usize).min(48)]); - let mut p = unwrap!(embassy_net::PacketBox::new(embassy_net::Packet::new())); - p[..packet.len()].copy_from_slice(packet); - - if let Err(_) = self.state.rx_channel.try_send(p.slice(0..packet.len())) { - warn!("failed to push rxd packet to the channel.") + match self.ch.try_rx_buf() { + Some(buf) => { + buf[..packet.len()].copy_from_slice(packet); + self.ch.rx_done(packet.len()) + } + None => warn!("failed to push rxd packet to the channel."), } } _ => {}