have shape as_array return a reference

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-04 13:34:48 +02:00
parent 1f9070378a
commit 08d72af38c
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
3 changed files with 165 additions and 10 deletions

View File

@ -65,9 +65,10 @@ impl<const R: usize> TensorIndex<R> {
if self.indices()[0] >= self.shape().get(0) {
return false;
}
let shape = self.shape().as_array().clone();
let mut carry = 1;
for (i, &dim_size) in
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
self.indices.iter_mut().zip(&shape).rev()
{
if carry == 1 {
*i += 1;
@ -158,8 +159,9 @@ impl<const R: usize> TensorIndex<R> {
}
let mut borrow = true;
let shape = self.shape().as_array().clone();
for (i, &dim_size) in
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
self.indices_mut().iter_mut().zip(&shape).rev()
{
if borrow {
if *i == 0 {
@ -271,7 +273,7 @@ impl<const R: usize> TensorIndex<R> {
pub fn flat(&self) -> usize {
self.indices()
.iter()
.zip(&self.shape().as_array())
.zip(&self.shape().as_array().clone())
.rev()
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
(flat_index + idx * product, product * dim_size)

View File

@ -24,8 +24,8 @@ impl<const R: usize> TensorShape<R> {
new_shape
}
pub const fn as_array(&self) -> [usize; R] {
self.0
pub const fn as_array(&self) -> &[usize; R] {
&self.0
}
pub const fn rank(&self) -> usize {

View File

@ -4,7 +4,10 @@ use getset::{Getters, MutGetters};
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter, Result as FmtResult},
ops::{Index, IndexMut},
ops::{
Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Rem,
RemAssign, Sub, SubAssign,
},
};
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
@ -413,6 +416,158 @@ impl<T: Value, const R: usize> Tensor<T, R> {
}
}
// ---- Operations ------------------------------------------------------------
impl<T: Value, const R: usize> Add for Tensor<T, R> {
type Output = Self;
fn add(self, other: Self) -> Self::Output {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_add(&self, &other, &mut result).unwrap();
result
}
}
impl<T: Value, const R: usize> Sub for Tensor<T, R> {
type Output = Self;
fn sub(self, other: Self) -> Self::Output {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_subtract(&self, &other, &mut result).unwrap();
result
}
}
impl<T: Value, const R: usize> Mul for Tensor<T, R> {
type Output = Self;
fn mul(self, other: Self) -> Self::Output {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_multiply(&self, &other, &mut result).unwrap();
result
}
}
impl<T: Value, const R: usize> Div for Tensor<T, R> {
type Output = Self;
fn div(self, other: Self) -> Self::Output {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_divide(&self, &other, &mut result).unwrap();
result
}
}
impl<T: Value, const R: usize> Rem for Tensor<T, R> {
type Output = Self;
fn rem(self, other: Self) -> Self::Output {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_modulo(&self, &other, &mut result).unwrap();
result
}
}
impl<T: Value, const R: usize> AddAssign for Tensor<T, R> {
fn add_assign(&mut self, other: Self) {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_add(&self, &other, &mut result).unwrap();
*self = result;
}
}
impl<T: Value, const R: usize> SubAssign for Tensor<T, R> {
fn sub_assign(&mut self, other: Self) {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_subtract(&self, &other, &mut result).unwrap();
*self = result;
}
}
impl<T: Value, const R: usize> MulAssign for Tensor<T, R> {
fn mul_assign(&mut self, other: Self) {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_multiply(&self, &other, &mut result).unwrap();
*self = result;
}
}
impl<T: Value, const R: usize> DivAssign for Tensor<T, R> {
fn div_assign(&mut self, other: Self) {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_divide(&self, &other, &mut result).unwrap();
*self = result;
}
}
impl<T: Value, const R: usize> RemAssign for Tensor<T, R> {
fn rem_assign(&mut self, other: Self) {
if self.shape() != other.shape() {
todo!("Check for broadcasting");
}
let mut result = Self::new(self.shape().clone());
Self::ew_modulo(&self, &other, &mut result).unwrap();
*self = result;
}
}
// ---- Indexing --------------------------------------------------------------
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
@ -423,9 +578,7 @@ impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
}
}
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
for Tensor<T, R>
{
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>> for Tensor<T, R> {
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
&mut self.buffer[index.flat()]
}
@ -479,7 +632,7 @@ where
T: Display + Clone,
{
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
}
}