Refactored cryp din/dout into functions.

This commit is contained in:
Caleb Garrett 2024-03-05 11:25:56 -05:00
parent ac06ca2fa0
commit 6e9e8eeb5f
2 changed files with 144 additions and 146 deletions

View File

@ -4,12 +4,35 @@ use core::cmp::min;
use core::marker::PhantomData;
use embassy_hal_internal::{into_ref, PeripheralRef};
use embassy_sync::waitqueue::AtomicWaker;
use crate::{interrupt, pac, peripherals, Peripheral};
use crate::interrupt::typelevel::Interrupt;
use crate::{dma::NoDma, interrupt, pac, peripherals, Peripheral};
const DES_BLOCK_SIZE: usize = 8; // 64 bits
const AES_BLOCK_SIZE: usize = 16; // 128 bits
static CRYP_WAKER: AtomicWaker = AtomicWaker::new();
/// CRYP interrupt handler.
pub struct InterruptHandler<T: Instance> {
_phantom: PhantomData<T>,
}
impl<T: Instance> interrupt::typelevel::Handler<T::Interrupt> for InterruptHandler<T> {
unsafe fn on_interrupt() {
let bits = T::regs().misr().read();
if bits.inmis() {
T::regs().imscr().modify(|w| w.set_inim(false));
CRYP_WAKER.wake();
}
if bits.outmis() {
T::regs().imscr().modify(|w| w.set_outim(false));
CRYP_WAKER.wake();
}
}
}
/// This trait encapsulates all cipher-specific behavior/
pub trait Cipher<'c> {
/// Processing block size. Determined by the processor and the algorithm.
@ -32,7 +55,7 @@ pub trait Cipher<'c> {
fn prepare_key(&self, _p: &pac::cryp::Cryp) {}
/// Performs any cipher-specific initialization.
fn init_phase(&self, _p: &pac::cryp::Cryp) {}
fn init_phase<T: Instance, D>(&self, _p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {}
/// Called prior to processing the last data block for cipher-specific operations.
fn pre_final_block(&self, _p: &pac::cryp::Cryp, _dir: Direction, _padding_len: usize) -> [u32; 4] {
@ -40,9 +63,10 @@ pub trait Cipher<'c> {
}
/// Called after processing the last data block for cipher-specific operations.
fn post_final_block(
fn post_final_block<T: Instance, D>(
&self,
_p: &pac::cryp::Cryp,
_cryp: &Cryp<T, D>,
_dir: Direction,
_int_data: &mut [u8; AES_BLOCK_SIZE],
_temp1: [u32; 4],
@ -425,7 +449,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
p.cr().modify(|w| w.set_algomode3(true));
}
fn init_phase(&self, p: &pac::cryp::Cryp) {
fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {
p.cr().modify(|w| w.set_gcm_ccmph(0));
p.cr().modify(|w| w.set_crypen(true));
while p.cr().read().crypen() {}
@ -453,9 +477,10 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
}
#[cfg(cryp_v2)]
fn post_final_block(
fn post_final_block<T: Instance, D>(
&self,
p: &pac::cryp::Cryp,
cryp: &Cryp<T, D>,
dir: Direction,
int_data: &mut [u8; AES_BLOCK_SIZE],
_temp1: [u32; 4],
@ -471,17 +496,9 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGcm<'c, KEY_SIZE> {
}
p.cr().modify(|w| w.set_crypen(true));
p.cr().modify(|w| w.set_gcm_ccmph(3));
let mut index = 0;
let end_index = Self::BLOCK_SIZE;
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&int_data[index..index + 4]);
p.din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
for _ in 0..4 {
p.dout().read();
}
cryp.write_bytes_blocking(Self::BLOCK_SIZE, int_data);
cryp.read_bytes_blocking(Self::BLOCK_SIZE, int_data);
}
}
}
@ -532,7 +549,7 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
p.cr().modify(|w| w.set_algomode3(true));
}
fn init_phase(&self, p: &pac::cryp::Cryp) {
fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, _cryp: &Cryp<T, D>) {
p.cr().modify(|w| w.set_gcm_ccmph(0));
p.cr().modify(|w| w.set_crypen(true));
while p.cr().read().crypen() {}
@ -560,9 +577,10 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
}
#[cfg(cryp_v2)]
fn post_final_block(
fn post_final_block<T: Instance, D>(
&self,
p: &pac::cryp::Cryp,
cryp: &Cryp<T, D>,
dir: Direction,
int_data: &mut [u8; AES_BLOCK_SIZE],
_temp1: [u32; 4],
@ -578,17 +596,9 @@ impl<'c, const KEY_SIZE: usize> Cipher<'c> for AesGmac<'c, KEY_SIZE> {
}
p.cr().modify(|w| w.set_crypen(true));
p.cr().modify(|w| w.set_gcm_ccmph(3));
let mut index = 0;
let end_index = Self::BLOCK_SIZE;
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&int_data[index..index + 4]);
p.din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
for _ in 0..4 {
p.dout().read();
}
cryp.write_bytes_blocking(Self::BLOCK_SIZE, int_data);
cryp.read_bytes_blocking(Self::BLOCK_SIZE, int_data);
}
}
}
@ -697,18 +707,11 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
p.cr().modify(|w| w.set_algomode3(true));
}
fn init_phase(&self, p: &pac::cryp::Cryp) {
fn init_phase<T: Instance, D>(&self, p: &pac::cryp::Cryp, cryp: &Cryp<T, D>) {
p.cr().modify(|w| w.set_gcm_ccmph(0));
let mut index = 0;
let end_index = index + Self::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&self.block0[index..index + 4]);
p.din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
cryp.write_bytes_blocking(Self::BLOCK_SIZE, &self.block0);
p.cr().modify(|w| w.set_crypen(true));
while p.cr().read().crypen() {}
}
@ -744,9 +747,10 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
}
#[cfg(cryp_v2)]
fn post_final_block(
fn post_final_block<T: Instance, D>(
&self,
p: &pac::cryp::Cryp,
cryp: &Cryp<T, D>,
dir: Direction,
int_data: &mut [u8; AES_BLOCK_SIZE],
temp1: [u32; 4],
@ -774,8 +778,8 @@ impl<'c, const KEY_SIZE: usize, const TAG_SIZE: usize, const IV_SIZE: usize> Cip
let int_word = u32::from_le_bytes(int_bytes);
in_data[i] = int_word;
in_data[i] = in_data[i] ^ temp1[i] ^ temp2[i];
p.din().write_value(in_data[i]);
}
cryp.write_words_blocking(Self::BLOCK_SIZE, &in_data);
}
}
}
@ -845,16 +849,31 @@ pub enum Direction {
}
/// Crypto Accelerator Driver
pub struct Cryp<'d, T: Instance> {
pub struct Cryp<'d, T: Instance, D = NoDma> {
_peripheral: PeripheralRef<'d, T>,
indma: PeripheralRef<'d, D>,
outdma: PeripheralRef<'d, D>,
}
impl<'d, T: Instance> Cryp<'d, T> {
impl<'d, T: Instance, D> Cryp<'d, T, D> {
/// Create a new CRYP driver.
pub fn new(peri: impl Peripheral<P = T> + 'd) -> Self {
pub fn new(
peri: impl Peripheral<P = T> + 'd,
indma: impl Peripheral<P = D> + 'd,
outdma: impl Peripheral<P = D> + 'd,
_irq: impl interrupt::typelevel::Binding<T::Interrupt, InterruptHandler<T>> + 'd,
) -> Self {
T::enable_and_reset();
into_ref!(peri);
let instance = Self { _peripheral: peri };
into_ref!(peri, indma, outdma);
let instance = Self {
_peripheral: peri,
indma: indma,
outdma: outdma,
};
T::Interrupt::unpend();
unsafe { T::Interrupt::enable() };
instance
}
@ -929,7 +948,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
// Flush in/out FIFOs
T::regs().cr().modify(|w| w.fflush());
ctx.cipher.init_phase(&T::regs());
ctx.cipher.init_phase(&T::regs(), self);
self.store_context(&mut ctx);
@ -985,15 +1004,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
if ctx.aad_buffer_len < C::BLOCK_SIZE {
// The buffer isn't full and this is the last buffer, so process it as is (already padded).
if last_aad_block {
let mut index = 0;
let end_index = C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
@ -1008,15 +1019,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
}
} else {
// Load the full block from the buffer.
let mut index = 0;
let end_index = C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
}
@ -1032,33 +1035,13 @@ impl<'d, T: Instance> Cryp<'d, T> {
// Load full data blocks into core.
let num_full_blocks = aad_len_remaining / C::BLOCK_SIZE;
for block in 0..num_full_blocks {
let mut index = len_to_copy + (block * C::BLOCK_SIZE);
let end_index = index + C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&aad[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
}
let start_index = len_to_copy;
let end_index = start_index + (C::BLOCK_SIZE * num_full_blocks);
self.write_bytes_blocking(C::BLOCK_SIZE, &aad[start_index..end_index]);
if last_aad_block {
if leftovers > 0 {
let mut index = 0;
let end_index = C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&ctx.aad_buffer[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
self.write_bytes_blocking(C::BLOCK_SIZE, &ctx.aad_buffer);
}
// Switch to payload phase.
ctx.aad_complete = true;
@ -1125,25 +1108,11 @@ impl<'d, T: Instance> Cryp<'d, T> {
// Load data into core, block by block.
let num_full_blocks = input.len() / C::BLOCK_SIZE;
for block in 0..num_full_blocks {
let mut index = block * C::BLOCK_SIZE;
let end_index = index + C::BLOCK_SIZE;
let index = block * C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&input[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
let mut index = block * C::BLOCK_SIZE;
let end_index = index + C::BLOCK_SIZE;
// Block until there is output to read.
while !T::regs().sr().read().ofne() {}
self.write_bytes_blocking(C::BLOCK_SIZE, &input[index..index + 4]);
// Read block out
while index < end_index {
let out_word: u32 = T::regs().dout().read();
output[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
index += 4;
}
self.read_bytes_blocking(C::BLOCK_SIZE, &mut output[index..index + 4]);
}
// Handle the final block, which is incomplete.
@ -1154,25 +1123,8 @@ impl<'d, T: Instance> Cryp<'d, T> {
let mut intermediate_data: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
let mut last_block: [u8; AES_BLOCK_SIZE] = [0; AES_BLOCK_SIZE];
last_block[..last_block_remainder].copy_from_slice(&input[input.len() - last_block_remainder..input.len()]);
let mut index = 0;
let end_index = C::BLOCK_SIZE;
// Write block in
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&last_block[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
}
let mut index = 0;
let end_index = C::BLOCK_SIZE;
// Block until there is output to read.
while !T::regs().sr().read().ofne() {}
// Read block out
while index < end_index {
let out_word: u32 = T::regs().dout().read();
intermediate_data[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
index += 4;
}
self.write_bytes_blocking(C::BLOCK_SIZE, &last_block);
self.read_bytes_blocking(C::BLOCK_SIZE, &mut intermediate_data);
// Handle the last block depending on mode.
let output_len = output.len();
@ -1182,7 +1134,7 @@ impl<'d, T: Instance> Cryp<'d, T> {
let mut mask: [u8; 16] = [0; 16];
mask[..last_block_remainder].fill(0xFF);
ctx.cipher
.post_final_block(&T::regs(), ctx.dir, &mut intermediate_data, temp1, mask);
.post_final_block(&T::regs(), self, ctx.dir, &mut intermediate_data, temp1, mask);
}
ctx.payload_len += input.len() as u64;
@ -1213,28 +1165,21 @@ impl<'d, T: Instance> Cryp<'d, T> {
let payloadlen2: u32 = (ctx.payload_len * 8) as u32;
#[cfg(cryp_v2)]
{
T::regs().din().write_value(headerlen1.swap_bytes());
T::regs().din().write_value(headerlen2.swap_bytes());
T::regs().din().write_value(payloadlen1.swap_bytes());
T::regs().din().write_value(payloadlen2.swap_bytes());
}
let footer: [u32; 4] = [
headerlen1.swap_bytes(),
headerlen2.swap_bytes(),
payloadlen1.swap_bytes(),
payloadlen2.swap_bytes(),
];
#[cfg(cryp_v3)]
{
T::regs().din().write_value(headerlen1);
T::regs().din().write_value(headerlen2);
T::regs().din().write_value(payloadlen1);
T::regs().din().write_value(payloadlen2);
}
let footer: [u32; 4] = [headerlen1, headerlen2, payloadlen1, payloadlen2];
self.write_words_blocking(C::BLOCK_SIZE, &footer);
while !T::regs().sr().read().ofne() {}
let mut full_tag: [u8; 16] = [0; 16];
full_tag[0..4].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
full_tag[4..8].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
full_tag[8..12].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
full_tag[12..16].copy_from_slice(T::regs().dout().read().to_ne_bytes().as_slice());
self.read_bytes_blocking(C::BLOCK_SIZE, &mut full_tag);
let mut tag: [u8; TAG_SIZE] = [0; TAG_SIZE];
tag.copy_from_slice(&full_tag[0..TAG_SIZE]);
@ -1325,6 +1270,51 @@ impl<'d, T: Instance> Cryp<'d, T> {
// Enable crypto processor.
T::regs().cr().modify(|w| w.set_crypen(true));
}
fn write_bytes_blocking(&self, block_size: usize, blocks: &[u8]) {
// Ensure input is a multiple of block size.
assert_eq!(blocks.len() % block_size, 0);
let mut index = 0;
let end_index = blocks.len();
while index < end_index {
let mut in_word: [u8; 4] = [0; 4];
in_word.copy_from_slice(&blocks[index..index + 4]);
T::regs().din().write_value(u32::from_ne_bytes(in_word));
index += 4;
if index % block_size == 0 {
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
}
}
}
fn write_words_blocking(&self, block_size: usize, blocks: &[u32]) {
assert_eq!((blocks.len() * 4) % block_size, 0);
let mut byte_counter: usize = 0;
for word in blocks {
T::regs().din().write_value(*word);
byte_counter += 4;
if byte_counter % block_size == 0 {
// Block until input FIFO is empty.
while !T::regs().sr().read().ifem() {}
}
}
}
fn read_bytes_blocking(&self, block_size: usize, blocks: &mut [u8]) {
// Block until there is output to read.
while !T::regs().sr().read().ofne() {}
// Ensure input is a multiple of block size.
assert_eq!(blocks.len() % block_size, 0);
// Read block out
let mut index = 0;
let end_index = blocks.len();
while index < end_index {
let out_word: u32 = T::regs().dout().read();
blocks[index..index + 4].copy_from_slice(u32::to_ne_bytes(out_word).as_slice());
index += 4;
}
}
}
pub(crate) mod sealed {

View File

@ -6,11 +6,19 @@ use aes_gcm::aead::{AeadInPlace, KeyInit};
use aes_gcm::Aes128Gcm;
use defmt::info;
use embassy_executor::Spawner;
use embassy_stm32::cryp::*;
use embassy_stm32::Config;
use embassy_stm32::dma::NoDma;
use embassy_stm32::{
bind_interrupts,
cryp::{self, *},
};
use embassy_stm32::{peripherals, Config};
use embassy_time::Instant;
use {defmt_rtt as _, panic_probe as _};
bind_interrupts!(struct Irqs {
CRYP => cryp::InterruptHandler<peripherals::CRYP>;
});
#[embassy_executor::main]
async fn main(_spawner: Spawner) -> ! {
let config = Config::default();
@ -19,7 +27,7 @@ async fn main(_spawner: Spawner) -> ! {
let payload: &[u8] = b"hello world";
let aad: &[u8] = b"additional data";
let hw_cryp = Cryp::new(p.CRYP);
let hw_cryp = Cryp::new(p.CRYP, NoDma, NoDma, Irqs);
let key: [u8; 16] = [0; 16];
let mut ciphertext: [u8; 11] = [0; 11];
let mut plaintext: [u8; 11] = [0; 11];