Update types because change in shape as_array

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-04 13:55:51 +02:00
parent bd40c9c1f2
commit e3eef62700
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 23 additions and 32 deletions

View File

@ -1,8 +1,8 @@
use super::*;
use getset::{Getters, MutGetters};
use std::{
ops::{Index, IndexMut, Add, Sub},
cmp::Ordering,
cmp::Ordering,
ops::{Add, Index, IndexMut, Sub},
};
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
@ -16,7 +16,6 @@ pub struct TensorIndex<const R: usize> {
// ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorIndex<R> {
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
@ -65,11 +64,9 @@ impl<const R: usize> TensorIndex<R> {
if self.indices()[0] >= self.shape().get(0) {
return false;
}
let shape = self.shape().as_array().clone();
let shape = self.shape().as_array().clone();
let mut carry = 1;
for (i, &dim_size) in
self.indices.iter_mut().zip(&shape).rev()
{
for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
if carry == 1 {
*i += 1;
if *i >= dim_size {
@ -159,10 +156,8 @@ impl<const R: usize> TensorIndex<R> {
}
let mut borrow = true;
let shape = self.shape().as_array().clone();
for (i, &dim_size) in
self.indices_mut().iter_mut().zip(&shape).rev()
{
let shape = self.shape().as_array().clone();
for (i, &dim_size) in self.indices_mut().iter_mut().zip(&shape).rev() {
if borrow {
if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of
@ -346,18 +341,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
}
}
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
for TensorIndex<R>
{
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<const R: usize> From<(TensorShape<R>, usize)>
for TensorIndex<R>
{
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
@ -370,9 +361,7 @@ impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
}
}
impl<T: Value, const R: usize> From<Tensor<T, R>>
for TensorIndex<R>
{
impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape().clone())
}

View File

@ -9,7 +9,7 @@ pub mod shape;
pub mod tensor;
pub mod value;
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*};
#[macro_export]
macro_rules! tensor {
@ -27,9 +27,9 @@ macro_rules! shape {
#[macro_export]
macro_rules! index {
($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone())
};
($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone())
};
($tensor:expr, $indices:expr) => {
TensorIndex::from(($tensor.shape().clone(), $indices))
};

View File

@ -1,8 +1,8 @@
use super::*;
use core::result::Result as SerdeResult;
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt::{Result as FmtResult, Formatter};
use core::result::Result as SerdeResult;
use std::fmt::{Formatter, Result as FmtResult};
#[derive(Clone, Copy, Debug)]
pub struct TensorShape<const R: usize>([usize; R]);

View File

@ -632,7 +632,12 @@ where
T: Display + Clone,
{
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
Tensor::<T, R>::fmt_helper(
&self.buffer,
&self.shape().as_array().clone(),
f,
1,
)
}
}

View File

@ -1,9 +1,6 @@
use num::{Num, One, Zero};
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
iter::Sum,
};
use std::{fmt::Display, iter::Sum};
/// A trait for types that can be used as values in a tensor.
pub trait Value:
@ -22,4 +19,4 @@ impl<T> Value for T where
+ Deserialize<'static>
+ Sum
{
}
}