From 08d72af38cd47a2dcb473c61183d7d86b9e57575 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Thu, 4 Jan 2024 13:34:48 +0200 Subject: [PATCH] have shape as_array return a reference Signed-off-by: Julius Koskela --- src/index.rs | 8 ++- src/shape.rs | 4 +- src/tensor.rs | 163 ++++++++++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 165 insertions(+), 10 deletions(-) diff --git a/src/index.rs b/src/index.rs index ac91c6c..a3bd3d8 100644 --- a/src/index.rs +++ b/src/index.rs @@ -65,9 +65,10 @@ impl TensorIndex { if self.indices()[0] >= self.shape().get(0) { return false; } + let shape = self.shape().as_array().clone(); let mut carry = 1; for (i, &dim_size) in - self.indices.iter_mut().zip(&self.shape.as_array()).rev() + self.indices.iter_mut().zip(&shape).rev() { if carry == 1 { *i += 1; @@ -158,8 +159,9 @@ impl TensorIndex { } let mut borrow = true; + let shape = self.shape().as_array().clone(); for (i, &dim_size) in - self.indices.iter_mut().zip(&self.shape.as_array()).rev() + self.indices_mut().iter_mut().zip(&shape).rev() { if borrow { if *i == 0 { @@ -271,7 +273,7 @@ impl TensorIndex { pub fn flat(&self) -> usize { self.indices() .iter() - .zip(&self.shape().as_array()) + .zip(&self.shape().as_array().clone()) .rev() .fold((0, 1), |(flat_index, product), (&idx, &dim_size)| { (flat_index + idx * product, product * dim_size) diff --git a/src/shape.rs b/src/shape.rs index 06c65c2..841ba37 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -24,8 +24,8 @@ impl TensorShape { new_shape } - pub const fn as_array(&self) -> [usize; R] { - self.0 + pub const fn as_array(&self) -> &[usize; R] { + &self.0 } pub const fn rank(&self) -> usize { diff --git a/src/tensor.rs b/src/tensor.rs index b40784d..3069822 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -4,7 +4,10 @@ use getset::{Getters, MutGetters}; use serde::{Deserialize, Serialize}; use std::{ fmt::{Display, Formatter, Result as FmtResult}, - ops::{Index, IndexMut}, + ops::{ + Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Rem, + RemAssign, Sub, SubAssign, + }, }; /// A tensor is a multi-dimensional array of values. The rank of a tensor is the @@ -413,6 +416,158 @@ impl Tensor { } } +// ---- Operations ------------------------------------------------------------ + +impl Add for Tensor { + type Output = Self; + + fn add(self, other: Self) -> Self::Output { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_add(&self, &other, &mut result).unwrap(); + + result + } +} + +impl Sub for Tensor { + type Output = Self; + + fn sub(self, other: Self) -> Self::Output { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_subtract(&self, &other, &mut result).unwrap(); + + result + } +} + +impl Mul for Tensor { + type Output = Self; + + fn mul(self, other: Self) -> Self::Output { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_multiply(&self, &other, &mut result).unwrap(); + + result + } +} + +impl Div for Tensor { + type Output = Self; + + fn div(self, other: Self) -> Self::Output { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_divide(&self, &other, &mut result).unwrap(); + + result + } +} + +impl Rem for Tensor { + type Output = Self; + + fn rem(self, other: Self) -> Self::Output { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_modulo(&self, &other, &mut result).unwrap(); + + result + } +} + +impl AddAssign for Tensor { + fn add_assign(&mut self, other: Self) { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_add(&self, &other, &mut result).unwrap(); + + *self = result; + } +} + +impl SubAssign for Tensor { + fn sub_assign(&mut self, other: Self) { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_subtract(&self, &other, &mut result).unwrap(); + + *self = result; + } +} + +impl MulAssign for Tensor { + fn mul_assign(&mut self, other: Self) { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_multiply(&self, &other, &mut result).unwrap(); + + *self = result; + } +} + +impl DivAssign for Tensor { + fn div_assign(&mut self, other: Self) { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_divide(&self, &other, &mut result).unwrap(); + + *self = result; + } +} + +impl RemAssign for Tensor { + fn rem_assign(&mut self, other: Self) { + if self.shape() != other.shape() { + todo!("Check for broadcasting"); + } + + let mut result = Self::new(self.shape().clone()); + + Self::ew_modulo(&self, &other, &mut result).unwrap(); + + *self = result; + } +} + // ---- Indexing -------------------------------------------------------------- impl Index> for Tensor { @@ -423,9 +578,7 @@ impl Index> for Tensor { } } -impl IndexMut> - for Tensor -{ +impl IndexMut> for Tensor { fn index_mut(&mut self, index: TensorIndex) -> &mut Self::Output { &mut self.buffer[index.flat()] } @@ -479,7 +632,7 @@ where T: Display + Clone, { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { - Tensor::::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1) + Tensor::::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1) } }