From eb1ca20158c21f2791783161cf8cf9cef0400938 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Wed, 3 Jan 2024 21:52:53 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20Implement=20Tensor-type=20and=20?= =?UTF-8?q?basic=20methods=20(#15)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Basic implementation of the tensor-type. - Basic implementations for the Tensor-type - Documentation and doctests - Various refactoring - Small corrections to other types Reviewed-on: https://nordic-dev.net/julius/manifold/pulls/15 Co-authored-by: Julius Koskela Co-committed-by: Julius Koskela --- README.md | 23 +- docs/tensor-contraction.md | 34 --- docs/tensor-operations.md | 239 --------------- examples/operations.rs | 94 ------ rustfmt.toml | 2 + src/axis.rs | 311 ++------------------ src/error.rs | 4 +- src/index.rs | 201 +++++++------ src/lib.rs | 228 ++------------- src/shape.rs | 75 ++--- src/tensor.rs | 578 ++++++++++++++++++++++++++----------- src/value.rs | 25 ++ 12 files changed, 631 insertions(+), 1183 deletions(-) delete mode 100644 docs/tensor-contraction.md delete mode 100644 docs/tensor-operations.md delete mode 100644 examples/operations.rs create mode 100644 src/value.rs diff --git a/README.md b/README.md index afbb613..25a8e76 100644 --- a/README.md +++ b/README.md @@ -1,22 +1,3 @@ -# Mainfold +# Manifold -```rust -// Create two tensors with different ranks and shapes -let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor -let mut tensor2 = Tensor::::from([2]); // 2-element vector - -// Fill tensors with some values -tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); -tensor2.buffer_mut().copy_from_slice(&[5, 6]); - -// Calculate tensor product -let product = tensor1.tensor_product(&tensor2); - -println!("T1 * T2 = {}", product); - -// Check shape of the resulting tensor -assert_eq!(product.shape(), Shape::new([2, 2, 2])); - -// Check buffer of the resulting tensor -assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]); -``` +A tensor implementation in Rust. diff --git a/docs/tensor-contraction.md b/docs/tensor-contraction.md deleted file mode 100644 index 0361388..0000000 --- a/docs/tensor-contraction.md +++ /dev/null @@ -1,34 +0,0 @@ -To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps: - -1. **Tensor Shapes**: - - Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\] - - Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\] - -2. **Tensor Contraction Operation**: - - The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results. - - The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3. - -3. **Contraction Steps**: - - - Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix. - - \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \) - - Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix. - - \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \) - - Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix. - - \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \) - - - Continue this process for the remaining rows of `a` and columns of `b`: - - For the second row of `a`: - - \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \) - - \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \) - - \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \) - - For the third row of `a`: - - \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \) - - \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \) - - \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \) - -4. **Resulting Tensor**: - - The resulting 3x3 tensor from the contraction of `a` and `b` will be: - \[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\] - -These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`. \ No newline at end of file diff --git a/docs/tensor-operations.md b/docs/tensor-operations.md deleted file mode 100644 index 277cefd..0000000 --- a/docs/tensor-operations.md +++ /dev/null @@ -1,239 +0,0 @@ -# Operations Index - -## 1. Addition - -Element-wize addition of two tensors. - -\( C = A + B \) where \( C_{ijk...} = A_{ijk...} + B_{ijk...} \) for all indices \( i, j, k, ... \). - -```rust -let t1 = tensor!([[1, 2], [3, 4]]); -let t2 = tensor!([[5, 6], [7, 8]]); -let sum = t1 + t2; -``` - -```sh -[[7, 8], [10, 12]] -``` - -## 2. Subtraction - -Element-wize substraction of two tensors. - -\( C = A - B \) where \( C_{ijk...} = A_{ijk...} - B_{ijk...} \). - -```rust -let t1 = tensor!([[1, 2], [3, 4]]); -let t2 = tensor!([[5, 6], [7, 8]]); -let diff = i1 - t2; -``` - -```sh -[[-4, -4], [-4, -4]] -``` - -## 3. Multiplication - -Element-wize multiplication of two tensors. - -\( C = A \odot B \) where \( C_{ijk...} = A_{ijk...} \times B_{ijk...} \). - -```rust -let t1 = tensor!([[1, 2], [3, 4]]); -let t2 = tensor!([[5, 6], [7, 8]]); -let prod = t1 * t2; -``` - -```sh -[[5, 12], [21, 32]] -``` - -## 4. Division - -Element-wize division of two tensors. - -\( C = A \div B \) where \( C_{ijk...} = A_{ijk...} \div B_{ijk...} \). - -```rust -let t1 = tensor!([[1, 2], [3, 4]]); -let t2 = tensor!([[1, 2], [3, 4]]); -let quot = t1 / t2; -``` - -```sh -[[1, 1], [1, 1]] -``` - -## 5. Contraction - -Contract two tensors over given axes. - -For matrices \( A \) and \( B \), \( C = AB \) where \( C_{ij} = \sum_k A_{ik} B_{kj} \). - -```rust -let t1 = tensor!([[1, 2], [3, 4], [5, 6]]); -let t2 = tensor!([[1, 2, 3], [4, 5, 6]]); - -let cont = contract((t1, [1]), (t2, [0])); -``` - -```sh -TODO! -``` - -## 6. Reduction (e.g., Sum) - -\( \text{sum}(A) \) where sum over all elements of A. - -```rust -let t1 = tensor!([[1, 2], [3, 4]]); -let total = t1.sum(); -``` - -```sh -10 -``` - -## 7. Broadcasting - -Adjusts tensors with different shapes to make them compatible for element-wise operations automatically -when using supported functions. - -## 8. Reshape - -Changing the shape of a tensor without altering its data. - -```rust -let t1 = tensor!([1, 2, 3, 4, 5, 6]); -let tr = t1.reshape([2, 3]); -``` - -```sh -[[1, 2, 3], [4, 5, 6]] -``` - -## 9. Transpose - -Transpose a tensor over given axes. - -\( B = A^T \) where \( B_{ij} = A_{ji} \). - -```rust -let t1 = tensor!([1, 2, 3, 4]); -let transposed = t1.transpose(); -``` - -```sh -TODO! -``` - -## 10. Concatenation - -Joining tensors along a specified dimension. - -```rust -let t1 = tensor!([1, 2, 3]); -let t2 = tensor!([4, 5, 6]); -let cat = t1.concat(&t2, 0); -``` - -```sh -TODO! -``` - -## 11. Slicing and Indexing - -Extracting parts of tensors based on indices. - -```rust -let t1 = tensor!([1, 2, 3, 4, 5, 6]); -let slice = t1.slice(s![1, ..]); -``` - -```sh -TODO! -``` - -## 12. Element-wise Functions (e.g., Sigmoid) - -**Mathematical Definition**: - -Applying a function to each element of a tensor, like \( \sigma(x) = \frac{1}{1 + e^{-x}} \) for sigmoid. - -**Rust Code Example**: - -```rust -let tensor = Tensor::::from([-1.0, 0.0, 1.0, 2.0]); // 2x2 tensor -let sigmoid_tensor = tensor.map(|x| 1.0 / (1.0 + (-x).exp())); // Apply sigmoid element-wise -``` - -## 13. Gradient Computation/Automatic Differentiation - -**Description**: - -Calculating the derivatives of tensors, crucial for training machine learning models. - -**Rust Code Example**: Depends on if your tensor library supports automatic differentiation. This is typically more complex and may involve constructing computational graphs. - -## 14. Normalization Operations (e.g., Batch Normalization) - -**Description**: Standardizing the inputs of a model across the batch dimension. - -**Rust Code Example**: This is specific to deep learning libraries and may not be directly supported in a general-purpose tensor library. - -## 15. Convolution Operations - -**Description**: Essential for image processing and CNNs. - -**Rust Code Example**: If your library supports it, convolutions typically involve using a specialized function that takes the input tensor and a kernel tensor. - -## 16. Pooling Operations (e.g., Max Pooling) - -**Description**: Reducing the spatial dimensions of - a tensor, commonly used in CNNs. - -**Rust Code Example**: Again, this depends on your library's support for such operations. - -## 17. Tensor Slicing and Joining - -**Description**: Operations to slice a tensor into sub-tensors or join multiple tensors into a larger tensor. - -**Rust Code Example**: Similar to the slicing and concatenation examples provided above. - -## 18. Dimension Permutation - -**Description**: Rearranging the dimensions of a tensor. - -**Rust Code Example**: - -```rust -let tensor = Tensor::::from([...]); // 3D tensor -let permuted_tensor = tensor.permute_dims([2, 0, 1]); // Permute dimensions -``` - -## 19. Expand and Squeeze Operations - -**Description**: Increasing or decreasing the dimensions of a tensor (adding/removing singleton dimensions). - -**Rust Code Example**: Depends on the specific functions provided by your library. - -## 20. Data Type Conversions - -**Description**: Converting tensors from one data type to another. - -**Rust Code Example**: - -```rust -let tensor = Tensor::::from([1, 2, 3, 4]); // 2x2 tensor -let converted_tensor = tensor.to_type::(); // Convert to f32 tensor -``` - -These examples provide a general guide. The actual implementation details may vary depending on the specific features and capabilities of the Rust tensor library you're using. - -## 21. Tensor Decompositions - -**CANDECOMP/PARAFAC (CP) Decomposition**: This decomposes a tensor into a sum of component rank-one tensors. For a third-order tensor, it's like expressing it as a sum of outer products of vectors. This is useful in applications like signal processing, psychometrics, and chemometrics. - -**Tucker Decomposition**: Similar to PCA for matrices, Tucker Decomposition decomposes a tensor into a core tensor multiplied by a matrix along each mode (dimension). It's more general than CP Decomposition and is useful in areas like data compression and tensor completion. - -**Higher-Order Singular Value Decomposition (HOSVD)**: A generalization of SVD for higher-order tensors, HOSVD decomposes a tensor into a core tensor and a set of orthogonal matrices for each mode. It's used in image processing, computer vision, and multilinear subspace learning. diff --git a/examples/operations.rs b/examples/operations.rs deleted file mode 100644 index b0babae..0000000 --- a/examples/operations.rs +++ /dev/null @@ -1,94 +0,0 @@ -#![allow(mixed_script_confusables)] -#![allow(non_snake_case)] -use bytemuck::cast_slice; -use manifold::contract; -use manifold::*; - -fn tensor_product() { - println!("Tensor Product\n"); - let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor - let mut tensor2 = Tensor::::from([2]); // 2-element vector - - // Fill tensors with some values - tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); - tensor2.buffer_mut().copy_from_slice(&[5, 6]); - - println!("T1: {}", tensor1); - println!("T2: {}", tensor2); - - let product = tensor1.tensor_product(&tensor2); - - println!("T1 * T2 = {}", product); - - // Check shape of the resulting tensor - assert_eq!(product.shape(), &Shape::new([2, 2, 2])); - - // Check buffer of the resulting tensor - let expect: &[i32] = - cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]); - assert_eq!(product.buffer(), expect); -} - -fn test_tensor_contraction_23x32() { - // Define two 2D tensors (matrices) - - // Tensor A is 2x3 - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); - println!("a: {:?}\n{}\n", a.shape(), a); - - // Tensor B is 3x2 - let b: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); - println!("b: {:?}\n{}\n", b.shape(), b); - - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let ctr10 = contract((&a, [1]), (&b, [0])); - - println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10); - - let ctr01 = contract((&a, [0]), (&b, [1])); - - println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01); - // assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); - // assert_eq!( - // contracted_tensor.buffer(), - // &[9, 12, 15, 19, 26, 33, 29, 40, 51], - // "Contracted tensor buffer does not match expected" - // ); -} - -fn test_tensor_contraction_rank3() { - let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]); - let contracted_tensor = contract((&a, [2]), (&b, [0])); - - println!("a: {}", a); - println!("b: {}", b); - println!("contracted_tensor: {}", contracted_tensor); - // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); - // Verify specific elements of contracted_tensor - // assert_eq!(contracted_tensor[0][0][0][0], 50); - // assert_eq!(contracted_tensor[0][0][0][1], 60); - // ... further checks for other elements ... -} - -fn transpose() { - let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); - let b = tensor!([[1, 2, 3], [4, 5, 6]]); - - // let iter = a.idx().iter_transposed([1, 0]); - - // for idx in iter { - // println!("{idx}"); - // } - let b = a.clone().transpose([1, 0]).unwrap(); - println!("a: {}", a); - println!("ta: {}", b); -} - -fn main() { - // tensor_product(); - // test_tensor_contraction_23x32(); - // test_tensor_contraction_rank3(); - - transpose(); -} diff --git a/rustfmt.toml b/rustfmt.toml index df99c69..54c6c64 100644 --- a/rustfmt.toml +++ b/rustfmt.toml @@ -1 +1,3 @@ max_width = 80 +wrap_comments = true +comment_width = 80 \ No newline at end of file diff --git a/src/axis.rs b/src/axis.rs index 97472e7..37c670b 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -2,16 +2,16 @@ use super::*; use getset::{Getters, MutGetters}; #[derive(Clone, Debug, Getters)] -pub struct Axis<'a, T: Value, const R: usize> { +pub struct TensorAxis<'a, T: Value, const R: usize> { #[getset(get = "pub")] tensor: &'a Tensor, #[getset(get = "pub")] dim: usize, } -impl<'a, T: Value, const R: usize> Axis<'a, T, R> { +impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> { pub fn new(tensor: &'a Tensor, dim: usize) -> Self { - assert!(dim < R, "Axis out of bounds"); + assert!(dim < R, "TensorAxis out of bounds"); Self { tensor, dim } } @@ -19,40 +19,42 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> { self.tensor.shape().get(self.dim) } - pub fn shape(&self) -> &Shape { + pub fn shape(&self) -> &TensorShape { self.tensor.shape() } - pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> { + pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> { assert!(level < self.len(), "Level out of bounds"); - let mut index = Idx::new(self.shape(), [0; R]); + let mut index = TensorIndex::new(self.shape().clone(), [0; R]); index.set_axis(self.dim, level); - AxisIterator::new(self).set_start(level).set_end(level + 1) + TensorAxisIterator::new(self) + .set_start(level) + .set_end(level + 1) } } #[derive(Clone, Debug, Getters, MutGetters)] -pub struct AxisIterator<'a, T: Value, const R: usize> { +pub struct TensorAxisIterator<'a, T: Value, const R: usize> { #[getset(get = "pub")] - axis: &'a Axis<'a, T, R>, + axis: &'a TensorAxis<'a, T, R>, #[getset(get = "pub", get_mut = "pub")] - index: Idx<'a, R>, + index: TensorIndex, #[getset(get = "pub")] end: Option, } -impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { - pub fn new(axis: &'a Axis<'a, T, R>) -> Self { +impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> { + pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self { Self { axis, - index: Idx::new(axis.shape(), [0; R]), + index: TensorIndex::new(axis.shape().clone(), [0; R]), end: None, } } pub fn set_start(self, start: usize) -> Self { assert!(start < self.axis().len(), "Start out of bounds"); - let mut index = Idx::new(self.axis().shape(), [0; R]); + let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]); index.set_axis(self.axis.dim, start); Self { axis: self.axis(), @@ -92,7 +94,7 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { } } -impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { +impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> { type Item = &'a T; fn next(&mut self) -> Option { @@ -106,284 +108,11 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { } } -impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> { +impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> { type Item = &'a T; - type IntoIter = AxisIterator<'a, T, R>; + type IntoIter = TensorAxisIterator<'a, T, R>; fn into_iter(self) -> Self::IntoIter { - AxisIterator::new(&self) - } -} - -pub fn contract< - 'a, - T: Value + std::fmt::Debug, - const R: usize, - const S: usize, - const N: usize, ->( - lhs: (&'a Tensor, [usize; N]), - rhs: (&'a Tensor, [usize; N]), -) -> Tensor -where - [(); R - N]:, - [(); S - N]:, - [(); R + S - 2 * N]:, -{ - let (lhs, la) = lhs; - let (rhs, ra) = rhs; - let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); - let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); - - let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); - let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); - - let mut shape = Vec::new(); - shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); - shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); - let shape: [usize; R + S - 2 * N] = - shape.try_into().expect("Failed to create shape array"); - - let shape = Shape::new(shape); - - let result = contract_axes(&lnc, &rnc); - - Tensor::new_with_buffer(shape, result) -} - -pub fn contract_axes< - 'a, - T: Value + std::fmt::Debug, - const R: usize, - const S: usize, - const N: usize, ->( - laxes: &'a [Axis<'a, T, R>], - raxes: &'a [Axis<'a, T, S>], -) -> Vec -where - [(); R - N]:, - [(); S - N]:, -{ - let mut result = vec![]; - - let axes = laxes.into_iter().zip(raxes); - - for (laxis, raxis) in axes { - let mut axes_result: Vec = vec![]; - for i in 0..raxis.len() { - for j in 0..laxis.len() { - let mut sum = T::zero(); - let llevel = laxis.into_iter(); - let rlevel = raxis.into_iter(); - let zip = llevel.level(j).zip(rlevel.level(i)); - for (lv, rv) in zip { - sum = sum + *lv * *rv; - } - axes_result.push(sum); - } - } - result.extend_from_slice(&axes_result); - } - - result -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tensor_contraction_simple() { - // Define two 2D tensors (matrices) - // Tensor A is 2x3 - let a: Tensor = Tensor::from([[1, 2], [3, 4]]); - - // Tensor B is 1x3x2 - let b: Tensor = Tensor::from([[1, 2], [3, 4]]); - - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); - assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2])); - assert_eq!( - contracted_tensor.buffer(), - &[7, 10, 15, 22], - "Contracted tensor buffer does not match expected" - ); - } - - #[test] - fn test_tensor_contraction_23x32() { - // Define two 2D tensors (matrices) - - // Tensor A is 2x3 - let b: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); - println!("b: {}", b); - - // Tensor B is 3x2 - let a: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); - println!("a: {}", a); - - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); - - println!("contracted_tensor: {}", contracted_tensor); - assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); - assert_eq!( - contracted_tensor.buffer(), - &[9, 12, 15, 19, 26, 33, 29, 40, 51], - "Contracted tensor buffer does not match expected" - ); - } - - #[test] - fn test_tensor_contraction_rank3() { - let a: Tensor = - Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 - let b: Tensor = - Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 - let contracted_tensor: Tensor = contract((&a, [2]), (&b, [0])); - - println!("a: {}", a); - println!("b: {}", b); - println!("contracted_tensor: {}", contracted_tensor); - // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); - // Verify specific elements of contracted_tensor - // assert_eq!(contracted_tensor[0][0][0][0], 50); - // assert_eq!(contracted_tensor[0][0][0][1], 60); - // ... further checks for other elements ... - } - - // #[test] - // fn test_axis_iterator_disassemble() { - // // Creating a 2x2 Tensor for testing - // let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); - - // // Testing iteration over the first axis (axis = 0) - // let axis = Axis::new(&tensor, 0); - - // let mut axis_iter = axis.into_iter().disassemble(); - - // assert_eq!(axis_iter[0].next(), Some(&1.0)); - // assert_eq!(axis_iter[0].next(), Some(&2.0)); - // assert_eq!(axis_iter[0].next(), None); - // assert_eq!(axis_iter[1].next(), Some(&3.0)); - // assert_eq!(axis_iter[1].next(), Some(&4.0)); - // assert_eq!(axis_iter[1].next(), None); - - // // Resetting the iterator for the second axis (axis = 1) - // let axis = Axis::new(&tensor, 1); - - // let mut axis_iter = axis.into_iter().disassemble(); - - // assert_eq!(axis_iter[0].next(), Some(&1.0)); - // assert_eq!(axis_iter[0].next(), Some(&3.0)); - // assert_eq!(axis_iter[0].next(), None); - // assert_eq!(axis_iter[1].next(), Some(&2.0)); - // assert_eq!(axis_iter[1].next(), Some(&4.0)); - // assert_eq!(axis_iter[1].next(), None); - // } - - #[test] - fn test_axis_iterator() { - // Creating a 2x2 Tensor for testing - let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); - - // Testing iteration over the first axis (axis = 0) - let axis = Axis::new(&tensor, 0); - - let mut axis_iter = axis.into_iter(); - - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); - - // Resetting the iterator for the second axis (axis = 1) - let axis = Axis::new(&tensor, 1); - - let mut axis_iter = axis.into_iter(); - - assert_eq!(axis_iter.next(), Some(&1.0)); - assert_eq!(axis_iter.next(), Some(&3.0)); - assert_eq!(axis_iter.next(), Some(&2.0)); - assert_eq!(axis_iter.next(), Some(&4.0)); - - let shape = tensor.shape(); - - let mut a: Idx<2> = (shape, [0, 0]).into(); - let b: Idx<2> = (shape, [1, 1]).into(); - - while a <= b { - println!("a: {}", a); - a.inc(); - } - } - - #[test] - fn test_3d_tensor_axis_iteration() { - // Create a 3D Tensor with specific values - // Tensor shape is 2x2x2 for simplicity - let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - - // Axis 0 (Layer-wise): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 3, 4, 5, 6, 7, 8] - // - // This order suggests that for each "layer" (first level of arrays), - // the iterator goes through all rows and columns. It first completes - // the entire first layer, then moves to the second. - - let a0 = Axis::new(&t, 0); - let a0_order = a0.into_iter().cloned().collect::>(); - assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); - - // Axis 1 (Row-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][0][1] = 2 - // t[1][0][0] = 5 - // t[1][0][1] = 6 - // t[0][1][0] = 3 - // t[0][1][1] = 4 - // t[1][1][0] = 7 - // t[1][1][1] = 8 - // [1, 2, 5, 6, 3, 4, 7, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first row across all layers, then the second row - // across all layers. - - let a1 = Axis::new(&t, 1); - let a1_order = a1.into_iter().cloned().collect::>(); - assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); - - // Axis 2 (Column-wise within each layer): - // - // t[0][0][0] = 1 - // t[0][1][0] = 3 - // t[1][0][0] = 5 - // t[1][1][0] = 7 - // t[0][0][1] = 2 - // t[0][1][1] = 4 - // t[1][0][1] = 6 - // t[1][1][1] = 8 - // [1, 3, 5, 7, 2, 4, 6, 8] - // - // This indicates that within each "layer", the iterator first - // completes the first column across all layers, then the second - // column across all layers. - - let a2 = Axis::new(&t, 2); - let a2_order = a2.into_iter().cloned().collect::>(); - assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); + TensorAxisIterator::new(&self) } } diff --git a/src/error.rs b/src/error.rs index 76f24a8..66bc0c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,9 @@ use thiserror::Error; -pub type Result = std::result::Result; +pub type Result = std::result::Result; #[derive(Error, Debug)] -pub enum Error { +pub enum TensorError { #[error("Invalid argument: {0}")] InvalidArgument(String), } diff --git a/src/index.rs b/src/index.rs index cde85cd..ac91c6c 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,43 +1,47 @@ use super::*; use getset::{Getters, MutGetters}; -use std::cmp::Ordering; -use std::ops::{Add, Sub}; +use std::{ + ops::{Index, IndexMut, Add, Sub}, + cmp::Ordering, +}; #[derive(Clone, Copy, Debug, Getters, MutGetters)] -pub struct Idx<'a, const R: usize> { +pub struct TensorIndex { #[getset(get = "pub", get_mut = "pub")] indices: [usize; R], #[getset(get = "pub")] - shape: &'a Shape, + shape: TensorShape, } -impl<'a, const R: usize> Idx<'a, R> { - pub const fn zero(shape: &'a Shape) -> Self { +// ---- Construction and Initialization --------------------------------------- + +impl TensorIndex { + + pub fn new(shape: TensorShape, indices: [usize; R]) -> Self { + if !shape.check_indices(indices) { + panic!("indices out of bounds"); + } + Self { indices, shape } + } + + pub const fn zero(shape: TensorShape) -> Self { Self { indices: [0; R], shape, } } - pub fn last(shape: &'a Shape) -> Self { + pub fn last(shape: TensorShape) -> Self { let max_indices = shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); Self { indices: max_indices, - shape: shape, - } - } - - pub fn new(shape: &'a Shape, indices: [usize; R]) -> Self { - if !shape.check_indices(indices) { - panic!("indices out of bounds"); - } - Self { - indices, - shape: shape, + shape, } } +} +impl TensorIndex { pub fn is_zero(&self) -> bool { self.indices.iter().all(|&i| i == 0) } @@ -51,7 +55,8 @@ impl<'a, const R: usize> Idx<'a, R> { self.indices = [0; R]; } - /// Increments the index and returns a boolean indicating whether the end has been reached. + /// Increments the index and returns a boolean indicating whether the end + /// has been reached. /// /// # Returns /// `true` if the increment does not overflow and is still within bounds; @@ -74,10 +79,12 @@ impl<'a, const R: usize> Idx<'a, R> { } } - // If carry is still 1 after the loop, it means we've incremented past the last dimension + // If carry is still 1 after the loop, it means we've incremented past + // the last dimension if carry == 1 { - // Set the index to an invalid state to indicate the end of the iteration indicated - // by setting the first index to the size of the first dimension + // Set the index to an invalid state to indicate the end of the + // iteration indicated by setting the first index to the + // size of the first dimension self.indices[0] = self.shape.as_array()[0]; return true; // Indicate that the iteration is complete } @@ -87,7 +94,7 @@ impl<'a, const R: usize> Idx<'a, R> { // fn inc_axis pub fn inc_axis(&mut self, fixed_axis: usize) { - assert!(fixed_axis < R, "Axis out of bounds"); + assert!(fixed_axis < R, "TensorAxis out of bounds"); assert!( self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds" @@ -156,7 +163,8 @@ impl<'a, const R: usize> Idx<'a, R> { { if borrow { if *i == 0 { - *i = dim_size - 1; // Wrap around to the maximum index of this dimension + *i = dim_size - 1; // Wrap around to the maximum index of + // this dimension } else { *i -= 1; // Decrement the index borrow = false; // No more borrowing needed @@ -183,7 +191,8 @@ impl<'a, const R: usize> Idx<'a, R> { } } - // Decrement the fixed axis if possible and reset other axes to their max + // Decrement the fixed axis if possible and reset other axes to their + // max if self.indices[fixed_axis] > 0 { self.indices[fixed_axis] -= 1; for i in 0..R { @@ -216,45 +225,53 @@ impl<'a, const R: usize> Idx<'a, R> { } } - // If no axis can be decremented, set the first axis in the order to indicate overflow + // If no axis can be decremented, set the first axis in the order to + // indicate overflow self.indices[order[0]] = self.shape.get(order[0]); } /// Converts the multi-dimensional index to a flat index. /// - /// This method calculates the flat index corresponding to the multi-dimensional index - /// stored in `self.indices`, given the shape of the tensor stored in `self.shape`. - /// The calculation is based on the assumption that the tensor is stored in row-major order, + /// This method calculates the flat index corresponding to the + /// multi-dimensional index stored in `self.indices`, given the shape of + /// the tensor stored in `self.shape`. The calculation is based on the + /// assumption that the tensor is stored in row-major order, /// where the last dimension varies the fastest. /// /// # Returns /// The flat index corresponding to the multi-dimensional index. /// /// # How It Works - /// - The method iterates over each pair of corresponding index and dimension size, - /// starting from the last dimension (hence `rev()` is used for reverse iteration). + /// - The method iterates over each pair of corresponding index and + /// dimension size, starting from the last dimension (hence `rev()` is + /// used for reverse iteration). /// - In each iteration, it performs two main operations: - /// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product - /// of dimension sizes (`product`). This calculates the contribution of the current index - /// to the overall flat index. - /// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`). - /// This updates `product` for the next iteration, as each dimension contributes to the - /// flat index based on the sizes of all preceding dimensions. - /// - The `fold` operation accumulates these results, starting with an initial state of - /// `(0, 1)` where `0` is the initial flat index and `1` is the initial product. - /// - The final flat index is obtained after the last iteration, which is the first element - /// of the tuple resulting from the `fold`. + /// 1. **Index Contribution**: Multiplies the current index (`idx`) by a + /// running product of dimension sizes (`product`). This calculates the + /// contribution of the current index to the overall flat index. + /// 2. **Product Update**: Multiplies `product` by the current dimension + /// size (`dim_size`). This updates `product` for the next iteration, + /// as each dimension contributes to the flat index based on the sizes + /// of all preceding dimensions. + /// - The `fold` operation accumulates these results, starting with an + /// initial state of `(0, 1)` where `0` is the initial flat index and `1` + /// is the initial product. + /// - The final flat index is obtained after the last iteration, which is + /// the first element of the tuple resulting from the `fold`. /// /// # Example /// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`. /// - Starting with a flat index of 0 and a product of 1, - /// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5. - /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20. - /// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33. + /// - For the last dimension (size 5), add 3 * 1 to the flat index. Update + /// the product to 1 * 5 = 5. + /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update + /// the product to 5 * 4 = 20. + /// - For the first dimension (size 3), add 1 * 20 to the flat index. The + /// final flat index is 3 + 10 + 20 = 33. pub fn flat(&self) -> usize { - self.indices + self.indices() .iter() - .zip(&self.shape.as_array()) + .zip(&self.shape().as_array()) .rev() .fold((0, 1), |(flat_index, product), (&idx, &dim_size)| { (flat_index + idx * product, product * dim_size) @@ -263,13 +280,13 @@ impl<'a, const R: usize> Idx<'a, R> { } pub fn set_axis(&mut self, axis: usize, value: usize) { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); // assert!(value < self.shape.get(axis), "Value out of bounds"); self.indices[axis] = value; } pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); if value < self.shape.get(axis) { self.indices[axis] = value; true @@ -279,41 +296,41 @@ impl<'a, const R: usize> Idx<'a, R> { } pub fn get_axis(&self, axis: usize) -> usize { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); self.indices[axis] } pub fn iter_transposed( &self, order: [usize; R], - ) -> IdxTransposedIterator<'a, R> { - IdxTransposedIterator::new(self.shape(), order) + ) -> TensorIndexTransposedIterator { + TensorIndexTransposedIterator::new(self.shape().clone(), order) } } // --- blanket impls --- -impl<'a, const R: usize> PartialEq for Idx<'a, R> { +impl PartialEq for TensorIndex { fn eq(&self, other: &Self) -> bool { self.flat() == other.flat() } } -impl<'a, const R: usize> Eq for Idx<'a, R> {} +impl Eq for TensorIndex {} -impl<'a, const R: usize> PartialOrd for Idx<'a, R> { +impl PartialOrd for TensorIndex { fn partial_cmp(&self, other: &Self) -> Option { self.flat().partial_cmp(&other.flat()) } } -impl<'a, const R: usize> Ord for Idx<'a, R> { +impl Ord for TensorIndex { fn cmp(&self, other: &Self) -> Ordering { self.flat().cmp(&other.flat()) } } -impl<'a, const R: usize> Index for Idx<'a, R> { +impl Index for TensorIndex { type Output = usize; fn index(&self, index: usize) -> &Self::Output { @@ -321,39 +338,45 @@ impl<'a, const R: usize> Index for Idx<'a, R> { } } -impl<'a, const R: usize> IndexMut for Idx<'a, R> { +impl IndexMut for TensorIndex { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.indices[index] } } -impl<'a, const R: usize> From<(&'a Shape, [usize; R])> for Idx<'a, R> { - fn from((shape, indices): (&'a Shape, [usize; R])) -> Self { +impl From<(TensorShape, [usize; R])> + for TensorIndex +{ + fn from((shape, indices): (TensorShape, [usize; R])) -> Self { assert!(shape.check_indices(indices)); Self::new(shape, indices) } } -impl<'a, const R: usize> From<(&'a Shape, usize)> for Idx<'a, R> { - fn from((shape, flat_index): (&'a Shape, usize)) -> Self { +impl From<(TensorShape, usize)> + for TensorIndex +{ + fn from((shape, flat_index): (TensorShape, usize)) -> Self { let indices = shape.index_from_flat(flat_index).indices; Self::new(shape, indices) } } -impl<'a, const R: usize> From<&'a Shape> for Idx<'a, R> { - fn from(shape: &'a Shape) -> Self { +impl From> for TensorIndex { + fn from(shape: TensorShape) -> Self { Self::zero(shape) } } -impl<'a, T: Value, const R: usize> From<&'a Tensor> for Idx<'a, R> { - fn from(tensor: &'a Tensor) -> Self { - Self::zero(tensor.shape()) +impl From> + for TensorIndex +{ + fn from(tensor: Tensor) -> Self { + Self::zero(tensor.shape().clone()) } } -impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> { +impl std::fmt::Display for TensorIndex { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; for (i, (&idx, &dim_size)) in self @@ -373,11 +396,11 @@ impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> { // ---- Arithmetic Operations ---- -impl<'a, const R: usize> Add for Idx<'a, R> { +impl Add for TensorIndex { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); let mut result_indices = [0; R]; for i in 0..R { @@ -391,11 +414,11 @@ impl<'a, const R: usize> Add for Idx<'a, R> { } } -impl<'a, const R: usize> Sub for Idx<'a, R> { +impl Sub for TensorIndex { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); let mut result_indices = [0; R]; for i in 0..R { @@ -411,22 +434,22 @@ impl<'a, const R: usize> Sub for Idx<'a, R> { // ---- Iterator ---- -pub struct IdxIterator<'a, const R: usize> { - current: Idx<'a, R>, +pub struct TensorIndexIterator { + current: TensorIndex, end: bool, } -impl<'a, const R: usize> IdxIterator<'a, R> { - pub fn new(shape: &'a Shape) -> Self { +impl TensorIndexIterator { + pub fn new(shape: TensorShape) -> Self { Self { - current: Idx::zero(shape), + current: TensorIndex::zero(shape), end: false, } } } -impl<'a, const R: usize> Iterator for IdxIterator<'a, R> { - type Item = Idx<'a, R>; +impl Iterator for TensorIndexIterator { + type Item = TensorIndex; fn next(&mut self) -> Option { if self.end { @@ -439,36 +462,36 @@ impl<'a, const R: usize> Iterator for IdxIterator<'a, R> { } } -impl<'a, const R: usize> IntoIterator for Idx<'a, R> { - type Item = Idx<'a, R>; - type IntoIter = IdxIterator<'a, R>; +impl IntoIterator for TensorIndex { + type Item = TensorIndex; + type IntoIter = TensorIndexIterator; fn into_iter(self) -> Self::IntoIter { - IdxIterator { + TensorIndexIterator { current: self, end: false, } } } -pub struct IdxTransposedIterator<'a, const R: usize> { - current: Idx<'a, R>, +pub struct TensorIndexTransposedIterator { + current: TensorIndex, order: [usize; R], end: bool, } -impl<'a, const R: usize> IdxTransposedIterator<'a, R> { - pub fn new(shape: &'a Shape, order: [usize; R]) -> Self { +impl TensorIndexTransposedIterator { + pub fn new(shape: TensorShape, order: [usize; R]) -> Self { Self { - current: Idx::zero(shape), + current: TensorIndex::zero(shape), end: false, order, } } } -impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> { - type Item = Idx<'a, R>; +impl Iterator for TensorIndexTransposedIterator { + type Item = TensorIndex; fn next(&mut self) -> Option { if self.end { diff --git a/src/lib.rs b/src/lib.rs index 52e6cf6..3234e67 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,40 +1,15 @@ #![allow(incomplete_features)] #![feature(generic_const_exprs)] +#![warn(clippy::all)] + pub mod axis; pub mod error; pub mod index; pub mod shape; pub mod tensor; +pub mod value; -pub use axis::*; -pub use index::Idx; -pub use itertools::Itertools; -use num::{Num, One, Zero}; -pub use serde::{Deserialize, Serialize}; -pub use shape::Shape; -pub use static_assertions::const_assert; -pub use std::fmt::{Display, Formatter, Result as FmtResult}; -use std::ops::{Index, IndexMut}; -pub use std::sync::Arc; -pub use tensor::{Tensor, TensorIterator}; - -pub trait Value: - Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> -{ -} - -impl Value for T where - T: Num - + Zero - + One - + Copy - + Clone - + Display - + Serialize - + Deserialize<'static> - + std::iter::Sum -{ -} +pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*}; #[macro_export] macro_rules! tensor { @@ -43,184 +18,19 @@ macro_rules! tensor { }; } -// ---- Tests ---- - -#[cfg(test)] -mod tests { - use super::*; - use serde_json; - - #[test] - fn test_tensor_product() { - let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor - let mut tensor2 = Tensor::::from([2]); // 2-element vector - - // Fill tensors with some values - tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); - tensor2.buffer_mut().copy_from_slice(&[5, 6]); - - let product = tensor1.tensor_product(&tensor2); - - // Check shape of the resulting tensor - assert_eq!(*product.shape(), Shape::new([2, 2, 2])); - - // Check buffer of the resulting tensor - let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; - assert_eq!(product.buffer(), &expected_buffer); - } - - #[test] - fn serde_shape_serialization_test() { - // Create a shape instance - let shape: Shape<3> = [1, 2, 3].into(); - - // Serialize the shape to a JSON string - let serialized = - serde_json::to_string(&shape).expect("Failed to serialize"); - - // Deserialize the JSON string back into a shape - let deserialized: Shape<3> = - serde_json::from_str(&serialized).expect("Failed to deserialize"); - - // Check that the deserialized shape is equal to the original - assert_eq!(shape, deserialized); - } - - #[test] - fn tensor_serde_serialization_test() { - // Create an instance of Tensor - let tensor: Tensor = Tensor::new(Shape::new([2, 2])); - - // Serialize the Tensor to a JSON string - let serialized = - serde_json::to_string(&tensor).expect("Failed to serialize"); - - // Deserialize the JSON string back into a Tensor - let deserialized: Tensor = - serde_json::from_str(&serialized).expect("Failed to deserialize"); - - // Check that the deserialized Tensor is equal to the original - assert_eq!(tensor.buffer(), deserialized.buffer()); - assert_eq!(tensor.shape(), deserialized.shape()); - } - - #[test] - fn iterate_3d_tensor() { - let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 - let mut tensor = Tensor::new(shape); - let mut num = 0; - - // Fill the tensor with sequential numbers - for i in 0..2 { - for j in 0..2 { - for k in 0..2 { - tensor.buffer_mut()[i * 4 + j * 2 + k] = num; - num += 1; - } - } - } - - println!("{}", tensor); - - // Iterate over the tensor and check that the numbers are correct - - let mut iter = TensorIterator::new(&tensor); - - println!("{}", iter); - - assert_eq!(iter.next(), Some(&0)); - - assert_eq!(iter.next(), Some(&1)); - assert_eq!(iter.next(), Some(&2)); - assert_eq!(iter.next(), Some(&3)); - assert_eq!(iter.next(), Some(&4)); - assert_eq!(iter.next(), Some(&5)); - assert_eq!(iter.next(), Some(&6)); - assert_eq!(iter.next(), Some(&7)); - assert_eq!(iter.next(), None); - assert_eq!(iter.next(), None); - } - - #[test] - fn iterate_rank_4_tensor() { - // Define the shape of the rank-4 tensor (e.g., 2x2x2x2) - let shape = Shape::new([2, 2, 2, 2]); - let mut tensor = Tensor::new(shape); - let mut num = 0; - - // Fill the tensor with sequential numbers - for i in 0..tensor.len() { - tensor.buffer_mut()[i] = num; - num += 1; - } - - // Iterate over the tensor and check that the numbers are correct - let mut iter = TensorIterator::new(&tensor); - for expected_value in 0..tensor.len() { - assert_eq!(*iter.next().unwrap(), expected_value); - } - - // Ensure the iterator is exhausted - assert!(iter.next().is_none()); - } - - #[test] - fn test_dec_method() { - let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor - let mut index = Idx::zero(&shape); - - // Increment the index to the maximum - for _ in 0..26 { - // 3 * 3 * 3 - 1 = 26 increments to reach the end - index.inc(); - } - - // Check if the index is at the maximum - assert_eq!(index, Idx::new(&shape, [2, 2, 2])); - - // Decrement step by step and check the index - let expected_indices = [ - [2, 2, 2], - [2, 2, 1], - [2, 2, 0], - [2, 1, 2], - [2, 1, 1], - [2, 1, 0], - [2, 0, 2], - [2, 0, 1], - [2, 0, 0], - [1, 2, 2], - [1, 2, 1], - [1, 2, 0], - [1, 1, 2], - [1, 1, 1], - [1, 1, 0], - [1, 0, 2], - [1, 0, 1], - [1, 0, 0], - [0, 2, 2], - [0, 2, 1], - [0, 2, 0], - [0, 1, 2], - [0, 1, 1], - [0, 1, 0], - [0, 0, 2], - [0, 0, 1], - [0, 0, 0], - ]; - - for (i, &expected) in expected_indices.iter().enumerate() { - assert_eq!( - index, - Idx::new(&shape, expected), - "Failed at index {}", - i - ); - index.dec(); - } - - // Finally, the index should reach [0, 0, 0] - index.dec(); - assert_eq!(index, Idx::zero(&shape)); - } +#[macro_export] +macro_rules! shape { + ($array:expr) => { + TensorShape::from($array) + }; +} + +#[macro_export] +macro_rules! index { + ($tensor:expr) => { + TensorIndex::zero($tensor.shape().clone()) + }; + ($tensor:expr, $indices:expr) => { + TensorIndex::from(($tensor.shape().clone(), $indices)) + }; } diff --git a/src/shape.rs b/src/shape.rs index 7d2ff7d..06c65c2 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -1,12 +1,13 @@ use super::*; use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; use serde::ser::{Serialize, SerializeTuple, Serializer}; -use std::fmt; +use std::fmt::{Result as FmtResult, Formatter}; +use core::result::Result as SerdeResult; #[derive(Clone, Copy, Debug)] -pub struct Shape([usize; R]); +pub struct TensorShape([usize; R]); -impl Shape { +impl TensorShape { pub const fn new(shape: [usize; R]) -> Self { Self(shape) } @@ -16,7 +17,7 @@ impl Shape { } pub fn reorder(&self, indices: [usize; R]) -> Self { - let mut new_shape = Shape::new([0; R]); + let mut new_shape = TensorShape::new([0; R]); for (new_index, &index) in indices.iter().enumerate() { new_shape.0[new_index] = self.0[index]; } @@ -60,14 +61,18 @@ impl Shape { /// * `flat_index` - The flat index to convert. /// /// # Returns - /// An `Idx` instance representing the multi-dimensional index corresponding to the flat index. + /// An `TensorIndex` instance representing the multi-dimensional index + /// corresponding to the flat index. /// /// # How It Works - /// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order). - /// - In each iteration, it uses the modulo operation to find the index in the current dimension - /// and integer division to reduce the flat index for the next higher dimension. - /// - This process is repeated for each dimension to build the multi-dimensional index. - pub fn index_from_flat(&self, flat_index: usize) -> Idx { + /// - The method iterates over the dimensions of the tensor in reverse order + /// (assuming row-major order). + /// - In each iteration, it uses the modulo operation to find the index in + /// the current dimension and integer division to reduce the flat index + /// for the next higher dimension. + /// - This process is repeated for each dimension to build the + /// multi-dimensional index. + pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex { let mut indices = [0; R]; let mut remaining = flat_index; @@ -77,24 +82,24 @@ impl Shape { } indices.reverse(); // Reverse the indices to match the original dimension order - Idx::new(self, indices) + TensorIndex::new(self.clone(), indices) } - pub const fn index_zero(&self) -> Idx { - Idx::zero(self) + pub fn index_zero(&self) -> TensorIndex { + TensorIndex::zero(self.clone()) } - pub fn index_max(&self) -> Idx { + pub fn index_max(&self) -> TensorIndex { let max_indices = self.0 .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); - Idx::new(self, max_indices) + TensorIndex::new(self.clone(), max_indices) } pub fn remove_dims( &self, dims_to_remove: [usize; NAX], - ) -> Shape<{ R - NAX }> { + ) -> TensorShape<{ R - NAX }> { // Create a new array to store the remaining dimensions let mut new_shape = [0; R - NAX]; let mut new_index = 0; @@ -111,13 +116,13 @@ impl Shape { new_index += 1; } - Shape(new_shape) + TensorShape(new_shape) } pub fn remove_axes<'a, T: Value, const NAX: usize>( &self, - axes_to_remove: &'a [Axis<'a, T, R>; NAX], - ) -> Shape<{ R - NAX }> { + axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX], + ) -> TensorShape<{ R - NAX }> { // Create a new array to store the remaining dimensions let mut new_shape = [0; R - NAX]; let mut new_index = 0; @@ -136,22 +141,22 @@ impl Shape { new_index += 1; } - Shape(new_shape) + TensorShape(new_shape) } } // ---- Serialize and Deserialize ---- -struct ShapeVisitor; +struct TensorShapeVisitor; -impl<'de, const R: usize> Visitor<'de> for ShapeVisitor { - type Value = Shape; +impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor { + type Value = TensorShape; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + fn expecting(&self, formatter: &mut Formatter) -> FmtResult { formatter.write_str(concat!("an array of length ", "{R}")) } - fn visit_seq(self, mut seq: A) -> Result + fn visit_seq(self, mut seq: A) -> SerdeResult where A: SeqAccess<'de>, { @@ -161,21 +166,21 @@ impl<'de, const R: usize> Visitor<'de> for ShapeVisitor { .next_element()? .ok_or_else(|| de::Error::invalid_length(i, &self))?; } - Ok(Shape(arr)) + Ok(TensorShape(arr)) } } -impl<'de, const R: usize> Deserialize<'de> for Shape { - fn deserialize(deserializer: D) -> Result, D::Error> +impl<'de, const R: usize> Deserialize<'de> for TensorShape { + fn deserialize(deserializer: D) -> SerdeResult, D::Error> where D: Deserializer<'de>, { - deserializer.deserialize_tuple(R, ShapeVisitor) + deserializer.deserialize_tuple(R, TensorShapeVisitor) } } -impl Serialize for Shape { - fn serialize(&self, serializer: S) -> Result +impl Serialize for TensorShape { + fn serialize(&self, serializer: S) -> SerdeResult where S: Serializer, { @@ -189,23 +194,23 @@ impl Serialize for Shape { // ---- Blanket Implementations ---- -impl From<[usize; R]> for Shape { +impl From<[usize; R]> for TensorShape { fn from(shape: [usize; R]) -> Self { Self::new(shape) } } -impl PartialEq for Shape { +impl PartialEq for TensorShape { fn eq(&self, other: &Self) -> bool { self.0 == other.0 } } -impl Eq for Shape {} +impl Eq for TensorShape {} // ---- From and Into Implementations ---- -impl From> for Shape +impl From> for TensorShape where T: Value, { diff --git a/src/tensor.rs b/src/tensor.rs index 0413b12..b40784d 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,24 +1,50 @@ use super::*; use crate::error::*; use getset::{Getters, MutGetters}; -use std::fmt; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::{Display, Formatter, Result as FmtResult}, + ops::{Index, IndexMut}, +}; +/// A tensor is a multi-dimensional array of values. The rank of a tensor is the +/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is +/// a vector, a rank 2 tensor is a matrix, and so on. +/// +/// ``` +/// use manifold::{tensor, Tensor}; +/// +/// let t = tensor!([[1, 2], [3, 4]]); +/// assert_eq!(t.rank(), 2); +/// ``` #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)] pub struct Tensor { #[getset(get = "pub", get_mut = "pub")] buffer: Vec, #[getset(get = "pub")] - shape: Shape, + shape: TensorShape, } +// ---- Construction and Initialization --------------------------------------- + impl Tensor { - pub fn new(shape: Shape) -> Self { + /// Create a new tensor with the given shape. The rank of the tensor is + /// determined by the shape and all elements are initialized to zero. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let t = Tensor::::new([3, 3].into()); + /// assert_eq!(t.shape().as_array(), [3, 3]); + /// ``` + pub fn new(shape: TensorShape) -> Self { // Handle rank 0 tensor (scalar) as a special case let total_size = if R == 0 { // A rank 0 tensor should still have a buffer with one element 1 } else { - // For tensors of rank 1 or higher, calculate the total size normally + // For tensors of rank 1 or higher, calculate the total size + // normally shape.iter().product() }; @@ -26,27 +52,356 @@ impl Tensor { Self { buffer, shape } } - pub fn new_with_buffer(shape: Shape, buffer: Vec) -> Self { + /// Create a new tensor with the given shape and initialize it from the + /// given buffer. The rank of the tensor is determined by the shape. + /// + /// ``` + /// use manifold::Tensor; + /// + /// let buffer = vec![1, 2, 3, 4, 5, 6]; + /// let t = Tensor::::new_with_buffer([2, 3].into(), buffer); + /// assert_eq!(t.shape().as_array(), [2, 3]); + /// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]); + /// ``` + pub fn new_with_buffer(shape: TensorShape, buffer: Vec) -> Self { Self { buffer, shape } } +} - pub fn reshape(self, shape: Shape) -> Result { +// ---- Trivial Getters ------------------------------------------------------- + +impl Tensor { + pub fn rank(&self) -> usize { + R + } + + pub fn len(&self) -> usize { + self.buffer().len() + } +} + +// ---- Get Values ------------------------------------------------------------ + +impl Tensor { + /// Get a reference to a value at the given index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let i = (t.shape().clone(), [1, 1]).into(); + /// assert_eq!(t.get(i), Some(&4)); + /// ``` + pub fn get(&self, index: TensorIndex) -> Option<&T> { + self.buffer().get(index.flat()) + } + + /// Get a reference to a value at the given index without bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let i = (t.shape().clone(), [1, 1]).into(); + /// unsafe { assert_eq!(t.get_unchecked(i), &4); } + /// ``` + pub unsafe fn get_unchecked(&self, index: TensorIndex) -> &T { + self.buffer().get_unchecked(index.flat()) + } + + /// Get a mutable reference to a value at the given index. + /// + /// ``` + /// use manifold::*; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4)); + /// ``` + pub fn get_mut(&mut self, index: TensorIndex) -> Option<&mut T> { + self.buffer_mut().get_mut(index.flat()) + } + + /// Get a mutable reference to a value at the given index without bounds + /// checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// let s = t.shape().clone(); + /// let i = (s, [1, 1]).into(); + /// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); } + /// ``` + pub unsafe fn get_unchecked_mut( + &mut self, + index: TensorIndex, + ) -> &mut T { + self.buffer_mut().get_unchecked_mut(index.flat()) + } + + /// Get a reference to a value at the given flat index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// assert_eq!(t.get_flat(3), Some(&4)); + /// ``` + pub fn get_flat(&self, index: usize) -> Option<&T> { + self.buffer().get(index) + } + + /// Get a reference to a value at the given flat index without bounds + /// checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// unsafe { assert_eq!(t.get_flat_unchecked(3), &4); } + /// ``` + pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T { + self.buffer().get_unchecked(index) + } + + /// Get a mutable reference to a value at the given flat index. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// assert_eq!(t.get_flat_mut(3), Some(&mut 4)); + /// ``` + pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> { + self.buffer_mut().get_mut(index) + } + + /// Get a mutable reference to a value at the given flat index without + /// bounds checking. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let mut t = tensor!([[1, 2], [3, 4]]); + /// unsafe { assert_eq!(t.get_flat_unchecked_mut(3), &mut 4); } + /// ``` + pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T { + self.buffer_mut().get_unchecked_mut(index) + } +} + +// ---- Arithmetic ------------------------------------------------------------ + +impl Tensor { + /// Elementwise operation on two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_for_each(&a, &b, &mut c, &|a, b| a * b).unwrap(); + /// assert_eq!(c, tensor!([[5, 12], [21, 32]])); + /// ``` + pub fn ew_for_each( + &self, + other: &Tensor, + result: &mut Tensor, + f: &dyn Fn(T, T) -> T, + ) -> Result<()> { + if self.shape() != other.shape() { + return Err(TensorError::InvalidArgument(format!( + "TensorShape mismatch: {:?} != {:?}", + self.shape(), + other.shape() + ))); + } else if self.shape() != result.shape() { + return Err(TensorError::InvalidArgument(format!( + "TensorShape mismatch: {:?} != {:?}", + self.shape(), + result.shape() + ))); + } + + for (i, (a, b)) in + self.buffer().iter().zip(other.buffer().iter()).enumerate() + { + unsafe { + *result.get_flat_unchecked_mut(i) = f(*a, *b); + } + } + + Ok(()) + } + + /// Elementwise multiplication of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_multiply(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[5, 12], [21, 32]])); + /// ``` + pub fn ew_multiply( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a * b) + } + + /// Elementwise addition of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_add(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[6, 8], [10, 12]])); + /// ``` + pub fn ew_add( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a + b) + } + + /// Elementwise subtraction of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[1, 2], [3, 4]]); + /// let b = tensor!([[5, 6], [7, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_subtract(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[-4, -4], [-4, -4]])); + /// ``` + pub fn ew_subtract( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a - b) + } + + /// Elementwise division of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[2, 4], [8, 16]]); + /// let b = tensor!([[2, 2], [4, 8]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_divide(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[1, 2], [2, 2]])); + /// ``` + pub fn ew_divide( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a / b) + } + + /// Elementwise modulo of two tensors. + /// + /// ``` + /// use manifold::{tensor, Tensor}; + /// + /// let a = tensor!([[2, 2], [3, 3]]); + /// let b = tensor!([[4, 4], [6, 9]]); + /// let mut c = Tensor::::new([2, 2].into()); + /// Tensor::ew_modulo(&a, &b, &mut c).unwrap(); + /// assert_eq!(c, tensor!([[2, 2], [3, 3]])); + /// ``` + pub fn ew_modulo( + &self, + other: &Tensor, + result: &mut Tensor, + ) -> Result<()> { + self.ew_for_each(other, result, &|a, b| a % b) + } + + // pub fn product( + // &self, + // other: &Tensor, + // ) -> Tensor { + // let mut new_shape_vec = Vec::new(); + // new_shape_vec.extend_from_slice(&self.shape().as_array()); + // new_shape_vec.extend_from_slice(&other.shape().as_array()); + + // let new_shape_array: [usize; R + S] = new_shape_vec + // .try_into() + // .expect("Failed to create shape array"); + + // let mut new_buffer = + // Vec::with_capacity(self.buffer.len() * other.buffer.len()); + // for &item_self in &self.buffer { + // for &item_other in &other.buffer { + // new_buffer.push(item_self * item_other); + // } + // } + + // Tensor { + // buffer: new_buffer, + // shape: TensorShape::new(new_shape_array), + // } + // } +} + +// ---- Reshape --------------------------------------------------------------- + +impl Tensor { + /// Reshape the tensor to the given shape. The total size of the new shape + /// must be the same as the total size of the old shape. + /// + /// ``` + /// use manifold::{tensor, shape, Tensor, TensorShape}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let s = shape!([4]); + /// let t = t.reshape(s).unwrap(); + /// assert_eq!(t, tensor!([1, 2, 3, 4])); + /// ``` + pub fn reshape( + self, + shape: TensorShape, + ) -> Result> { if self.shape().size() != shape.size() { let (ls, rs) = (self.shape().as_array(), shape.as_array()); let (lsize, rsize) = (self.shape().size(), shape.size()); - Err(Error::InvalidArgument(format!( - "Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", + Err(TensorError::InvalidArgument(format!( + "TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", ))) } else { - Ok(Self { - buffer: self.buffer, - shape, - }) + Ok(Tensor::new_with_buffer(shape, self.buffer)) } } +} +// ---- Transpose ------------------------------------------------------------- + +impl Tensor { + /// Transpose the tensor according to the given order. The order must be a + /// permutation of the tensor's axes. + /// + /// ``` + /// use manifold::{tensor, Tensor, TensorShape}; + /// + /// let t = tensor!([[1, 2], [3, 4]]); + /// let t = t.transpose([1, 0]).unwrap(); + /// assert_eq!(t, tensor!([[1, 3], [2, 4]])); + /// ``` pub fn transpose(self, order: [usize; R]) -> Result { - let buffer = Idx::from(self.shape()) + let buffer = TensorIndex::from(self.shape().clone()) .iter_transposed(order) .map(|index| self.get(index).unwrap().clone()) .collect(); @@ -56,150 +411,22 @@ impl Tensor { shape: self.shape().reorder(order), }) } - - pub fn idx(&self) -> Idx { - Idx::from(self) - } - - pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> { - Axis::new(self, axis) - } - - pub fn get(&self, index: Idx) -> Option<&T> { - self.buffer.get(index.flat()) - } - - pub unsafe fn get_unchecked(&self, index: Idx) -> &T { - self.buffer.get_unchecked(index.flat()) - } - - pub fn get_mut(&mut self, index: Idx) -> Option<&mut T> { - self.buffer.get_mut(index.flat()) - } - - pub unsafe fn get_unchecked_mut(&mut self, index: Idx) -> &mut T { - self.buffer.get_unchecked_mut(index.flat()) - } - - pub fn get_flat(&self, index: usize) -> Option<&T> { - self.buffer.get(index) - } - - pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T { - self.buffer.get_unchecked(index) - } - - pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> { - self.buffer.get_mut(index) - } - - pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T { - self.buffer.get_unchecked_mut(index) - } - - pub fn rank(&self) -> usize { - R - } - - pub fn len(&self) -> usize { - self.buffer.len() - } - - pub fn iter(&self) -> TensorIterator { - TensorIterator::new(self) - } - - pub fn elementwise_multiply(&self, other: &Tensor) -> Tensor { - if self.shape != other.shape { - panic!("Shapes of tensors do not match"); - } - - let mut result_buffer = Vec::with_capacity(self.buffer.len()); - - for (a, b) in self.buffer.iter().zip(other.buffer.iter()) { - result_buffer.push(*a * *b); - } - - Tensor { - buffer: result_buffer, - shape: self.shape, - } - } - - pub fn tensor_product( - &self, - other: &Tensor, - ) -> Tensor { - let mut new_shape_vec = Vec::new(); - new_shape_vec.extend_from_slice(&self.shape.as_array()); - new_shape_vec.extend_from_slice(&other.shape.as_array()); - - let new_shape_array: [usize; R + S] = new_shape_vec - .try_into() - .expect("Failed to create shape array"); - - let mut new_buffer = - Vec::with_capacity(self.buffer.len() * other.buffer.len()); - for &item_self in &self.buffer { - for &item_other in &other.buffer { - new_buffer.push(item_self * item_other); - } - } - - Tensor { - buffer: new_buffer, - shape: Shape::new(new_shape_array), - } - } - - // Retrieve an element based on a specific axis and index - pub fn get_by_axis(&self, axis: usize, index: usize) -> Option { - // Convert axis and index to a flat index - let flat_index = self.axis_to_flat_index(axis, index); - if flat_index >= self.buffer.len() { - return None; - } - - Some(self.buffer[flat_index]) - } - - // Convert axis and index to a flat index in the buffer - fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize { - let mut flat_index = 0; - let mut stride = 1; - - // Ensure the given axis is within the tensor's dimensions - if axis >= R { - panic!("Axis out of bounds"); - } - - // Calculate the stride for each dimension and accumulate the flat index - for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() { - println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); - if i > axis { - stride *= dim_size; - } else if i == axis { - flat_index += index * stride; - break; // We've reached the target axis - } - } - - flat_index - } } -// ---- Indexing ---- +// ---- Indexing -------------------------------------------------------------- -impl<'a, T: Value, const R: usize> Index> for Tensor { +impl Index> for Tensor { type Output = T; - fn index(&self, index: Idx) -> &Self::Output { + fn index(&self, index: TensorIndex) -> &Self::Output { &self.buffer[index.flat()] } } -impl<'a, T: Value, const R: usize> IndexMut> for Tensor { - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { +impl IndexMut> + for Tensor +{ + fn index_mut(&mut self, index: TensorIndex) -> &mut Self::Output { &mut self.buffer[index.flat()] } } @@ -218,18 +445,18 @@ impl IndexMut for Tensor { } } -// ---- Display ---- +// ---- Display --------------------------------------------------------------- impl Tensor where - T: fmt::Display + Clone, + T: Display + Clone, { fn fmt_helper( buffer: &[T], shape: &[usize], - f: &mut fmt::Formatter<'_>, + f: &mut Formatter<'_>, level: usize, - ) -> fmt::Result { + ) -> FmtResult { if shape.is_empty() { // Base case: print individual elements write!(f, "{}", buffer[0]) @@ -247,24 +474,37 @@ where } } -impl fmt::Display for Tensor +impl Display for Tensor where - T: fmt::Display + Clone, + T: Display + Clone, { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { Tensor::::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1) } } -// ---- Iterator ---- +// ---- Equality -------------------------------------------------------------- + +impl PartialEq for Tensor +where + T: PartialEq, +{ + fn eq(&self, other: &Self) -> bool { + self.shape == other.shape && self.buffer == other.buffer + } +} + +impl Eq for Tensor where T: Eq {} + +// ---- Iterator -------------------------------------------------------------- pub struct TensorIterator<'a, T: Value, const R: usize> { tensor: &'a Tensor, - index: Idx<'a, R>, + index: TensorIndex, } impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { - pub const fn new(tensor: &'a Tensor) -> Self { + pub fn new(tensor: &'a Tensor) -> Self { Self { tensor, index: tensor.shape.index_zero(), @@ -294,7 +534,7 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor { } } -// ---- Formatting ---- +// ---- Formatting ------------------------------------------------------------ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { @@ -323,17 +563,17 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { } } -// ---- From ---- +// ---- From ------------------------------------------------------------------ -impl From> for Tensor { - fn from(shape: Shape) -> Self { +impl From> for Tensor { + fn from(shape: TensorShape) -> Self { Self::new(shape) } } impl From for Tensor { fn from(value: T) -> Self { - let shape = Shape::new([]); + let shape = TensorShape::new([]); let mut tensor = Tensor::new(shape); tensor.buffer_mut()[0] = value; tensor @@ -342,7 +582,7 @@ impl From for Tensor { impl From<[T; X]> for Tensor { fn from(array: [T; X]) -> Self { - let shape = Shape::new([X]); + let shape = TensorShape::new([X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -358,7 +598,7 @@ impl From<[[T; X]; Y]> for Tensor { fn from(array: [[T; X]; Y]) -> Self { - let shape = Shape::new([Y, X]); + let shape = TensorShape::new([Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -376,7 +616,7 @@ impl From<[[[T; X]; Y]; Z]> for Tensor { fn from(array: [[[T; X]; Y]; Z]) -> Self { - let shape = Shape::new([Z, Y, X]); + let shape = TensorShape::new([Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -401,7 +641,7 @@ impl< > From<[[[[T; X]; Y]; Z]; W]> for Tensor { fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { - let shape = Shape::new([W, Z, Y, X]); + let shape = TensorShape::new([W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -429,7 +669,7 @@ impl< > From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor { fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { - let shape = Shape::new([V, W, Z, Y, X]); + let shape = TensorShape::new([V, W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -464,7 +704,7 @@ impl< > From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor { fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { - let shape = Shape::new([U, V, W, Z, Y, X]); + let shape = TensorShape::new([U, V, W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); diff --git a/src/value.rs b/src/value.rs new file mode 100644 index 0000000..54c9eb2 --- /dev/null +++ b/src/value.rs @@ -0,0 +1,25 @@ +use num::{Num, One, Zero}; +use serde::{Deserialize, Serialize}; +use std::{ + fmt::Display, + iter::Sum, +}; + +/// A trait for types that can be used as values in a tensor. +pub trait Value: + Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> +{ +} + +impl Value for T where + T: Num + + Zero + + One + + Copy + + Clone + + Display + + Serialize + + Deserialize<'static> + + Sum +{ +} \ No newline at end of file