have shape as_array return a reference
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
1f9070378a
commit
08d72af38c
@ -65,9 +65,10 @@ impl<const R: usize> TensorIndex<R> {
|
||||
if self.indices()[0] >= self.shape().get(0) {
|
||||
return false;
|
||||
}
|
||||
let shape = self.shape().as_array().clone();
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
self.indices.iter_mut().zip(&shape).rev()
|
||||
{
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
@ -158,8 +159,9 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
|
||||
let mut borrow = true;
|
||||
let shape = self.shape().as_array().clone();
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
self.indices_mut().iter_mut().zip(&shape).rev()
|
||||
{
|
||||
if borrow {
|
||||
if *i == 0 {
|
||||
@ -271,7 +273,7 @@ impl<const R: usize> TensorIndex<R> {
|
||||
pub fn flat(&self) -> usize {
|
||||
self.indices()
|
||||
.iter()
|
||||
.zip(&self.shape().as_array())
|
||||
.zip(&self.shape().as_array().clone())
|
||||
.rev()
|
||||
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
|
||||
(flat_index + idx * product, product * dim_size)
|
||||
|
@ -24,8 +24,8 @@ impl<const R: usize> TensorShape<R> {
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> [usize; R] {
|
||||
self.0
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub const fn rank(&self) -> usize {
|
||||
|
163
src/tensor.rs
163
src/tensor.rs
@ -4,7 +4,10 @@ use getset::{Getters, MutGetters};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{Display, Formatter, Result as FmtResult},
|
||||
ops::{Index, IndexMut},
|
||||
ops::{
|
||||
Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Rem,
|
||||
RemAssign, Sub, SubAssign,
|
||||
},
|
||||
};
|
||||
|
||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
||||
@ -413,6 +416,158 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Operations ------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Add for Tensor<T, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, other: Self) -> Self::Output {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_add(&self, &other, &mut result).unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Sub for Tensor<T, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, other: Self) -> Self::Output {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_subtract(&self, &other, &mut result).unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Mul for Tensor<T, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn mul(self, other: Self) -> Self::Output {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_multiply(&self, &other, &mut result).unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Div for Tensor<T, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn div(self, other: Self) -> Self::Output {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_divide(&self, &other, &mut result).unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Rem for Tensor<T, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn rem(self, other: Self) -> Self::Output {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_modulo(&self, &other, &mut result).unwrap();
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> AddAssign for Tensor<T, R> {
|
||||
fn add_assign(&mut self, other: Self) {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_add(&self, &other, &mut result).unwrap();
|
||||
|
||||
*self = result;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> SubAssign for Tensor<T, R> {
|
||||
fn sub_assign(&mut self, other: Self) {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_subtract(&self, &other, &mut result).unwrap();
|
||||
|
||||
*self = result;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> MulAssign for Tensor<T, R> {
|
||||
fn mul_assign(&mut self, other: Self) {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_multiply(&self, &other, &mut result).unwrap();
|
||||
|
||||
*self = result;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> DivAssign for Tensor<T, R> {
|
||||
fn div_assign(&mut self, other: Self) {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_divide(&self, &other, &mut result).unwrap();
|
||||
|
||||
*self = result;
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> RemAssign for Tensor<T, R> {
|
||||
fn rem_assign(&mut self, other: Self) {
|
||||
if self.shape() != other.shape() {
|
||||
todo!("Check for broadcasting");
|
||||
}
|
||||
|
||||
let mut result = Self::new(self.shape().clone());
|
||||
|
||||
Self::ew_modulo(&self, &other, &mut result).unwrap();
|
||||
|
||||
*self = result;
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Indexing --------------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
||||
@ -423,9 +578,7 @@ impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
|
||||
for Tensor<T, R>
|
||||
{
|
||||
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>> for Tensor<T, R> {
|
||||
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
||||
&mut self.buffer[index.flat()]
|
||||
}
|
||||
@ -479,7 +632,7 @@ where
|
||||
T: Display + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user