Update types because change in shape as_array
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
bd40c9c1f2
commit
e3eef62700
29
src/index.rs
29
src/index.rs
@ -1,8 +1,8 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::{
|
||||
ops::{Index, IndexMut, Add, Sub},
|
||||
cmp::Ordering,
|
||||
cmp::Ordering,
|
||||
ops::{Add, Index, IndexMut, Sub},
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
@ -16,7 +16,6 @@ pub struct TensorIndex<const R: usize> {
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
|
||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
@ -65,11 +64,9 @@ impl<const R: usize> TensorIndex<R> {
|
||||
if self.indices()[0] >= self.shape().get(0) {
|
||||
return false;
|
||||
}
|
||||
let shape = self.shape().as_array().clone();
|
||||
let shape = self.shape().as_array().clone();
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&shape).rev()
|
||||
{
|
||||
for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
if *i >= dim_size {
|
||||
@ -159,10 +156,8 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
|
||||
let mut borrow = true;
|
||||
let shape = self.shape().as_array().clone();
|
||||
for (i, &dim_size) in
|
||||
self.indices_mut().iter_mut().zip(&shape).rev()
|
||||
{
|
||||
let shape = self.shape().as_array().clone();
|
||||
for (i, &dim_size) in self.indices_mut().iter_mut().zip(&shape).rev() {
|
||||
if borrow {
|
||||
if *i == 0 {
|
||||
*i = dim_size - 1; // Wrap around to the maximum index of
|
||||
@ -346,18 +341,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
|
||||
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(TensorShape<R>, usize)>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
|
||||
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
@ -370,9 +361,7 @@ impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> From<Tensor<T, R>>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape().clone())
|
||||
}
|
||||
|
@ -9,7 +9,7 @@ pub mod shape;
|
||||
pub mod tensor;
|
||||
pub mod value;
|
||||
|
||||
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
|
||||
pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
@ -27,9 +27,9 @@ macro_rules! shape {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! index {
|
||||
($tensor:expr) => {
|
||||
TensorIndex::zero($tensor.shape().clone())
|
||||
};
|
||||
($tensor:expr) => {
|
||||
TensorIndex::zero($tensor.shape().clone())
|
||||
};
|
||||
($tensor:expr, $indices:expr) => {
|
||||
TensorIndex::from(($tensor.shape().clone(), $indices))
|
||||
};
|
||||
|
@ -1,8 +1,8 @@
|
||||
use super::*;
|
||||
use core::result::Result as SerdeResult;
|
||||
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||
use std::fmt::{Result as FmtResult, Formatter};
|
||||
use core::result::Result as SerdeResult;
|
||||
use std::fmt::{Formatter, Result as FmtResult};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorShape<const R: usize>([usize; R]);
|
||||
|
@ -632,7 +632,12 @@ where
|
||||
T: Display + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
|
||||
Tensor::<T, R>::fmt_helper(
|
||||
&self.buffer,
|
||||
&self.shape().as_array().clone(),
|
||||
f,
|
||||
1,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1,9 +1,6 @@
|
||||
use num::{Num, One, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
iter::Sum,
|
||||
};
|
||||
use std::{fmt::Display, iter::Sum};
|
||||
|
||||
/// A trait for types that can be used as values in a tensor.
|
||||
pub trait Value:
|
||||
@ -22,4 +19,4 @@ impl<T> Value for T where
|
||||
+ Deserialize<'static>
|
||||
+ Sum
|
||||
{
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user