Wrap comments and fmt

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-03 21:34:07 +02:00
parent d19ce40494
commit b68e75814b
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
6 changed files with 112 additions and 82 deletions

View File

@ -1 +1,3 @@
max_width = 80 max_width = 80
wrap_comments = true
comment_width = 80

View File

@ -27,7 +27,9 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
assert!(level < self.len(), "Level out of bounds"); assert!(level < self.len(), "Level out of bounds");
let mut index = TensorIndex::new(self.shape(), [0; R]); let mut index = TensorIndex::new(self.shape(), [0; R]);
index.set_axis(self.dim, level); index.set_axis(self.dim, level);
TensorAxisIterator::new(self).set_start(level).set_end(level + 1) TensorAxisIterator::new(self)
.set_start(level)
.set_end(level + 1)
} }
} }
@ -113,4 +115,4 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
TensorAxisIterator::new(&self) TensorAxisIterator::new(&self)
} }
} }

View File

@ -32,10 +32,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
if !shape.check_indices(indices) { if !shape.check_indices(indices) {
panic!("indices out of bounds"); panic!("indices out of bounds");
} }
Self { Self { indices, shape }
indices,
shape,
}
} }
pub fn is_zero(&self) -> bool { pub fn is_zero(&self) -> bool {
@ -51,7 +48,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
self.indices = [0; R]; self.indices = [0; R];
} }
/// Increments the index and returns a boolean indicating whether the end has been reached. /// Increments the index and returns a boolean indicating whether the end
/// has been reached.
/// ///
/// # Returns /// # Returns
/// `true` if the increment does not overflow and is still within bounds; /// `true` if the increment does not overflow and is still within bounds;
@ -74,10 +72,12 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
} }
} }
// If carry is still 1 after the loop, it means we've incremented past the last dimension // If carry is still 1 after the loop, it means we've incremented past
// the last dimension
if carry == 1 { if carry == 1 {
// Set the index to an invalid state to indicate the end of the iteration indicated // Set the index to an invalid state to indicate the end of the
// by setting the first index to the size of the first dimension // iteration indicated by setting the first index to the
// size of the first dimension
self.indices[0] = self.shape.as_array()[0]; self.indices[0] = self.shape.as_array()[0];
return true; // Indicate that the iteration is complete return true; // Indicate that the iteration is complete
} }
@ -156,7 +156,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
{ {
if borrow { if borrow {
if *i == 0 { if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of this dimension *i = dim_size - 1; // Wrap around to the maximum index of
// this dimension
} else { } else {
*i -= 1; // Decrement the index *i -= 1; // Decrement the index
borrow = false; // No more borrowing needed borrow = false; // No more borrowing needed
@ -183,7 +184,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
} }
} }
// Decrement the fixed axis if possible and reset other axes to their max // Decrement the fixed axis if possible and reset other axes to their
// max
if self.indices[fixed_axis] > 0 { if self.indices[fixed_axis] > 0 {
self.indices[fixed_axis] -= 1; self.indices[fixed_axis] -= 1;
for i in 0..R { for i in 0..R {
@ -216,41 +218,49 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
} }
} }
// If no axis can be decremented, set the first axis in the order to indicate overflow // If no axis can be decremented, set the first axis in the order to
// indicate overflow
self.indices[order[0]] = self.shape.get(order[0]); self.indices[order[0]] = self.shape.get(order[0]);
} }
/// Converts the multi-dimensional index to a flat index. /// Converts the multi-dimensional index to a flat index.
/// ///
/// This method calculates the flat index corresponding to the multi-dimensional index /// This method calculates the flat index corresponding to the
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`. /// multi-dimensional index stored in `self.indices`, given the shape of
/// The calculation is based on the assumption that the tensor is stored in row-major order, /// the tensor stored in `self.shape`. The calculation is based on the
/// assumption that the tensor is stored in row-major order,
/// where the last dimension varies the fastest. /// where the last dimension varies the fastest.
/// ///
/// # Returns /// # Returns
/// The flat index corresponding to the multi-dimensional index. /// The flat index corresponding to the multi-dimensional index.
/// ///
/// # How It Works /// # How It Works
/// - The method iterates over each pair of corresponding index and dimension size, /// - The method iterates over each pair of corresponding index and
/// starting from the last dimension (hence `rev()` is used for reverse iteration). /// dimension size, starting from the last dimension (hence `rev()` is
/// used for reverse iteration).
/// - In each iteration, it performs two main operations: /// - In each iteration, it performs two main operations:
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product /// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
/// of dimension sizes (`product`). This calculates the contribution of the current index /// running product of dimension sizes (`product`). This calculates the
/// to the overall flat index. /// contribution of the current index to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`). /// 2. **Product Update**: Multiplies `product` by the current dimension
/// This updates `product` for the next iteration, as each dimension contributes to the /// size (`dim_size`). This updates `product` for the next iteration,
/// flat index based on the sizes of all preceding dimensions. /// as each dimension contributes to the flat index based on the sizes
/// - The `fold` operation accumulates these results, starting with an initial state of /// of all preceding dimensions.
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product. /// - The `fold` operation accumulates these results, starting with an
/// - The final flat index is obtained after the last iteration, which is the first element /// initial state of `(0, 1)` where `0` is the initial flat index and `1`
/// of the tuple resulting from the `fold`. /// is the initial product.
/// - The final flat index is obtained after the last iteration, which is
/// the first element of the tuple resulting from the `fold`.
/// ///
/// # Example /// # Example
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`. /// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
/// - Starting with a flat index of 0 and a product of 1, /// - Starting with a flat index of 0 and a product of 1,
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5. /// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20. /// the product to 1 * 5 = 5.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33. /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
/// the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
/// final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize { pub fn flat(&self) -> usize {
self.indices() self.indices()
.iter() .iter()
@ -327,14 +337,18 @@ impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
} }
} }
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> for TensorIndex<'a, R> { impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
for TensorIndex<'a, R>
{
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self { fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices)); assert!(shape.check_indices(indices));
Self::new(shape, indices) Self::new(shape, indices)
} }
} }
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> for TensorIndex<'a, R> { impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
for TensorIndex<'a, R>
{
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self { fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices; let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices) Self::new(shape, indices)
@ -347,7 +361,9 @@ impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
} }
} }
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for TensorIndex<'a, R> { impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
for TensorIndex<'a, R>
{
fn from(tensor: &'a Tensor<T, R>) -> Self { fn from(tensor: &'a Tensor<T, R>) -> Self {
Self::zero(tensor.shape()) Self::zero(tensor.shape())
} }

View File

@ -1,5 +1,6 @@
#![allow(incomplete_features)] #![allow(incomplete_features)]
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![warn(clippy::all)]
pub mod axis; pub mod axis;
pub mod error; pub mod error;
pub mod index; pub mod index;
@ -8,7 +9,7 @@ pub mod tensor;
pub use axis::*; pub use axis::*;
pub use index::TensorIndex; pub use index::TensorIndex;
pub use itertools::Itertools; // pub use itertools::Itertools;
use num::{Num, One, Zero}; use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize}; pub use serde::{Deserialize, Serialize};
pub use shape::TensorShape; pub use shape::TensorShape;
@ -45,14 +46,14 @@ macro_rules! tensor {
#[macro_export] #[macro_export]
macro_rules! shape { macro_rules! shape {
($array:expr) => { ($array:expr) => {
TensorShape::from($array) TensorShape::from($array)
}; };
} }
#[macro_export] #[macro_export]
macro_rules! index { macro_rules! index {
($array:expr) => { ($array:expr) => {
TensorIndex::from($array) TensorIndex::from($array)
}; };
} }

View File

@ -60,13 +60,17 @@ impl<const R: usize> TensorShape<R> {
/// * `flat_index` - The flat index to convert. /// * `flat_index` - The flat index to convert.
/// ///
/// # Returns /// # Returns
/// An `TensorIndex<R>` instance representing the multi-dimensional index corresponding to the flat index. /// An `TensorIndex<R>` instance representing the multi-dimensional index
/// corresponding to the flat index.
/// ///
/// # How It Works /// # How It Works
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order). /// - The method iterates over the dimensions of the tensor in reverse order
/// - In each iteration, it uses the modulo operation to find the index in the current dimension /// (assuming row-major order).
/// and integer division to reduce the flat index for the next higher dimension. /// - In each iteration, it uses the modulo operation to find the index in
/// - This process is repeated for each dimension to build the multi-dimensional index. /// the current dimension and integer division to reduce the flat index
/// for the next higher dimension.
/// - This process is repeated for each dimension to build the
/// multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> { pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
let mut indices = [0; R]; let mut indices = [0; R];
let mut remaining = flat_index; let mut remaining = flat_index;

View File

@ -3,9 +3,9 @@ use crate::error::*;
use getset::{Getters, MutGetters}; use getset::{Getters, MutGetters};
use std::fmt; use std::fmt;
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of /// A tensor is a multi-dimensional array of values. The rank of a tensor is the
/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is /// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
/// a matrix, and so on. /// a vector, a rank 2 tensor is a matrix, and so on.
/// ///
/// ``` /// ```
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
@ -24,8 +24,8 @@ pub struct Tensor<T, const R: usize> {
// ---- Construction and Initialization --------------------------------------- // ---- Construction and Initialization ---------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
/// Create a new tensor with the given shape. The rank of the tensor is determined by the shape /// Create a new tensor with the given shape. The rank of the tensor is
/// and all elements are initialized to zero. /// determined by the shape and all elements are initialized to zero.
/// ///
/// ``` /// ```
/// use manifold::Tensor; /// use manifold::Tensor;
@ -39,7 +39,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// A rank 0 tensor should still have a buffer with one element // A rank 0 tensor should still have a buffer with one element
1 1
} else { } else {
// For tensors of rank 1 or higher, calculate the total size normally // For tensors of rank 1 or higher, calculate the total size
// normally
shape.iter().product() shape.iter().product()
}; };
@ -47,8 +48,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
Self { buffer, shape } Self { buffer, shape }
} }
/// Create a new tensor with the given shape and initialize it from the given buffer. The rank /// Create a new tensor with the given shape and initialize it from the
/// of the tensor is determined by the shape. /// given buffer. The rank of the tensor is determined by the shape.
/// ///
/// ``` /// ```
/// use manifold::Tensor; /// use manifold::Tensor;
@ -118,7 +119,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer_mut().get_mut(index.flat()) self.buffer_mut().get_mut(index.flat())
} }
/// Get a mutable reference to a value at the given index without bounds checking. /// Get a mutable reference to a value at the given index without bounds
/// checking.
/// ///
/// ``` /// ```
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
@ -147,7 +149,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer().get(index) self.buffer().get(index)
} }
/// Get a reference to a value at the given flat index without bounds checking. /// Get a reference to a value at the given flat index without bounds
/// checking.
/// ///
/// ``` /// ```
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
@ -171,7 +174,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer_mut().get_mut(index) self.buffer_mut().get_mut(index)
} }
/// Get a mutable reference to a value at the given flat index without bounds checking. /// Get a mutable reference to a value at the given flat index without
/// bounds checking.
/// ///
/// ``` /// ```
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
@ -354,19 +358,21 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Reshape --------------------------------------------------------------- // ---- Reshape ---------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
/// Reshape the tensor to the given shape. The total size of the new shape
/// Reshape the tensor to the given shape. The total size of the new shape must be the same as /// must be the same as the total size of the old shape.
/// the total size of the old shape. ///
/// /// ```
/// ``` /// use manifold::{tensor, shape, Tensor, TensorShape};
/// use manifold::{tensor, shape, Tensor, TensorShape}; ///
/// /// let t = tensor!([[1, 2], [3, 4]]);
/// let t = tensor!([[1, 2], [3, 4]]); /// let s = shape!([4]);
/// let s = shape!([4]); /// let t = t.reshape(s).unwrap();
/// let t = t.reshape(s).unwrap(); /// assert_eq!(t, tensor!([1, 2, 3, 4]));
/// assert_eq!(t, tensor!([1, 2, 3, 4])); /// ```
/// ``` pub fn reshape<const S: usize>(
pub fn reshape<const S: usize>(self, shape: TensorShape<S>) -> Result<Tensor<T, S>> { self,
shape: TensorShape<S>,
) -> Result<Tensor<T, S>> {
if self.shape().size() != shape.size() { if self.shape().size() != shape.size() {
let (ls, rs) = (self.shape().as_array(), shape.as_array()); let (ls, rs) = (self.shape().as_array(), shape.as_array());
let (lsize, rsize) = (self.shape().size(), shape.size()); let (lsize, rsize) = (self.shape().size(), shape.size());
@ -382,17 +388,16 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Transpose ------------------------------------------------------------- // ---- Transpose -------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
/// Transpose the tensor according to the given order. The order must be a
/// Transpose the tensor according to the given order. The order must be a permutation of the /// permutation of the tensor's axes.
/// tensor's axes. ///
/// /// ```
/// ``` /// use manifold::{tensor, Tensor, TensorShape};
/// use manifold::{tensor, Tensor, TensorShape}; ///
/// /// let t = tensor!([[1, 2], [3, 4]]);
/// let t = tensor!([[1, 2], [3, 4]]); /// let t = t.transpose([1, 0]).unwrap();
/// let t = t.transpose([1, 0]).unwrap(); /// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// assert_eq!(t, tensor!([[1, 3], [2, 4]])); /// ```
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> { pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape()) let buffer = TensorIndex::from(self.shape())
.iter_transposed(order) .iter_transposed(order)