From d19ce40494160393f4513a8fbbe5eb3af5a142c8 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Wed, 3 Jan 2024 21:28:01 +0200 Subject: [PATCH] Refactor and document Tensor-type - Add documentation to all methods exposed by the Tensor type. - Remove some tests and methods to simplify structure, some might be introduced back later. - Add elementwise operations. - Add doctests to Tensor. Signed-off-by: Julius Koskela --- src/axis.rs | 275 +-------------------------- src/index.rs | 8 +- src/lib.rs | 192 ++----------------- src/tensor.rs | 515 ++++++++++++++++++++++++++++++++++++-------------- 4 files changed, 391 insertions(+), 599 deletions(-) diff --git a/src/axis.rs b/src/axis.rs index d01f0f0..19c9bf9 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -113,277 +113,4 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> { fn into_iter(self) -> Self::IntoIter { TensorAxisIterator::new(&self) } -} - -pub fn contract< - 'a, - T: Value + std::fmt::Debug, - const R: usize, - const S: usize, - const N: usize, ->( - lhs: (&'a Tensor, [usize; N]), - rhs: (&'a Tensor, [usize; N]), -) -> Tensor -where - [(); R - N]:, - [(); S - N]:, - [(); R + S - 2 * N]:, -{ - let (lhs, la) = lhs; - let (rhs, ra) = rhs; - let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); - let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); - - let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); - let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); - - let mut shape = Vec::new(); - shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); - shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); - let shape: [usize; R + S - 2 * N] = - shape.try_into().expect("Failed to create shape array"); - - let shape = TensorShape::new(shape); - - let result = contract_axes(&lnc, &rnc); - - Tensor::new_with_buffer(shape, result) -} - -pub fn contract_axes< - 'a, - T: Value + std::fmt::Debug, - const R: usize, - const S: usize, - const N: usize, ->( - laxes: &'a [TensorAxis<'a, T, R>], - raxes: &'a [TensorAxis<'a, T, S>], -) -> Vec -where - [(); R - N]:, - [(); S - N]:, -{ - let mut result = vec![]; - - let axes = laxes.into_iter().zip(raxes); - - for (laxis, raxis) in axes { - let mut axes_result: Vec = vec![]; - for i in 0..raxis.len() { - for j in 0..laxis.len() { - let mut sum = T::zero(); - let llevel = laxis.into_iter(); - let rlevel = raxis.into_iter(); - let zip = llevel.level(j).zip(rlevel.level(i)); - for (lv, rv) in zip { - sum = sum + *lv * *rv; - } - axes_result.push(sum); - } - } - result.extend_from_slice(&axes_result); - } - - result -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tensor_contraction_simple() { - // Define two 2D tensors (matrices) - // Tensor A is 2x3 - let a: Tensor = Tensor::from([[1, 2], [3, 4]]); - - // Tensor B is 1x3x2 - let b: Tensor = Tensor::from([[1, 2], [3, 4]]); - - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); - assert_eq!(contracted_tensor.shape(), &TensorShape::new([2, 2])); - assert_eq!( - contracted_tensor.buffer(), - &[7, 10, 15, 22], - "Contracted tensor buffer does not match expected" - ); - } - - #[test] - fn test_tensor_contraction_23x32() { - // Define two 2D tensors (matrices) - - // Tensor A is 2x3 - let b: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); - println!("b: {}", b); - - // Tensor B is 3x2 - let a: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); - println!("a: {}", a); - - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); - - println!("contracted_tensor: {}", contracted_tensor); - assert_eq!(contracted_tensor.shape(), &TensorShape::new([3, 3])); - assert_eq!( - contracted_tensor.buffer(), - &[9, 12, 15, 19, 26, 33, 29, 40, 51], - "Contracted tensor buffer does not match expected" - ); - } - - #[test] - fn test_tensor_contraction_rank3() { - let a: Tensor = - Tensor::new_with_buffer(TensorShape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 - let b: Tensor = - Tensor::new_with_buffer(TensorShape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 - let contracted_tensor: Tensor = contract((&a, [2]), (&b, [0])); - - println!("a: {}", a); - println!("b: {}", b); - println!("contracted_tensor: {}", contracted_tensor); - // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); - // Verify specific elements of contracted_tensor - // assert_eq!(contracted_tensor[0][0][0][0], 50); - // assert_eq!(contracted_tensor[0][0][0][1], 60); - // ... further checks for other elements ... - } - - // #[test] - // fn test_axis_iterator_disassemble() { - // // Creating a 2x2 Tensor for testing - // let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); - - // // Testing iteration over the first axis (axis = 0) - // let axis = TensorAxis::new(&tensor, 0); - - // let mut axis_iter = axis.into_iter().disassemble(); - - // assert_eq!(axis_iter[0].next(), Some(&1.0)); - // assert_eq!(axis_iter[0].next(), Some(&2.0)); - // assert_eq!(axis_iter[0].next(), None); - // assert_eq!(axis_iter[1].next(), Some(&3.0)); - // assert_eq!(axis_iter[1].next(), Some(&4.0)); - // assert_eq!(axis_iter[1].next(), None); - - // // Resetting the iterator for the second axis (axis = 1) - // let axis = TensorAxis::new(&tensor, 1); - - // let mut axis_iter = axis.into_iter().disassemble(); - - // assert_eq!(axis_iter[0].next(), Some(&1.0)); - // assert_eq!(axis_iter[0].next(), Some(&3.0)); - // assert_eq!(axis_iter[0].next(), None); - // assert_eq!(axis_iter[1].next(), Some(&2.0)); - // assert_eq!(axis_iter[1].next(), Some(&4.0)); - // assert_eq!(axis_iter[1].next(), None); - // } - - #[test] - fn test_axis_iterator() { - // Creating a 2x2 Tensor for testing - let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); - - // Testing iteration over the first axis (axis = 0) - let axis = TensorAxis::new(&tensor, 0); - - let mut axis_iter = axis.into_iter(); - - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); - - // Resetting the iterator for the second axis (axis = 1) - let axis = TensorAxis::new(&tensor, 1); - - let mut axis_iter = axis.into_iter(); - - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); - - let shape = tensor.shape(); - - let mut a: TensorIndex<2> = (shape, [0, 0]).into(); - let b: TensorIndex<2> = (shape, [1, 1]).into(); - - while a <= b { - println!("a: {}", a); - a.inc(); - } - } - - #[test] - fn test_3d_tensor_axis_iteration() { - // Create a 3D Tensor with specific values - // Tensor shape is 2x2x2 for simplicity - let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - - // TensorAxis 0 (Layer-wise): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 3, 4, 5, 6, 7, 8] - // - // This order suggests that for each "layer" (first level of arrays), - // the iterator goes through all rows and columns. It first completes - // the entire first layer, then moves to the second. - - let a0 = TensorAxis::new(&t, 0); - let a0_order = a0.into_iter().cloned().collect::>(); - assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); - - // TensorAxis 1 (Row-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 5, 6, 3, 4, 7, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first row across all layers, then the second row - // across all layers. - - let a1 = TensorAxis::new(&t, 1); - let a1_order = a1.into_iter().cloned().collect::>(); - assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); - - // TensorAxis 2 (Column-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][1][0] = 3 - // t[1][0][0] = 5 - // t[1][1][0] = 7 - // t[0][0][1] = 2 - // t[0][1][1] = 4 - // t[1][0][1] = 6 - // t[1][1][1] = 8 - // [1, 3, 5, 7, 2, 4, 6, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first column across all layers, then the second - // column across all layers. - - let a2 = TensorAxis::new(&t, 2); - let a2_order = a2.into_iter().cloned().collect::>(); - assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); - } -} +} \ No newline at end of file diff --git a/src/index.rs b/src/index.rs index cc93a08..dc9bed4 100644 --- a/src/index.rs +++ b/src/index.rs @@ -24,7 +24,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> { shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); Self { indices: max_indices, - shape: shape, + shape, } } @@ -34,7 +34,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> { } Self { indices, - shape: shape, + shape, } } @@ -252,9 +252,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> { /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20. /// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33. pub fn flat(&self) -> usize { - self.indices + self.indices() .iter() - .zip(&self.shape.as_array()) + .zip(&self.shape().as_array()) .rev() .fold((0, 1), |(flat_index, product), (&idx, &dim_size)| { (flat_index + idx * product, product * dim_size) diff --git a/src/lib.rs b/src/lib.rs index 3407622..ecdaf91 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,184 +43,16 @@ macro_rules! tensor { }; } -// ---- Tests ---- - -#[cfg(test)] -mod tests { - use super::*; - use serde_json; - - #[test] - fn test_tensor_product() { - let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor - let mut tensor2 = Tensor::::from([2]); // 2-element vector - - // Fill tensors with some values - tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); - tensor2.buffer_mut().copy_from_slice(&[5, 6]); - - let product = tensor1.tensor_product(&tensor2); - - // Check shape of the resulting tensor - assert_eq!(*product.shape(), TensorShape::new([2, 2, 2])); - - // Check buffer of the resulting tensor - let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; - assert_eq!(product.buffer(), &expected_buffer); - } - - #[test] - fn serde_shape_serialization_test() { - // Create a shape instance - let shape: TensorShape<3> = [1, 2, 3].into(); - - // Serialize the shape to a JSON string - let serialized = - serde_json::to_string(&shape).expect("Failed to serialize"); - - // Deserialize the JSON string back into a shape - let deserialized: TensorShape<3> = - serde_json::from_str(&serialized).expect("Failed to deserialize"); - - // Check that the deserialized shape is equal to the original - assert_eq!(shape, deserialized); - } - - #[test] - fn tensor_serde_serialization_test() { - // Create an instance of Tensor - let tensor: Tensor = Tensor::new(TensorShape::new([2, 2])); - - // Serialize the Tensor to a JSON string - let serialized = - serde_json::to_string(&tensor).expect("Failed to serialize"); - - // Deserialize the JSON string back into a Tensor - let deserialized: Tensor = - serde_json::from_str(&serialized).expect("Failed to deserialize"); - - // Check that the deserialized Tensor is equal to the original - assert_eq!(tensor.buffer(), deserialized.buffer()); - assert_eq!(tensor.shape(), deserialized.shape()); - } - - #[test] - fn iterate_3d_tensor() { - let shape = TensorShape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 - let mut tensor = Tensor::new(shape); - let mut num = 0; - - // Fill the tensor with sequential numbers - for i in 0..2 { - for j in 0..2 { - for k in 0..2 { - tensor.buffer_mut()[i * 4 + j * 2 + k] = num; - num += 1; - } - } - } - - println!("{}", tensor); - - // Iterate over the tensor and check that the numbers are correct - - let mut iter = TensorIterator::new(&tensor); - - println!("{}", iter); - - assert_eq!(iter.next(), Some(&0)); - - assert_eq!(iter.next(), Some(&1)); - assert_eq!(iter.next(), Some(&2)); - assert_eq!(iter.next(), Some(&3)); - assert_eq!(iter.next(), Some(&4)); - assert_eq!(iter.next(), Some(&5)); - assert_eq!(iter.next(), Some(&6)); - assert_eq!(iter.next(), Some(&7)); - assert_eq!(iter.next(), None); - assert_eq!(iter.next(), None); - } - - #[test] - fn iterate_rank_4_tensor() { - // Define the shape of the rank-4 tensor (e.g., 2x2x2x2) - let shape = TensorShape::new([2, 2, 2, 2]); - let mut tensor = Tensor::new(shape); - let mut num = 0; - - // Fill the tensor with sequential numbers - for i in 0..tensor.len() { - tensor.buffer_mut()[i] = num; - num += 1; - } - - // Iterate over the tensor and check that the numbers are correct - let mut iter = TensorIterator::new(&tensor); - for expected_value in 0..tensor.len() { - assert_eq!(*iter.next().unwrap(), expected_value); - } - - // Ensure the iterator is exhausted - assert!(iter.next().is_none()); - } - - #[test] - fn test_dec_method() { - let shape = TensorShape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor - let mut index = TensorIndex::zero(&shape); - - // Increment the index to the maximum - for _ in 0..26 { - // 3 * 3 * 3 - 1 = 26 increments to reach the end - index.inc(); - } - - // Check if the index is at the maximum - assert_eq!(index, TensorIndex::new(&shape, [2, 2, 2])); - - // Decrement step by step and check the index - let expected_indices = [ - [2, 2, 2], - [2, 2, 1], - [2, 2, 0], - [2, 1, 2], - [2, 1, 1], - [2, 1, 0], - [2, 0, 2], - [2, 0, 1], - [2, 0, 0], - [1, 2, 2], - [1, 2, 1], - [1, 2, 0], - [1, 1, 2], - [1, 1, 1], - [1, 1, 0], - [1, 0, 2], - [1, 0, 1], - [1, 0, 0], - [0, 2, 2], - [0, 2, 1], - [0, 2, 0], - [0, 1, 2], - [0, 1, 1], - [0, 1, 0], - [0, 0, 2], - [0, 0, 1], - [0, 0, 0], - ]; - - for (i, &expected) in expected_indices.iter().enumerate() { - assert_eq!( - index, - TensorIndex::new(&shape, expected), - "Failed at index {}", - i - ); - index.dec(); - } - - // Finally, the index should reach [0, 0, 0] - index.dec(); - assert_eq!(index, TensorIndex::zero(&shape)); - } +#[macro_export] +macro_rules! shape { + ($array:expr) => { + TensorShape::from($array) + }; } + +#[macro_export] +macro_rules! index { + ($array:expr) => { + TensorIndex::from($array) + }; +} \ No newline at end of file diff --git a/src/tensor.rs b/src/tensor.rs index b788a6c..6377808 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -3,6 +3,16 @@ use crate::error::*; use getset::{Getters, MutGetters}; use std::fmt; +/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of +/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is +/// a matrix, and so on. +/// +/// ``` +/// use manifold::{tensor, Tensor}; +/// +/// let t = tensor!([[1, 2], [3, 4]]); +/// assert_eq!(t.rank(), 2); +/// ``` #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)] pub struct Tensor { #[getset(get = "pub", get_mut = "pub")] @@ -11,7 +21,18 @@ pub struct Tensor { shape: TensorShape, } +// ---- Construction and Initialization --------------------------------------- + impl Tensor { + /// Create a new tensor with the given shape. The rank of the tensor is determined by the shape + /// and all elements are initialized to zero. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let t = Tensor::::new([3, 3].into()); + /// assert_eq!(t.shape().as_array(), [3, 3]); + /// ``` pub fn new(shape: TensorShape) -> Self { // Handle rank 0 tensor (scalar) as a special case let total_size = if R == 0 { @@ -26,11 +47,326 @@ impl Tensor { Self { buffer, shape } } + /// Create a new tensor with the given shape and initialize it from the given buffer. The rank + /// of the tensor is determined by the shape. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let buffer = vec![1, 2, 3, 4, 5, 6]; + /// let t = Tensor::::new_with_buffer([2, 3].into(), buffer); + /// assert_eq!(t.shape().as_array(), [2, 3]); + /// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]); + /// ``` pub fn new_with_buffer(shape: TensorShape, buffer: Vec) -> Self { Self { buffer, shape } } +} - pub fn reshape(self, shape: TensorShape) -> Result { +// ---- Trivial Getters ------------------------------------------------------- + +impl Tensor { + pub fn rank(&self) -> usize { + R + } + + pub fn len(&self) -> usize { + self.buffer().len() + } +} + +// ---- Get Values ------------------------------------------------------------ + +impl Tensor { + /// Get a reference to a value at the given index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let i = (t.shape(), [1, 1]).into(); + /// assert_eq!(t.get(i), Some(&4)); + /// ``` + pub fn get(&self, index: TensorIndex) -> Option<&T> { + self.buffer().get(index.flat()) + } + + /// Get a reference to a value at the given index without bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let i = (t.shape(), [1, 1]).into(); + /// unsafe { assert_eq!(t.get_unchecked(i), &4); } + /// ``` + pub unsafe fn get_unchecked(&self, index: TensorIndex) -> &T { + self.buffer().get_unchecked(index.flat()) + } + + /// Get a mutable reference to a value at the given index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// let s = t.shape().clone(); + /// let i = (&s, [1, 1]).into(); + /// assert_eq!(t.get_mut(i), Some(&mut 4)); + /// ``` + pub fn get_mut(&mut self, index: TensorIndex) -> Option<&mut T> { + self.buffer_mut().get_mut(index.flat()) + } + + /// Get a mutable reference to a value at the given index without bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// let s = t.shape().clone(); + /// let i = (&s, [1, 1]).into(); + /// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); } + /// ``` + pub unsafe fn get_unchecked_mut( + &mut self, + index: TensorIndex, + ) -> &mut T { + self.buffer_mut().get_unchecked_mut(index.flat()) + } + + /// Get a reference to a value at the given flat index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// assert_eq!(t.get_flat(3), Some(&4)); + /// ``` + pub fn get_flat(&self, index: usize) -> Option<&T> { + self.buffer().get(index) + } + + /// Get a reference to a value at the given flat index without bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// unsafe { assert_eq!(t.get_flat_unchecked(3), &4); } + /// ``` + pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T { + self.buffer().get_unchecked(index) + } + + /// Get a mutable reference to a value at the given flat index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// assert_eq!(t.get_flat_mut(3), Some(&mut 4)); + /// ``` + pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> { + self.buffer_mut().get_mut(index) + } + + /// Get a mutable reference to a value at the given flat index without bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// unsafe { assert_eq!(t.get_flat_unchecked_mut(3), &mut 4); } + /// ``` + pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T { + self.buffer_mut().get_unchecked_mut(index) + } +} + +// ---- Arithmetic ------------------------------------------------------------ + +impl Tensor { + /// Elementwise operation on two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_for_each(&a, &b, &mut c, &|a, b| a * b).unwrap(); + /// assert_eq!(c, tensor!([[5, 12], [21, 32]])); + /// ``` + pub fn ew_for_each( + &self, + other: &Tensor, + result: &mut Tensor, + f: &dyn Fn(T, T) -> T, + ) -> Result<()> { + if self.shape() != other.shape() { + return Err(TensorError::InvalidArgument(format!( + "TensorShape mismatch: {:?} != {:?}", + self.shape(), + other.shape() + ))); + } else if self.shape() != result.shape() { + return Err(TensorError::InvalidArgument(format!( + "TensorShape mismatch: {:?} != {:?}", + self.shape(), + result.shape() + ))); + } + + for (i, (a, b)) in + self.buffer().iter().zip(other.buffer().iter()).enumerate() + { + unsafe { + *result.get_flat_unchecked_mut(i) = f(*a, *b); + } + } + + Ok(()) + } + + /// Elementwise multiplication of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_multiply(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[5, 12], [21, 32]])); + /// ``` + pub fn ew_multiply( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a * b) + } + + /// Elementwise addition of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_add(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[6, 8], [10, 12]])); + /// ``` + pub fn ew_add( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a + b) + } + + /// Elementwise subtraction of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_subtract(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[-4, -4], [-4, -4]])); + /// ``` + pub fn ew_subtract( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a - b) + } + + /// Elementwise division of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[2, 4], [8, 16]]); + /// let b = tensor!([[2, 2], [4, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_divide(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[1, 2], [2, 2]])); + /// ``` + pub fn ew_divide( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a / b) + } + + /// Elementwise modulo of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[2, 2], [3, 3]]); + /// let b = tensor!([[4, 4], [6, 9]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_modulo(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[2, 2], [3, 3]])); + /// ``` + pub fn ew_modulo( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a % b) + } + + // pub fn product( + // &self, + // other: &Tensor, + // ) -> Tensor { + // let mut new_shape_vec = Vec::new(); + // new_shape_vec.extend_from_slice(&self.shape().as_array()); + // new_shape_vec.extend_from_slice(&other.shape().as_array()); + + // let new_shape_array: [usize; R + S] = new_shape_vec + // .try_into() + // .expect("Failed to create shape array"); + + // let mut new_buffer = + // Vec::with_capacity(self.buffer.len() * other.buffer.len()); + // for &item_self in &self.buffer { + // for &item_other in &other.buffer { + // new_buffer.push(item_self * item_other); + // } + // } + + // Tensor { + // buffer: new_buffer, + // shape: TensorShape::new(new_shape_array), + // } + // } +} + +// ---- Reshape --------------------------------------------------------------- + +impl Tensor { + + /// Reshape the tensor to the given shape. The total size of the new shape must be the same as + /// the total size of the old shape. + /// + /// ``` + /// use manifold::{tensor, shape, Tensor, TensorShape}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let s = shape!([4]); + /// let t = t.reshape(s).unwrap(); + /// assert_eq!(t, tensor!([1, 2, 3, 4])); + /// ``` + pub fn reshape(self, shape: TensorShape) -> Result> { if self.shape().size() != shape.size() { let (ls, rs) = (self.shape().as_array(), shape.as_array()); let (lsize, rsize) = (self.shape().size(), shape.size()); @@ -38,13 +374,25 @@ impl Tensor { "TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", ))) } else { - Ok(Self { - buffer: self.buffer, - shape, - }) + Ok(Tensor::new_with_buffer(shape, self.buffer)) } } +} +// ---- Transpose ------------------------------------------------------------- + +impl Tensor { + + /// Transpose the tensor according to the given order. The order must be a permutation of the + /// tensor's axes. + /// + /// ``` + /// use manifold::{tensor, Tensor, TensorShape}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let t = t.transpose([1, 0]).unwrap(); + /// assert_eq!(t, tensor!([[1, 3], [2, 4]])); + /// ``` pub fn transpose(self, order: [usize; R]) -> Result { let buffer = TensorIndex::from(self.shape()) .iter_transposed(order) @@ -56,139 +404,9 @@ impl Tensor { shape: self.shape().reorder(order), }) } - - pub fn idx(&self) -> TensorIndex { - TensorIndex::from(self) - } - - pub fn axis<'a>(&'a self, axis: usize) -> TensorAxis<'a, T, R> { - TensorAxis::new(self, axis) - } - - pub fn get(&self, index: TensorIndex) -> Option<&T> { - self.buffer.get(index.flat()) - } - - pub unsafe fn get_unchecked(&self, index: TensorIndex) -> &T { - self.buffer.get_unchecked(index.flat()) - } - - pub fn get_mut(&mut self, index: TensorIndex) -> Option<&mut T> { - self.buffer.get_mut(index.flat()) - } - - pub unsafe fn get_unchecked_mut(&mut self, index: TensorIndex) -> &mut T { - self.buffer.get_unchecked_mut(index.flat()) - } - - pub fn get_flat(&self, index: usize) -> Option<&T> { - self.buffer.get(index) - } - - pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T { - self.buffer.get_unchecked(index) - } - - pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> { - self.buffer.get_mut(index) - } - - pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T { - self.buffer.get_unchecked_mut(index) - } - - pub fn rank(&self) -> usize { - R - } - - pub fn len(&self) -> usize { - self.buffer.len() - } - - pub fn iter(&self) -> TensorIterator { - TensorIterator::new(self) - } - - pub fn elementwise_multiply(&self, other: &Tensor) -> Tensor { - if self.shape != other.shape { - panic!("TensorShapes of tensors do not match"); - } - - let mut result_buffer = Vec::with_capacity(self.buffer.len()); - - for (a, b) in self.buffer.iter().zip(other.buffer.iter()) { - result_buffer.push(*a * *b); - } - - Tensor { - buffer: result_buffer, - shape: self.shape, - } - } - - pub fn tensor_product( - &self, - other: &Tensor, - ) -> Tensor { - let mut new_shape_vec = Vec::new(); - new_shape_vec.extend_from_slice(&self.shape.as_array()); - new_shape_vec.extend_from_slice(&other.shape.as_array()); - - let new_shape_array: [usize; R + S] = new_shape_vec - .try_into() - .expect("Failed to create shape array"); - - let mut new_buffer = - Vec::with_capacity(self.buffer.len() * other.buffer.len()); - for &item_self in &self.buffer { - for &item_other in &other.buffer { - new_buffer.push(item_self * item_other); - } - } - - Tensor { - buffer: new_buffer, - shape: TensorShape::new(new_shape_array), - } - } - - // Retrieve an element based on a specific axis and index - pub fn get_by_axis(&self, axis: usize, index: usize) -> Option { - // Convert axis and index to a flat index - let flat_index = self.axis_to_flat_index(axis, index); - if flat_index >= self.buffer.len() { - return None; - } - - Some(self.buffer[flat_index]) - } - - // Convert axis and index to a flat index in the buffer - fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize { - let mut flat_index = 0; - let mut stride = 1; - - // Ensure the given axis is within the tensor's dimensions - if axis >= R { - panic!("TensorAxis out of bounds"); - } - - // Calculate the stride for each dimension and accumulate the flat index - for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() { - println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); - if i > axis { - stride *= dim_size; - } else if i == axis { - flat_index += index * stride; - break; // We've reached the target axis - } - } - - flat_index - } } -// ---- Indexing ---- +// ---- Indexing -------------------------------------------------------------- impl<'a, T: Value, const R: usize> Index> for Tensor { type Output = T; @@ -198,7 +416,9 @@ impl<'a, T: Value, const R: usize> Index> for Tensor { } } -impl<'a, T: Value, const R: usize> IndexMut> for Tensor { +impl<'a, T: Value, const R: usize> IndexMut> + for Tensor +{ fn index_mut(&mut self, index: TensorIndex) -> &mut Self::Output { &mut self.buffer[index.flat()] } @@ -218,7 +438,7 @@ impl IndexMut for Tensor { } } -// ---- Display ---- +// ---- Display --------------------------------------------------------------- impl Tensor where @@ -256,7 +476,20 @@ where } } -// ---- Iterator ---- +// ---- Equality -------------------------------------------------------------- + +impl PartialEq for Tensor +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.shape == other.shape && self.buffer == other.buffer + } +} + +impl Eq for Tensor where T: Eq {} + +// ---- Iterator -------------------------------------------------------------- pub struct TensorIterator<'a, T: Value, const R: usize> { tensor: &'a Tensor, @@ -294,7 +527,7 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor { } } -// ---- Formatting ---- +// ---- Formatting ------------------------------------------------------------ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { @@ -323,7 +556,7 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { } } -// ---- From ---- +// ---- From ------------------------------------------------------------------ impl From> for Tensor { fn from(shape: TensorShape) -> Self {