Refactoring, documentation and dodctests
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
aa7f35f748
commit
f734fa6434
111
src/index.rs
111
src/index.rs
@ -35,10 +35,9 @@ impl<const R: usize> TensorIndex<R> {
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
let index = Self { indices, shape };
|
||||
assert!(index.check_indices(indices));
|
||||
index
|
||||
}
|
||||
|
||||
/// Creates a new `TensorIndex` instance with all indices set to zero.
|
||||
@ -56,6 +55,27 @@ impl<const R: usize> TensorIndex<R> {
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_flat(shape: TensorShape<R>, flat_index: usize) -> Self {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
// - The method iterates over the dimensions of the tensor in reverse order
|
||||
// (assuming row-major order).
|
||||
// - In each iteration, it uses the modulo operation to find the index in
|
||||
// the current dimension and integer division to reduce the flat index
|
||||
// for the next higher dimension.
|
||||
// - This process is repeated for each dimension to build the
|
||||
// multi-dimensional index.
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
// Reverse the indices to match the original dimension order
|
||||
indices.reverse();
|
||||
Self::new(shape.clone(), indices)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Trivial Functions -----------------------------------------------------
|
||||
@ -108,6 +128,13 @@ impl<const R: usize> TensorIndex<R> {
|
||||
pub fn reset(&mut self) {
|
||||
*self.indices_mut() = [0; R];
|
||||
}
|
||||
|
||||
fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.shape().as_array().iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Increment and Decrement -----------------------------------------------
|
||||
@ -157,8 +184,7 @@ impl<const R: usize> TensorIndex<R> {
|
||||
/// index.inc_fixed_axis(Axis(1));
|
||||
/// assert_eq!(index.flat(), 5);
|
||||
/// ```
|
||||
pub fn inc_fixed_axis(&mut self, ax: Axis) {
|
||||
let ax = ax.0;
|
||||
pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) {
|
||||
let shape = self.shape().as_array().clone();
|
||||
assert!(ax < R, "TensorAxis out of bounds");
|
||||
if self.indices()[ax] == self.shape().get(ax) {
|
||||
@ -224,6 +250,17 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Decrements the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.dec();
|
||||
/// assert_eq!(index.flat(), 22);
|
||||
/// ```
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -244,63 +281,6 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
|
||||
// Check if the fixed axis index is already in an invalid state
|
||||
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to decrement non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] > 0 {
|
||||
self.indices[i] -= 1;
|
||||
return true;
|
||||
} else {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the fixed axis if possible and reset other axes to their
|
||||
// max
|
||||
if self.indices[fixed_axis] > 0 {
|
||||
self.indices[fixed_axis] -= 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fixed axis already at minimum, set to invalid state
|
||||
self.indices[fixed_axis] = self.shape.get(fixed_axis);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn dec_transposed(&mut self, order: [usize; R]) {
|
||||
// Iterate over the axes in the specified order
|
||||
for &axis in &order {
|
||||
// Try to decrement the current axis
|
||||
if self.indices[axis] > 0 {
|
||||
self.indices[axis] -= 1;
|
||||
// Reset all preceding axes in the order to their maximum
|
||||
for &prev_axis in &order {
|
||||
if prev_axis == axis {
|
||||
break;
|
||||
}
|
||||
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no axis can be decremented, set the first axis in the order to
|
||||
// indicate overflow
|
||||
self.indices[order[0]] = self.shape.get(order[0]);
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Conversion to Flat Index ----------------------------------------------
|
||||
@ -372,15 +352,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
|
||||
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
// assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
|
||||
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
Self::from_flat(shape, flat_index)
|
||||
}
|
||||
}
|
||||
|
||||
|
240
src/shape.rs
240
src/shape.rs
@ -4,18 +4,90 @@ use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||
use std::fmt::{Formatter, Result as FmtResult};
|
||||
|
||||
/// A tensor's shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = shape!([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorShape<const R: usize>([usize; R]);
|
||||
pub struct TensorShape<const R: usize>(pub(crate) [usize; R]);
|
||||
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Creates a new `TensorShape` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
pub const fn new(shape: [usize; R]) -> Self {
|
||||
Self(shape)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dim_size(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
// ---- Getters ---------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Get the size of the specified dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.get(0), 2);
|
||||
/// assert_eq!(shape.get(1), 3);
|
||||
/// ```
|
||||
pub const fn get(&self, index: usize) -> usize {
|
||||
self.0[index]
|
||||
}
|
||||
|
||||
/// Returns the shape as an array.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.as_array(), &[2, 3]);
|
||||
/// ```
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Returns the size of the shape, meaning the product of all dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.size(), 6);
|
||||
/// ```
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Manipulation ----------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Reorders the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.reorder([2, 0, 1]);
|
||||
/// assert_eq!(new_shape, shape!([4, 2, 3]));
|
||||
/// ```
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = TensorShape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
@ -24,75 +96,16 @@ impl<const R: usize> TensorShape<R> {
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
// pub fn flat_max(&self) -> usize {
|
||||
// self.size() - 1
|
||||
// }
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
pub const fn get(&self, index: usize) -> usize {
|
||||
self.0[index]
|
||||
}
|
||||
|
||||
pub fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.0.iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
|
||||
/// Converts a flat index to a multi-dimensional index.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `flat_index` - The flat index to convert.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `TensorIndex<R>` instance representing the multi-dimensional index
|
||||
/// corresponding to the flat index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over the dimensions of the tensor in reverse order
|
||||
/// (assuming row-major order).
|
||||
/// - In each iteration, it uses the modulo operation to find the index in
|
||||
/// the current dimension and integer division to reduce the flat index
|
||||
/// for the next higher dimension.
|
||||
/// - This process is repeated for each dimension to build the
|
||||
/// multi-dimensional index.
|
||||
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||
TensorIndex::new(self.clone(), indices)
|
||||
}
|
||||
|
||||
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self.clone())
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> TensorIndex<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
TensorIndex::new(self.clone(), max_indices)
|
||||
}
|
||||
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
/// Creates a new shape by removing the specified dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.remove_dims([0, 2]);
|
||||
/// assert_eq!(new_shape, shape!([3]));
|
||||
/// ```
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
&self,
|
||||
dims_to_remove: [usize; NAX],
|
||||
) -> TensorShape<{ R - NAX }> {
|
||||
@ -114,34 +127,43 @@ impl<const R: usize> TensorShape<R> {
|
||||
|
||||
TensorShape(new_shape)
|
||||
}
|
||||
|
||||
// pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||
// &self,
|
||||
// axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
|
||||
// ) -> TensorShape<{ R - NAX }> {
|
||||
// // Create a new array to store the remaining dimensions
|
||||
// let mut new_shape = [0; R - NAX];
|
||||
// let mut new_index = 0;
|
||||
|
||||
// // Iterate over the original dimensions
|
||||
// for (index, &dim) in self.0.iter().enumerate() {
|
||||
// // Skip dimensions that are in the axes_to_remove array
|
||||
// for axis in axes_to_remove {
|
||||
// if *axis.dim() == index {
|
||||
// continue;
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Add the dimension to the new shape array
|
||||
// new_shape[new_index] = dim;
|
||||
// new_index += 1;
|
||||
// }
|
||||
|
||||
// TensorShape(new_shape)
|
||||
// }
|
||||
}
|
||||
|
||||
// ---- Blanket Implementations ----
|
||||
// ---- Iterators -------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Returns an iterator over the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// let mut iter = shape.iter();
|
||||
/// assert_eq!(iter.next(), Some(&2));
|
||||
/// assert_eq!(iter.next(), Some(&3));
|
||||
/// assert_eq!(iter.next(), None);
|
||||
/// ```
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Utils -----------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self.clone())
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> TensorIndex<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
TensorIndex::new(self.clone(), max_indices)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- From ------------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
@ -149,16 +171,6 @@ impl<const R: usize> From<[usize; R]> for TensorShape<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> PartialEq for TensorShape<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for TensorShape<R> {}
|
||||
|
||||
// ---- From and Into Implementations ----
|
||||
|
||||
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
|
||||
where
|
||||
T: Value,
|
||||
@ -168,7 +180,17 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Serialize and Deserialize ----
|
||||
// ---- Equality --------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> PartialEq for TensorShape<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for TensorShape<R> {}
|
||||
|
||||
// ---- Serialization ---------------------------------------------------------
|
||||
|
||||
struct TensorShapeVisitor<const R: usize>;
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user