Wrap comments and fmt
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
d19ce40494
commit
b68e75814b
@ -1 +1,3 @@
|
||||
max_width = 80
|
||||
wrap_comments = true
|
||||
comment_width = 80
|
@ -27,7 +27,9 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
||||
assert!(level < self.len(), "Level out of bounds");
|
||||
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
TensorAxisIterator::new(self).set_start(level).set_end(level + 1)
|
||||
TensorAxisIterator::new(self)
|
||||
.set_start(level)
|
||||
.set_end(level + 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -113,4 +115,4 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorAxisIterator::new(&self)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
80
src/index.rs
80
src/index.rs
@ -32,10 +32,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self {
|
||||
indices,
|
||||
shape,
|
||||
}
|
||||
Self { indices, shape }
|
||||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
@ -51,7 +48,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
self.indices = [0; R];
|
||||
}
|
||||
|
||||
/// Increments the index and returns a boolean indicating whether the end has been reached.
|
||||
/// Increments the index and returns a boolean indicating whether the end
|
||||
/// has been reached.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
@ -74,10 +72,12 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||
// If carry is still 1 after the loop, it means we've incremented past
|
||||
// the last dimension
|
||||
if carry == 1 {
|
||||
// Set the index to an invalid state to indicate the end of the iteration indicated
|
||||
// by setting the first index to the size of the first dimension
|
||||
// Set the index to an invalid state to indicate the end of the
|
||||
// iteration indicated by setting the first index to the
|
||||
// size of the first dimension
|
||||
self.indices[0] = self.shape.as_array()[0];
|
||||
return true; // Indicate that the iteration is complete
|
||||
}
|
||||
@ -156,7 +156,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
{
|
||||
if borrow {
|
||||
if *i == 0 {
|
||||
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
|
||||
*i = dim_size - 1; // Wrap around to the maximum index of
|
||||
// this dimension
|
||||
} else {
|
||||
*i -= 1; // Decrement the index
|
||||
borrow = false; // No more borrowing needed
|
||||
@ -183,7 +184,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the fixed axis if possible and reset other axes to their max
|
||||
// Decrement the fixed axis if possible and reset other axes to their
|
||||
// max
|
||||
if self.indices[fixed_axis] > 0 {
|
||||
self.indices[fixed_axis] -= 1;
|
||||
for i in 0..R {
|
||||
@ -216,41 +218,49 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// If no axis can be decremented, set the first axis in the order to indicate overflow
|
||||
// If no axis can be decremented, set the first axis in the order to
|
||||
// indicate overflow
|
||||
self.indices[order[0]] = self.shape.get(order[0]);
|
||||
}
|
||||
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
|
||||
/// The calculation is based on the assumption that the tensor is stored in row-major order,
|
||||
/// This method calculates the flat index corresponding to the
|
||||
/// multi-dimensional index stored in `self.indices`, given the shape of
|
||||
/// the tensor stored in `self.shape`. The calculation is based on the
|
||||
/// assumption that the tensor is stored in row-major order,
|
||||
/// where the last dimension varies the fastest.
|
||||
///
|
||||
/// # Returns
|
||||
/// The flat index corresponding to the multi-dimensional index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over each pair of corresponding index and dimension size,
|
||||
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
|
||||
/// - The method iterates over each pair of corresponding index and
|
||||
/// dimension size, starting from the last dimension (hence `rev()` is
|
||||
/// used for reverse iteration).
|
||||
/// - In each iteration, it performs two main operations:
|
||||
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
|
||||
/// of dimension sizes (`product`). This calculates the contribution of the current index
|
||||
/// to the overall flat index.
|
||||
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
|
||||
/// This updates `product` for the next iteration, as each dimension contributes to the
|
||||
/// flat index based on the sizes of all preceding dimensions.
|
||||
/// - The `fold` operation accumulates these results, starting with an initial state of
|
||||
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
|
||||
/// - The final flat index is obtained after the last iteration, which is the first element
|
||||
/// of the tuple resulting from the `fold`.
|
||||
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
|
||||
/// running product of dimension sizes (`product`). This calculates the
|
||||
/// contribution of the current index to the overall flat index.
|
||||
/// 2. **Product Update**: Multiplies `product` by the current dimension
|
||||
/// size (`dim_size`). This updates `product` for the next iteration,
|
||||
/// as each dimension contributes to the flat index based on the sizes
|
||||
/// of all preceding dimensions.
|
||||
/// - The `fold` operation accumulates these results, starting with an
|
||||
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
|
||||
/// is the initial product.
|
||||
/// - The final flat index is obtained after the last iteration, which is
|
||||
/// the first element of the tuple resulting from the `fold`.
|
||||
///
|
||||
/// # Example
|
||||
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
||||
/// - Starting with a flat index of 0 and a product of 1,
|
||||
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
|
||||
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
|
||||
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
|
||||
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
|
||||
/// the product to 1 * 5 = 5.
|
||||
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
|
||||
/// the product to 5 * 4 = 20.
|
||||
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
|
||||
/// final flat index is 3 + 10 + 20 = 33.
|
||||
pub fn flat(&self) -> usize {
|
||||
self.indices()
|
||||
.iter()
|
||||
@ -327,14 +337,18 @@ impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> for TensorIndex<'a, R> {
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
|
||||
for TensorIndex<'a, R>
|
||||
{
|
||||
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> for TensorIndex<'a, R> {
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
|
||||
for TensorIndex<'a, R>
|
||||
{
|
||||
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
@ -347,7 +361,9 @@ impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for TensorIndex<'a, R> {
|
||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
|
||||
for TensorIndex<'a, R>
|
||||
{
|
||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape())
|
||||
}
|
||||
|
17
src/lib.rs
17
src/lib.rs
@ -1,5 +1,6 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
#![warn(clippy::all)]
|
||||
pub mod axis;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
@ -8,7 +9,7 @@ pub mod tensor;
|
||||
|
||||
pub use axis::*;
|
||||
pub use index::TensorIndex;
|
||||
pub use itertools::Itertools;
|
||||
// pub use itertools::Itertools;
|
||||
use num::{Num, One, Zero};
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use shape::TensorShape;
|
||||
@ -45,14 +46,14 @@ macro_rules! tensor {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! shape {
|
||||
($array:expr) => {
|
||||
TensorShape::from($array)
|
||||
};
|
||||
($array:expr) => {
|
||||
TensorShape::from($array)
|
||||
};
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! index {
|
||||
($array:expr) => {
|
||||
TensorIndex::from($array)
|
||||
};
|
||||
}
|
||||
($array:expr) => {
|
||||
TensorIndex::from($array)
|
||||
};
|
||||
}
|
||||
|
14
src/shape.rs
14
src/shape.rs
@ -60,13 +60,17 @@ impl<const R: usize> TensorShape<R> {
|
||||
/// * `flat_index` - The flat index to convert.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `TensorIndex<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
||||
/// An `TensorIndex<R>` instance representing the multi-dimensional index
|
||||
/// corresponding to the flat index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
||||
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
||||
/// and integer division to reduce the flat index for the next higher dimension.
|
||||
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
||||
/// - The method iterates over the dimensions of the tensor in reverse order
|
||||
/// (assuming row-major order).
|
||||
/// - In each iteration, it uses the modulo operation to find the index in
|
||||
/// the current dimension and integer division to reduce the flat index
|
||||
/// for the next higher dimension.
|
||||
/// - This process is repeated for each dimension to build the
|
||||
/// multi-dimensional index.
|
||||
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
@ -3,9 +3,9 @@ use crate::error::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::fmt;
|
||||
|
||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of
|
||||
/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is
|
||||
/// a matrix, and so on.
|
||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
||||
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
|
||||
/// a vector, a rank 2 tensor is a matrix, and so on.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor};
|
||||
@ -24,8 +24,8 @@ pub struct Tensor<T, const R: usize> {
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// Create a new tensor with the given shape. The rank of the tensor is determined by the shape
|
||||
/// and all elements are initialized to zero.
|
||||
/// Create a new tensor with the given shape. The rank of the tensor is
|
||||
/// determined by the shape and all elements are initialized to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
@ -39,7 +39,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
// A rank 0 tensor should still have a buffer with one element
|
||||
1
|
||||
} else {
|
||||
// For tensors of rank 1 or higher, calculate the total size normally
|
||||
// For tensors of rank 1 or higher, calculate the total size
|
||||
// normally
|
||||
shape.iter().product()
|
||||
};
|
||||
|
||||
@ -47,8 +48,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
/// Create a new tensor with the given shape and initialize it from the given buffer. The rank
|
||||
/// of the tensor is determined by the shape.
|
||||
/// Create a new tensor with the given shape and initialize it from the
|
||||
/// given buffer. The rank of the tensor is determined by the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
@ -118,7 +119,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer_mut().get_mut(index.flat())
|
||||
}
|
||||
|
||||
/// Get a mutable reference to a value at the given index without bounds checking.
|
||||
/// Get a mutable reference to a value at the given index without bounds
|
||||
/// checking.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor};
|
||||
@ -147,7 +149,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer().get(index)
|
||||
}
|
||||
|
||||
/// Get a reference to a value at the given flat index without bounds checking.
|
||||
/// Get a reference to a value at the given flat index without bounds
|
||||
/// checking.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor};
|
||||
@ -171,7 +174,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer_mut().get_mut(index)
|
||||
}
|
||||
|
||||
/// Get a mutable reference to a value at the given flat index without bounds checking.
|
||||
/// Get a mutable reference to a value at the given flat index without
|
||||
/// bounds checking.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor};
|
||||
@ -354,19 +358,21 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
// ---- Reshape ---------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
/// Reshape the tensor to the given shape. The total size of the new shape must be the same as
|
||||
/// the total size of the old shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, shape, Tensor, TensorShape};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let s = shape!([4]);
|
||||
/// let t = t.reshape(s).unwrap();
|
||||
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
|
||||
/// ```
|
||||
pub fn reshape<const S: usize>(self, shape: TensorShape<S>) -> Result<Tensor<T, S>> {
|
||||
/// Reshape the tensor to the given shape. The total size of the new shape
|
||||
/// must be the same as the total size of the old shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, shape, Tensor, TensorShape};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let s = shape!([4]);
|
||||
/// let t = t.reshape(s).unwrap();
|
||||
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
|
||||
/// ```
|
||||
pub fn reshape<const S: usize>(
|
||||
self,
|
||||
shape: TensorShape<S>,
|
||||
) -> Result<Tensor<T, S>> {
|
||||
if self.shape().size() != shape.size() {
|
||||
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
||||
let (lsize, rsize) = (self.shape().size(), shape.size());
|
||||
@ -382,17 +388,16 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
// ---- Transpose -------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
/// Transpose the tensor according to the given order. The order must be a permutation of the
|
||||
/// tensor's axes.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor, TensorShape};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let t = t.transpose([1, 0]).unwrap();
|
||||
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
||||
/// ```
|
||||
/// Transpose the tensor according to the given order. The order must be a
|
||||
/// permutation of the tensor's axes.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor, TensorShape};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let t = t.transpose([1, 0]).unwrap();
|
||||
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
||||
/// ```
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
let buffer = TensorIndex::from(self.shape())
|
||||
.iter_transposed(order)
|
||||
|
Loading…
Reference in New Issue
Block a user