🚀 Finish basic types and refactor (#17)
- Finished basic API's for types Tensor, TensorIndex, TensorShape, TensorError. - Make Axis a simple wrapper type used only for clarity. - Add documentation and doctests. Reviewed-on: #17 Co-authored-by: Julius Koskela <julius.koskela@unikie.com> Co-committed-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
9b53301513
commit
84e5cb256a
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -302,6 +302,7 @@ dependencies = [
|
||||
"ndarray",
|
||||
"num",
|
||||
"rand",
|
||||
"rayon",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"static_assertions",
|
||||
|
@ -19,6 +19,7 @@ bytemuck = "1.14.0"
|
||||
getset = "0.1.2"
|
||||
itertools = "0.12.0"
|
||||
num = "0.4.1"
|
||||
rayon = "1.8.0"
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
thiserror = "1.0.52"
|
||||
|
||||
|
@ -3,9 +3,11 @@ use criterion::{criterion_group, criterion_main, BenchmarkId, Criterion};
|
||||
use manifold::*;
|
||||
use rand::Rng;
|
||||
|
||||
const SIZE: usize = 10000;
|
||||
|
||||
fn random_tensor_r2_manifold() -> Tensor<f64, 2> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let mut tensor = tensor!([[0.0; 1000]; 1000]);
|
||||
let mut tensor = tensor!([[0.0; SIZE]; SIZE]);
|
||||
for i in 0..tensor.len() {
|
||||
tensor[i] = rng.gen();
|
||||
}
|
||||
@ -14,7 +16,7 @@ fn random_tensor_r2_manifold() -> Tensor<f64, 2> {
|
||||
|
||||
fn random_tensor_r2_ndarray() -> ndarray::Array2<f64> {
|
||||
let mut rng = rand::thread_rng();
|
||||
let (rows, cols) = (1000, 1000);
|
||||
let (rows, cols) = (SIZE, SIZE);
|
||||
let mut tensor = ndarray::Array2::<f64>::zeros((rows, cols));
|
||||
for i in 0..rows {
|
||||
for j in 0..cols {
|
||||
@ -40,7 +42,7 @@ fn tensor_product(c: &mut Criterion) {
|
||||
let a = random_tensor_r2_manifold();
|
||||
let b = random_tensor_r2_manifold();
|
||||
let c = a + b;
|
||||
assert!(c.shape().as_array() == &[1000, 1000]);
|
||||
assert!(c.shape().as_array() == &[SIZE, SIZE]);
|
||||
})
|
||||
},
|
||||
);
|
||||
@ -53,7 +55,7 @@ fn tensor_product(c: &mut Criterion) {
|
||||
let a = random_tensor_r2_ndarray();
|
||||
let b = random_tensor_r2_ndarray();
|
||||
let c = a + b;
|
||||
assert!(c.shape() == &[1000, 1000]);
|
||||
assert!(c.shape() == &[SIZE, SIZE]);
|
||||
})
|
||||
},
|
||||
);
|
||||
|
118
src/axis.rs
118
src/axis.rs
@ -1,118 +0,0 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
#[derive(Clone, Debug, Getters)]
|
||||
pub struct TensorAxis<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
tensor: &'a Tensor<T, R>,
|
||||
#[getset(get = "pub")]
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
||||
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
||||
assert!(dim < R, "TensorAxis out of bounds");
|
||||
Self { tensor, dim }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tensor.shape().get(self.dim)
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &TensorShape<R> {
|
||||
self.tensor.shape()
|
||||
}
|
||||
|
||||
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
|
||||
assert!(level < self.len(), "Level out of bounds");
|
||||
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
TensorAxisIterator::new(self)
|
||||
.set_start(level)
|
||||
.set_end(level + 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Getters, MutGetters)]
|
||||
pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
axis: &'a TensorAxis<'a, T, R>,
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
index: TensorIndex<R>,
|
||||
#[getset(get = "pub")]
|
||||
end: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
|
||||
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
|
||||
Self {
|
||||
axis,
|
||||
index: TensorIndex::new(axis.shape().clone(), [0; R]),
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_start(self, start: usize) -> Self {
|
||||
assert!(start < self.axis().len(), "Start out of bounds");
|
||||
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
|
||||
index.set_axis(self.axis.dim, start);
|
||||
Self {
|
||||
axis: self.axis(),
|
||||
index,
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_end(self, end: usize) -> Self {
|
||||
assert!(end <= self.axis().len(), "End out of bounds");
|
||||
Self {
|
||||
axis: self.axis(),
|
||||
index: self.index().clone(),
|
||||
end: Some(end),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_level(self, level: usize) -> Self {
|
||||
assert!(level < self.axis().len(), "Level out of bounds");
|
||||
self.set_start(level).set_end(level + 1)
|
||||
}
|
||||
|
||||
pub fn level(&'a self, level: usize) -> impl Iterator<Item = &'a T> + 'a {
|
||||
Self::new(self.axis()).set_level(level)
|
||||
}
|
||||
|
||||
pub fn axis_max_idx(&self) -> usize {
|
||||
self.end().unwrap_or(self.axis().len())
|
||||
}
|
||||
|
||||
pub fn axis_idx(&self) -> usize {
|
||||
self.index().get_axis(*self.axis().dim())
|
||||
}
|
||||
|
||||
pub fn axis_dim(&self) -> usize {
|
||||
self.axis().dim().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
let result = unsafe { self.axis().tensor().get_unchecked(self.index) };
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = TensorAxisIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorAxisIterator::new(&self)
|
||||
}
|
||||
}
|
548
src/index.rs
548
src/index.rs
@ -2,9 +2,18 @@ use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::{
|
||||
cmp::Ordering,
|
||||
ops::{Add, Index, IndexMut, Sub},
|
||||
ops::{Index, IndexMut},
|
||||
};
|
||||
|
||||
/// A multi-dimensional index into a tensor.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
pub struct TensorIndex<const R: usize> {
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
@ -16,13 +25,30 @@ pub struct TensorIndex<const R: usize> {
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Creates a new `TensorIndex` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
let index = Self { indices, shape };
|
||||
assert!(index.check_indices(indices));
|
||||
index
|
||||
}
|
||||
|
||||
/// Creates a new `TensorIndex` instance with all indices set to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
pub const fn zero(shape: TensorShape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
@ -30,106 +56,185 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn last(shape: TensorShape<R>) -> Self {
|
||||
let max_indices =
|
||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||
Self {
|
||||
indices: max_indices,
|
||||
shape,
|
||||
pub fn from_flat(shape: TensorShape<R>, flat_index: usize) -> Self {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
// - The method iterates over the dimensions of the tensor in reverse
|
||||
// order (assuming row-major order).
|
||||
// - In each iteration, it uses the modulo operation to find the index
|
||||
// in the current dimension and integer division to reduce the flat
|
||||
// index for the next higher dimension.
|
||||
// - This process is repeated for each dimension to build the
|
||||
// multi-dimensional index.
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
// Reverse the indices to match the original dimension order
|
||||
indices.reverse();
|
||||
Self::new(shape.clone(), indices)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Trivial Functions -----------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.indices.iter().all(|&i| i == 0)
|
||||
}
|
||||
|
||||
pub fn is_overflow(&self) -> bool {
|
||||
// Check if the last index is equal to the size of the last dimension
|
||||
self.indices[0] >= self.shape.get(R - 1)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.indices = [0; R];
|
||||
}
|
||||
|
||||
/// Increments the index and returns a boolean indicating whether the end
|
||||
/// has been reached.
|
||||
/// Returns `true` if all indices are zero.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
/// `false` if it overflows, indicating the end of the tensor.
|
||||
pub fn inc(&mut self) -> bool {
|
||||
if self.indices()[0] >= self.shape().get(0) {
|
||||
return false;
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert!(index.is_zero());
|
||||
/// ```
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.indices().iter().all(|&i| i == 0)
|
||||
}
|
||||
|
||||
/// Returns `true` if the last index is equal to the size of the last
|
||||
/// dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([1, 1]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert!(!index.is_max());
|
||||
/// index.inc();
|
||||
/// assert!(index.is_max());
|
||||
/// ```
|
||||
pub fn is_max(&self) -> bool {
|
||||
self.indices().get(0) == self.shape().dim_size(0)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Utils -----------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Resets the index to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.reset();
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
pub fn reset(&mut self) {
|
||||
*self.indices_mut() = [0; R];
|
||||
}
|
||||
|
||||
fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.shape().as_array().iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Increment and Decrement -----------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Increments the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc();
|
||||
/// assert_eq!(index.flat(), 1);
|
||||
/// ```
|
||||
pub fn inc(&mut self) {
|
||||
if self.indices().get(0) == self.shape().dim_size(0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let shape = self.shape().as_array().clone();
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
if *i >= dim_size {
|
||||
*i = 0; // Reset index in this dimension and carry over
|
||||
} else {
|
||||
carry = 0; // Increment successful, no carry needed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If carry is still 1 after the loop, it means we've incremented past
|
||||
// the last dimension
|
||||
if carry == 1 {
|
||||
// Set the index to an invalid state to indicate the end of the
|
||||
// iteration indicated by setting the first index to the
|
||||
// size of the first dimension
|
||||
self.indices[0] = self.shape.as_array()[0];
|
||||
return true; // Indicate that the iteration is complete
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// fn inc_axis
|
||||
|
||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||
assert!(fixed_axis < R, "TensorAxis out of bounds");
|
||||
assert!(
|
||||
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
|
||||
"Index out of bounds"
|
||||
);
|
||||
|
||||
// Try to increment non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] + 1 < self.shape.get(i) {
|
||||
self.indices[i] += 1;
|
||||
return;
|
||||
} else {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
|
||||
self.indices[fixed_axis] += 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return;
|
||||
}
|
||||
*dim = 0;
|
||||
}
|
||||
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
|
||||
if self.indices()[order[0]] >= self.shape().get(order[0]) {
|
||||
return false;
|
||||
self.indices_mut()[0] = *self.shape().dim_size(0).unwrap();
|
||||
}
|
||||
|
||||
/// Increments the index by one, with the specified axis fixed,
|
||||
/// i.e. the index of the fixed axis is incremented only if the
|
||||
/// index of the other axes reaches the maximum.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_fixed_axis(Axis(0));
|
||||
/// assert_eq!(index.flat(), 4);
|
||||
/// index.inc_fixed_axis(Axis(1));
|
||||
/// assert_eq!(index.flat(), 5);
|
||||
/// ```
|
||||
pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) {
|
||||
let shape = self.shape().as_array().clone();
|
||||
assert!(ax < R, "TensorAxis out of bounds");
|
||||
if self.indices().get(ax) == self.shape().dim_size(ax) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Iterate over all axes, skipping the fixed axis 'ax'
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
if i != ax {
|
||||
// Skip the fixed axis
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return; // No carry over needed
|
||||
}
|
||||
*dim = 0; // Reset the current axis and carry over to the next
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the case where incrementing has reached the end
|
||||
if self.indices().get(ax) < self.shape().dim_size(ax) {
|
||||
self.indices_mut()[ax] += 1;
|
||||
} else {
|
||||
// Reset if the fixed axis also overflows
|
||||
self.indices_mut()[ax] = *self.shape().dim_size(ax).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments the index by one, reordering the order in which the
|
||||
/// axes are incremented.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_transposed(&[2, 1, 0]);
|
||||
/// assert_eq!(index.flat(), 12);
|
||||
/// index.inc_transposed(&[0, 1, 2]);
|
||||
/// assert_eq!(index.flat(), 13);
|
||||
/// ```
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) {
|
||||
if self.indices().get(order[0]) == self.shape().dim_size(order[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut carry = 1;
|
||||
|
||||
for i in order.iter().rev() {
|
||||
let dim_size = self.shape().get(*i);
|
||||
let dim_size = self.shape().dim_size(*i).unwrap().clone();
|
||||
let i = self.index_mut(*i);
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
@ -142,13 +247,21 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
|
||||
if carry == 1 {
|
||||
self.indices_mut()[order[0]] = self.shape().get(order[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
self.indices_mut()[order[0]] = *self.shape().dim_size(order[0]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Decrements the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.dec();
|
||||
/// assert_eq!(index.flat(), 22);
|
||||
/// ```
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -169,102 +282,22 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
|
||||
// Check if the fixed axis index is already in an invalid state
|
||||
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to decrement non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] > 0 {
|
||||
self.indices[i] -= 1;
|
||||
return true;
|
||||
} else {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the fixed axis if possible and reset other axes to their
|
||||
// max
|
||||
if self.indices[fixed_axis] > 0 {
|
||||
self.indices[fixed_axis] -= 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fixed axis already at minimum, set to invalid state
|
||||
self.indices[fixed_axis] = self.shape.get(fixed_axis);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn dec_transposed(&mut self, order: [usize; R]) {
|
||||
// Iterate over the axes in the specified order
|
||||
for &axis in &order {
|
||||
// Try to decrement the current axis
|
||||
if self.indices[axis] > 0 {
|
||||
self.indices[axis] -= 1;
|
||||
// Reset all preceding axes in the order to their maximum
|
||||
for &prev_axis in &order {
|
||||
if prev_axis == axis {
|
||||
break;
|
||||
}
|
||||
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no axis can be decremented, set the first axis in the order to
|
||||
// indicate overflow
|
||||
self.indices[order[0]] = self.shape.get(order[0]);
|
||||
}
|
||||
// ---- Conversion to Flat Index ----------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the
|
||||
/// multi-dimensional index stored in `self.indices`, given the shape of
|
||||
/// the tensor stored in `self.shape`. The calculation is based on the
|
||||
/// assumption that the tensor is stored in row-major order,
|
||||
/// where the last dimension varies the fastest.
|
||||
/// # Examples
|
||||
///
|
||||
/// # Returns
|
||||
/// The flat index corresponding to the multi-dimensional index.
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over each pair of corresponding index and
|
||||
/// dimension size, starting from the last dimension (hence `rev()` is
|
||||
/// used for reverse iteration).
|
||||
/// - In each iteration, it performs two main operations:
|
||||
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
|
||||
/// running product of dimension sizes (`product`). This calculates the
|
||||
/// contribution of the current index to the overall flat index.
|
||||
/// 2. **Product Update**: Multiplies `product` by the current dimension
|
||||
/// size (`dim_size`). This updates `product` for the next iteration,
|
||||
/// as each dimension contributes to the flat index based on the sizes
|
||||
/// of all preceding dimensions.
|
||||
/// - The `fold` operation accumulates these results, starting with an
|
||||
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
|
||||
/// is the initial product.
|
||||
/// - The final flat index is obtained after the last iteration, which is
|
||||
/// the first element of the tuple resulting from the `fold`.
|
||||
///
|
||||
/// # Example
|
||||
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
||||
/// - Starting with a flat index of 0 and a product of 1,
|
||||
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
|
||||
/// the product to 1 * 5 = 5.
|
||||
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
|
||||
/// the product to 5 * 4 = 20.
|
||||
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
|
||||
/// final flat index is 3 + 10 + 20 = 33.
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
pub fn flat(&self) -> usize {
|
||||
self.indices()
|
||||
.iter()
|
||||
@ -275,37 +308,9 @@ impl<const R: usize> TensorIndex<R> {
|
||||
})
|
||||
.0
|
||||
}
|
||||
|
||||
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
||||
assert!(axis < R, "TensorAxis out of bounds");
|
||||
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
||||
self.indices[axis] = value;
|
||||
}
|
||||
|
||||
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
||||
assert!(axis < R, "TensorAxis out of bounds");
|
||||
if value < self.shape.get(axis) {
|
||||
self.indices[axis] = value;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_axis(&self, axis: usize) -> usize {
|
||||
assert!(axis < R, "TensorAxis out of bounds");
|
||||
self.indices[axis]
|
||||
}
|
||||
|
||||
pub fn iter_transposed(
|
||||
&self,
|
||||
order: [usize; R],
|
||||
) -> TensorIndexTransposedIterator<R> {
|
||||
TensorIndexTransposedIterator::new(self.shape().clone(), order)
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
// --- Equality and Ordering --------------------------------------------------
|
||||
|
||||
impl<const R: usize> PartialEq for TensorIndex<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
@ -327,6 +332,8 @@ impl<const R: usize> Ord for TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Indexing --------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> Index<usize> for TensorIndex<R> {
|
||||
type Output = usize;
|
||||
|
||||
@ -341,17 +348,18 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Conversion to TensorIndex ---------------------------------------------
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
|
||||
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
// assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
|
||||
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
Self::from_flat(shape, flat_index)
|
||||
}
|
||||
}
|
||||
|
||||
@ -367,6 +375,8 @@ impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Display ---------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> std::fmt::Display for TensorIndex<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
@ -384,113 +394,3 @@ impl<const R: usize> std::fmt::Display for TensorIndex<R> {
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Arithmetic Operations ----
|
||||
|
||||
impl<const R: usize> Add for TensorIndex<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
|
||||
|
||||
let mut result_indices = [0; R];
|
||||
for i in 0..R {
|
||||
result_indices[i] = self.indices[i] + rhs.indices[i];
|
||||
}
|
||||
|
||||
Self {
|
||||
indices: result_indices,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Sub for TensorIndex<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
|
||||
|
||||
let mut result_indices = [0; R];
|
||||
for i in 0..R {
|
||||
result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]);
|
||||
}
|
||||
|
||||
Self {
|
||||
indices: result_indices,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Iterator ----
|
||||
|
||||
pub struct TensorIndexIterator<const R: usize> {
|
||||
current: TensorIndex<R>,
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<const R: usize> TensorIndexIterator<R> {
|
||||
pub fn new(shape: TensorShape<R>) -> Self {
|
||||
Self {
|
||||
current: TensorIndex::zero(shape),
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Iterator for TensorIndexIterator<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> IntoIterator for TensorIndex<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
type IntoIter = TensorIndexIterator<R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorIndexIterator {
|
||||
current: self,
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TensorIndexTransposedIterator<const R: usize> {
|
||||
current: TensorIndex<R>,
|
||||
order: [usize; R],
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<const R: usize> TensorIndexTransposedIterator<R> {
|
||||
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
|
||||
Self {
|
||||
current: TensorIndex::zero(shape),
|
||||
end: false,
|
||||
order,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc_transposed(&self.order);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
71
src/iterators.rs
Normal file
71
src/iterators.rs
Normal file
@ -0,0 +1,71 @@
|
||||
use super::*;
|
||||
use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
|
||||
// ---- Iterator --------------------------------------------------------------
|
||||
|
||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
index: TensorIndex<R>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self {
|
||||
tensor,
|
||||
index: tensor.shape().index_zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index.is_max() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = unsafe { self.tensor.get_unchecked(self.index) };
|
||||
self.index.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = TensorIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
// Print the current index and flat index
|
||||
write!(
|
||||
f,
|
||||
"Current Index: {}, Flat Index: {}",
|
||||
self.index,
|
||||
self.index.flat()
|
||||
)?;
|
||||
|
||||
// Print the tensor elements, highlighting the current element
|
||||
write!(f, ", Tensor Elements: [")?;
|
||||
for (i, elem) in self.tensor.buffer().iter().enumerate() {
|
||||
if i == self.index.flat() {
|
||||
write!(f, "*{}*", elem)?; // Highlight the current element
|
||||
} else {
|
||||
write!(f, "{}", elem)?;
|
||||
}
|
||||
if i < self.tensor.buffer().len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Axis Iterator ---------------------------------------------------------
|
||||
|
||||
// ---- Transposed Iterator ---------------------------------------------------
|
28
src/lib.rs
28
src/lib.rs
@ -2,14 +2,36 @@
|
||||
#![feature(generic_const_exprs)]
|
||||
#![warn(clippy::all)]
|
||||
|
||||
pub mod axis;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod iterators;
|
||||
pub mod shape;
|
||||
pub mod tensor;
|
||||
pub mod value;
|
||||
|
||||
pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*};
|
||||
pub use {error::*, index::*, iterators::*, shape::*, tensor::*};
|
||||
|
||||
use num::{Num, One, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Display, iter::Sum};
|
||||
|
||||
/// A trait for types that can be used as values in a tensor.
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> Value for T where
|
||||
T: Num
|
||||
+ Zero
|
||||
+ One
|
||||
+ Copy
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ Sum
|
||||
{
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
|
262
src/shape.rs
262
src/shape.rs
@ -4,18 +4,90 @@ use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||
use std::fmt::{Formatter, Result as FmtResult};
|
||||
|
||||
/// A tensor's shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = shape!([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorShape<const R: usize>([usize; R]);
|
||||
pub struct TensorShape<const R: usize>(pub(crate) [usize; R]);
|
||||
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Creates a new `TensorShape` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
pub const fn new(shape: [usize; R]) -> Self {
|
||||
Self(shape)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn axis(&self, index: usize) -> Option<&usize> {
|
||||
// ---- Getters ---------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Get the size of the specified dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0).unwrap(), &2);
|
||||
/// assert_eq!(shape.dim_size(1).unwrap(), &3);
|
||||
/// ```
|
||||
pub fn dim_size(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
/// Returns the shape as an array.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.as_array(), &[2, 3]);
|
||||
/// ```
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Returns the size of the shape, meaning the product of all dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.size(), 6);
|
||||
/// ```
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Manipulation ----------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Reorders the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.reorder([2, 0, 1]);
|
||||
/// assert_eq!(new_shape, shape!([4, 2, 3]));
|
||||
/// ```
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = TensorShape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
@ -24,78 +96,15 @@ impl<const R: usize> TensorShape<R> {
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub const fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
||||
pub fn flat_max(&self) -> usize {
|
||||
self.size() - 1
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
pub const fn get(&self, index: usize) -> usize {
|
||||
self.0[index]
|
||||
}
|
||||
|
||||
pub fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.0.iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
|
||||
/// Converts a flat index to a multi-dimensional index.
|
||||
/// Creates a new shape by removing the specified dimensions.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `flat_index` - The flat index to convert.
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// # Returns
|
||||
/// An `TensorIndex<R>` instance representing the multi-dimensional index
|
||||
/// corresponding to the flat index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over the dimensions of the tensor in reverse order
|
||||
/// (assuming row-major order).
|
||||
/// - In each iteration, it uses the modulo operation to find the index in
|
||||
/// the current dimension and integer division to reduce the flat index
|
||||
/// for the next higher dimension.
|
||||
/// - This process is repeated for each dimension to build the
|
||||
/// multi-dimensional index.
|
||||
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||
TensorIndex::new(self.clone(), indices)
|
||||
}
|
||||
|
||||
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self.clone())
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> TensorIndex<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
TensorIndex::new(self.clone(), max_indices)
|
||||
}
|
||||
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.remove_dims([0, 2]);
|
||||
/// assert_eq!(new_shape, shape!([3]));
|
||||
/// ```
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
&self,
|
||||
dims_to_remove: [usize; NAX],
|
||||
@ -118,34 +127,70 @@ impl<const R: usize> TensorShape<R> {
|
||||
|
||||
TensorShape(new_shape)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||
&self,
|
||||
axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
|
||||
) -> TensorShape<{ R - NAX }> {
|
||||
// Create a new array to store the remaining dimensions
|
||||
let mut new_shape = [0; R - NAX];
|
||||
let mut new_index = 0;
|
||||
// ---- Iterators -------------------------------------------------------------
|
||||
|
||||
// Iterate over the original dimensions
|
||||
for (index, &dim) in self.0.iter().enumerate() {
|
||||
// Skip dimensions that are in the axes_to_remove array
|
||||
for axis in axes_to_remove {
|
||||
if *axis.dim() == index {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the dimension to the new shape array
|
||||
new_shape[new_index] = dim;
|
||||
new_index += 1;
|
||||
}
|
||||
|
||||
TensorShape(new_shape)
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Returns an iterator over the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// let mut iter = shape.iter();
|
||||
/// assert_eq!(iter.next(), Some(&2));
|
||||
/// assert_eq!(iter.next(), Some(&3));
|
||||
/// assert_eq!(iter.next(), None);
|
||||
/// ```
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Serialize and Deserialize ----
|
||||
// ---- Utils -----------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self.clone())
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> TensorIndex<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
TensorIndex::new(self.clone(), max_indices)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- From ------------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
Self::new(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
|
||||
where
|
||||
T: Value,
|
||||
{
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
*tensor.shape()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Equality --------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> PartialEq for TensorShape<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for TensorShape<R> {}
|
||||
|
||||
// ---- Serialization ---------------------------------------------------------
|
||||
|
||||
struct TensorShapeVisitor<const R: usize>;
|
||||
|
||||
@ -191,30 +236,3 @@ impl<const R: usize> Serialize for TensorShape<R> {
|
||||
seq.end()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Blanket Implementations ----
|
||||
|
||||
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
Self::new(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> PartialEq for TensorShape<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for TensorShape<R> {}
|
||||
|
||||
// ---- From and Into Implementations ----
|
||||
|
||||
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
|
||||
where
|
||||
T: Value,
|
||||
{
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
*tensor.shape()
|
||||
}
|
||||
}
|
||||
|
159
src/tensor.rs
159
src/tensor.rs
@ -28,6 +28,9 @@ pub struct Tensor<T, const R: usize> {
|
||||
shape: TensorShape<R>,
|
||||
}
|
||||
|
||||
/// A type that represents an axis of a tensor.
|
||||
pub struct Axis(pub usize);
|
||||
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
@ -62,22 +65,49 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let buffer = vec![1, 2, 3, 4, 5, 6];
|
||||
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer);
|
||||
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer).unwrap();
|
||||
/// assert_eq!(t.shape().as_array(), &[2, 3]);
|
||||
/// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]);
|
||||
/// ```
|
||||
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
pub fn new_with_buffer(
|
||||
shape: TensorShape<R>,
|
||||
buffer: Vec<T>,
|
||||
) -> Result<Self> {
|
||||
if buffer.len() != shape.size() {
|
||||
Err(TensorError::InvalidArgument(format!(
|
||||
"Provided buffer has length {} but shape has size {}",
|
||||
buffer.len(),
|
||||
shape.size()
|
||||
)))
|
||||
} else {
|
||||
Ok(Self { buffer, shape })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Trivial Getters -------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
pub fn rank(&self) -> usize {
|
||||
/// Get the rank of the tensor.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.rank(), 2);
|
||||
/// ```
|
||||
pub const fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
||||
/// Get the length of the tensor's buffer.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.len(), 9);
|
||||
/// ```
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer().len()
|
||||
}
|
||||
@ -332,32 +362,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
) -> Result<()> {
|
||||
self.ew_for_each(other, result, &|a, b| a % b)
|
||||
}
|
||||
|
||||
// pub fn product<const S: usize>(
|
||||
// &self,
|
||||
// other: &Tensor<T, S>,
|
||||
// ) -> Tensor<T, { R + S }> {
|
||||
// let mut new_shape_vec = Vec::new();
|
||||
// new_shape_vec.extend_from_slice(&self.shape().as_array());
|
||||
// new_shape_vec.extend_from_slice(&other.shape().as_array());
|
||||
|
||||
// let new_shape_array: [usize; R + S] = new_shape_vec
|
||||
// .try_into()
|
||||
// .expect("Failed to create shape array");
|
||||
|
||||
// let mut new_buffer =
|
||||
// Vec::with_capacity(self.buffer.len() * other.buffer.len());
|
||||
// for &item_self in &self.buffer {
|
||||
// for &item_other in &other.buffer {
|
||||
// new_buffer.push(item_self * item_other);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Tensor {
|
||||
// buffer: new_buffer,
|
||||
// shape: TensorShape::new(new_shape_array),
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// ---- Reshape ---------------------------------------------------------------
|
||||
@ -385,37 +389,11 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
||||
)))
|
||||
} else {
|
||||
Ok(Tensor::new_with_buffer(shape, self.buffer))
|
||||
Ok(Tensor::new_with_buffer(shape, self.buffer).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Transpose -------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// Transpose the tensor according to the given order. The order must be a
|
||||
/// permutation of the tensor's axes.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor, TensorShape};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let t = t.transpose([1, 0]).unwrap();
|
||||
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
||||
/// ```
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
let buffer = TensorIndex::from(self.shape().clone())
|
||||
.iter_transposed(order)
|
||||
.map(|index| self.get(index).unwrap().clone())
|
||||
.collect();
|
||||
|
||||
Ok(Tensor {
|
||||
buffer,
|
||||
shape: self.shape().reorder(order),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Operations ------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Add for Tensor<T, R> {
|
||||
@ -654,73 +632,6 @@ where
|
||||
|
||||
impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
|
||||
|
||||
// ---- Iterator --------------------------------------------------------------
|
||||
|
||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
index: TensorIndex<R>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self {
|
||||
tensor,
|
||||
index: tensor.shape.index_zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index.is_overflow() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = unsafe { self.tensor.get_unchecked(self.index) };
|
||||
self.index.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = TensorIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Formatting ------------------------------------------------------------
|
||||
|
||||
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
// Print the current index and flat index
|
||||
write!(
|
||||
f,
|
||||
"Current Index: {}, Flat Index: {}",
|
||||
self.index,
|
||||
self.index.flat()
|
||||
)?;
|
||||
|
||||
// Print the tensor elements, highlighting the current element
|
||||
write!(f, ", Tensor Elements: [")?;
|
||||
for (i, elem) in self.tensor.buffer().iter().enumerate() {
|
||||
if i == self.index.flat() {
|
||||
write!(f, "*{}*", elem)?; // Highlight the current element
|
||||
} else {
|
||||
write!(f, "{}", elem)?;
|
||||
}
|
||||
if i < self.tensor.buffer().len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- From ------------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> {
|
||||
|
22
src/value.rs
22
src/value.rs
@ -1,22 +0,0 @@
|
||||
use num::{Num, One, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{fmt::Display, iter::Sum};
|
||||
|
||||
/// A trait for types that can be used as values in a tensor.
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> Value for T where
|
||||
T: Num
|
||||
+ Zero
|
||||
+ One
|
||||
+ Copy
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ Sum
|
||||
{
|
||||
}
|
@ -53,14 +53,10 @@ fn test_iterating_3d_tensor() {
|
||||
}
|
||||
}
|
||||
|
||||
println!("{}", tensor);
|
||||
|
||||
// Iterate over the tensor and check that the numbers are correct
|
||||
|
||||
let mut iter = TensorIterator::new(&tensor);
|
||||
|
||||
println!("{}", iter);
|
||||
|
||||
assert_eq!(iter.next(), Some(&0));
|
||||
|
||||
assert_eq!(iter.next(), Some(&1));
|
||||
@ -157,105 +153,105 @@ fn test_index_dec_method() {
|
||||
assert_eq!(index, TensorIndex::zero(shape));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
// #[test]
|
||||
// fn test_axis_iterator() {
|
||||
// // Creating a 2x2 Tensor for testing
|
||||
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = TensorAxis::new(&tensor, 0);
|
||||
// // Testing iteration over the first axis (axis = 0)
|
||||
// let axis = TensorAxis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
// let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = TensorAxis::new(&tensor, 1);
|
||||
// // Resetting the iterator for the second axis (axis = 1)
|
||||
// let axis = TensorAxis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
// let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
// let shape = tensor.shape();
|
||||
|
||||
let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into();
|
||||
let b: TensorIndex<2> = (shape.clone(), [1, 1]).into();
|
||||
// let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into();
|
||||
// let b: TensorIndex<2> = (shape.clone(), [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
// while a <= b {
|
||||
// println!("a: {}", a);
|
||||
// a.inc();
|
||||
// }
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_3d_tensor_axis_iteration() {
|
||||
// Create a 3D Tensor with specific values
|
||||
// Tensor shape is 2x2x2 for simplicity
|
||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
// #[test]
|
||||
// fn test_3d_tensor_axis_iteration() {
|
||||
// // Create a 3D Tensor with specific values
|
||||
// // Tensor shape is 2x2x2 for simplicity
|
||||
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// TensorAxis 0 (Layer-wise):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
//
|
||||
// This order suggests that for each "layer" (first level of arrays),
|
||||
// the iterator goes through all rows and columns. It first completes
|
||||
// the entire first layer, then moves to the second.
|
||||
// // TensorAxis 0 (Layer-wise):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
// //
|
||||
// // This order suggests that for each "layer" (first level of arrays),
|
||||
// // the iterator goes through all rows and columns. It first completes
|
||||
// // the entire first layer, then moves to the second.
|
||||
|
||||
let a0 = TensorAxis::new(&t, 0);
|
||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
// let a0 = TensorAxis::new(&t, 0);
|
||||
// let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// TensorAxis 1 (Row-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first row across all layers, then the second row
|
||||
// across all layers.
|
||||
// // TensorAxis 1 (Row-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first row across all layers, then the second row
|
||||
// // across all layers.
|
||||
|
||||
let a1 = TensorAxis::new(&t, 1);
|
||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
// let a1 = TensorAxis::new(&t, 1);
|
||||
// let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// TensorAxis 2 (Column-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][1][0] = 3
|
||||
// t[1][0][0] = 5
|
||||
// t[1][1][0] = 7
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][1] = 8
|
||||
// [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first column across all layers, then the second
|
||||
// column across all layers.
|
||||
// // TensorAxis 2 (Column-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][1][0] = 3
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][1][0] = 7
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first column across all layers, then the second
|
||||
// // column across all layers.
|
||||
|
||||
let a2 = TensorAxis::new(&t, 2);
|
||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
}
|
||||
// let a2 = TensorAxis::new(&t, 2);
|
||||
// let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
// }
|
||||
|
Loading…
Reference in New Issue
Block a user