🚀 Implement Tensor-type and basic methods #15

Merged
julius merged 8 commits from core-types into master 2024-01-03 21:52:54 +00:00
6 changed files with 112 additions and 82 deletions
Showing only changes of commit b68e75814b - Show all commits

View File

@ -1 +1,3 @@
max_width = 80
wrap_comments = true
comment_width = 80

View File

@ -27,7 +27,9 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
assert!(level < self.len(), "Level out of bounds");
let mut index = TensorIndex::new(self.shape(), [0; R]);
index.set_axis(self.dim, level);
TensorAxisIterator::new(self).set_start(level).set_end(level + 1)
TensorAxisIterator::new(self)
.set_start(level)
.set_end(level + 1)
}
}
@ -113,4 +115,4 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
fn into_iter(self) -> Self::IntoIter {
TensorAxisIterator::new(&self)
}
}
}

View File

@ -32,10 +32,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self {
indices,
shape,
}
Self { indices, shape }
}
pub fn is_zero(&self) -> bool {
@ -51,7 +48,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
self.indices = [0; R];
}
/// Increments the index and returns a boolean indicating whether the end has been reached.
/// Increments the index and returns a boolean indicating whether the end
/// has been reached.
///
/// # Returns
/// `true` if the increment does not overflow and is still within bounds;
@ -74,10 +72,12 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
}
}
// If carry is still 1 after the loop, it means we've incremented past the last dimension
// If carry is still 1 after the loop, it means we've incremented past
// the last dimension
if carry == 1 {
// Set the index to an invalid state to indicate the end of the iteration indicated
// by setting the first index to the size of the first dimension
// Set the index to an invalid state to indicate the end of the
// iteration indicated by setting the first index to the
// size of the first dimension
self.indices[0] = self.shape.as_array()[0];
return true; // Indicate that the iteration is complete
}
@ -156,7 +156,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
{
if borrow {
if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
*i = dim_size - 1; // Wrap around to the maximum index of
// this dimension
} else {
*i -= 1; // Decrement the index
borrow = false; // No more borrowing needed
@ -183,7 +184,8 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
}
}
// Decrement the fixed axis if possible and reset other axes to their max
// Decrement the fixed axis if possible and reset other axes to their
// max
if self.indices[fixed_axis] > 0 {
self.indices[fixed_axis] -= 1;
for i in 0..R {
@ -216,41 +218,49 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
}
}
// If no axis can be decremented, set the first axis in the order to indicate overflow
// If no axis can be decremented, set the first axis in the order to
// indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
/// Converts the multi-dimensional index to a flat index.
///
/// This method calculates the flat index corresponding to the multi-dimensional index
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
/// The calculation is based on the assumption that the tensor is stored in row-major order,
/// This method calculates the flat index corresponding to the
/// multi-dimensional index stored in `self.indices`, given the shape of
/// the tensor stored in `self.shape`. The calculation is based on the
/// assumption that the tensor is stored in row-major order,
/// where the last dimension varies the fastest.
///
/// # Returns
/// The flat index corresponding to the multi-dimensional index.
///
/// # How It Works
/// - The method iterates over each pair of corresponding index and dimension size,
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
/// - The method iterates over each pair of corresponding index and
/// dimension size, starting from the last dimension (hence `rev()` is
/// used for reverse iteration).
/// - In each iteration, it performs two main operations:
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
/// of dimension sizes (`product`). This calculates the contribution of the current index
/// to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
/// This updates `product` for the next iteration, as each dimension contributes to the
/// flat index based on the sizes of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an initial state of
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
/// - The final flat index is obtained after the last iteration, which is the first element
/// of the tuple resulting from the `fold`.
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
/// running product of dimension sizes (`product`). This calculates the
/// contribution of the current index to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension
/// size (`dim_size`). This updates `product` for the next iteration,
/// as each dimension contributes to the flat index based on the sizes
/// of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
/// is the initial product.
/// - The final flat index is obtained after the last iteration, which is
/// the first element of the tuple resulting from the `fold`.
///
/// # Example
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
/// - Starting with a flat index of 0 and a product of 1,
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
/// the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
/// the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
/// final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize {
self.indices()
.iter()
@ -327,14 +337,18 @@ impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
}
}
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> for TensorIndex<'a, R> {
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
for TensorIndex<'a, R>
{
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> for TensorIndex<'a, R> {
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
for TensorIndex<'a, R>
{
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
@ -347,7 +361,9 @@ impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
}
}
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for TensorIndex<'a, R> {
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
for TensorIndex<'a, R>
{
fn from(tensor: &'a Tensor<T, R>) -> Self {
Self::zero(tensor.shape())
}

View File

@ -1,5 +1,6 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![warn(clippy::all)]
pub mod axis;
pub mod error;
pub mod index;
@ -8,7 +9,7 @@ pub mod tensor;
pub use axis::*;
pub use index::TensorIndex;
pub use itertools::Itertools;
// pub use itertools::Itertools;
use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize};
pub use shape::TensorShape;
@ -45,14 +46,14 @@ macro_rules! tensor {
#[macro_export]
macro_rules! shape {
($array:expr) => {
TensorShape::from($array)
};
($array:expr) => {
TensorShape::from($array)
};
}
#[macro_export]
macro_rules! index {
($array:expr) => {
TensorIndex::from($array)
};
}
($array:expr) => {
TensorIndex::from($array)
};
}

View File

@ -60,13 +60,17 @@ impl<const R: usize> TensorShape<R> {
/// * `flat_index` - The flat index to convert.
///
/// # Returns
/// An `TensorIndex<R>` instance representing the multi-dimensional index corresponding to the flat index.
/// An `TensorIndex<R>` instance representing the multi-dimensional index
/// corresponding to the flat index.
///
/// # How It Works
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
/// and integer division to reduce the flat index for the next higher dimension.
/// - This process is repeated for each dimension to build the multi-dimensional index.
/// - The method iterates over the dimensions of the tensor in reverse order
/// (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in
/// the current dimension and integer division to reduce the flat index
/// for the next higher dimension.
/// - This process is repeated for each dimension to build the
/// multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
let mut indices = [0; R];
let mut remaining = flat_index;

View File

@ -3,9 +3,9 @@ use crate::error::*;
use getset::{Getters, MutGetters};
use std::fmt;
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of
/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is
/// a matrix, and so on.
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
/// a vector, a rank 2 tensor is a matrix, and so on.
///
/// ```
/// use manifold::{tensor, Tensor};
@ -24,8 +24,8 @@ pub struct Tensor<T, const R: usize> {
// ---- Construction and Initialization ---------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Create a new tensor with the given shape. The rank of the tensor is determined by the shape
/// and all elements are initialized to zero.
/// Create a new tensor with the given shape. The rank of the tensor is
/// determined by the shape and all elements are initialized to zero.
///
/// ```
/// use manifold::Tensor;
@ -39,7 +39,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// A rank 0 tensor should still have a buffer with one element
1
} else {
// For tensors of rank 1 or higher, calculate the total size normally
// For tensors of rank 1 or higher, calculate the total size
// normally
shape.iter().product()
};
@ -47,8 +48,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
Self { buffer, shape }
}
/// Create a new tensor with the given shape and initialize it from the given buffer. The rank
/// of the tensor is determined by the shape.
/// Create a new tensor with the given shape and initialize it from the
/// given buffer. The rank of the tensor is determined by the shape.
///
/// ```
/// use manifold::Tensor;
@ -118,7 +119,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer_mut().get_mut(index.flat())
}
/// Get a mutable reference to a value at the given index without bounds checking.
/// Get a mutable reference to a value at the given index without bounds
/// checking.
///
/// ```
/// use manifold::{tensor, Tensor};
@ -147,7 +149,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer().get(index)
}
/// Get a reference to a value at the given flat index without bounds checking.
/// Get a reference to a value at the given flat index without bounds
/// checking.
///
/// ```
/// use manifold::{tensor, Tensor};
@ -171,7 +174,8 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer_mut().get_mut(index)
}
/// Get a mutable reference to a value at the given flat index without bounds checking.
/// Get a mutable reference to a value at the given flat index without
/// bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
@ -354,19 +358,21 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Reshape ---------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Reshape the tensor to the given shape. The total size of the new shape must be the same as
/// the total size of the old shape.
///
/// ```
/// use manifold::{tensor, shape, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let s = shape!([4]);
/// let t = t.reshape(s).unwrap();
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
/// ```
pub fn reshape<const S: usize>(self, shape: TensorShape<S>) -> Result<Tensor<T, S>> {
/// Reshape the tensor to the given shape. The total size of the new shape
/// must be the same as the total size of the old shape.
///
/// ```
/// use manifold::{tensor, shape, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let s = shape!([4]);
/// let t = t.reshape(s).unwrap();
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
/// ```
pub fn reshape<const S: usize>(
self,
shape: TensorShape<S>,
) -> Result<Tensor<T, S>> {
if self.shape().size() != shape.size() {
let (ls, rs) = (self.shape().as_array(), shape.as_array());
let (lsize, rsize) = (self.shape().size(), shape.size());
@ -382,17 +388,16 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Transpose -------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Transpose the tensor according to the given order. The order must be a permutation of the
/// tensor's axes.
///
/// ```
/// use manifold::{tensor, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let t = t.transpose([1, 0]).unwrap();
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
/// Transpose the tensor according to the given order. The order must be a
/// permutation of the tensor's axes.
///
/// ```
/// use manifold::{tensor, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let t = t.transpose([1, 0]).unwrap();
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape())
.iter_transposed(order)