Major refactor, documentation, tests

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-05 01:48:15 +02:00
parent 9b53301513
commit aa7f35f748
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
11 changed files with 440 additions and 639 deletions

1
Cargo.lock generated
View File

@ -302,6 +302,7 @@ dependencies = [
"ndarray", "ndarray",
"num", "num",
"rand", "rand",
"rayon",
"serde", "serde",
"serde_json", "serde_json",
"static_assertions", "static_assertions",

View File

@ -19,6 +19,7 @@ bytemuck = "1.14.0"
getset = "0.1.2" getset = "0.1.2"
itertools = "0.12.0" itertools = "0.12.0"
num = "0.4.1" num = "0.4.1"
rayon = "1.8.0"
serde = { version = "1.0.193", features = ["derive"] } serde = { version = "1.0.193", features = ["derive"] }
thiserror = "1.0.52" thiserror = "1.0.52"

View File

@ -3,9 +3,11 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
use manifold::*; use manifold::*;
use rand::Rng; use rand::Rng;
const SIZE: usize = 10000;
fn random_tensor_r2_manifold() -> Tensor<f64, 2> { fn random_tensor_r2_manifold() -> Tensor<f64, 2> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let mut tensor = tensor!([[0.0; 1000]; 1000]); let mut tensor = tensor!([[0.0; SIZE]; SIZE]);
for i in 0..tensor.len() { for i in 0..tensor.len() {
tensor[i] = rng.gen(); tensor[i] = rng.gen();
} }
@ -14,7 +16,7 @@ fn random_tensor_r2_manifold() -> Tensor<f64, 2> {
fn random_tensor_r2_ndarray() -> ndarray::Array2<f64> { fn random_tensor_r2_ndarray() -> ndarray::Array2<f64> {
let mut rng = rand::thread_rng(); let mut rng = rand::thread_rng();
let (rows, cols) = (1000, 1000); let (rows, cols) = (SIZE, SIZE);
let mut tensor = ndarray::Array2::<f64>::zeros((rows, cols)); let mut tensor = ndarray::Array2::<f64>::zeros((rows, cols));
for i in 0..rows { for i in 0..rows {
for j in 0..cols { for j in 0..cols {
@ -40,7 +42,7 @@ fn tensor_product(c: &mut Criterion) {
let a = random_tensor_r2_manifold(); let a = random_tensor_r2_manifold();
let b = random_tensor_r2_manifold(); let b = random_tensor_r2_manifold();
let c = a + b; let c = a + b;
assert!(c.shape().as_array() == &[1000, 1000]); assert!(c.shape().as_array() == &[SIZE, SIZE]);
}) })
}, },
); );
@ -53,7 +55,7 @@ fn tensor_product(c: &mut Criterion) {
let a = random_tensor_r2_ndarray(); let a = random_tensor_r2_ndarray();
let b = random_tensor_r2_ndarray(); let b = random_tensor_r2_ndarray();
let c = a + b; let c = a + b;
assert!(c.shape() == &[1000, 1000]); assert!(c.shape() == &[SIZE, SIZE]);
}) })
}, },
); );

View File

@ -1,118 +0,0 @@
use super::*;
use getset::{Getters, MutGetters};
#[derive(Clone, Debug, Getters)]
pub struct TensorAxis<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
tensor: &'a Tensor<T, R>,
#[getset(get = "pub")]
dim: usize,
}
impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
assert!(dim < R, "TensorAxis out of bounds");
Self { tensor, dim }
}
pub fn len(&self) -> usize {
self.tensor.shape().get(self.dim)
}
pub fn shape(&self) -> &TensorShape<R> {
self.tensor.shape()
}
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
assert!(level < self.len(), "Level out of bounds");
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
index.set_axis(self.dim, level);
TensorAxisIterator::new(self)
.set_start(level)
.set_end(level + 1)
}
}
#[derive(Clone, Debug, Getters, MutGetters)]
pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
axis: &'a TensorAxis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")]
index: TensorIndex<R>,
#[getset(get = "pub")]
end: Option<usize>,
}
impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
Self {
axis,
index: TensorIndex::new(axis.shape().clone(), [0; R]),
end: None,
}
}
pub fn set_start(self, start: usize) -> Self {
assert!(start < self.axis().len(), "Start out of bounds");
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
index.set_axis(self.axis.dim, start);
Self {
axis: self.axis(),
index,
end: None,
}
}
pub fn set_end(self, end: usize) -> Self {
assert!(end <= self.axis().len(), "End out of bounds");
Self {
axis: self.axis(),
index: self.index().clone(),
end: Some(end),
}
}
pub fn set_level(self, level: usize) -> Self {
assert!(level < self.axis().len(), "Level out of bounds");
self.set_start(level).set_end(level + 1)
}
pub fn level(&'a self, level: usize) -> impl Iterator<Item = &'a T> + 'a {
Self::new(self.axis()).set_level(level)
}
pub fn axis_max_idx(&self) -> usize {
self.end().unwrap_or(self.axis().len())
}
pub fn axis_idx(&self) -> usize {
self.index().get_axis(*self.axis().dim())
}
pub fn axis_dim(&self) -> usize {
self.axis().dim().clone()
}
}
impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.axis_idx() == self.axis_max_idx() {
return None;
}
let result = unsafe { self.axis().tensor().get_unchecked(self.index) };
let axis_dim = self.axis_dim();
self.index_mut().inc_axis(axis_dim);
Some(result)
}
}
impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
type Item = &'a T;
type IntoIter = TensorAxisIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
TensorAxisIterator::new(&self)
}
}

View File

@ -2,9 +2,18 @@ use super::*;
use getset::{Getters, MutGetters}; use getset::{Getters, MutGetters};
use std::{ use std::{
cmp::Ordering, cmp::Ordering,
ops::{Add, Index, IndexMut, Sub}, ops::{Index, IndexMut},
}; };
/// A multi-dimensional index into a tensor.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let index = TensorIndex::new(shape, [1, 2, 3]);
/// assert_eq!(index.flat(), 23);
/// ```
#[derive(Clone, Copy, Debug, Getters, MutGetters)] #[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct TensorIndex<const R: usize> { pub struct TensorIndex<const R: usize> {
#[getset(get = "pub", get_mut = "pub")] #[getset(get = "pub", get_mut = "pub")]
@ -16,6 +25,15 @@ pub struct TensorIndex<const R: usize> {
// ---- Construction and Initialization --------------------------------------- // ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorIndex<R> { impl<const R: usize> TensorIndex<R> {
/// Creates a new `TensorIndex` instance.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let index = TensorIndex::new(shape, [1, 2, 3]);
/// assert_eq!(index.flat(), 23);
/// ```
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self { pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) { if !shape.check_indices(indices) {
panic!("indices out of bounds"); panic!("indices out of bounds");
@ -23,107 +41,167 @@ impl<const R: usize> TensorIndex<R> {
Self { indices, shape } Self { indices, shape }
} }
/// Creates a new `TensorIndex` instance with all indices set to zero.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let index = TensorIndex::zero(shape);
/// assert_eq!(index.flat(), 0);
/// ```
pub const fn zero(shape: TensorShape<R>) -> Self { pub const fn zero(shape: TensorShape<R>) -> Self {
Self { Self {
indices: [0; R], indices: [0; R],
shape, shape,
} }
} }
}
pub fn last(shape: TensorShape<R>) -> Self { // ---- Trivial Functions -----------------------------------------------------
let max_indices =
shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); impl<const R: usize> TensorIndex<R> {
Self { /// Returns `true` if all indices are zero.
indices: max_indices, ///
shape, /// ```
} /// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let index = TensorIndex::zero(shape);
/// assert!(index.is_zero());
/// ```
pub fn is_zero(&self) -> bool {
self.indices().iter().all(|&i| i == 0)
}
/// Returns `true` if the last index is equal to the size of the last
/// dimension.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([1, 1]);
/// let mut index = TensorIndex::zero(shape);
/// assert!(!index.is_max());
/// index.inc();
/// assert!(index.is_max());
/// ```
pub fn is_max(&self) -> bool {
self.indices()[0] == self.shape().get(0)
} }
} }
// ---- Utils -----------------------------------------------------------------
impl<const R: usize> TensorIndex<R> { impl<const R: usize> TensorIndex<R> {
pub fn is_zero(&self) -> bool { /// Resets the index to zero.
self.indices.iter().all(|&i| i == 0) ///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
/// assert_eq!(index.flat(), 23);
/// index.reset();
/// assert_eq!(index.flat(), 0);
/// ```
pub fn reset(&mut self) {
*self.indices_mut() = [0; R];
} }
}
pub fn is_overflow(&self) -> bool { // ---- Increment and Decrement -----------------------------------------------
// Check if the last index is equal to the size of the last dimension
self.indices[0] >= self.shape.get(R - 1)
}
pub fn reset(&mut self) { impl<const R: usize> TensorIndex<R> {
self.indices = [0; R]; /// Increments the index by one.
} ///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let mut index = TensorIndex::zero(shape);
/// assert_eq!(index.flat(), 0);
/// index.inc();
/// assert_eq!(index.flat(), 1);
/// ```
pub fn inc(&mut self) {
if self.indices()[0] == self.shape().get(0) {
return;
}
/// Increments the index and returns a boolean indicating whether the end let shape = self.shape().as_array().clone();
/// has been reached.
///
/// # Returns
/// `true` if the increment does not overflow and is still within bounds;
/// `false` if it overflows, indicating the end of the tensor.
pub fn inc(&mut self) -> bool {
if self.indices()[0] >= self.shape().get(0) {
return false;
}
let shape = self.shape().as_array().clone();
let mut carry = 1;
for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
if carry == 1 {
*i += 1;
if *i >= dim_size {
*i = 0; // Reset index in this dimension and carry over
} else {
carry = 0; // Increment successful, no carry needed
}
}
}
// If carry is still 1 after the loop, it means we've incremented past for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
// the last dimension *dim += 1;
if carry == 1 { if *dim < shape[i] {
// Set the index to an invalid state to indicate the end of the return;
// iteration indicated by setting the first index to the }
// size of the first dimension *dim = 0;
self.indices[0] = self.shape.as_array()[0]; }
return true; // Indicate that the iteration is complete
}
false
}
// fn inc_axis self.indices_mut()[0] = self.shape().get(0);
}
pub fn inc_axis(&mut self, fixed_axis: usize) { /// Increments the index by one, with the specified axis fixed,
assert!(fixed_axis < R, "TensorAxis out of bounds"); /// i.e. the index of the fixed axis is incremented only if the
assert!( /// index of the other axes reaches the maximum.
self.indices()[fixed_axis] < self.shape().get(fixed_axis), ///
"Index out of bounds" /// ```
); /// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let mut index = TensorIndex::zero(shape);
/// assert_eq!(index.flat(), 0);
/// index.inc_fixed_axis(Axis(0));
/// assert_eq!(index.flat(), 4);
/// index.inc_fixed_axis(Axis(1));
/// assert_eq!(index.flat(), 5);
/// ```
pub fn inc_fixed_axis(&mut self, ax: Axis) {
let ax = ax.0;
let shape = self.shape().as_array().clone();
assert!(ax < R, "TensorAxis out of bounds");
if self.indices()[ax] == self.shape().get(ax) {
return;
}
// Try to increment non-fixed axes // Iterate over all axes, skipping the fixed axis 'ax'
for i in (0..R).rev() { for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
if i != fixed_axis { if i != ax { // Skip the fixed axis
if self.indices[i] + 1 < self.shape.get(i) { *dim += 1;
self.indices[i] += 1; if *dim < shape[i] {
return; return; // No carry over needed
} else { }
self.indices[i] = 0; *dim = 0; // Reset the current axis and carry over to the next
} }
} }
}
if self.indices[fixed_axis] < self.shape.get(fixed_axis) { // Handle the case where incrementing has reached the end
self.indices[fixed_axis] += 1; if self.indices()[ax] < self.shape().get(ax) {
for i in 0..R { self.indices_mut()[ax] += 1;
if i != fixed_axis { } else {
self.indices[i] = 0; // Reset if the fixed axis also overflows
} self.indices_mut()[ax] = self.shape().get(ax);
} }
return; }
}
}
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool { /// Increments the index by one, reordering the order in which the
/// axes are incremented.
///
/// ```
/// use manifold::*;
///
/// let shape = TensorShape::new([2, 3, 4]);
/// let mut index = TensorIndex::zero(shape);
/// assert_eq!(index.flat(), 0);
/// index.inc_transposed(&[2, 1, 0]);
/// assert_eq!(index.flat(), 12);
/// index.inc_transposed(&[0, 1, 2]);
/// assert_eq!(index.flat(), 13);
/// ```
pub fn inc_transposed(&mut self, order: &[usize; R]) {
if self.indices()[order[0]] >= self.shape().get(order[0]) { if self.indices()[order[0]] >= self.shape().get(order[0]) {
return false; return ;
} }
let mut carry = 1; let mut carry = 1;
@ -143,10 +221,7 @@ impl<const R: usize> TensorIndex<R> {
if carry == 1 { if carry == 1 {
self.indices_mut()[order[0]] = self.shape().get(order[0]); self.indices_mut()[order[0]] = self.shape().get(order[0]);
return true;
} }
false
} }
pub fn dec(&mut self) { pub fn dec(&mut self) {
@ -226,45 +301,23 @@ impl<const R: usize> TensorIndex<R> {
// indicate overflow // indicate overflow
self.indices[order[0]] = self.shape.get(order[0]); self.indices[order[0]] = self.shape.get(order[0]);
} }
}
// ---- Conversion to Flat Index ----------------------------------------------
impl <const R: usize> TensorIndex<R> {
/// Converts the multi-dimensional index to a flat index. /// Converts the multi-dimensional index to a flat index.
/// ///
/// This method calculates the flat index corresponding to the /// # Examples
/// multi-dimensional index stored in `self.indices`, given the shape of ///
/// the tensor stored in `self.shape`. The calculation is based on the /// ```
/// assumption that the tensor is stored in row-major order, /// use manifold::*;
/// where the last dimension varies the fastest. ///
/// /// let shape = TensorShape::new([2, 3, 4]);
/// # Returns /// let index = TensorIndex::new(shape, [1, 2, 3]);
/// The flat index corresponding to the multi-dimensional index. /// assert_eq!(index.flat(), 23);
/// /// ```
/// # How It Works
/// - The method iterates over each pair of corresponding index and
/// dimension size, starting from the last dimension (hence `rev()` is
/// used for reverse iteration).
/// - In each iteration, it performs two main operations:
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
/// running product of dimension sizes (`product`). This calculates the
/// contribution of the current index to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension
/// size (`dim_size`). This updates `product` for the next iteration,
/// as each dimension contributes to the flat index based on the sizes
/// of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
/// is the initial product.
/// - The final flat index is obtained after the last iteration, which is
/// the first element of the tuple resulting from the `fold`.
///
/// # Example
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
/// - Starting with a flat index of 0 and a product of 1,
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
/// the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
/// the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
/// final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize { pub fn flat(&self) -> usize {
self.indices() self.indices()
.iter() .iter()
@ -275,37 +328,9 @@ impl<const R: usize> TensorIndex<R> {
}) })
.0 .0
} }
pub fn set_axis(&mut self, axis: usize, value: usize) {
assert!(axis < R, "TensorAxis out of bounds");
// assert!(value < self.shape.get(axis), "Value out of bounds");
self.indices[axis] = value;
}
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
assert!(axis < R, "TensorAxis out of bounds");
if value < self.shape.get(axis) {
self.indices[axis] = value;
true
} else {
false
}
}
pub fn get_axis(&self, axis: usize) -> usize {
assert!(axis < R, "TensorAxis out of bounds");
self.indices[axis]
}
pub fn iter_transposed(
&self,
order: [usize; R],
) -> TensorIndexTransposedIterator<R> {
TensorIndexTransposedIterator::new(self.shape().clone(), order)
}
} }
// --- blanket impls --- // --- Equality and Ordering --------------------------------------------------
impl<const R: usize> PartialEq for TensorIndex<R> { impl<const R: usize> PartialEq for TensorIndex<R> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
@ -327,6 +352,8 @@ impl<const R: usize> Ord for TensorIndex<R> {
} }
} }
// ---- Indexing --------------------------------------------------------------
impl<const R: usize> Index<usize> for TensorIndex<R> { impl<const R: usize> Index<usize> for TensorIndex<R> {
type Output = usize; type Output = usize;
@ -341,6 +368,8 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
} }
} }
// ---- Conversion to TensorIndex ---------------------------------------------
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> { impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self { fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices)); assert!(shape.check_indices(indices));
@ -367,6 +396,8 @@ impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
} }
} }
// ---- Display ---------------------------------------------------------------
impl<const R: usize> std::fmt::Display for TensorIndex<R> { impl<const R: usize> std::fmt::Display for TensorIndex<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?; write!(f, "[")?;
@ -384,113 +415,3 @@ impl<const R: usize> std::fmt::Display for TensorIndex<R> {
write!(f, "]") write!(f, "]")
} }
} }
// ---- Arithmetic Operations ----
impl<const R: usize> Add for TensorIndex<R> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
result_indices[i] = self.indices[i] + rhs.indices[i];
}
Self {
indices: result_indices,
shape: self.shape,
}
}
}
impl<const R: usize> Sub for TensorIndex<R> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]);
}
Self {
indices: result_indices,
shape: self.shape,
}
}
}
// ---- Iterator ----
pub struct TensorIndexIterator<const R: usize> {
current: TensorIndex<R>,
end: bool,
}
impl<const R: usize> TensorIndexIterator<R> {
pub fn new(shape: TensorShape<R>) -> Self {
Self {
current: TensorIndex::zero(shape),
end: false,
}
}
}
impl<const R: usize> Iterator for TensorIndexIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
return None;
}
let result = self.current;
self.end = self.current.inc();
Some(result)
}
}
impl<const R: usize> IntoIterator for TensorIndex<R> {
type Item = TensorIndex<R>;
type IntoIter = TensorIndexIterator<R>;
fn into_iter(self) -> Self::IntoIter {
TensorIndexIterator {
current: self,
end: false,
}
}
}
pub struct TensorIndexTransposedIterator<const R: usize> {
current: TensorIndex<R>,
order: [usize; R],
end: bool,
}
impl<const R: usize> TensorIndexTransposedIterator<R> {
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
Self {
current: TensorIndex::zero(shape),
end: false,
order,
}
}
}
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
return None;
}
let result = self.current;
self.end = self.current.inc_transposed(&self.order);
Some(result)
}
}

71
src/iterators.rs Normal file
View File

@ -0,0 +1,71 @@
use super::*;
use std::fmt::{Display, Formatter, Result as FmtResult};
// ---- Iterator --------------------------------------------------------------
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
index: TensorIndex<R>,
}
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
Self {
tensor,
index: tensor.shape().index_zero(),
}
}
}
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.index.is_max() {
return None;
}
let result = unsafe { self.tensor.get_unchecked(self.index) };
self.index.inc();
Some(result)
}
}
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
type Item = &'a T;
type IntoIter = TensorIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
TensorIterator::new(self)
}
}
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
// Print the current index and flat index
write!(
f,
"Current Index: {}, Flat Index: {}",
self.index,
self.index.flat()
)?;
// Print the tensor elements, highlighting the current element
write!(f, ", Tensor Elements: [")?;
for (i, elem) in self.tensor.buffer().iter().enumerate() {
if i == self.index.flat() {
write!(f, "*{}*", elem)?; // Highlight the current element
} else {
write!(f, "{}", elem)?;
}
if i < self.tensor.buffer().len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")
}
}
// ---- Axis Iterator ---------------------------------------------------------
// ---- Transposed Iterator ---------------------------------------------------

View File

@ -2,14 +2,36 @@
#![feature(generic_const_exprs)] #![feature(generic_const_exprs)]
#![warn(clippy::all)] #![warn(clippy::all)]
pub mod axis;
pub mod error; pub mod error;
pub mod index; pub mod index;
pub mod shape; pub mod shape;
pub mod tensor; pub mod tensor;
pub mod value; pub mod iterators;
pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*}; pub use {iterators::*, error::*, index::*, shape::*, tensor::*};
use num::{Num, One, Zero};
use serde::{Deserialize, Serialize};
use std::{fmt::Display, iter::Sum};
/// A trait for types that can be used as values in a tensor.
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ Sum
{
}
#[macro_export] #[macro_export]
macro_rules! tensor { macro_rules! tensor {

View File

@ -12,7 +12,7 @@ impl<const R: usize> TensorShape<R> {
Self(shape) Self(shape)
} }
pub fn axis(&self, index: usize) -> Option<&usize> { pub fn dim_size(&self, index: usize) -> Option<&usize> {
self.0.get(index) self.0.get(index)
} }
@ -28,13 +28,9 @@ impl<const R: usize> TensorShape<R> {
&self.0 &self.0
} }
pub const fn rank(&self) -> usize { // pub fn flat_max(&self) -> usize {
R // self.size() - 1
} // }
pub fn flat_max(&self) -> usize {
self.size() - 1
}
pub fn size(&self) -> usize { pub fn size(&self) -> usize {
self.0.iter().product() self.0.iter().product()
@ -119,29 +115,56 @@ impl<const R: usize> TensorShape<R> {
TensorShape(new_shape) TensorShape(new_shape)
} }
pub fn remove_axes<'a, T: Value, const NAX: usize>( // pub fn remove_axes<'a, T: Value, const NAX: usize>(
&self, // &self,
axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX], // axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
) -> TensorShape<{ R - NAX }> { // ) -> TensorShape<{ R - NAX }> {
// Create a new array to store the remaining dimensions // // Create a new array to store the remaining dimensions
let mut new_shape = [0; R - NAX]; // let mut new_shape = [0; R - NAX];
let mut new_index = 0; // let mut new_index = 0;
// Iterate over the original dimensions // // Iterate over the original dimensions
for (index, &dim) in self.0.iter().enumerate() { // for (index, &dim) in self.0.iter().enumerate() {
// Skip dimensions that are in the axes_to_remove array // // Skip dimensions that are in the axes_to_remove array
for axis in axes_to_remove { // for axis in axes_to_remove {
if *axis.dim() == index { // if *axis.dim() == index {
continue; // continue;
} // }
} // }
// Add the dimension to the new shape array // // Add the dimension to the new shape array
new_shape[new_index] = dim; // new_shape[new_index] = dim;
new_index += 1; // new_index += 1;
} // }
TensorShape(new_shape) // TensorShape(new_shape)
// }
}
// ---- Blanket Implementations ----
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
fn from(shape: [usize; R]) -> Self {
Self::new(shape)
}
}
impl<const R: usize> PartialEq for TensorShape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for TensorShape<R> {}
// ---- From and Into Implementations ----
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
where
T: Value,
{
fn from(tensor: Tensor<T, R>) -> Self {
*tensor.shape()
} }
} }
@ -191,30 +214,3 @@ impl<const R: usize> Serialize for TensorShape<R> {
seq.end() seq.end()
} }
} }
// ---- Blanket Implementations ----
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
fn from(shape: [usize; R]) -> Self {
Self::new(shape)
}
}
impl<const R: usize> PartialEq for TensorShape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for TensorShape<R> {}
// ---- From and Into Implementations ----
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
where
T: Value,
{
fn from(tensor: Tensor<T, R>) -> Self {
*tensor.shape()
}
}

View File

@ -28,6 +28,9 @@ pub struct Tensor<T, const R: usize> {
shape: TensorShape<R>, shape: TensorShape<R>,
} }
/// A type that represents an axis of a tensor.
pub struct Axis(pub usize);
// ---- Construction and Initialization --------------------------------------- // ---- Construction and Initialization ---------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
@ -74,10 +77,27 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Trivial Getters ------------------------------------------------------- // ---- Trivial Getters -------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> { impl<T: Value, const R: usize> Tensor<T, R> {
/// Get the rank of the tensor.
///
/// ```
/// use manifold::Tensor;
///
/// let t = Tensor::<f64, 2>::new([3, 3].into());
/// assert_eq!(t.rank(), 2);
/// ```
pub fn rank(&self) -> usize { pub fn rank(&self) -> usize {
R R
} }
/// Get the length of the tensor's buffer.
///
/// ```
/// use manifold::Tensor;
///
/// let t = Tensor::<f64, 2>::new([3, 3].into());
/// assert_eq!(t.len(), 9);
/// ```
pub fn len(&self) -> usize { pub fn len(&self) -> usize {
self.buffer().len() self.buffer().len()
} }
@ -390,32 +410,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
} }
} }
// ---- Transpose -------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Transpose the tensor according to the given order. The order must be a
/// permutation of the tensor's axes.
///
/// ```
/// use manifold::{tensor, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let t = t.transpose([1, 0]).unwrap();
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape().clone())
.iter_transposed(order)
.map(|index| self.get(index).unwrap().clone())
.collect();
Ok(Tensor {
buffer,
shape: self.shape().reorder(order),
})
}
}
// ---- Operations ------------------------------------------------------------ // ---- Operations ------------------------------------------------------------
impl<T: Value, const R: usize> Add for Tensor<T, R> { impl<T: Value, const R: usize> Add for Tensor<T, R> {
@ -654,73 +648,6 @@ where
impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {} impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
// ---- Iterator --------------------------------------------------------------
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
index: TensorIndex<R>,
}
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
Self {
tensor,
index: tensor.shape.index_zero(),
}
}
}
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.index.is_overflow() {
return None;
}
let result = unsafe { self.tensor.get_unchecked(self.index) };
self.index.inc();
Some(result)
}
}
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
type Item = &'a T;
type IntoIter = TensorIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
TensorIterator::new(self)
}
}
// ---- Formatting ------------------------------------------------------------
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
// Print the current index and flat index
write!(
f,
"Current Index: {}, Flat Index: {}",
self.index,
self.index.flat()
)?;
// Print the tensor elements, highlighting the current element
write!(f, ", Tensor Elements: [")?;
for (i, elem) in self.tensor.buffer().iter().enumerate() {
if i == self.index.flat() {
write!(f, "*{}*", elem)?; // Highlight the current element
} else {
write!(f, "{}", elem)?;
}
if i < self.tensor.buffer().len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")
}
}
// ---- From ------------------------------------------------------------------ // ---- From ------------------------------------------------------------------
impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> { impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> {

View File

@ -1,22 +0,0 @@
use num::{Num, One, Zero};
use serde::{Deserialize, Serialize};
use std::{fmt::Display, iter::Sum};
/// A trait for types that can be used as values in a tensor.
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ Sum
{
}

View File

@ -157,105 +157,105 @@ fn test_index_dec_method() {
assert_eq!(index, TensorIndex::zero(shape)); assert_eq!(index, TensorIndex::zero(shape));
} }
#[test] // #[test]
fn test_axis_iterator() { // fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing // // Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); // let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0) // // Testing iteration over the first axis (axis = 0)
let axis = TensorAxis::new(&tensor, 0); // let axis = TensorAxis::new(&tensor, 0);
let mut axis_iter = axis.into_iter(); // let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0)); // assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0)); // assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0)); // assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0)); // assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1) // // Resetting the iterator for the second axis (axis = 1)
let axis = TensorAxis::new(&tensor, 1); // let axis = TensorAxis::new(&tensor, 1);
let mut axis_iter = axis.into_iter(); // let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0)); // assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0)); // assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0)); // assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0)); // assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape(); // let shape = tensor.shape();
let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into(); // let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into();
let b: TensorIndex<2> = (shape.clone(), [1, 1]).into(); // let b: TensorIndex<2> = (shape.clone(), [1, 1]).into();
while a <= b { // while a <= b {
println!("a: {}", a); // println!("a: {}", a);
a.inc(); // a.inc();
} // }
} // }
#[test] // #[test]
fn test_3d_tensor_axis_iteration() { // fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values // // Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity // // Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); // let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// TensorAxis 0 (Layer-wise): // // TensorAxis 0 (Layer-wise):
// // //
// t[0][0][0] = 1 // // t[0][0][0] = 1
// t[0][0][1] = 2 // // t[0][0][1] = 2
// t[0][1][0] = 3 // // t[0][1][0] = 3
// t[0][1][1] = 4 // // t[0][1][1] = 4
// t[1][0][0] = 5 // // t[1][0][0] = 5
// t[1][0][1] = 6 // // t[1][0][1] = 6
// t[1][1][0] = 7 // // t[1][1][0] = 7
// t[1][1][1] = 8 // // t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8] // // [1, 2, 3, 4, 5, 6, 7, 8]
// // //
// This order suggests that for each "layer" (first level of arrays), // // This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes // // the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second. // // the entire first layer, then moves to the second.
let a0 = TensorAxis::new(&t, 0); // let a0 = TensorAxis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>(); // let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); // assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// TensorAxis 1 (Row-wise within each layer): // // TensorAxis 1 (Row-wise within each layer):
// // //
// t[0][0][0] = 1 // // t[0][0][0] = 1
// t[0][0][1] = 2 // // t[0][0][1] = 2
// t[1][0][0] = 5 // // t[1][0][0] = 5
// t[1][0][1] = 6 // // t[1][0][1] = 6
// t[0][1][0] = 3 // // t[0][1][0] = 3
// t[0][1][1] = 4 // // t[0][1][1] = 4
// t[1][1][0] = 7 // // t[1][1][0] = 7
// t[1][1][1] = 8 // // t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8] // // [1, 2, 5, 6, 3, 4, 7, 8]
// // //
// This indicates that within each "layer", the iterator first // // This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row // // completes the first row across all layers, then the second row
// across all layers. // // across all layers.
let a1 = TensorAxis::new(&t, 1); // let a1 = TensorAxis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>(); // let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); // assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// TensorAxis 2 (Column-wise within each layer): // // TensorAxis 2 (Column-wise within each layer):
// // //
// t[0][0][0] = 1 // // t[0][0][0] = 1
// t[0][1][0] = 3 // // t[0][1][0] = 3
// t[1][0][0] = 5 // // t[1][0][0] = 5
// t[1][1][0] = 7 // // t[1][1][0] = 7
// t[0][0][1] = 2 // // t[0][0][1] = 2
// t[0][1][1] = 4 // // t[0][1][1] = 4
// t[1][0][1] = 6 // // t[1][0][1] = 6
// t[1][1][1] = 8 // // t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8] // // [1, 3, 5, 7, 2, 4, 6, 8]
// // //
// This indicates that within each "layer", the iterator first // // This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second // // completes the first column across all layers, then the second
// column across all layers. // // column across all layers.
let a2 = TensorAxis::new(&t, 2); // let a2 = TensorAxis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>(); // let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); // assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
} // }