Update types because change in shape as_array

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-04 13:55:51 +02:00
parent bd40c9c1f2
commit e3eef62700
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 23 additions and 32 deletions

View File

@ -1,8 +1,8 @@
use super::*; use super::*;
use getset::{Getters, MutGetters}; use getset::{Getters, MutGetters};
use std::{ use std::{
ops::{Index, IndexMut, Add, Sub}, cmp::Ordering,
cmp::Ordering, ops::{Add, Index, IndexMut, Sub},
}; };
#[derive(Clone, Copy, Debug, Getters, MutGetters)] #[derive(Clone, Copy, Debug, Getters, MutGetters)]
@ -16,7 +16,6 @@ pub struct TensorIndex<const R: usize> {
// ---- Construction and Initialization --------------------------------------- // ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorIndex<R> { impl<const R: usize> TensorIndex<R> {
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self { pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) { if !shape.check_indices(indices) {
panic!("indices out of bounds"); panic!("indices out of bounds");
@ -65,11 +64,9 @@ impl<const R: usize> TensorIndex<R> {
if self.indices()[0] >= self.shape().get(0) { if self.indices()[0] >= self.shape().get(0) {
return false; return false;
} }
let shape = self.shape().as_array().clone(); let shape = self.shape().as_array().clone();
let mut carry = 1; let mut carry = 1;
for (i, &dim_size) in for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
self.indices.iter_mut().zip(&shape).rev()
{
if carry == 1 { if carry == 1 {
*i += 1; *i += 1;
if *i >= dim_size { if *i >= dim_size {
@ -159,10 +156,8 @@ impl<const R: usize> TensorIndex<R> {
} }
let mut borrow = true; let mut borrow = true;
let shape = self.shape().as_array().clone(); let shape = self.shape().as_array().clone();
for (i, &dim_size) in for (i, &dim_size) in self.indices_mut().iter_mut().zip(&shape).rev() {
self.indices_mut().iter_mut().zip(&shape).rev()
{
if borrow { if borrow {
if *i == 0 { if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of *i = dim_size - 1; // Wrap around to the maximum index of
@ -346,18 +341,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
} }
} }
impl<const R: usize> From<(TensorShape<R>, [usize; R])> impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
for TensorIndex<R>
{
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self { fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices)); assert!(shape.check_indices(indices));
Self::new(shape, indices) Self::new(shape, indices)
} }
} }
impl<const R: usize> From<(TensorShape<R>, usize)> impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
for TensorIndex<R>
{
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self { fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices; let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices) Self::new(shape, indices)
@ -370,9 +361,7 @@ impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
} }
} }
impl<T: Value, const R: usize> From<Tensor<T, R>> impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
for TensorIndex<R>
{
fn from(tensor: Tensor<T, R>) -> Self { fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape().clone()) Self::zero(tensor.shape().clone())
} }

View File

@ -9,7 +9,7 @@ pub mod shape;
pub mod tensor; pub mod tensor;
pub mod value; pub mod value;
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*}; pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*};
#[macro_export] #[macro_export]
macro_rules! tensor { macro_rules! tensor {
@ -27,9 +27,9 @@ macro_rules! shape {
#[macro_export] #[macro_export]
macro_rules! index { macro_rules! index {
($tensor:expr) => { ($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone()) TensorIndex::zero($tensor.shape().clone())
}; };
($tensor:expr, $indices:expr) => { ($tensor:expr, $indices:expr) => {
TensorIndex::from(($tensor.shape().clone(), $indices)) TensorIndex::from(($tensor.shape().clone(), $indices))
}; };

View File

@ -1,8 +1,8 @@
use super::*; use super::*;
use core::result::Result as SerdeResult;
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer}; use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt::{Result as FmtResult, Formatter}; use std::fmt::{Formatter, Result as FmtResult};
use core::result::Result as SerdeResult;
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct TensorShape<const R: usize>([usize; R]); pub struct TensorShape<const R: usize>([usize; R]);

View File

@ -632,7 +632,12 @@ where
T: Display + Clone, T: Display + Clone,
{ {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1) Tensor::<T, R>::fmt_helper(
&self.buffer,
&self.shape().as_array().clone(),
f,
1,
)
} }
} }

View File

@ -1,9 +1,6 @@
use num::{Num, One, Zero}; use num::{Num, One, Zero};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::{ use std::{fmt::Display, iter::Sum};
fmt::Display,
iter::Sum,
};
/// A trait for types that can be used as values in a tensor. /// A trait for types that can be used as values in a tensor.
pub trait Value: pub trait Value:
@ -22,4 +19,4 @@ impl<T> Value for T where
+ Deserialize<'static> + Deserialize<'static>
+ Sum + Sum
{ {
} }