Refactoring, documentation and dodctests

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-06 01:36:20 +02:00
parent aa7f35f748
commit f734fa6434
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
2 changed files with 176 additions and 175 deletions

View File

@ -35,10 +35,9 @@ impl<const R: usize> TensorIndex<R> {
/// assert_eq!(index.flat(), 23);
/// ```
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
let index = Self { indices, shape };
assert!(index.check_indices(indices));
index
}
/// Creates a new `TensorIndex` instance with all indices set to zero.
@ -56,6 +55,27 @@ impl<const R: usize> TensorIndex<R> {
shape,
}
}
pub fn from_flat(shape: TensorShape<R>, flat_index: usize) -> Self {
let mut indices = [0; R];
let mut remaining = flat_index;
// - The method iterates over the dimensions of the tensor in reverse order
// (assuming row-major order).
// - In each iteration, it uses the modulo operation to find the index in
// the current dimension and integer division to reduce the flat index
// for the next higher dimension.
// - This process is repeated for each dimension to build the
// multi-dimensional index.
for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() {
*idx = remaining % dim_size;
remaining /= dim_size;
}
// Reverse the indices to match the original dimension order
indices.reverse();
Self::new(shape.clone(), indices)
}
}
// ---- Trivial Functions -----------------------------------------------------
@ -108,6 +128,13 @@ impl<const R: usize> TensorIndex<R> {
pub fn reset(&mut self) {
*self.indices_mut() = [0; R];
}
fn check_indices(&self, indices: [usize; R]) -> bool {
indices
.iter()
.zip(self.shape().as_array().iter())
.all(|(&idx, &dim_size)| idx < dim_size)
}
}
// ---- Increment and Decrement -----------------------------------------------
@ -157,8 +184,7 @@ impl<const R: usize> TensorIndex<R> {
/// index.inc_fixed_axis(Axis(1));
/// assert_eq!(index.flat(), 5);
/// ```
pub fn inc_fixed_axis(&mut self, ax: Axis) {
let ax = ax.0;
pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) {
let shape = self.shape().as_array().clone();
assert!(ax < R, "TensorAxis out of bounds");
if self.indices()[ax] == self.shape().get(ax) {
@ -224,6 +250,17 @@ impl<const R: usize> TensorIndex<R> {
}
}
/// Decrements the index by one.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
/// assert_eq!(index.flat(), 23);
/// index.dec();
/// assert_eq!(index.flat(), 22);
/// ```
pub fn dec(&mut self) {
// Check if already at the start
if self.indices.iter().all(|&i| i == 0) {
@ -244,63 +281,6 @@ impl<const R: usize> TensorIndex<R> {
}
}
}
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
// Check if the fixed axis index is already in an invalid state
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
return false;
}
// Try to decrement non-fixed axes
for i in (0..R).rev() {
if i != fixed_axis {
if self.indices[i] > 0 {
self.indices[i] -= 1;
return true;
} else {
self.indices[i] = self.shape.get(i) - 1;
}
}
}
// Decrement the fixed axis if possible and reset other axes to their
// max
if self.indices[fixed_axis] > 0 {
self.indices[fixed_axis] -= 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = self.shape.get(i) - 1;
}
}
} else {
// Fixed axis already at minimum, set to invalid state
self.indices[fixed_axis] = self.shape.get(fixed_axis);
}
true
}
pub fn dec_transposed(&mut self, order: [usize; R]) {
// Iterate over the axes in the specified order
for &axis in &order {
// Try to decrement the current axis
if self.indices[axis] > 0 {
self.indices[axis] -= 1;
// Reset all preceding axes in the order to their maximum
for &prev_axis in &order {
if prev_axis == axis {
break;
}
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
}
return;
}
}
// If no axis can be decremented, set the first axis in the order to
// indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
}
// ---- Conversion to Flat Index ----------------------------------------------
@ -372,15 +352,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
// assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
Self::from_flat(shape, flat_index)
}
}

View File

@ -4,18 +4,90 @@ use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt::{Formatter, Result as FmtResult};
/// A tensor's shape.
///
/// ```
/// use manifold::*;
///
/// let shape = shape!([2, 3]);
/// assert_eq!(shape.dim_size(0), Some(&2));
/// assert_eq!(shape.dim_size(1), Some(&3));
/// assert_eq!(shape.dim_size(2), None);
/// ```
#[derive(Clone, Copy, Debug)]
pub struct TensorShape<const R: usize>([usize; R]);
pub struct TensorShape<const R: usize>(pub(crate) [usize; R]);
// ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorShape<R> {
/// Creates a new `TensorShape` instance.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3]);
/// assert_eq!(shape.dim_size(0), Some(&2));
/// assert_eq!(shape.dim_size(1), Some(&3));
/// assert_eq!(shape.dim_size(2), None);
/// ```
pub const fn new(shape: [usize; R]) -> Self {
Self(shape)
}
}
pub fn dim_size(&self, index: usize) -> Option<&usize> {
self.0.get(index)
// ---- Getters ---------------------------------------------------------------
impl<const R: usize> TensorShape<R> {
/// Get the size of the specified dimension.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3]);
/// assert_eq!(shape.get(0), 2);
/// assert_eq!(shape.get(1), 3);
/// ```
pub const fn get(&self, index: usize) -> usize {
self.0[index]
}
/// Returns the shape as an array.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3]);
/// assert_eq!(shape.as_array(), &[2, 3]);
/// ```
pub const fn as_array(&self) -> &[usize; R] {
&self.0
}
/// Returns the size of the shape, meaning the product of all dimensions.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3]);
/// assert_eq!(shape.size(), 6);
/// ```
pub fn size(&self) -> usize {
self.0.iter().product()
}
}
// ---- Manipulation ----------------------------------------------------------
impl<const R: usize> TensorShape<R> {
/// Reorders the dimensions of the shape.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let new_shape = shape.reorder([2, 0, 1]);
/// assert_eq!(new_shape, shape!([4, 2, 3]));
/// ```
pub fn reorder(&self, indices: [usize; R]) -> Self {
let mut new_shape = TensorShape::new([0; R]);
for (new_index, &index) in indices.iter().enumerate() {
@ -24,75 +96,16 @@ impl<const R: usize> TensorShape<R> {
new_shape
}
pub const fn as_array(&self) -> &[usize; R] {
&self.0
}
// pub fn flat_max(&self) -> usize {
// self.size() - 1
// }
pub fn size(&self) -> usize {
self.0.iter().product()
}
pub fn iter(&self) -> impl Iterator<Item = &usize> {
self.0.iter()
}
pub const fn get(&self, index: usize) -> usize {
self.0[index]
}
pub fn check_indices(&self, indices: [usize; R]) -> bool {
indices
.iter()
.zip(self.0.iter())
.all(|(&idx, &dim_size)| idx < dim_size)
}
/// Converts a flat index to a multi-dimensional index.
///
/// # Arguments
/// * `flat_index` - The flat index to convert.
///
/// # Returns
/// An `TensorIndex<R>` instance representing the multi-dimensional index
/// corresponding to the flat index.
///
/// # How It Works
/// - The method iterates over the dimensions of the tensor in reverse order
/// (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in
/// the current dimension and integer division to reduce the flat index
/// for the next higher dimension.
/// - This process is repeated for each dimension to build the
/// multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
let mut indices = [0; R];
let mut remaining = flat_index;
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
*idx = remaining % dim_size;
remaining /= dim_size;
}
indices.reverse(); // Reverse the indices to match the original dimension order
TensorIndex::new(self.clone(), indices)
}
pub fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self.clone())
}
pub fn index_max(&self) -> TensorIndex<R> {
let max_indices =
self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
TensorIndex::new(self.clone(), max_indices)
}
pub fn remove_dims<const NAX: usize>(
/// Creates a new shape by removing the specified dimensions.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let new_shape = shape.remove_dims([0, 2]);
/// assert_eq!(new_shape, shape!([3]));
/// ```
pub fn remove_dims<const NAX: usize>(
&self,
dims_to_remove: [usize; NAX],
) -> TensorShape<{ R - NAX }> {
@ -114,34 +127,43 @@ impl<const R: usize> TensorShape<R> {
TensorShape(new_shape)
}
// pub fn remove_axes<'a, T: Value, const NAX: usize>(
// &self,
// axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
// ) -> TensorShape<{ R - NAX }> {
// // Create a new array to store the remaining dimensions
// let mut new_shape = [0; R - NAX];
// let mut new_index = 0;
// // Iterate over the original dimensions
// for (index, &dim) in self.0.iter().enumerate() {
// // Skip dimensions that are in the axes_to_remove array
// for axis in axes_to_remove {
// if *axis.dim() == index {
// continue;
// }
// }
// // Add the dimension to the new shape array
// new_shape[new_index] = dim;
// new_index += 1;
// }
// TensorShape(new_shape)
// }
}
// ---- Blanket Implementations ----
// ---- Iterators -------------------------------------------------------------
impl<const R: usize> TensorShape<R> {
/// Returns an iterator over the dimensions of the shape.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3]);
/// let mut iter = shape.iter();
/// assert_eq!(iter.next(), Some(&2));
/// assert_eq!(iter.next(), Some(&3));
/// assert_eq!(iter.next(), None);
/// ```
pub fn iter(&self) -> impl Iterator<Item = &usize> {
self.0.iter()
}
}
// ---- Utils -----------------------------------------------------------------
impl<const R: usize> TensorShape<R> {
pub fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self.clone())
}
pub fn index_max(&self) -> TensorIndex<R> {
let max_indices =
self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
TensorIndex::new(self.clone(), max_indices)
}
}
// ---- From ------------------------------------------------------------------
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
fn from(shape: [usize; R]) -> Self {
@ -149,16 +171,6 @@ impl<const R: usize> From<[usize; R]> for TensorShape<R> {
}
}
impl<const R: usize> PartialEq for TensorShape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for TensorShape<R> {}
// ---- From and Into Implementations ----
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
where
T: Value,
@ -168,7 +180,17 @@ where
}
}
// ---- Serialize and Deserialize ----
// ---- Equality --------------------------------------------------------------
impl<const R: usize> PartialEq for TensorShape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for TensorShape<R> {}
// ---- Serialization ---------------------------------------------------------
struct TensorShapeVisitor<const R: usize>;