From 6ab486fb15a881175496840358f8ac4c6ba3a602 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Wed, 3 Jan 2024 17:11:27 +0200 Subject: [PATCH] Change the names of the existing types to align with requirements Signed-off-by: Julius Koskela --- src/axis.rs | 74 +++++++++++++++++++-------------------- src/error.rs | 4 +-- src/index.rs | 96 +++++++++++++++++++++++++-------------------------- src/lib.rs | 26 +++++++------- src/shape.rs | 54 ++++++++++++++--------------- src/tensor.rs | 64 +++++++++++++++++----------------- 6 files changed, 159 insertions(+), 159 deletions(-) diff --git a/src/axis.rs b/src/axis.rs index 97472e7..d01f0f0 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -2,16 +2,16 @@ use super::*; use getset::{Getters, MutGetters}; #[derive(Clone, Debug, Getters)] -pub struct Axis<'a, T: Value, const R: usize> { +pub struct TensorAxis<'a, T: Value, const R: usize> { #[getset(get = "pub")] tensor: &'a Tensor, #[getset(get = "pub")] dim: usize, } -impl<'a, T: Value, const R: usize> Axis<'a, T, R> { +impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> { pub fn new(tensor: &'a Tensor, dim: usize) -> Self { - assert!(dim < R, "Axis out of bounds"); + assert!(dim < R, "TensorAxis out of bounds"); Self { tensor, dim } } @@ -19,40 +19,40 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> { self.tensor.shape().get(self.dim) } - pub fn shape(&self) -> &Shape { + pub fn shape(&self) -> &TensorShape { self.tensor.shape() } - pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> { + pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> { assert!(level < self.len(), "Level out of bounds"); - let mut index = Idx::new(self.shape(), [0; R]); + let mut index = TensorIndex::new(self.shape(), [0; R]); index.set_axis(self.dim, level); - AxisIterator::new(self).set_start(level).set_end(level + 1) + TensorAxisIterator::new(self).set_start(level).set_end(level + 1) } } #[derive(Clone, Debug, Getters, MutGetters)] -pub struct AxisIterator<'a, T: Value, const R: usize> { +pub struct TensorAxisIterator<'a, T: Value, const R: usize> { #[getset(get = "pub")] - axis: &'a Axis<'a, T, R>, + axis: &'a TensorAxis<'a, T, R>, #[getset(get = "pub", get_mut = "pub")] - index: Idx<'a, R>, + index: TensorIndex<'a, R>, #[getset(get = "pub")] end: Option, } -impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { - pub fn new(axis: &'a Axis<'a, T, R>) -> Self { +impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> { + pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self { Self { axis, - index: Idx::new(axis.shape(), [0; R]), + index: TensorIndex::new(axis.shape(), [0; R]), end: None, } } pub fn set_start(self, start: usize) -> Self { assert!(start < self.axis().len(), "Start out of bounds"); - let mut index = Idx::new(self.axis().shape(), [0; R]); + let mut index = TensorIndex::new(self.axis().shape(), [0; R]); index.set_axis(self.axis.dim, start); Self { axis: self.axis(), @@ -92,7 +92,7 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { } } -impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { +impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> { type Item = &'a T; fn next(&mut self) -> Option { @@ -106,12 +106,12 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { } } -impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> { +impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> { type Item = &'a T; - type IntoIter = AxisIterator<'a, T, R>; + type IntoIter = TensorAxisIterator<'a, T, R>; fn into_iter(self) -> Self::IntoIter { - AxisIterator::new(&self) + TensorAxisIterator::new(&self) } } @@ -144,7 +144,7 @@ where let shape: [usize; R + S - 2 * N] = shape.try_into().expect("Failed to create shape array"); - let shape = Shape::new(shape); + let shape = TensorShape::new(shape); let result = contract_axes(&lnc, &rnc); @@ -158,8 +158,8 @@ pub fn contract_axes< const S: usize, const N: usize, >( - laxes: &'a [Axis<'a, T, R>], - raxes: &'a [Axis<'a, T, S>], + laxes: &'a [TensorAxis<'a, T, R>], + raxes: &'a [TensorAxis<'a, T, S>], ) -> Vec where [(); R - N]:, @@ -204,7 +204,7 @@ mod tests { // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); - assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2])); + assert_eq!(contracted_tensor.shape(), &TensorShape::new([2, 2])); assert_eq!( contracted_tensor.buffer(), &[7, 10, 15, 22], @@ -228,7 +228,7 @@ mod tests { let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); println!("contracted_tensor: {}", contracted_tensor); - assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); + assert_eq!(contracted_tensor.shape(), &TensorShape::new([3, 3])); assert_eq!( contracted_tensor.buffer(), &[9, 12, 15, 19, 26, 33, 29, 40, 51], @@ -239,9 +239,9 @@ mod tests { #[test] fn test_tensor_contraction_rank3() { let a: Tensor = - Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 + Tensor::new_with_buffer(TensorShape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 let b: Tensor = - Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 + Tensor::new_with_buffer(TensorShape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 let contracted_tensor: Tensor = contract((&a, [2]), (&b, [0])); println!("a: {}", a); @@ -260,7 +260,7 @@ mod tests { // let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); // // Testing iteration over the first axis (axis = 0) - // let axis = Axis::new(&tensor, 0); + // let axis = TensorAxis::new(&tensor, 0); // let mut axis_iter = axis.into_iter().disassemble(); @@ -272,7 +272,7 @@ mod tests { // assert_eq!(axis_iter[1].next(), None); // // Resetting the iterator for the second axis (axis = 1) - // let axis = Axis::new(&tensor, 1); + // let axis = TensorAxis::new(&tensor, 1); // let mut axis_iter = axis.into_iter().disassemble(); @@ -290,7 +290,7 @@ mod tests { let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); // Testing iteration over the first axis (axis = 0) - let axis = Axis::new(&tensor, 0); + let axis = TensorAxis::new(&tensor, 0); let mut axis_iter = axis.into_iter(); @@ -300,7 +300,7 @@ mod tests { assert_eq!(axis_iter.next(), Some(&4.0)); // Resetting the iterator for the second axis (axis = 1) - let axis = Axis::new(&tensor, 1); + let axis = TensorAxis::new(&tensor, 1); let mut axis_iter = axis.into_iter(); @@ -311,8 +311,8 @@ mod tests { let shape = tensor.shape(); - let mut a: Idx<2> = (shape, [0, 0]).into(); - let b: Idx<2> = (shape, [1, 1]).into(); + let mut a: TensorIndex<2> = (shape, [0, 0]).into(); + let b: TensorIndex<2> = (shape, [1, 1]).into(); while a <= b { println!("a: {}", a); @@ -326,7 +326,7 @@ mod tests { // Tensor shape is 2x2x2 for simplicity let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - // Axis 0 (Layer-wise): + // TensorAxis 0 (Layer-wise): // // t[0][0][0] = 1 // t[0][0][1] = 2 @@ -342,11 +342,11 @@ mod tests { // the iterator goes through all rows and columns. It first completes // the entire first layer, then moves to the second. - let a0 = Axis::new(&t, 0); + let a0 = TensorAxis::new(&t, 0); let a0_order = a0.into_iter().cloned().collect::>(); assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); - // Axis 1 (Row-wise within each layer): + // TensorAxis 1 (Row-wise within each layer): // // t[0][0][0] = 1 // t[0][0][1] = 2 @@ -362,11 +362,11 @@ mod tests { // completes the first row across all layers, then the second row // across all layers. - let a1 = Axis::new(&t, 1); + let a1 = TensorAxis::new(&t, 1); let a1_order = a1.into_iter().cloned().collect::>(); assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); - // Axis 2 (Column-wise within each layer): + // TensorAxis 2 (Column-wise within each layer): // // t[0][0][0] = 1 // t[0][1][0] = 3 @@ -382,7 +382,7 @@ mod tests { // completes the first column across all layers, then the second // column across all layers. - let a2 = Axis::new(&t, 2); + let a2 = TensorAxis::new(&t, 2); let a2_order = a2.into_iter().cloned().collect::>(); assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); } diff --git a/src/error.rs b/src/error.rs index 76f24a8..66bc0c9 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,9 +1,9 @@ use thiserror::Error; -pub type Result = std::result::Result; +pub type Result = std::result::Result; #[derive(Error, Debug)] -pub enum Error { +pub enum TensorError { #[error("Invalid argument: {0}")] InvalidArgument(String), } diff --git a/src/index.rs b/src/index.rs index cde85cd..cc93a08 100644 --- a/src/index.rs +++ b/src/index.rs @@ -4,22 +4,22 @@ use std::cmp::Ordering; use std::ops::{Add, Sub}; #[derive(Clone, Copy, Debug, Getters, MutGetters)] -pub struct Idx<'a, const R: usize> { +pub struct TensorIndex<'a, const R: usize> { #[getset(get = "pub", get_mut = "pub")] indices: [usize; R], #[getset(get = "pub")] - shape: &'a Shape, + shape: &'a TensorShape, } -impl<'a, const R: usize> Idx<'a, R> { - pub const fn zero(shape: &'a Shape) -> Self { +impl<'a, const R: usize> TensorIndex<'a, R> { + pub const fn zero(shape: &'a TensorShape) -> Self { Self { indices: [0; R], shape, } } - pub fn last(shape: &'a Shape) -> Self { + pub fn last(shape: &'a TensorShape) -> Self { let max_indices = shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); Self { @@ -28,7 +28,7 @@ impl<'a, const R: usize> Idx<'a, R> { } } - pub fn new(shape: &'a Shape, indices: [usize; R]) -> Self { + pub fn new(shape: &'a TensorShape, indices: [usize; R]) -> Self { if !shape.check_indices(indices) { panic!("indices out of bounds"); } @@ -87,7 +87,7 @@ impl<'a, const R: usize> Idx<'a, R> { // fn inc_axis pub fn inc_axis(&mut self, fixed_axis: usize) { - assert!(fixed_axis < R, "Axis out of bounds"); + assert!(fixed_axis < R, "TensorAxis out of bounds"); assert!( self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds" @@ -263,13 +263,13 @@ impl<'a, const R: usize> Idx<'a, R> { } pub fn set_axis(&mut self, axis: usize, value: usize) { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); // assert!(value < self.shape.get(axis), "Value out of bounds"); self.indices[axis] = value; } pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); if value < self.shape.get(axis) { self.indices[axis] = value; true @@ -279,41 +279,41 @@ impl<'a, const R: usize> Idx<'a, R> { } pub fn get_axis(&self, axis: usize) -> usize { - assert!(axis < R, "Axis out of bounds"); + assert!(axis < R, "TensorAxis out of bounds"); self.indices[axis] } pub fn iter_transposed( &self, order: [usize; R], - ) -> IdxTransposedIterator<'a, R> { - IdxTransposedIterator::new(self.shape(), order) + ) -> TensorIndexTransposedIterator<'a, R> { + TensorIndexTransposedIterator::new(self.shape(), order) } } // --- blanket impls --- -impl<'a, const R: usize> PartialEq for Idx<'a, R> { +impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> { fn eq(&self, other: &Self) -> bool { self.flat() == other.flat() } } -impl<'a, const R: usize> Eq for Idx<'a, R> {} +impl<'a, const R: usize> Eq for TensorIndex<'a, R> {} -impl<'a, const R: usize> PartialOrd for Idx<'a, R> { +impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> { fn partial_cmp(&self, other: &Self) -> Option { self.flat().partial_cmp(&other.flat()) } } -impl<'a, const R: usize> Ord for Idx<'a, R> { +impl<'a, const R: usize> Ord for TensorIndex<'a, R> { fn cmp(&self, other: &Self) -> Ordering { self.flat().cmp(&other.flat()) } } -impl<'a, const R: usize> Index for Idx<'a, R> { +impl<'a, const R: usize> Index for TensorIndex<'a, R> { type Output = usize; fn index(&self, index: usize) -> &Self::Output { @@ -321,39 +321,39 @@ impl<'a, const R: usize> Index for Idx<'a, R> { } } -impl<'a, const R: usize> IndexMut for Idx<'a, R> { +impl<'a, const R: usize> IndexMut for TensorIndex<'a, R> { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.indices[index] } } -impl<'a, const R: usize> From<(&'a Shape, [usize; R])> for Idx<'a, R> { - fn from((shape, indices): (&'a Shape, [usize; R])) -> Self { +impl<'a, const R: usize> From<(&'a TensorShape, [usize; R])> for TensorIndex<'a, R> { + fn from((shape, indices): (&'a TensorShape, [usize; R])) -> Self { assert!(shape.check_indices(indices)); Self::new(shape, indices) } } -impl<'a, const R: usize> From<(&'a Shape, usize)> for Idx<'a, R> { - fn from((shape, flat_index): (&'a Shape, usize)) -> Self { +impl<'a, const R: usize> From<(&'a TensorShape, usize)> for TensorIndex<'a, R> { + fn from((shape, flat_index): (&'a TensorShape, usize)) -> Self { let indices = shape.index_from_flat(flat_index).indices; Self::new(shape, indices) } } -impl<'a, const R: usize> From<&'a Shape> for Idx<'a, R> { - fn from(shape: &'a Shape) -> Self { +impl<'a, const R: usize> From<&'a TensorShape> for TensorIndex<'a, R> { + fn from(shape: &'a TensorShape) -> Self { Self::zero(shape) } } -impl<'a, T: Value, const R: usize> From<&'a Tensor> for Idx<'a, R> { +impl<'a, T: Value, const R: usize> From<&'a Tensor> for TensorIndex<'a, R> { fn from(tensor: &'a Tensor) -> Self { Self::zero(tensor.shape()) } } -impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> { +impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; for (i, (&idx, &dim_size)) in self @@ -373,11 +373,11 @@ impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> { // ---- Arithmetic Operations ---- -impl<'a, const R: usize> Add for Idx<'a, R> { +impl<'a, const R: usize> Add for TensorIndex<'a, R> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); let mut result_indices = [0; R]; for i in 0..R { @@ -391,11 +391,11 @@ impl<'a, const R: usize> Add for Idx<'a, R> { } } -impl<'a, const R: usize> Sub for Idx<'a, R> { +impl<'a, const R: usize> Sub for TensorIndex<'a, R> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { - assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + assert_eq!(self.shape, rhs.shape, "TensorShape mismatch"); let mut result_indices = [0; R]; for i in 0..R { @@ -411,22 +411,22 @@ impl<'a, const R: usize> Sub for Idx<'a, R> { // ---- Iterator ---- -pub struct IdxIterator<'a, const R: usize> { - current: Idx<'a, R>, +pub struct TensorIndexIterator<'a, const R: usize> { + current: TensorIndex<'a, R>, end: bool, } -impl<'a, const R: usize> IdxIterator<'a, R> { - pub fn new(shape: &'a Shape) -> Self { +impl<'a, const R: usize> TensorIndexIterator<'a, R> { + pub fn new(shape: &'a TensorShape) -> Self { Self { - current: Idx::zero(shape), + current: TensorIndex::zero(shape), end: false, } } } -impl<'a, const R: usize> Iterator for IdxIterator<'a, R> { - type Item = Idx<'a, R>; +impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> { + type Item = TensorIndex<'a, R>; fn next(&mut self) -> Option { if self.end { @@ -439,36 +439,36 @@ impl<'a, const R: usize> Iterator for IdxIterator<'a, R> { } } -impl<'a, const R: usize> IntoIterator for Idx<'a, R> { - type Item = Idx<'a, R>; - type IntoIter = IdxIterator<'a, R>; +impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> { + type Item = TensorIndex<'a, R>; + type IntoIter = TensorIndexIterator<'a, R>; fn into_iter(self) -> Self::IntoIter { - IdxIterator { + TensorIndexIterator { current: self, end: false, } } } -pub struct IdxTransposedIterator<'a, const R: usize> { - current: Idx<'a, R>, +pub struct TensorIndexTransposedIterator<'a, const R: usize> { + current: TensorIndex<'a, R>, order: [usize; R], end: bool, } -impl<'a, const R: usize> IdxTransposedIterator<'a, R> { - pub fn new(shape: &'a Shape, order: [usize; R]) -> Self { +impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> { + pub fn new(shape: &'a TensorShape, order: [usize; R]) -> Self { Self { - current: Idx::zero(shape), + current: TensorIndex::zero(shape), end: false, order, } } } -impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> { - type Item = Idx<'a, R>; +impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> { + type Item = TensorIndex<'a, R>; fn next(&mut self) -> Option { if self.end { diff --git a/src/lib.rs b/src/lib.rs index 52e6cf6..3407622 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -7,11 +7,11 @@ pub mod shape; pub mod tensor; pub use axis::*; -pub use index::Idx; +pub use index::TensorIndex; pub use itertools::Itertools; use num::{Num, One, Zero}; pub use serde::{Deserialize, Serialize}; -pub use shape::Shape; +pub use shape::TensorShape; pub use static_assertions::const_assert; pub use std::fmt::{Display, Formatter, Result as FmtResult}; use std::ops::{Index, IndexMut}; @@ -62,7 +62,7 @@ mod tests { let product = tensor1.tensor_product(&tensor2); // Check shape of the resulting tensor - assert_eq!(*product.shape(), Shape::new([2, 2, 2])); + assert_eq!(*product.shape(), TensorShape::new([2, 2, 2])); // Check buffer of the resulting tensor let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; @@ -72,14 +72,14 @@ mod tests { #[test] fn serde_shape_serialization_test() { // Create a shape instance - let shape: Shape<3> = [1, 2, 3].into(); + let shape: TensorShape<3> = [1, 2, 3].into(); // Serialize the shape to a JSON string let serialized = serde_json::to_string(&shape).expect("Failed to serialize"); // Deserialize the JSON string back into a shape - let deserialized: Shape<3> = + let deserialized: TensorShape<3> = serde_json::from_str(&serialized).expect("Failed to deserialize"); // Check that the deserialized shape is equal to the original @@ -89,7 +89,7 @@ mod tests { #[test] fn tensor_serde_serialization_test() { // Create an instance of Tensor - let tensor: Tensor = Tensor::new(Shape::new([2, 2])); + let tensor: Tensor = Tensor::new(TensorShape::new([2, 2])); // Serialize the Tensor to a JSON string let serialized = @@ -106,7 +106,7 @@ mod tests { #[test] fn iterate_3d_tensor() { - let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 + let shape = TensorShape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 let mut tensor = Tensor::new(shape); let mut num = 0; @@ -144,7 +144,7 @@ mod tests { #[test] fn iterate_rank_4_tensor() { // Define the shape of the rank-4 tensor (e.g., 2x2x2x2) - let shape = Shape::new([2, 2, 2, 2]); + let shape = TensorShape::new([2, 2, 2, 2]); let mut tensor = Tensor::new(shape); let mut num = 0; @@ -166,8 +166,8 @@ mod tests { #[test] fn test_dec_method() { - let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor - let mut index = Idx::zero(&shape); + let shape = TensorShape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor + let mut index = TensorIndex::zero(&shape); // Increment the index to the maximum for _ in 0..26 { @@ -176,7 +176,7 @@ mod tests { } // Check if the index is at the maximum - assert_eq!(index, Idx::new(&shape, [2, 2, 2])); + assert_eq!(index, TensorIndex::new(&shape, [2, 2, 2])); // Decrement step by step and check the index let expected_indices = [ @@ -212,7 +212,7 @@ mod tests { for (i, &expected) in expected_indices.iter().enumerate() { assert_eq!( index, - Idx::new(&shape, expected), + TensorIndex::new(&shape, expected), "Failed at index {}", i ); @@ -221,6 +221,6 @@ mod tests { // Finally, the index should reach [0, 0, 0] index.dec(); - assert_eq!(index, Idx::zero(&shape)); + assert_eq!(index, TensorIndex::zero(&shape)); } } diff --git a/src/shape.rs b/src/shape.rs index 7d2ff7d..d1ffda6 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -4,9 +4,9 @@ use serde::ser::{Serialize, SerializeTuple, Serializer}; use std::fmt; #[derive(Clone, Copy, Debug)] -pub struct Shape([usize; R]); +pub struct TensorShape([usize; R]); -impl Shape { +impl TensorShape { pub const fn new(shape: [usize; R]) -> Self { Self(shape) } @@ -16,7 +16,7 @@ impl Shape { } pub fn reorder(&self, indices: [usize; R]) -> Self { - let mut new_shape = Shape::new([0; R]); + let mut new_shape = TensorShape::new([0; R]); for (new_index, &index) in indices.iter().enumerate() { new_shape.0[new_index] = self.0[index]; } @@ -60,14 +60,14 @@ impl Shape { /// * `flat_index` - The flat index to convert. /// /// # Returns - /// An `Idx` instance representing the multi-dimensional index corresponding to the flat index. + /// An `TensorIndex` instance representing the multi-dimensional index corresponding to the flat index. /// /// # How It Works /// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order). /// - In each iteration, it uses the modulo operation to find the index in the current dimension /// and integer division to reduce the flat index for the next higher dimension. /// - This process is repeated for each dimension to build the multi-dimensional index. - pub fn index_from_flat(&self, flat_index: usize) -> Idx { + pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex { let mut indices = [0; R]; let mut remaining = flat_index; @@ -77,24 +77,24 @@ impl Shape { } indices.reverse(); // Reverse the indices to match the original dimension order - Idx::new(self, indices) + TensorIndex::new(self, indices) } - pub const fn index_zero(&self) -> Idx { - Idx::zero(self) + pub const fn index_zero(&self) -> TensorIndex { + TensorIndex::zero(self) } - pub fn index_max(&self) -> Idx { + pub fn index_max(&self) -> TensorIndex { let max_indices = self.0 .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); - Idx::new(self, max_indices) + TensorIndex::new(self, max_indices) } pub fn remove_dims( &self, dims_to_remove: [usize; NAX], - ) -> Shape<{ R - NAX }> { + ) -> TensorShape<{ R - NAX }> { // Create a new array to store the remaining dimensions let mut new_shape = [0; R - NAX]; let mut new_index = 0; @@ -111,13 +111,13 @@ impl Shape { new_index += 1; } - Shape(new_shape) + TensorShape(new_shape) } pub fn remove_axes<'a, T: Value, const NAX: usize>( &self, - axes_to_remove: &'a [Axis<'a, T, R>; NAX], - ) -> Shape<{ R - NAX }> { + axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX], + ) -> TensorShape<{ R - NAX }> { // Create a new array to store the remaining dimensions let mut new_shape = [0; R - NAX]; let mut new_index = 0; @@ -136,16 +136,16 @@ impl Shape { new_index += 1; } - Shape(new_shape) + TensorShape(new_shape) } } // ---- Serialize and Deserialize ---- -struct ShapeVisitor; +struct TensorShapeVisitor; -impl<'de, const R: usize> Visitor<'de> for ShapeVisitor { - type Value = Shape; +impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor { + type Value = TensorShape; fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { formatter.write_str(concat!("an array of length ", "{R}")) @@ -161,20 +161,20 @@ impl<'de, const R: usize> Visitor<'de> for ShapeVisitor { .next_element()? .ok_or_else(|| de::Error::invalid_length(i, &self))?; } - Ok(Shape(arr)) + Ok(TensorShape(arr)) } } -impl<'de, const R: usize> Deserialize<'de> for Shape { - fn deserialize(deserializer: D) -> Result, D::Error> +impl<'de, const R: usize> Deserialize<'de> for TensorShape { + fn deserialize(deserializer: D) -> Result, D::Error> where D: Deserializer<'de>, { - deserializer.deserialize_tuple(R, ShapeVisitor) + deserializer.deserialize_tuple(R, TensorShapeVisitor) } } -impl Serialize for Shape { +impl Serialize for TensorShape { fn serialize(&self, serializer: S) -> Result where S: Serializer, @@ -189,23 +189,23 @@ impl Serialize for Shape { // ---- Blanket Implementations ---- -impl From<[usize; R]> for Shape { +impl From<[usize; R]> for TensorShape { fn from(shape: [usize; R]) -> Self { Self::new(shape) } } -impl PartialEq for Shape { +impl PartialEq for TensorShape { fn eq(&self, other: &Self) -> bool { self.0 == other.0 } } -impl Eq for Shape {} +impl Eq for TensorShape {} // ---- From and Into Implementations ---- -impl From> for Shape +impl From> for TensorShape where T: Value, { diff --git a/src/tensor.rs b/src/tensor.rs index 0413b12..b788a6c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -8,11 +8,11 @@ pub struct Tensor { #[getset(get = "pub", get_mut = "pub")] buffer: Vec, #[getset(get = "pub")] - shape: Shape, + shape: TensorShape, } impl Tensor { - pub fn new(shape: Shape) -> Self { + pub fn new(shape: TensorShape) -> Self { // Handle rank 0 tensor (scalar) as a special case let total_size = if R == 0 { // A rank 0 tensor should still have a buffer with one element @@ -26,16 +26,16 @@ impl Tensor { Self { buffer, shape } } - pub fn new_with_buffer(shape: Shape, buffer: Vec) -> Self { + pub fn new_with_buffer(shape: TensorShape, buffer: Vec) -> Self { Self { buffer, shape } } - pub fn reshape(self, shape: Shape) -> Result { + pub fn reshape(self, shape: TensorShape) -> Result { if self.shape().size() != shape.size() { let (ls, rs) = (self.shape().as_array(), shape.as_array()); let (lsize, rsize) = (self.shape().size(), shape.size()); - Err(Error::InvalidArgument(format!( - "Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", + Err(TensorError::InvalidArgument(format!( + "TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", ))) } else { Ok(Self { @@ -46,7 +46,7 @@ impl Tensor { } pub fn transpose(self, order: [usize; R]) -> Result { - let buffer = Idx::from(self.shape()) + let buffer = TensorIndex::from(self.shape()) .iter_transposed(order) .map(|index| self.get(index).unwrap().clone()) .collect(); @@ -57,27 +57,27 @@ impl Tensor { }) } - pub fn idx(&self) -> Idx { - Idx::from(self) + pub fn idx(&self) -> TensorIndex { + TensorIndex::from(self) } - pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> { - Axis::new(self, axis) + pub fn axis<'a>(&'a self, axis: usize) -> TensorAxis<'a, T, R> { + TensorAxis::new(self, axis) } - pub fn get(&self, index: Idx) -> Option<&T> { + pub fn get(&self, index: TensorIndex) -> Option<&T> { self.buffer.get(index.flat()) } - pub unsafe fn get_unchecked(&self, index: Idx) -> &T { + pub unsafe fn get_unchecked(&self, index: TensorIndex) -> &T { self.buffer.get_unchecked(index.flat()) } - pub fn get_mut(&mut self, index: Idx) -> Option<&mut T> { + pub fn get_mut(&mut self, index: TensorIndex) -> Option<&mut T> { self.buffer.get_mut(index.flat()) } - pub unsafe fn get_unchecked_mut(&mut self, index: Idx) -> &mut T { + pub unsafe fn get_unchecked_mut(&mut self, index: TensorIndex) -> &mut T { self.buffer.get_unchecked_mut(index.flat()) } @@ -111,7 +111,7 @@ impl Tensor { pub fn elementwise_multiply(&self, other: &Tensor) -> Tensor { if self.shape != other.shape { - panic!("Shapes of tensors do not match"); + panic!("TensorShapes of tensors do not match"); } let mut result_buffer = Vec::with_capacity(self.buffer.len()); @@ -148,7 +148,7 @@ impl Tensor { Tensor { buffer: new_buffer, - shape: Shape::new(new_shape_array), + shape: TensorShape::new(new_shape_array), } } @@ -170,7 +170,7 @@ impl Tensor { // Ensure the given axis is within the tensor's dimensions if axis >= R { - panic!("Axis out of bounds"); + panic!("TensorAxis out of bounds"); } // Calculate the stride for each dimension and accumulate the flat index @@ -190,16 +190,16 @@ impl Tensor { // ---- Indexing ---- -impl<'a, T: Value, const R: usize> Index> for Tensor { +impl<'a, T: Value, const R: usize> Index> for Tensor { type Output = T; - fn index(&self, index: Idx) -> &Self::Output { + fn index(&self, index: TensorIndex) -> &Self::Output { &self.buffer[index.flat()] } } -impl<'a, T: Value, const R: usize> IndexMut> for Tensor { - fn index_mut(&mut self, index: Idx) -> &mut Self::Output { +impl<'a, T: Value, const R: usize> IndexMut> for Tensor { + fn index_mut(&mut self, index: TensorIndex) -> &mut Self::Output { &mut self.buffer[index.flat()] } } @@ -260,7 +260,7 @@ where pub struct TensorIterator<'a, T: Value, const R: usize> { tensor: &'a Tensor, - index: Idx<'a, R>, + index: TensorIndex<'a, R>, } impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { @@ -325,15 +325,15 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { // ---- From ---- -impl From> for Tensor { - fn from(shape: Shape) -> Self { +impl From> for Tensor { + fn from(shape: TensorShape) -> Self { Self::new(shape) } } impl From for Tensor { fn from(value: T) -> Self { - let shape = Shape::new([]); + let shape = TensorShape::new([]); let mut tensor = Tensor::new(shape); tensor.buffer_mut()[0] = value; tensor @@ -342,7 +342,7 @@ impl From for Tensor { impl From<[T; X]> for Tensor { fn from(array: [T; X]) -> Self { - let shape = Shape::new([X]); + let shape = TensorShape::new([X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -358,7 +358,7 @@ impl From<[[T; X]; Y]> for Tensor { fn from(array: [[T; X]; Y]) -> Self { - let shape = Shape::new([Y, X]); + let shape = TensorShape::new([Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -376,7 +376,7 @@ impl From<[[[T; X]; Y]; Z]> for Tensor { fn from(array: [[[T; X]; Y]; Z]) -> Self { - let shape = Shape::new([Z, Y, X]); + let shape = TensorShape::new([Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -401,7 +401,7 @@ impl< > From<[[[[T; X]; Y]; Z]; W]> for Tensor { fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { - let shape = Shape::new([W, Z, Y, X]); + let shape = TensorShape::new([W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -429,7 +429,7 @@ impl< > From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor { fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { - let shape = Shape::new([V, W, Z, Y, X]); + let shape = TensorShape::new([V, W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut(); @@ -464,7 +464,7 @@ impl< > From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor { fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { - let shape = Shape::new([U, V, W, Z, Y, X]); + let shape = TensorShape::new([U, V, W, Z, Y, X]); let mut tensor = Tensor::new(shape); let buffer = tensor.buffer_mut();