From d099625b97d1872a4203c7da7f6e336bc0f0f4c9 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Tue, 2 Jan 2024 21:16:56 +0200 Subject: [PATCH] Formatting Signed-off-by: Julius Koskela --- examples/operations.rs | 23 +++-- src/axis.rs | 24 +++--- src/error.rs | 6 +- src/lib.rs | 20 ++--- src/shape.rs | 30 +++---- src/tensor.rs | 187 ++++++++++++++++++++++++----------------- 6 files changed, 158 insertions(+), 132 deletions(-) diff --git a/examples/operations.rs b/examples/operations.rs index 5ebb0a9..b0babae 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -72,20 +72,17 @@ fn test_tensor_contraction_rank3() { } fn transpose() { - let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); - let b = tensor!( - [[1, 2, 3], - [4, 5, 6]] - ); + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let b = tensor!([[1, 2, 3], [4, 5, 6]]); - // let iter = a.idx().iter_transposed([1, 0]); + // let iter = a.idx().iter_transposed([1, 0]); - // for idx in iter { - // println!("{idx}"); - // } - let b = a.clone().transpose([1, 0]).unwrap(); - println!("a: {}", a); - println!("ta: {}", b); + // for idx in iter { + // println!("{idx}"); + // } + let b = a.clone().transpose([1, 0]).unwrap(); + println!("a: {}", a); + println!("ta: {}", b); } fn main() { @@ -93,5 +90,5 @@ fn main() { // test_tensor_contraction_23x32(); // test_tensor_contraction_rank3(); - transpose(); + transpose(); } diff --git a/src/axis.rs b/src/axis.rs index db63940..97472e7 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -132,23 +132,23 @@ where { let (lhs, la) = lhs; let (rhs, ra) = rhs; - let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); - let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); + let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); + let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); - let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); - let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); + let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); + let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); - let mut shape = Vec::new(); - shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); - shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); - let shape: [usize; R + S - 2 * N] = - shape.try_into().expect("Failed to create shape array"); + let mut shape = Vec::new(); + shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); + shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); + let shape: [usize; R + S - 2 * N] = + shape.try_into().expect("Failed to create shape array"); - let shape = Shape::new(shape); + let shape = Shape::new(shape); - let result = contract_axes(&lnc, &rnc); + let result = contract_axes(&lnc, &rnc); - Tensor::new_with_buffer(shape, result) + Tensor::new_with_buffer(shape, result) } pub fn contract_axes< diff --git a/src/error.rs b/src/error.rs index db4524b..76f24a8 100644 --- a/src/error.rs +++ b/src/error.rs @@ -4,6 +4,6 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { - #[error("Invalid argument: {0}")] - InvalidArgument(String), -} \ No newline at end of file + #[error("Invalid argument: {0}")] + InvalidArgument(String), +} diff --git a/src/lib.rs b/src/lib.rs index 87bdcd5..52e6cf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,22 +1,22 @@ #![allow(incomplete_features)] #![feature(generic_const_exprs)] +pub mod axis; +pub mod error; pub mod index; pub mod shape; -pub mod axis; pub mod tensor; -pub mod error; +pub use axis::*; pub use index::Idx; +pub use itertools::Itertools; use num::{Num, One, Zero}; pub use serde::{Deserialize, Serialize}; pub use shape::Shape; +pub use static_assertions::const_assert; pub use std::fmt::{Display, Formatter, Result as FmtResult}; use std::ops::{Index, IndexMut}; -pub use tensor::{Tensor, TensorIterator}; -pub use static_assertions::const_assert; -pub use itertools::Itertools; pub use std::sync::Arc; -pub use axis::*; +pub use tensor::{Tensor, TensorIterator}; pub trait Value: Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> @@ -32,15 +32,15 @@ impl Value for T where + Display + Serialize + Deserialize<'static> - + std::iter::Sum + + std::iter::Sum { } #[macro_export] macro_rules! tensor { - ($array:expr) => { - Tensor::from($array) - }; + ($array:expr) => { + Tensor::from($array) + }; } // ---- Tests ---- diff --git a/src/shape.rs b/src/shape.rs index 982ec96..7d2ff7d 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -11,17 +11,17 @@ impl Shape { Self(shape) } - pub fn axis(&self, index: usize) -> Option<&usize> { - self.0.get(index) - } + pub fn axis(&self, index: usize) -> Option<&usize> { + self.0.get(index) + } - pub fn reorder(&self, indices: [usize; R]) -> Self { - let mut new_shape = Shape::new([0; R]); - for (new_index, &index) in indices.iter().enumerate() { - new_shape.0[new_index] = self.0[index]; - } - new_shape - } + pub fn reorder(&self, indices: [usize; R]) -> Self { + let mut new_shape = Shape::new([0; R]); + for (new_index, &index) in indices.iter().enumerate() { + new_shape.0[new_index] = self.0[index]; + } + new_shape + } pub const fn as_array(&self) -> [usize; R] { self.0 @@ -114,7 +114,7 @@ impl Shape { Shape(new_shape) } - pub fn remove_axes<'a, T: Value, const NAX: usize>( + pub fn remove_axes<'a, T: Value, const NAX: usize>( &self, axes_to_remove: &'a [Axis<'a, T, R>; NAX], ) -> Shape<{ R - NAX }> { @@ -126,10 +126,10 @@ impl Shape { for (index, &dim) in self.0.iter().enumerate() { // Skip dimensions that are in the axes_to_remove array for axis in axes_to_remove { - if *axis.dim() == index { - continue; - } - } + if *axis.dim() == index { + continue; + } + } // Add the dimension to the new shape array new_shape[new_index] = dim; diff --git a/src/tensor.rs b/src/tensor.rs index b470826..0413b12 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -46,14 +46,15 @@ impl Tensor { } pub fn transpose(self, order: [usize; R]) -> Result { - let buffer = Idx::from(self.shape()).iter_transposed(order) - .map(|index| { - println!("index: {}", index); - self.get(index).unwrap().clone() - }) - .collect(); + let buffer = Idx::from(self.shape()) + .iter_transposed(order) + .map(|index| self.get(index).unwrap().clone()) + .collect(); - Ok(Tensor { buffer, shape: self.shape().reorder(order) }) + Ok(Tensor { + buffer, + shape: self.shape().reorder(order), + }) } pub fn idx(&self) -> Idx { @@ -331,26 +332,26 @@ impl From> for Tensor { } impl From for Tensor { - fn from(value: T) -> Self { - let shape = Shape::new([]); - let mut tensor = Tensor::new(shape); - tensor.buffer_mut()[0] = value; - tensor - } + fn from(value: T) -> Self { + let shape = Shape::new([]); + let mut tensor = Tensor::new(shape); + tensor.buffer_mut()[0] = value; + tensor + } } impl From<[T; X]> for Tensor { - fn from(array: [T; X]) -> Self { - let shape = Shape::new([X]); - let mut tensor = Tensor::new(shape); - let buffer = tensor.buffer_mut(); + fn from(array: [T; X]) -> Self { + let shape = Shape::new([X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); - for (i, &elem) in array.iter().enumerate() { - buffer[i] = elem; - } + for (i, &elem) in array.iter().enumerate() { + buffer[i] = elem; + } - tensor - } + tensor + } } impl From<[[T; X]; Y]> @@ -391,74 +392,102 @@ impl } } -impl - From<[[[[T; X]; Y]; Z]; W]> for Tensor +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + > From<[[[[T; X]; Y]; Z]; W]> for Tensor { - fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { - let shape = Shape::new([W, Z, Y, X]); - let mut tensor = Tensor::new(shape); - let buffer = tensor.buffer_mut(); + fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { + let shape = Shape::new([W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); - for (i, hyperplane) in array.iter().enumerate() { - for (j, plane) in hyperplane.iter().enumerate() { - for (k, row) in plane.iter().enumerate() { - for (l, &elem) in row.iter().enumerate() { - buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem; - } - } - } - } + for (i, hyperplane) in array.iter().enumerate() { + for (j, plane) in hyperplane.iter().enumerate() { + for (k, row) in plane.iter().enumerate() { + for (l, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem; + } + } + } + } - tensor - } + tensor + } } -impl - From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + const V: usize, + > From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor { - fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { - let shape = Shape::new([V, W, Z, Y, X]); - let mut tensor = Tensor::new(shape); - let buffer = tensor.buffer_mut(); + fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { + let shape = Shape::new([V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); - for (i, hyperhyperplane) in array.iter().enumerate() { - for (j, hyperplane) in hyperhyperplane.iter().enumerate() { - for (k, plane) in hyperplane.iter().enumerate() { - for (l, row) in plane.iter().enumerate() { - for (m, &elem) in row.iter().enumerate() { - buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem; - } - } - } - } - } + for (i, hyperhyperplane) in array.iter().enumerate() { + for (j, hyperplane) in hyperhyperplane.iter().enumerate() { + for (k, plane) in hyperplane.iter().enumerate() { + for (l, row) in plane.iter().enumerate() { + for (m, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W + + j * X * Y * Z + + k * X * Y + + l * X + + m] = elem; + } + } + } + } + } - tensor - } + tensor + } } -impl - From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + const V: usize, + const U: usize, + > From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor { - fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { - let shape = Shape::new([U, V, W, Z, Y, X]); - let mut tensor = Tensor::new(shape); - let buffer = tensor.buffer_mut(); + fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { + let shape = Shape::new([U, V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); - for (i, hyperhyperhyperplane) in array.iter().enumerate() { - for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() { - for (k, hyperplane) in hyperhyperplane.iter().enumerate() { - for (l, plane) in hyperplane.iter().enumerate() { - for (m, row) in plane.iter().enumerate() { - for (n, &elem) in row.iter().enumerate() { - buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem; - } - } - } - } - } - } + for (i, hyperhyperhyperplane) in array.iter().enumerate() { + for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() + { + for (k, hyperplane) in hyperhyperplane.iter().enumerate() { + for (l, plane) in hyperplane.iter().enumerate() { + for (m, row) in plane.iter().enumerate() { + for (n, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W * V + + j * X * Y * Z * W + + k * X * Y * Z + + l * X * Y + + m * X + + n] = elem; + } + } + } + } + } + } - tensor - } + tensor + } }