diff --git a/Cargo.lock b/Cargo.lock index 5bbfa80..558f98c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -302,6 +302,7 @@ dependencies = [ "ndarray", "num", "rand", + "rayon", "serde", "serde_json", "static_assertions", diff --git a/Cargo.toml b/Cargo.toml index 22566c0..406bb40 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ bytemuck = "1.14.0" getset = "0.1.2" itertools = "0.12.0" num = "0.4.1" +rayon = "1.8.0" serde = { version = "1.0.193", features = ["derive"] } thiserror = "1.0.52" diff --git a/benches/manifold_benchmark.rs b/benches/manifold_benchmark.rs index 59e4d62..6ae440d 100644 --- a/benches/manifold_benchmark.rs +++ b/benches/manifold_benchmark.rs @@ -3,9 +3,11 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion}; use manifold::*; use rand::Rng; +const SIZE: usize = 10000; + fn random_tensor_r2_manifold() -> Tensor { let mut rng = rand::thread_rng(); - let mut tensor = tensor!([[0.0; 1000]; 1000]); + let mut tensor = tensor!([[0.0; SIZE]; SIZE]); for i in 0..tensor.len() { tensor[i] = rng.gen(); } @@ -14,7 +16,7 @@ fn random_tensor_r2_manifold() -> Tensor { fn random_tensor_r2_ndarray() -> ndarray::Array2 { let mut rng = rand::thread_rng(); - let (rows, cols) = (1000, 1000); + let (rows, cols) = (SIZE, SIZE); let mut tensor = ndarray::Array2::::zeros((rows, cols)); for i in 0..rows { for j in 0..cols { @@ -40,7 +42,7 @@ fn tensor_product(c: &mut Criterion) { let a = random_tensor_r2_manifold(); let b = random_tensor_r2_manifold(); let c = a + b; - assert!(c.shape().as_array() == &[1000, 1000]); + assert!(c.shape().as_array() == &[SIZE, SIZE]); }) }, ); @@ -53,7 +55,7 @@ fn tensor_product(c: &mut Criterion) { let a = random_tensor_r2_ndarray(); let b = random_tensor_r2_ndarray(); let c = a + b; - assert!(c.shape() == &[1000, 1000]); + assert!(c.shape() == &[SIZE, SIZE]); }) }, ); diff --git a/src/axis.rs b/src/axis.rs deleted file mode 100644 index 37c670b..0000000 --- a/src/axis.rs +++ /dev/null @@ -1,118 +0,0 @@ -use super::*; -use getset::{Getters, MutGetters}; - -#[derive(Clone, Debug, Getters)] -pub struct TensorAxis<'a, T: Value, const R: usize> { - #[getset(get = "pub")] - tensor: &'a Tensor, - #[getset(get = "pub")] - dim: usize, -} - -impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> { - pub fn new(tensor: &'a Tensor, dim: usize) -> Self { - assert!(dim < R, "TensorAxis out of bounds"); - Self { tensor, dim } - } - - pub fn len(&self) -> usize { - self.tensor.shape().get(self.dim) - } - - pub fn shape(&self) -> &TensorShape { - self.tensor.shape() - } - - pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> { - assert!(level < self.len(), "Level out of bounds"); - let mut index = TensorIndex::new(self.shape().clone(), [0; R]); - index.set_axis(self.dim, level); - TensorAxisIterator::new(self) - .set_start(level) - .set_end(level + 1) - } -} - -#[derive(Clone, Debug, Getters, MutGetters)] -pub struct TensorAxisIterator<'a, T: Value, const R: usize> { - #[getset(get = "pub")] - axis: &'a TensorAxis<'a, T, R>, - #[getset(get = "pub", get_mut = "pub")] - index: TensorIndex, - #[getset(get = "pub")] - end: Option, -} - -impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> { - pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self { - Self { - axis, - index: TensorIndex::new(axis.shape().clone(), [0; R]), - end: None, - } - } - - pub fn set_start(self, start: usize) -> Self { - assert!(start < self.axis().len(), "Start out of bounds"); - let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]); - index.set_axis(self.axis.dim, start); - Self { - axis: self.axis(), - index, - end: None, - } - } - - pub fn set_end(self, end: usize) -> Self { - assert!(end <= self.axis().len(), "End out of bounds"); - Self { - axis: self.axis(), - index: self.index().clone(), - end: Some(end), - } - } - - pub fn set_level(self, level: usize) -> Self { - assert!(level < self.axis().len(), "Level out of bounds"); - self.set_start(level).set_end(level + 1) - } - - pub fn level(&'a self, level: usize) -> impl Iterator + 'a { - Self::new(self.axis()).set_level(level) - } - - pub fn axis_max_idx(&self) -> usize { - self.end().unwrap_or(self.axis().len()) - } - - pub fn axis_idx(&self) -> usize { - self.index().get_axis(*self.axis().dim()) - } - - pub fn axis_dim(&self) -> usize { - self.axis().dim().clone() - } -} - -impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> { - type Item = &'a T; - - fn next(&mut self) -> Option { - if self.axis_idx() == self.axis_max_idx() { - return None; - } - let result = unsafe { self.axis().tensor().get_unchecked(self.index) }; - let axis_dim = self.axis_dim(); - self.index_mut().inc_axis(axis_dim); - Some(result) - } -} - -impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> { - type Item = &'a T; - type IntoIter = TensorAxisIterator<'a, T, R>; - - fn into_iter(self) -> Self::IntoIter { - TensorAxisIterator::new(&self) - } -} diff --git a/src/index.rs b/src/index.rs index 4632e6e..a6b153e 100644 --- a/src/index.rs +++ b/src/index.rs @@ -2,9 +2,18 @@ use super::*; use getset::{Getters, MutGetters}; use std::{ cmp::Ordering, - ops::{Add, Index, IndexMut, Sub}, + ops::{Index, IndexMut}, }; +/// A multi-dimensional index into a tensor. +/// +/// ``` +/// use manifold::*; +/// +/// let shape = TensorShape::new([2, 3, 4]); +/// let index = TensorIndex::new(shape, [1, 2, 3]); +/// assert_eq!(index.flat(), 23); +/// ``` #[derive(Clone, Copy, Debug, Getters, MutGetters)] pub struct TensorIndex { #[getset(get = "pub", get_mut = "pub")] @@ -16,13 +25,30 @@ pub struct TensorIndex { // ---- Construction and Initialization --------------------------------------- impl TensorIndex { + /// Creates a new `TensorIndex` instance. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let index = TensorIndex::new(shape, [1, 2, 3]); + /// assert_eq!(index.flat(), 23); + /// ``` pub fn new(shape: TensorShape, indices: [usize; R]) -> Self { - if !shape.check_indices(indices) { - panic!("indices out of bounds"); - } - Self { indices, shape } + let index = Self { indices, shape }; + assert!(index.check_indices(indices)); + index } + /// Creates a new `TensorIndex` instance with all indices set to zero. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let index = TensorIndex::zero(shape); + /// assert_eq!(index.flat(), 0); + /// ``` pub const fn zero(shape: TensorShape) -> Self { Self { indices: [0; R], @@ -30,106 +56,185 @@ impl TensorIndex { } } - pub fn last(shape: TensorShape) -> Self { - let max_indices = - shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); - Self { - indices: max_indices, - shape, + pub fn from_flat(shape: TensorShape, flat_index: usize) -> Self { + let mut indices = [0; R]; + let mut remaining = flat_index; + + // - The method iterates over the dimensions of the tensor in reverse + // order (assuming row-major order). + // - In each iteration, it uses the modulo operation to find the index + // in the current dimension and integer division to reduce the flat + // index for the next higher dimension. + // - This process is repeated for each dimension to build the + // multi-dimensional index. + for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() { + *idx = remaining % dim_size; + remaining /= dim_size; } + + // Reverse the indices to match the original dimension order + indices.reverse(); + Self::new(shape.clone(), indices) } } +// ---- Trivial Functions ----------------------------------------------------- + impl TensorIndex { - pub fn is_zero(&self) -> bool { - self.indices.iter().all(|&i| i == 0) - } - - pub fn is_overflow(&self) -> bool { - // Check if the last index is equal to the size of the last dimension - self.indices[0] >= self.shape.get(R - 1) - } - - pub fn reset(&mut self) { - self.indices = [0; R]; - } - - /// Increments the index and returns a boolean indicating whether the end - /// has been reached. + /// Returns `true` if all indices are zero. /// - /// # Returns - /// `true` if the increment does not overflow and is still within bounds; - /// `false` if it overflows, indicating the end of the tensor. - pub fn inc(&mut self) -> bool { - if self.indices()[0] >= self.shape().get(0) { - return false; - } - let shape = self.shape().as_array().clone(); - let mut carry = 1; - for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() { - if carry == 1 { - *i += 1; - if *i >= dim_size { - *i = 0; // Reset index in this dimension and carry over - } else { - carry = 0; // Increment successful, no carry needed - } - } - } - - // If carry is still 1 after the loop, it means we've incremented past - // the last dimension - if carry == 1 { - // Set the index to an invalid state to indicate the end of the - // iteration indicated by setting the first index to the - // size of the first dimension - self.indices[0] = self.shape.as_array()[0]; - return true; // Indicate that the iteration is complete - } - false + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let index = TensorIndex::zero(shape); + /// assert!(index.is_zero()); + /// ``` + pub fn is_zero(&self) -> bool { + self.indices().iter().all(|&i| i == 0) } - // fn inc_axis + /// Returns `true` if the last index is equal to the size of the last + /// dimension. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([1, 1]); + /// let mut index = TensorIndex::zero(shape); + /// assert!(!index.is_max()); + /// index.inc(); + /// assert!(index.is_max()); + /// ``` + pub fn is_max(&self) -> bool { + self.indices().get(0) == self.shape().dim_size(0) + } +} - pub fn inc_axis(&mut self, fixed_axis: usize) { - assert!(fixed_axis < R, "TensorAxis out of bounds"); - assert!( - self.indices()[fixed_axis] < self.shape().get(fixed_axis), - "Index out of bounds" - ); +// ---- Utils ----------------------------------------------------------------- - // Try to increment non-fixed axes - for i in (0..R).rev() { - if i != fixed_axis { - if self.indices[i] + 1 < self.shape.get(i) { - self.indices[i] += 1; - return; - } else { - self.indices[i] = 0; - } - } - } +impl TensorIndex { + /// Resets the index to zero. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let mut index = TensorIndex::new(shape, [1, 2, 3]); + /// assert_eq!(index.flat(), 23); + /// index.reset(); + /// assert_eq!(index.flat(), 0); + /// ``` + pub fn reset(&mut self) { + *self.indices_mut() = [0; R]; + } - if self.indices[fixed_axis] < self.shape.get(fixed_axis) { - self.indices[fixed_axis] += 1; - for i in 0..R { - if i != fixed_axis { - self.indices[i] = 0; - } - } + fn check_indices(&self, indices: [usize; R]) -> bool { + indices + .iter() + .zip(self.shape().as_array().iter()) + .all(|(&idx, &dim_size)| idx < dim_size) + } +} + +// ---- Increment and Decrement ----------------------------------------------- + +impl TensorIndex { + /// Increments the index by one. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let mut index = TensorIndex::zero(shape); + /// assert_eq!(index.flat(), 0); + /// index.inc(); + /// assert_eq!(index.flat(), 1); + /// ``` + pub fn inc(&mut self) { + if self.indices().get(0) == self.shape().dim_size(0) { return; } + + let shape = self.shape().as_array().clone(); + + for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() { + *dim += 1; + if *dim < shape[i] { + return; + } + *dim = 0; + } + + self.indices_mut()[0] = *self.shape().dim_size(0).unwrap(); } - pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool { - if self.indices()[order[0]] >= self.shape().get(order[0]) { - return false; + /// Increments the index by one, with the specified axis fixed, + /// i.e. the index of the fixed axis is incremented only if the + /// index of the other axes reaches the maximum. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let mut index = TensorIndex::zero(shape); + /// assert_eq!(index.flat(), 0); + /// index.inc_fixed_axis(Axis(0)); + /// assert_eq!(index.flat(), 4); + /// index.inc_fixed_axis(Axis(1)); + /// assert_eq!(index.flat(), 5); + /// ``` + pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) { + let shape = self.shape().as_array().clone(); + assert!(ax < R, "TensorAxis out of bounds"); + if self.indices().get(ax) == self.shape().dim_size(ax) { + return; + } + + // Iterate over all axes, skipping the fixed axis 'ax' + for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() { + if i != ax { + // Skip the fixed axis + *dim += 1; + if *dim < shape[i] { + return; // No carry over needed + } + *dim = 0; // Reset the current axis and carry over to the next + } + } + + // Handle the case where incrementing has reached the end + if self.indices().get(ax) < self.shape().dim_size(ax) { + self.indices_mut()[ax] += 1; + } else { + // Reset if the fixed axis also overflows + self.indices_mut()[ax] = *self.shape().dim_size(ax).unwrap(); + } + } + + /// Increments the index by one, reordering the order in which the + /// axes are incremented. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let mut index = TensorIndex::zero(shape); + /// assert_eq!(index.flat(), 0); + /// index.inc_transposed(&[2, 1, 0]); + /// assert_eq!(index.flat(), 12); + /// index.inc_transposed(&[0, 1, 2]); + /// assert_eq!(index.flat(), 13); + /// ``` + pub fn inc_transposed(&mut self, order: &[usize; R]) { + if self.indices().get(order[0]) == self.shape().dim_size(order[0]) { + return; } let mut carry = 1; for i in order.iter().rev() { - let dim_size = self.shape().get(*i); + let dim_size = self.shape().dim_size(*i).unwrap().clone(); let i = self.index_mut(*i); if carry == 1 { *i += 1; @@ -142,13 +247,21 @@ impl TensorIndex { } if carry == 1 { - self.indices_mut()[order[0]] = self.shape().get(order[0]); - return true; + self.indices_mut()[order[0]] = *self.shape().dim_size(order[0]).unwrap(); } - - false } + /// Decrements the index by one. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let mut index = TensorIndex::new(shape, [1, 2, 3]); + /// assert_eq!(index.flat(), 23); + /// index.dec(); + /// assert_eq!(index.flat(), 22); + /// ``` pub fn dec(&mut self) { // Check if already at the start if self.indices.iter().all(|&i| i == 0) { @@ -169,102 +282,22 @@ impl TensorIndex { } } } +} - pub fn dec_axis(&mut self, fixed_axis: usize) -> bool { - // Check if the fixed axis index is already in an invalid state - if self.indices[fixed_axis] == self.shape.get(fixed_axis) { - return false; - } - - // Try to decrement non-fixed axes - for i in (0..R).rev() { - if i != fixed_axis { - if self.indices[i] > 0 { - self.indices[i] -= 1; - return true; - } else { - self.indices[i] = self.shape.get(i) - 1; - } - } - } - - // Decrement the fixed axis if possible and reset other axes to their - // max - if self.indices[fixed_axis] > 0 { - self.indices[fixed_axis] -= 1; - for i in 0..R { - if i != fixed_axis { - self.indices[i] = self.shape.get(i) - 1; - } - } - } else { - // Fixed axis already at minimum, set to invalid state - self.indices[fixed_axis] = self.shape.get(fixed_axis); - } - - true - } - - pub fn dec_transposed(&mut self, order: [usize; R]) { - // Iterate over the axes in the specified order - for &axis in &order { - // Try to decrement the current axis - if self.indices[axis] > 0 { - self.indices[axis] -= 1; - // Reset all preceding axes in the order to their maximum - for &prev_axis in &order { - if prev_axis == axis { - break; - } - self.indices[prev_axis] = self.shape.get(prev_axis) - 1; - } - return; - } - } - - // If no axis can be decremented, set the first axis in the order to - // indicate overflow - self.indices[order[0]] = self.shape.get(order[0]); - } +// ---- Conversion to Flat Index ---------------------------------------------- +impl TensorIndex { /// Converts the multi-dimensional index to a flat index. /// - /// This method calculates the flat index corresponding to the - /// multi-dimensional index stored in `self.indices`, given the shape of - /// the tensor stored in `self.shape`. The calculation is based on the - /// assumption that the tensor is stored in row-major order, - /// where the last dimension varies the fastest. + /// # Examples /// - /// # Returns - /// The flat index corresponding to the multi-dimensional index. + /// ``` + /// use manifold::*; /// - /// # How It Works - /// - The method iterates over each pair of corresponding index and - /// dimension size, starting from the last dimension (hence `rev()` is - /// used for reverse iteration). - /// - In each iteration, it performs two main operations: - /// 1. **Index Contribution**: Multiplies the current index (`idx`) by a - /// running product of dimension sizes (`product`). This calculates the - /// contribution of the current index to the overall flat index. - /// 2. **Product Update**: Multiplies `product` by the current dimension - /// size (`dim_size`). This updates `product` for the next iteration, - /// as each dimension contributes to the flat index based on the sizes - /// of all preceding dimensions. - /// - The `fold` operation accumulates these results, starting with an - /// initial state of `(0, 1)` where `0` is the initial flat index and `1` - /// is the initial product. - /// - The final flat index is obtained after the last iteration, which is - /// the first element of the tuple resulting from the `fold`. - /// - /// # Example - /// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`. - /// - Starting with a flat index of 0 and a product of 1, - /// - For the last dimension (size 5), add 3 * 1 to the flat index. Update - /// the product to 1 * 5 = 5. - /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update - /// the product to 5 * 4 = 20. - /// - For the first dimension (size 3), add 1 * 20 to the flat index. The - /// final flat index is 3 + 10 + 20 = 33. + /// let shape = TensorShape::new([2, 3, 4]); + /// let index = TensorIndex::new(shape, [1, 2, 3]); + /// assert_eq!(index.flat(), 23); + /// ``` pub fn flat(&self) -> usize { self.indices() .iter() @@ -275,37 +308,9 @@ impl TensorIndex { }) .0 } - - pub fn set_axis(&mut self, axis: usize, value: usize) { - assert!(axis < R, "TensorAxis out of bounds"); - // assert!(value < self.shape.get(axis), "Value out of bounds"); - self.indices[axis] = value; - } - - pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool { - assert!(axis < R, "TensorAxis out of bounds"); - if value < self.shape.get(axis) { - self.indices[axis] = value; - true - } else { - false - } - } - - pub fn get_axis(&self, axis: usize) -> usize { - assert!(axis < R, "TensorAxis out of bounds"); - self.indices[axis] - } - - pub fn iter_transposed( - &self, - order: [usize; R], - ) -> TensorIndexTransposedIterator { - TensorIndexTransposedIterator::new(self.shape().clone(), order) - } } -// --- blanket impls --- +// --- Equality and Ordering -------------------------------------------------- impl PartialEq for TensorIndex { fn eq(&self, other: &Self) -> bool { @@ -327,6 +332,8 @@ impl Ord for TensorIndex { } } +// ---- Indexing -------------------------------------------------------------- + impl Index for TensorIndex { type Output = usize; @@ -341,17 +348,18 @@ impl IndexMut for TensorIndex { } } +// ---- Conversion to TensorIndex --------------------------------------------- + impl From<(TensorShape, [usize; R])> for TensorIndex { fn from((shape, indices): (TensorShape, [usize; R])) -> Self { - assert!(shape.check_indices(indices)); + // assert!(shape.check_indices(indices)); Self::new(shape, indices) } } impl From<(TensorShape, usize)> for TensorIndex { fn from((shape, flat_index): (TensorShape, usize)) -> Self { - let indices = shape.index_from_flat(flat_index).indices; - Self::new(shape, indices) + Self::from_flat(shape, flat_index) } } @@ -367,6 +375,8 @@ impl From> for TensorIndex { } } +// ---- Display --------------------------------------------------------------- + impl std::fmt::Display for TensorIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; @@ -384,113 +394,3 @@ impl std::fmt::Display for TensorIndex { write!(f, "]") } } - -// ---- Arithmetic Operations ---- - -impl Add for TensorIndex { - type Output = Self; - - fn add(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); - - let mut result_indices = [0; R]; - for i in 0..R { - result_indices[i] = self.indices[i] + rhs.indices[i]; - } - - Self { - indices: result_indices, - shape: self.shape, - } - } -} - -impl Sub for TensorIndex { - type Output = Self; - - fn sub(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); - - let mut result_indices = [0; R]; - for i in 0..R { - result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]); - } - - Self { - indices: result_indices, - shape: self.shape, - } - } -} - -// ---- Iterator ---- - -pub struct TensorIndexIterator { - current: TensorIndex, - end: bool, -} - -impl TensorIndexIterator { - pub fn new(shape: TensorShape) -> Self { - Self { - current: TensorIndex::zero(shape), - end: false, - } - } -} - -impl Iterator for TensorIndexIterator { - type Item = TensorIndex; - - fn next(&mut self) -> Option { - if self.end { - return None; - } - - let result = self.current; - self.end = self.current.inc(); - Some(result) - } -} - -impl IntoIterator for TensorIndex { - type Item = TensorIndex; - type IntoIter = TensorIndexIterator; - - fn into_iter(self) -> Self::IntoIter { - TensorIndexIterator { - current: self, - end: false, - } - } -} - -pub struct TensorIndexTransposedIterator { - current: TensorIndex, - order: [usize; R], - end: bool, -} - -impl TensorIndexTransposedIterator { - pub fn new(shape: TensorShape, order: [usize; R]) -> Self { - Self { - current: TensorIndex::zero(shape), - end: false, - order, - } - } -} - -impl Iterator for TensorIndexTransposedIterator { - type Item = TensorIndex; - - fn next(&mut self) -> Option { - if self.end { - return None; - } - - let result = self.current; - self.end = self.current.inc_transposed(&self.order); - Some(result) - } -} diff --git a/src/iterators.rs b/src/iterators.rs new file mode 100644 index 0000000..1b3ce9b --- /dev/null +++ b/src/iterators.rs @@ -0,0 +1,71 @@ +use super::*; +use std::fmt::{Display, Formatter, Result as FmtResult}; + +// ---- Iterator -------------------------------------------------------------- + +pub struct TensorIterator<'a, T: Value, const R: usize> { + tensor: &'a Tensor, + index: TensorIndex, +} + +impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { + pub fn new(tensor: &'a Tensor) -> Self { + Self { + tensor, + index: tensor.shape().index_zero(), + } + } +} + +impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> { + type Item = &'a T; + fn next(&mut self) -> Option { + if self.index.is_max() { + return None; + } + + let result = unsafe { self.tensor.get_unchecked(self.index) }; + self.index.inc(); + Some(result) + } +} + +impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor { + type Item = &'a T; + type IntoIter = TensorIterator<'a, T, R>; + + fn into_iter(self) -> Self::IntoIter { + TensorIterator::new(self) + } +} + +impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + // Print the current index and flat index + write!( + f, + "Current Index: {}, Flat Index: {}", + self.index, + self.index.flat() + )?; + + // Print the tensor elements, highlighting the current element + write!(f, ", Tensor Elements: [")?; + for (i, elem) in self.tensor.buffer().iter().enumerate() { + if i == self.index.flat() { + write!(f, "*{}*", elem)?; // Highlight the current element + } else { + write!(f, "{}", elem)?; + } + if i < self.tensor.buffer().len() - 1 { + write!(f, ", ")?; + } + } + + write!(f, "]") + } +} + +// ---- Axis Iterator --------------------------------------------------------- + +// ---- Transposed Iterator --------------------------------------------------- diff --git a/src/lib.rs b/src/lib.rs index 396c220..8de1485 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,14 +2,36 @@ #![feature(generic_const_exprs)] #![warn(clippy::all)] -pub mod axis; pub mod error; pub mod index; +pub mod iterators; pub mod shape; pub mod tensor; -pub mod value; -pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*}; +pub use {error::*, index::*, iterators::*, shape::*, tensor::*}; + +use num::{Num, One, Zero}; +use serde::{Deserialize, Serialize}; +use std::{fmt::Display, iter::Sum}; + +/// A trait for types that can be used as values in a tensor. +pub trait Value: + Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> +{ +} + +impl Value for T where + T: Num + + Zero + + One + + Copy + + Clone + + Display + + Serialize + + Deserialize<'static> + + Sum +{ +} #[macro_export] macro_rules! tensor { diff --git a/src/shape.rs b/src/shape.rs index a26112f..83c29dd 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -4,18 +4,90 @@ use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; use serde::ser::{Serialize, SerializeTuple, Serializer}; use std::fmt::{Formatter, Result as FmtResult}; +/// A tensor's shape. +/// +/// ``` +/// use manifold::*; +/// +/// let shape = shape!([2, 3]); +/// assert_eq!(shape.dim_size(0), Some(&2)); +/// assert_eq!(shape.dim_size(1), Some(&3)); +/// assert_eq!(shape.dim_size(2), None); +/// ``` #[derive(Clone, Copy, Debug)] -pub struct TensorShape([usize; R]); +pub struct TensorShape(pub(crate) [usize; R]); + +// ---- Construction and Initialization --------------------------------------- impl TensorShape { + /// Creates a new `TensorShape` instance. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3]); + /// assert_eq!(shape.dim_size(0), Some(&2)); + /// assert_eq!(shape.dim_size(1), Some(&3)); + /// assert_eq!(shape.dim_size(2), None); + /// ``` pub const fn new(shape: [usize; R]) -> Self { Self(shape) } +} - pub fn axis(&self, index: usize) -> Option<&usize> { +// ---- Getters --------------------------------------------------------------- + +impl TensorShape { + /// Get the size of the specified dimension. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3]); + /// assert_eq!(shape.dim_size(0).unwrap(), &2); + /// assert_eq!(shape.dim_size(1).unwrap(), &3); + /// ``` + pub fn dim_size(&self, index: usize) -> Option<&usize> { self.0.get(index) } + /// Returns the shape as an array. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3]); + /// assert_eq!(shape.as_array(), &[2, 3]); + /// ``` + pub const fn as_array(&self) -> &[usize; R] { + &self.0 + } + + /// Returns the size of the shape, meaning the product of all dimensions. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3]); + /// assert_eq!(shape.size(), 6); + /// ``` + pub fn size(&self) -> usize { + self.0.iter().product() + } +} + +// ---- Manipulation ---------------------------------------------------------- + +impl TensorShape { + /// Reorders the dimensions of the shape. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3, 4]); + /// let new_shape = shape.reorder([2, 0, 1]); + /// assert_eq!(new_shape, shape!([4, 2, 3])); + /// ``` pub fn reorder(&self, indices: [usize; R]) -> Self { let mut new_shape = TensorShape::new([0; R]); for (new_index, &index) in indices.iter().enumerate() { @@ -24,78 +96,15 @@ impl TensorShape { new_shape } - pub const fn as_array(&self) -> &[usize; R] { - &self.0 - } - - pub const fn rank(&self) -> usize { - R - } - - pub fn flat_max(&self) -> usize { - self.size() - 1 - } - - pub fn size(&self) -> usize { - self.0.iter().product() - } - - pub fn iter(&self) -> impl Iterator { - self.0.iter() - } - - pub const fn get(&self, index: usize) -> usize { - self.0[index] - } - - pub fn check_indices(&self, indices: [usize; R]) -> bool { - indices - .iter() - .zip(self.0.iter()) - .all(|(&idx, &dim_size)| idx < dim_size) - } - - /// Converts a flat index to a multi-dimensional index. + /// Creates a new shape by removing the specified dimensions. /// - /// # Arguments - /// * `flat_index` - The flat index to convert. + /// ``` + /// use manifold::*; /// - /// # Returns - /// An `TensorIndex` instance representing the multi-dimensional index - /// corresponding to the flat index. - /// - /// # How It Works - /// - The method iterates over the dimensions of the tensor in reverse order - /// (assuming row-major order). - /// - In each iteration, it uses the modulo operation to find the index in - /// the current dimension and integer division to reduce the flat index - /// for the next higher dimension. - /// - This process is repeated for each dimension to build the - /// multi-dimensional index. - pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex { - let mut indices = [0; R]; - let mut remaining = flat_index; - - for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() { - *idx = remaining % dim_size; - remaining /= dim_size; - } - - indices.reverse(); // Reverse the indices to match the original dimension order - TensorIndex::new(self.clone(), indices) - } - - pub fn index_zero(&self) -> TensorIndex { - TensorIndex::zero(self.clone()) - } - - pub fn index_max(&self) -> TensorIndex { - let max_indices = - self.0 - .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); - TensorIndex::new(self.clone(), max_indices) - } - + /// let shape = TensorShape::new([2, 3, 4]); + /// let new_shape = shape.remove_dims([0, 2]); + /// assert_eq!(new_shape, shape!([3])); + /// ``` pub fn remove_dims( &self, dims_to_remove: [usize; NAX], @@ -118,34 +127,70 @@ impl TensorShape { TensorShape(new_shape) } +} - pub fn remove_axes<'a, T: Value, const NAX: usize>( - &self, - axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX], - ) -> TensorShape<{ R - NAX }> { - // Create a new array to store the remaining dimensions - let mut new_shape = [0; R - NAX]; - let mut new_index = 0; +// ---- Iterators ------------------------------------------------------------- - // Iterate over the original dimensions - for (index, &dim) in self.0.iter().enumerate() { - // Skip dimensions that are in the axes_to_remove array - for axis in axes_to_remove { - if *axis.dim() == index { - continue; - } - } - - // Add the dimension to the new shape array - new_shape[new_index] = dim; - new_index += 1; - } - - TensorShape(new_shape) +impl TensorShape { + /// Returns an iterator over the dimensions of the shape. + /// + /// ``` + /// use manifold::*; + /// + /// let shape = TensorShape::new([2, 3]); + /// let mut iter = shape.iter(); + /// assert_eq!(iter.next(), Some(&2)); + /// assert_eq!(iter.next(), Some(&3)); + /// assert_eq!(iter.next(), None); + /// ``` + pub fn iter(&self) -> impl Iterator { + self.0.iter() } } -// ---- Serialize and Deserialize ---- +// ---- Utils ----------------------------------------------------------------- + +impl TensorShape { + pub fn index_zero(&self) -> TensorIndex { + TensorIndex::zero(self.clone()) + } + + pub fn index_max(&self) -> TensorIndex { + let max_indices = + self.0 + .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); + TensorIndex::new(self.clone(), max_indices) + } +} + +// ---- From ------------------------------------------------------------------ + +impl From<[usize; R]> for TensorShape { + fn from(shape: [usize; R]) -> Self { + Self::new(shape) + } +} + +impl From> for TensorShape +where + T: Value, +{ + fn from(tensor: Tensor) -> Self { + *tensor.shape() + } +} + +// ---- Equality -------------------------------------------------------------- + +impl PartialEq for TensorShape { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for TensorShape {} + +// ---- Serialization --------------------------------------------------------- struct TensorShapeVisitor; @@ -191,30 +236,3 @@ impl Serialize for TensorShape { seq.end() } } - -// ---- Blanket Implementations ---- - -impl From<[usize; R]> for TensorShape { - fn from(shape: [usize; R]) -> Self { - Self::new(shape) - } -} - -impl PartialEq for TensorShape { - fn eq(&self, other: &Self) -> bool { - self.0 == other.0 - } -} - -impl Eq for TensorShape {} - -// ---- From and Into Implementations ---- - -impl From> for TensorShape -where - T: Value, -{ - fn from(tensor: Tensor) -> Self { - *tensor.shape() - } -} diff --git a/src/tensor.rs b/src/tensor.rs index 6224859..a65c712 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -28,6 +28,9 @@ pub struct Tensor { shape: TensorShape, } +/// A type that represents an axis of a tensor. +pub struct Axis(pub usize); + // ---- Construction and Initialization --------------------------------------- impl Tensor { @@ -62,22 +65,49 @@ impl Tensor { /// use manifold::Tensor; /// /// let buffer = vec![1, 2, 3, 4, 5, 6]; - /// let t = Tensor::::new_with_buffer([2, 3].into(), buffer); + /// let t = Tensor::::new_with_buffer([2, 3].into(), buffer).unwrap(); /// assert_eq!(t.shape().as_array(), &[2, 3]); /// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]); /// ``` - pub fn new_with_buffer(shape: TensorShape, buffer: Vec) -> Self { - Self { buffer, shape } + pub fn new_with_buffer( + shape: TensorShape, + buffer: Vec, + ) -> Result { + if buffer.len() != shape.size() { + Err(TensorError::InvalidArgument(format!( + "Provided buffer has length {} but shape has size {}", + buffer.len(), + shape.size() + ))) + } else { + Ok(Self { buffer, shape }) + } } } // ---- Trivial Getters ------------------------------------------------------- impl Tensor { - pub fn rank(&self) -> usize { + /// Get the rank of the tensor. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let t = Tensor::::new([3, 3].into()); + /// assert_eq!(t.rank(), 2); + /// ``` + pub const fn rank(&self) -> usize { R } + /// Get the length of the tensor's buffer. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let t = Tensor::::new([3, 3].into()); + /// assert_eq!(t.len(), 9); + /// ``` pub fn len(&self) -> usize { self.buffer().len() } @@ -332,32 +362,6 @@ impl Tensor { ) -> Result<()> { self.ew_for_each(other, result, &|a, b| a % b) } - - // pub fn product( - // &self, - // other: &Tensor, - // ) -> Tensor { - // let mut new_shape_vec = Vec::new(); - // new_shape_vec.extend_from_slice(&self.shape().as_array()); - // new_shape_vec.extend_from_slice(&other.shape().as_array()); - - // let new_shape_array: [usize; R + S] = new_shape_vec - // .try_into() - // .expect("Failed to create shape array"); - - // let mut new_buffer = - // Vec::with_capacity(self.buffer.len() * other.buffer.len()); - // for &item_self in &self.buffer { - // for &item_other in &other.buffer { - // new_buffer.push(item_self * item_other); - // } - // } - - // Tensor { - // buffer: new_buffer, - // shape: TensorShape::new(new_shape_array), - // } - // } } // ---- Reshape --------------------------------------------------------------- @@ -385,37 +389,11 @@ impl Tensor { "TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", ))) } else { - Ok(Tensor::new_with_buffer(shape, self.buffer)) + Ok(Tensor::new_with_buffer(shape, self.buffer).unwrap()) } } } -// ---- Transpose ------------------------------------------------------------- - -impl Tensor { - /// Transpose the tensor according to the given order. The order must be a - /// permutation of the tensor's axes. - /// - /// ``` - /// use manifold::{tensor, Tensor, TensorShape}; - /// - /// let t = tensor!([[1, 2], [3, 4]]); - /// let t = t.transpose([1, 0]).unwrap(); - /// assert_eq!(t, tensor!([[1, 3], [2, 4]])); - /// ``` - pub fn transpose(self, order: [usize; R]) -> Result { - let buffer = TensorIndex::from(self.shape().clone()) - .iter_transposed(order) - .map(|index| self.get(index).unwrap().clone()) - .collect(); - - Ok(Tensor { - buffer, - shape: self.shape().reorder(order), - }) - } -} - // ---- Operations ------------------------------------------------------------ impl Add for Tensor { @@ -654,73 +632,6 @@ where impl Eq for Tensor where T: Eq {} -// ---- Iterator -------------------------------------------------------------- - -pub struct TensorIterator<'a, T: Value, const R: usize> { - tensor: &'a Tensor, - index: TensorIndex, -} - -impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { - pub fn new(tensor: &'a Tensor) -> Self { - Self { - tensor, - index: tensor.shape.index_zero(), - } - } -} - -impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> { - type Item = &'a T; - fn next(&mut self) -> Option { - if self.index.is_overflow() { - return None; - } - - let result = unsafe { self.tensor.get_unchecked(self.index) }; - self.index.inc(); - Some(result) - } -} - -impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor { - type Item = &'a T; - type IntoIter = TensorIterator<'a, T, R>; - - fn into_iter(self) -> Self::IntoIter { - TensorIterator::new(self) - } -} - -// ---- Formatting ------------------------------------------------------------ - -impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { - fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - // Print the current index and flat index - write!( - f, - "Current Index: {}, Flat Index: {}", - self.index, - self.index.flat() - )?; - - // Print the tensor elements, highlighting the current element - write!(f, ", Tensor Elements: [")?; - for (i, elem) in self.tensor.buffer().iter().enumerate() { - if i == self.index.flat() { - write!(f, "*{}*", elem)?; // Highlight the current element - } else { - write!(f, "{}", elem)?; - } - if i < self.tensor.buffer().len() - 1 { - write!(f, ", ")?; - } - } - - write!(f, "]") - } -} - // ---- From ------------------------------------------------------------------ impl From> for Tensor { diff --git a/src/value.rs b/src/value.rs deleted file mode 100644 index ae93e38..0000000 --- a/src/value.rs +++ /dev/null @@ -1,22 +0,0 @@ -use num::{Num, One, Zero}; -use serde::{Deserialize, Serialize}; -use std::{fmt::Display, iter::Sum}; - -/// A trait for types that can be used as values in a tensor. -pub trait Value: - Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> -{ -} - -impl Value for T where - T: Num - + Zero - + One - + Copy - + Clone - + Display - + Serialize - + Deserialize<'static> - + Sum -{ -} diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs index c96996f..e4c6f8c 100644 --- a/tests/basic_tests.rs +++ b/tests/basic_tests.rs @@ -53,14 +53,10 @@ fn test_iterating_3d_tensor() { } } - println!("{}", tensor); - // Iterate over the tensor and check that the numbers are correct let mut iter = TensorIterator::new(&tensor); - println!("{}", iter); - assert_eq!(iter.next(), Some(&0)); assert_eq!(iter.next(), Some(&1)); @@ -157,105 +153,105 @@ fn test_index_dec_method() { assert_eq!(index, TensorIndex::zero(shape)); } -#[test] -fn test_axis_iterator() { - // Creating a 2x2 Tensor for testing - let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); +// #[test] +// fn test_axis_iterator() { +// // Creating a 2x2 Tensor for testing +// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); - // Testing iteration over the first axis (axis = 0) - let axis = TensorAxis::new(&tensor, 0); +// // Testing iteration over the first axis (axis = 0) +// let axis = TensorAxis::new(&tensor, 0); - let mut axis_iter = axis.into_iter(); +// let mut axis_iter = axis.into_iter(); - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); +// assert_eq!(axis_iter.next(), Some(&1.0)); +// assert_eq!(axis_iter.next(), Some(&2.0)); +// assert_eq!(axis_iter.next(), Some(&3.0)); +// assert_eq!(axis_iter.next(), Some(&4.0)); - // Resetting the iterator for the second axis (axis = 1) - let axis = TensorAxis::new(&tensor, 1); +// // Resetting the iterator for the second axis (axis = 1) +// let axis = TensorAxis::new(&tensor, 1); - let mut axis_iter = axis.into_iter(); +// let mut axis_iter = axis.into_iter(); - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); +// assert_eq!(axis_iter.next(), Some(&1.0)); +// assert_eq!(axis_iter.next(), Some(&3.0)); +// assert_eq!(axis_iter.next(), Some(&2.0)); +// assert_eq!(axis_iter.next(), Some(&4.0)); - let shape = tensor.shape(); +// let shape = tensor.shape(); - let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into(); - let b: TensorIndex<2> = (shape.clone(), [1, 1]).into(); +// let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into(); +// let b: TensorIndex<2> = (shape.clone(), [1, 1]).into(); - while a <= b { - println!("a: {}", a); - a.inc(); - } -} +// while a <= b { +// println!("a: {}", a); +// a.inc(); +// } +// } -#[test] -fn test_3d_tensor_axis_iteration() { - // Create a 3D Tensor with specific values - // Tensor shape is 2x2x2 for simplicity - let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); +// #[test] +// fn test_3d_tensor_axis_iteration() { +// // Create a 3D Tensor with specific values +// // Tensor shape is 2x2x2 for simplicity +// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - // TensorAxis 0 (Layer-wise): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 3, 4, 5, 6, 7, 8] - // - // This order suggests that for each "layer" (first level of arrays), - // the iterator goes through all rows and columns. It first completes - // the entire first layer, then moves to the second. +// // TensorAxis 0 (Layer-wise): +// // +// // t[0][0][0] = 1 +// // t[0][0][1] = 2 +// // t[0][1][0] = 3 +// // t[0][1][1] = 4 +// // t[1][0][0] = 5 +// // t[1][0][1] = 6 +// // t[1][1][0] = 7 +// // t[1][1][1] = 8 +// // [1, 2, 3, 4, 5, 6, 7, 8] +// // +// // This order suggests that for each "layer" (first level of arrays), +// // the iterator goes through all rows and columns. It first completes +// // the entire first layer, then moves to the second. - let a0 = TensorAxis::new(&t, 0); - let a0_order = a0.into_iter().cloned().collect::>(); - assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); +// let a0 = TensorAxis::new(&t, 0); +// let a0_order = a0.into_iter().cloned().collect::>(); +// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); - // TensorAxis 1 (Row-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 5, 6, 3, 4, 7, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first row across all layers, then the second row - // across all layers. +// // TensorAxis 1 (Row-wise within each layer): +// // +// // t[0][0][0] = 1 +// // t[0][0][1] = 2 +// // t[1][0][0] = 5 +// // t[1][0][1] = 6 +// // t[0][1][0] = 3 +// // t[0][1][1] = 4 +// // t[1][1][0] = 7 +// // t[1][1][1] = 8 +// // [1, 2, 5, 6, 3, 4, 7, 8] +// // +// // This indicates that within each "layer", the iterator first +// // completes the first row across all layers, then the second row +// // across all layers. - let a1 = TensorAxis::new(&t, 1); - let a1_order = a1.into_iter().cloned().collect::>(); - assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); +// let a1 = TensorAxis::new(&t, 1); +// let a1_order = a1.into_iter().cloned().collect::>(); +// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); - // TensorAxis 2 (Column-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][1][0] = 3 - // t[1][0][0] = 5 - // t[1][1][0] = 7 - // t[0][0][1] = 2 - // t[0][1][1] = 4 - // t[1][0][1] = 6 - // t[1][1][1] = 8 - // [1, 3, 5, 7, 2, 4, 6, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first column across all layers, then the second - // column across all layers. +// // TensorAxis 2 (Column-wise within each layer): +// // +// // t[0][0][0] = 1 +// // t[0][1][0] = 3 +// // t[1][0][0] = 5 +// // t[1][1][0] = 7 +// // t[0][0][1] = 2 +// // t[0][1][1] = 4 +// // t[1][0][1] = 6 +// // t[1][1][1] = 8 +// // [1, 3, 5, 7, 2, 4, 6, 8] +// // +// // This indicates that within each "layer", the iterator first +// // completes the first column across all layers, then the second +// // column across all layers. - let a2 = TensorAxis::new(&t, 2); - let a2_order = a2.into_iter().cloned().collect::>(); - assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); -} +// let a2 = TensorAxis::new(&t, 2); +// let a2_order = a2.into_iter().cloned().collect::>(); +// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); +// }