Wrap comments and fmt
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
d19ce40494
commit
b68e75814b
@ -1 +1,3 @@
|
|||||||
max_width = 80
|
max_width = 80
|
||||||
|
wrap_comments = true
|
||||||
|
comment_width = 80
|
@ -27,7 +27,9 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
|||||||
assert!(level < self.len(), "Level out of bounds");
|
assert!(level < self.len(), "Level out of bounds");
|
||||||
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
||||||
index.set_axis(self.dim, level);
|
index.set_axis(self.dim, level);
|
||||||
TensorAxisIterator::new(self).set_start(level).set_end(level + 1)
|
TensorAxisIterator::new(self)
|
||||||
|
.set_start(level)
|
||||||
|
.set_end(level + 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
80
src/index.rs
80
src/index.rs
@ -32,10 +32,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
if !shape.check_indices(indices) {
|
if !shape.check_indices(indices) {
|
||||||
panic!("indices out of bounds");
|
panic!("indices out of bounds");
|
||||||
}
|
}
|
||||||
Self {
|
Self { indices, shape }
|
||||||
indices,
|
|
||||||
shape,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn is_zero(&self) -> bool {
|
pub fn is_zero(&self) -> bool {
|
||||||
@ -51,7 +48,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
self.indices = [0; R];
|
self.indices = [0; R];
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Increments the index and returns a boolean indicating whether the end has been reached.
|
/// Increments the index and returns a boolean indicating whether the end
|
||||||
|
/// has been reached.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// `true` if the increment does not overflow and is still within bounds;
|
/// `true` if the increment does not overflow and is still within bounds;
|
||||||
@ -74,10 +72,12 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
// If carry is still 1 after the loop, it means we've incremented past
|
||||||
|
// the last dimension
|
||||||
if carry == 1 {
|
if carry == 1 {
|
||||||
// Set the index to an invalid state to indicate the end of the iteration indicated
|
// Set the index to an invalid state to indicate the end of the
|
||||||
// by setting the first index to the size of the first dimension
|
// iteration indicated by setting the first index to the
|
||||||
|
// size of the first dimension
|
||||||
self.indices[0] = self.shape.as_array()[0];
|
self.indices[0] = self.shape.as_array()[0];
|
||||||
return true; // Indicate that the iteration is complete
|
return true; // Indicate that the iteration is complete
|
||||||
}
|
}
|
||||||
@ -156,7 +156,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
{
|
{
|
||||||
if borrow {
|
if borrow {
|
||||||
if *i == 0 {
|
if *i == 0 {
|
||||||
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
|
*i = dim_size - 1; // Wrap around to the maximum index of
|
||||||
|
// this dimension
|
||||||
} else {
|
} else {
|
||||||
*i -= 1; // Decrement the index
|
*i -= 1; // Decrement the index
|
||||||
borrow = false; // No more borrowing needed
|
borrow = false; // No more borrowing needed
|
||||||
@ -183,7 +184,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decrement the fixed axis if possible and reset other axes to their max
|
// Decrement the fixed axis if possible and reset other axes to their
|
||||||
|
// max
|
||||||
if self.indices[fixed_axis] > 0 {
|
if self.indices[fixed_axis] > 0 {
|
||||||
self.indices[fixed_axis] -= 1;
|
self.indices[fixed_axis] -= 1;
|
||||||
for i in 0..R {
|
for i in 0..R {
|
||||||
@ -216,41 +218,49 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no axis can be decremented, set the first axis in the order to indicate overflow
|
// If no axis can be decremented, set the first axis in the order to
|
||||||
|
// indicate overflow
|
||||||
self.indices[order[0]] = self.shape.get(order[0]);
|
self.indices[order[0]] = self.shape.get(order[0]);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Converts the multi-dimensional index to a flat index.
|
/// Converts the multi-dimensional index to a flat index.
|
||||||
///
|
///
|
||||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
/// This method calculates the flat index corresponding to the
|
||||||
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
|
/// multi-dimensional index stored in `self.indices`, given the shape of
|
||||||
/// The calculation is based on the assumption that the tensor is stored in row-major order,
|
/// the tensor stored in `self.shape`. The calculation is based on the
|
||||||
|
/// assumption that the tensor is stored in row-major order,
|
||||||
/// where the last dimension varies the fastest.
|
/// where the last dimension varies the fastest.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// The flat index corresponding to the multi-dimensional index.
|
/// The flat index corresponding to the multi-dimensional index.
|
||||||
///
|
///
|
||||||
/// # How It Works
|
/// # How It Works
|
||||||
/// - The method iterates over each pair of corresponding index and dimension size,
|
/// - The method iterates over each pair of corresponding index and
|
||||||
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
|
/// dimension size, starting from the last dimension (hence `rev()` is
|
||||||
|
/// used for reverse iteration).
|
||||||
/// - In each iteration, it performs two main operations:
|
/// - In each iteration, it performs two main operations:
|
||||||
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
|
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
|
||||||
/// of dimension sizes (`product`). This calculates the contribution of the current index
|
/// running product of dimension sizes (`product`). This calculates the
|
||||||
/// to the overall flat index.
|
/// contribution of the current index to the overall flat index.
|
||||||
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
|
/// 2. **Product Update**: Multiplies `product` by the current dimension
|
||||||
/// This updates `product` for the next iteration, as each dimension contributes to the
|
/// size (`dim_size`). This updates `product` for the next iteration,
|
||||||
/// flat index based on the sizes of all preceding dimensions.
|
/// as each dimension contributes to the flat index based on the sizes
|
||||||
/// - The `fold` operation accumulates these results, starting with an initial state of
|
/// of all preceding dimensions.
|
||||||
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
|
/// - The `fold` operation accumulates these results, starting with an
|
||||||
/// - The final flat index is obtained after the last iteration, which is the first element
|
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
|
||||||
/// of the tuple resulting from the `fold`.
|
/// is the initial product.
|
||||||
|
/// - The final flat index is obtained after the last iteration, which is
|
||||||
|
/// the first element of the tuple resulting from the `fold`.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
||||||
/// - Starting with a flat index of 0 and a product of 1,
|
/// - Starting with a flat index of 0 and a product of 1,
|
||||||
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
|
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
|
||||||
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
|
/// the product to 1 * 5 = 5.
|
||||||
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
|
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
|
||||||
|
/// the product to 5 * 4 = 20.
|
||||||
|
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
|
||||||
|
/// final flat index is 3 + 10 + 20 = 33.
|
||||||
pub fn flat(&self) -> usize {
|
pub fn flat(&self) -> usize {
|
||||||
self.indices()
|
self.indices()
|
||||||
.iter()
|
.iter()
|
||||||
@ -327,14 +337,18 @@ impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> for TensorIndex<'a, R> {
|
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
|
||||||
|
for TensorIndex<'a, R>
|
||||||
|
{
|
||||||
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
||||||
assert!(shape.check_indices(indices));
|
assert!(shape.check_indices(indices));
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> for TensorIndex<'a, R> {
|
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
|
||||||
|
for TensorIndex<'a, R>
|
||||||
|
{
|
||||||
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
||||||
let indices = shape.index_from_flat(flat_index).indices;
|
let indices = shape.index_from_flat(flat_index).indices;
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
@ -347,7 +361,9 @@ impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for TensorIndex<'a, R> {
|
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
|
||||||
|
for TensorIndex<'a, R>
|
||||||
|
{
|
||||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
||||||
Self::zero(tensor.shape())
|
Self::zero(tensor.shape())
|
||||||
}
|
}
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
#![allow(incomplete_features)]
|
#![allow(incomplete_features)]
|
||||||
#![feature(generic_const_exprs)]
|
#![feature(generic_const_exprs)]
|
||||||
|
#![warn(clippy::all)]
|
||||||
pub mod axis;
|
pub mod axis;
|
||||||
pub mod error;
|
pub mod error;
|
||||||
pub mod index;
|
pub mod index;
|
||||||
@ -8,7 +9,7 @@ pub mod tensor;
|
|||||||
|
|
||||||
pub use axis::*;
|
pub use axis::*;
|
||||||
pub use index::TensorIndex;
|
pub use index::TensorIndex;
|
||||||
pub use itertools::Itertools;
|
// pub use itertools::Itertools;
|
||||||
use num::{Num, One, Zero};
|
use num::{Num, One, Zero};
|
||||||
pub use serde::{Deserialize, Serialize};
|
pub use serde::{Deserialize, Serialize};
|
||||||
pub use shape::TensorShape;
|
pub use shape::TensorShape;
|
||||||
|
14
src/shape.rs
14
src/shape.rs
@ -60,13 +60,17 @@ impl<const R: usize> TensorShape<R> {
|
|||||||
/// * `flat_index` - The flat index to convert.
|
/// * `flat_index` - The flat index to convert.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// An `TensorIndex<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
/// An `TensorIndex<R>` instance representing the multi-dimensional index
|
||||||
|
/// corresponding to the flat index.
|
||||||
///
|
///
|
||||||
/// # How It Works
|
/// # How It Works
|
||||||
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
/// - The method iterates over the dimensions of the tensor in reverse order
|
||||||
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
/// (assuming row-major order).
|
||||||
/// and integer division to reduce the flat index for the next higher dimension.
|
/// - In each iteration, it uses the modulo operation to find the index in
|
||||||
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
/// the current dimension and integer division to reduce the flat index
|
||||||
|
/// for the next higher dimension.
|
||||||
|
/// - This process is repeated for each dimension to build the
|
||||||
|
/// multi-dimensional index.
|
||||||
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
||||||
let mut indices = [0; R];
|
let mut indices = [0; R];
|
||||||
let mut remaining = flat_index;
|
let mut remaining = flat_index;
|
||||||
|
@ -3,9 +3,9 @@ use crate::error::*;
|
|||||||
use getset::{Getters, MutGetters};
|
use getset::{Getters, MutGetters};
|
||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of
|
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
||||||
/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is
|
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
|
||||||
/// a matrix, and so on.
|
/// a vector, a rank 2 tensor is a matrix, and so on.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
@ -24,8 +24,8 @@ pub struct Tensor<T, const R: usize> {
|
|||||||
// ---- Construction and Initialization ---------------------------------------
|
// ---- Construction and Initialization ---------------------------------------
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
/// Create a new tensor with the given shape. The rank of the tensor is determined by the shape
|
/// Create a new tensor with the given shape. The rank of the tensor is
|
||||||
/// and all elements are initialized to zero.
|
/// determined by the shape and all elements are initialized to zero.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::Tensor;
|
/// use manifold::Tensor;
|
||||||
@ -39,7 +39,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
// A rank 0 tensor should still have a buffer with one element
|
// A rank 0 tensor should still have a buffer with one element
|
||||||
1
|
1
|
||||||
} else {
|
} else {
|
||||||
// For tensors of rank 1 or higher, calculate the total size normally
|
// For tensors of rank 1 or higher, calculate the total size
|
||||||
|
// normally
|
||||||
shape.iter().product()
|
shape.iter().product()
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -47,8 +48,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
Self { buffer, shape }
|
Self { buffer, shape }
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Create a new tensor with the given shape and initialize it from the given buffer. The rank
|
/// Create a new tensor with the given shape and initialize it from the
|
||||||
/// of the tensor is determined by the shape.
|
/// given buffer. The rank of the tensor is determined by the shape.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::Tensor;
|
/// use manifold::Tensor;
|
||||||
@ -118,7 +119,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
self.buffer_mut().get_mut(index.flat())
|
self.buffer_mut().get_mut(index.flat())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a mutable reference to a value at the given index without bounds checking.
|
/// Get a mutable reference to a value at the given index without bounds
|
||||||
|
/// checking.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
@ -147,7 +149,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
self.buffer().get(index)
|
self.buffer().get(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a reference to a value at the given flat index without bounds checking.
|
/// Get a reference to a value at the given flat index without bounds
|
||||||
|
/// checking.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
@ -171,7 +174,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
self.buffer_mut().get_mut(index)
|
self.buffer_mut().get_mut(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get a mutable reference to a value at the given flat index without bounds checking.
|
/// Get a mutable reference to a value at the given flat index without
|
||||||
|
/// bounds checking.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
@ -354,9 +358,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
// ---- Reshape ---------------------------------------------------------------
|
// ---- Reshape ---------------------------------------------------------------
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
|
/// Reshape the tensor to the given shape. The total size of the new shape
|
||||||
/// Reshape the tensor to the given shape. The total size of the new shape must be the same as
|
/// must be the same as the total size of the old shape.
|
||||||
/// the total size of the old shape.
|
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, shape, Tensor, TensorShape};
|
/// use manifold::{tensor, shape, Tensor, TensorShape};
|
||||||
@ -366,7 +369,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
/// let t = t.reshape(s).unwrap();
|
/// let t = t.reshape(s).unwrap();
|
||||||
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
|
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
|
||||||
/// ```
|
/// ```
|
||||||
pub fn reshape<const S: usize>(self, shape: TensorShape<S>) -> Result<Tensor<T, S>> {
|
pub fn reshape<const S: usize>(
|
||||||
|
self,
|
||||||
|
shape: TensorShape<S>,
|
||||||
|
) -> Result<Tensor<T, S>> {
|
||||||
if self.shape().size() != shape.size() {
|
if self.shape().size() != shape.size() {
|
||||||
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
||||||
let (lsize, rsize) = (self.shape().size(), shape.size());
|
let (lsize, rsize) = (self.shape().size(), shape.size());
|
||||||
@ -382,9 +388,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
// ---- Transpose -------------------------------------------------------------
|
// ---- Transpose -------------------------------------------------------------
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
|
/// Transpose the tensor according to the given order. The order must be a
|
||||||
/// Transpose the tensor according to the given order. The order must be a permutation of the
|
/// permutation of the tensor's axes.
|
||||||
/// tensor's axes.
|
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor, TensorShape};
|
/// use manifold::{tensor, Tensor, TensorShape};
|
||||||
|
Loading…
Reference in New Issue
Block a user