have shape as_array return a reference
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
1f9070378a
commit
08d72af38c
@ -65,9 +65,10 @@ impl<const R: usize> TensorIndex<R> {
|
|||||||
if self.indices()[0] >= self.shape().get(0) {
|
if self.indices()[0] >= self.shape().get(0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
let shape = self.shape().as_array().clone();
|
||||||
let mut carry = 1;
|
let mut carry = 1;
|
||||||
for (i, &dim_size) in
|
for (i, &dim_size) in
|
||||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
self.indices.iter_mut().zip(&shape).rev()
|
||||||
{
|
{
|
||||||
if carry == 1 {
|
if carry == 1 {
|
||||||
*i += 1;
|
*i += 1;
|
||||||
@ -158,8 +159,9 @@ impl<const R: usize> TensorIndex<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut borrow = true;
|
let mut borrow = true;
|
||||||
|
let shape = self.shape().as_array().clone();
|
||||||
for (i, &dim_size) in
|
for (i, &dim_size) in
|
||||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
self.indices_mut().iter_mut().zip(&shape).rev()
|
||||||
{
|
{
|
||||||
if borrow {
|
if borrow {
|
||||||
if *i == 0 {
|
if *i == 0 {
|
||||||
@ -271,7 +273,7 @@ impl<const R: usize> TensorIndex<R> {
|
|||||||
pub fn flat(&self) -> usize {
|
pub fn flat(&self) -> usize {
|
||||||
self.indices()
|
self.indices()
|
||||||
.iter()
|
.iter()
|
||||||
.zip(&self.shape().as_array())
|
.zip(&self.shape().as_array().clone())
|
||||||
.rev()
|
.rev()
|
||||||
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
|
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
|
||||||
(flat_index + idx * product, product * dim_size)
|
(flat_index + idx * product, product * dim_size)
|
||||||
|
@ -24,8 +24,8 @@ impl<const R: usize> TensorShape<R> {
|
|||||||
new_shape
|
new_shape
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn as_array(&self) -> [usize; R] {
|
pub const fn as_array(&self) -> &[usize; R] {
|
||||||
self.0
|
&self.0
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn rank(&self) -> usize {
|
pub const fn rank(&self) -> usize {
|
||||||
|
163
src/tensor.rs
163
src/tensor.rs
@ -4,7 +4,10 @@ use getset::{Getters, MutGetters};
|
|||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{
|
||||||
fmt::{Display, Formatter, Result as FmtResult},
|
fmt::{Display, Formatter, Result as FmtResult},
|
||||||
ops::{Index, IndexMut},
|
ops::{
|
||||||
|
Add, AddAssign, Div, DivAssign, Index, IndexMut, Mul, MulAssign, Rem,
|
||||||
|
RemAssign, Sub, SubAssign,
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
||||||
@ -413,6 +416,158 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---- Operations ------------------------------------------------------------
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Add for Tensor<T, R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn add(self, other: Self) -> Self::Output {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_add(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Sub for Tensor<T, R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn sub(self, other: Self) -> Self::Output {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_subtract(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Mul for Tensor<T, R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn mul(self, other: Self) -> Self::Output {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_multiply(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Div for Tensor<T, R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn div(self, other: Self) -> Self::Output {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_divide(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Rem for Tensor<T, R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn rem(self, other: Self) -> Self::Output {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_modulo(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> AddAssign for Tensor<T, R> {
|
||||||
|
fn add_assign(&mut self, other: Self) {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_add(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
*self = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> SubAssign for Tensor<T, R> {
|
||||||
|
fn sub_assign(&mut self, other: Self) {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_subtract(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
*self = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> MulAssign for Tensor<T, R> {
|
||||||
|
fn mul_assign(&mut self, other: Self) {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_multiply(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
*self = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> DivAssign for Tensor<T, R> {
|
||||||
|
fn div_assign(&mut self, other: Self) {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_divide(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
*self = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> RemAssign for Tensor<T, R> {
|
||||||
|
fn rem_assign(&mut self, other: Self) {
|
||||||
|
if self.shape() != other.shape() {
|
||||||
|
todo!("Check for broadcasting");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = Self::new(self.shape().clone());
|
||||||
|
|
||||||
|
Self::ew_modulo(&self, &other, &mut result).unwrap();
|
||||||
|
|
||||||
|
*self = result;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ---- Indexing --------------------------------------------------------------
|
// ---- Indexing --------------------------------------------------------------
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
||||||
@ -423,9 +578,7 @@ impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
|
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>> for Tensor<T, R> {
|
||||||
for Tensor<T, R>
|
|
||||||
{
|
|
||||||
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
||||||
&mut self.buffer[index.flat()]
|
&mut self.buffer[index.flat()]
|
||||||
}
|
}
|
||||||
@ -479,7 +632,7 @@ where
|
|||||||
T: Display + Clone,
|
T: Display + Clone,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
|
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user