Refactoring

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-06 02:08:36 +02:00
parent f734fa6434
commit d8a8551016
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
6 changed files with 274 additions and 294 deletions

View File

@ -60,11 +60,11 @@ impl<const R: usize> TensorIndex<R> {
let mut indices = [0; R]; let mut indices = [0; R];
let mut remaining = flat_index; let mut remaining = flat_index;
// - The method iterates over the dimensions of the tensor in reverse order // - The method iterates over the dimensions of the tensor in reverse
// (assuming row-major order). // order (assuming row-major order).
// - In each iteration, it uses the modulo operation to find the index in // - In each iteration, it uses the modulo operation to find the index
// the current dimension and integer division to reduce the flat index // in the current dimension and integer division to reduce the flat
// for the next higher dimension. // index for the next higher dimension.
// - This process is repeated for each dimension to build the // - This process is repeated for each dimension to build the
// multi-dimensional index. // multi-dimensional index.
for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() { for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() {
@ -107,7 +107,7 @@ impl<const R: usize> TensorIndex<R> {
/// assert!(index.is_max()); /// assert!(index.is_max());
/// ``` /// ```
pub fn is_max(&self) -> bool { pub fn is_max(&self) -> bool {
self.indices()[0] == self.shape().get(0) self.indices().get(0) == self.shape().dim_size(0)
} }
} }
@ -152,7 +152,7 @@ impl<const R: usize> TensorIndex<R> {
/// assert_eq!(index.flat(), 1); /// assert_eq!(index.flat(), 1);
/// ``` /// ```
pub fn inc(&mut self) { pub fn inc(&mut self) {
if self.indices()[0] == self.shape().get(0) { if self.indices().get(0) == self.shape().dim_size(0) {
return; return;
} }
@ -166,7 +166,7 @@ impl<const R: usize> TensorIndex<R> {
*dim = 0; *dim = 0;
} }
self.indices_mut()[0] = self.shape().get(0); self.indices_mut()[0] = *self.shape().dim_size(0).unwrap();
} }
/// Increments the index by one, with the specified axis fixed, /// Increments the index by one, with the specified axis fixed,
@ -187,13 +187,14 @@ impl<const R: usize> TensorIndex<R> {
pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) { pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) {
let shape = self.shape().as_array().clone(); let shape = self.shape().as_array().clone();
assert!(ax < R, "TensorAxis out of bounds"); assert!(ax < R, "TensorAxis out of bounds");
if self.indices()[ax] == self.shape().get(ax) { if self.indices().get(ax) == self.shape().dim_size(ax) {
return; return;
} }
// Iterate over all axes, skipping the fixed axis 'ax' // Iterate over all axes, skipping the fixed axis 'ax'
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() { for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
if i != ax { // Skip the fixed axis if i != ax {
// Skip the fixed axis
*dim += 1; *dim += 1;
if *dim < shape[i] { if *dim < shape[i] {
return; // No carry over needed return; // No carry over needed
@ -203,11 +204,11 @@ impl<const R: usize> TensorIndex<R> {
} }
// Handle the case where incrementing has reached the end // Handle the case where incrementing has reached the end
if self.indices()[ax] < self.shape().get(ax) { if self.indices().get(ax) < self.shape().dim_size(ax) {
self.indices_mut()[ax] += 1; self.indices_mut()[ax] += 1;
} else { } else {
// Reset if the fixed axis also overflows // Reset if the fixed axis also overflows
self.indices_mut()[ax] = self.shape().get(ax); self.indices_mut()[ax] = *self.shape().dim_size(ax).unwrap();
} }
} }
@ -226,14 +227,14 @@ impl<const R: usize> TensorIndex<R> {
/// assert_eq!(index.flat(), 13); /// assert_eq!(index.flat(), 13);
/// ``` /// ```
pub fn inc_transposed(&mut self, order: &[usize; R]) { pub fn inc_transposed(&mut self, order: &[usize; R]) {
if self.indices()[order[0]] >= self.shape().get(order[0]) { if self.indices().get(order[0]) == self.shape().dim_size(order[0]) {
return ; return;
} }
let mut carry = 1; let mut carry = 1;
for i in order.iter().rev() { for i in order.iter().rev() {
let dim_size = self.shape().get(*i); let dim_size = self.shape().dim_size(*i).unwrap().clone();
let i = self.index_mut(*i); let i = self.index_mut(*i);
if carry == 1 { if carry == 1 {
*i += 1; *i += 1;
@ -246,7 +247,7 @@ impl<const R: usize> TensorIndex<R> {
} }
if carry == 1 { if carry == 1 {
self.indices_mut()[order[0]] = self.shape().get(order[0]); self.indices_mut()[order[0]] = *self.shape().dim_size(order[0]).unwrap();
} }
} }
@ -285,8 +286,7 @@ impl<const R: usize> TensorIndex<R> {
// ---- Conversion to Flat Index ---------------------------------------------- // ---- Conversion to Flat Index ----------------------------------------------
impl <const R: usize> TensorIndex<R> { impl<const R: usize> TensorIndex<R> {
/// Converts the multi-dimensional index to a flat index. /// Converts the multi-dimensional index to a flat index.
/// ///
/// # Examples /// # Examples

View File

@ -4,11 +4,11 @@
pub mod error; pub mod error;
pub mod index; pub mod index;
pub mod iterators;
pub mod shape; pub mod shape;
pub mod tensor; pub mod tensor;
pub mod iterators;
pub use {iterators::*, error::*, index::*, shape::*, tensor::*}; pub use {error::*, index::*, iterators::*, shape::*, tensor::*};
use num::{Num, One, Zero}; use num::{Num, One, Zero};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};

View File

@ -44,11 +44,11 @@ impl<const R: usize> TensorShape<R> {
/// use manifold::*; /// use manifold::*;
/// ///
/// let shape = TensorShape::new([2, 3]); /// let shape = TensorShape::new([2, 3]);
/// assert_eq!(shape.get(0), 2); /// assert_eq!(shape.dim_size(0).unwrap(), &2);
/// assert_eq!(shape.get(1), 3); /// assert_eq!(shape.dim_size(1).unwrap(), &3);
/// ``` /// ```
pub const fn get(&self, index: usize) -> usize { pub fn dim_size(&self, index: usize) -> Option<&usize> {
self.0[index] self.0.get(index)
} }
/// Returns the shape as an array. /// Returns the shape as an array.

View File

@ -65,19 +65,29 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// use manifold::Tensor; /// use manifold::Tensor;
/// ///
/// let buffer = vec![1, 2, 3, 4, 5, 6]; /// let buffer = vec![1, 2, 3, 4, 5, 6];
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer); /// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer).unwrap();
/// assert_eq!(t.shape().as_array(), &[2, 3]); /// assert_eq!(t.shape().as_array(), &[2, 3]);
/// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]); /// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]);
/// ``` /// ```
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self { pub fn new_with_buffer(
Self { buffer, shape } shape: TensorShape<R>,
buffer: Vec<T>,
) -> Result<Self> {
if buffer.len() != shape.size() {
Err(TensorError::InvalidArgument(format!(
"Provided buffer has length {} but shape has size {}",
buffer.len(),
shape.size()
)))
} else {
Ok(Self { buffer, shape })
}
} }
} }
// ---- Trivial Getters ------------------------------------------------------- // ---- Trivial Getters -------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
/// Get the rank of the tensor. /// Get the rank of the tensor.
/// ///
/// ``` /// ```
@ -86,7 +96,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// let t = Tensor::<f64, 2>::new([3, 3].into()); /// let t = Tensor::<f64, 2>::new([3, 3].into());
/// assert_eq!(t.rank(), 2); /// assert_eq!(t.rank(), 2);
/// ``` /// ```
pub fn rank(&self) -> usize { pub const fn rank(&self) -> usize {
R R
} }
@ -352,32 +362,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
) -> Result<()> { ) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a % b) self.ew_for_each(other, result, &|a, b| a % b)
} }
// pub fn product<const S: usize>(
// &self,
// other: &Tensor<T, S>,
// ) -> Tensor<T, { R + S }> {
// let mut new_shape_vec = Vec::new();
// new_shape_vec.extend_from_slice(&self.shape().as_array());
// new_shape_vec.extend_from_slice(&other.shape().as_array());
// let new_shape_array: [usize; R + S] = new_shape_vec
// .try_into()
// .expect("Failed to create shape array");
// let mut new_buffer =
// Vec::with_capacity(self.buffer.len() * other.buffer.len());
// for &item_self in &self.buffer {
// for &item_other in &other.buffer {
// new_buffer.push(item_self * item_other);
// }
// }
// Tensor {
// buffer: new_buffer,
// shape: TensorShape::new(new_shape_array),
// }
// }
} }
// ---- Reshape --------------------------------------------------------------- // ---- Reshape ---------------------------------------------------------------
@ -405,7 +389,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", "TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
))) )))
} else { } else {
Ok(Tensor::new_with_buffer(shape, self.buffer)) Ok(Tensor::new_with_buffer(shape, self.buffer).unwrap())
} }
} }
} }

View File

@ -53,14 +53,10 @@ fn test_iterating_3d_tensor() {
} }
} }
println!("{}", tensor);
// Iterate over the tensor and check that the numbers are correct // Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor); let mut iter = TensorIterator::new(&tensor);
println!("{}", iter);
assert_eq!(iter.next(), Some(&0)); assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&1)); assert_eq!(iter.next(), Some(&1));