Update types because change in shape as_array
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
bd40c9c1f2
commit
e3eef62700
29
src/index.rs
29
src/index.rs
@ -1,8 +1,8 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
use getset::{Getters, MutGetters};
|
use getset::{Getters, MutGetters};
|
||||||
use std::{
|
use std::{
|
||||||
ops::{Index, IndexMut, Add, Sub},
|
cmp::Ordering,
|
||||||
cmp::Ordering,
|
ops::{Add, Index, IndexMut, Sub},
|
||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||||
@ -16,7 +16,6 @@ pub struct TensorIndex<const R: usize> {
|
|||||||
// ---- Construction and Initialization ---------------------------------------
|
// ---- Construction and Initialization ---------------------------------------
|
||||||
|
|
||||||
impl<const R: usize> TensorIndex<R> {
|
impl<const R: usize> TensorIndex<R> {
|
||||||
|
|
||||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||||
if !shape.check_indices(indices) {
|
if !shape.check_indices(indices) {
|
||||||
panic!("indices out of bounds");
|
panic!("indices out of bounds");
|
||||||
@ -65,11 +64,9 @@ impl<const R: usize> TensorIndex<R> {
|
|||||||
if self.indices()[0] >= self.shape().get(0) {
|
if self.indices()[0] >= self.shape().get(0) {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
let shape = self.shape().as_array().clone();
|
let shape = self.shape().as_array().clone();
|
||||||
let mut carry = 1;
|
let mut carry = 1;
|
||||||
for (i, &dim_size) in
|
for (i, &dim_size) in self.indices.iter_mut().zip(&shape).rev() {
|
||||||
self.indices.iter_mut().zip(&shape).rev()
|
|
||||||
{
|
|
||||||
if carry == 1 {
|
if carry == 1 {
|
||||||
*i += 1;
|
*i += 1;
|
||||||
if *i >= dim_size {
|
if *i >= dim_size {
|
||||||
@ -159,10 +156,8 @@ impl<const R: usize> TensorIndex<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut borrow = true;
|
let mut borrow = true;
|
||||||
let shape = self.shape().as_array().clone();
|
let shape = self.shape().as_array().clone();
|
||||||
for (i, &dim_size) in
|
for (i, &dim_size) in self.indices_mut().iter_mut().zip(&shape).rev() {
|
||||||
self.indices_mut().iter_mut().zip(&shape).rev()
|
|
||||||
{
|
|
||||||
if borrow {
|
if borrow {
|
||||||
if *i == 0 {
|
if *i == 0 {
|
||||||
*i = dim_size - 1; // Wrap around to the maximum index of
|
*i = dim_size - 1; // Wrap around to the maximum index of
|
||||||
@ -346,18 +341,14 @@ impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
|
impl<const R: usize> From<(TensorShape<R>, [usize; R])> for TensorIndex<R> {
|
||||||
for TensorIndex<R>
|
|
||||||
{
|
|
||||||
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||||
assert!(shape.check_indices(indices));
|
assert!(shape.check_indices(indices));
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const R: usize> From<(TensorShape<R>, usize)>
|
impl<const R: usize> From<(TensorShape<R>, usize)> for TensorIndex<R> {
|
||||||
for TensorIndex<R>
|
|
||||||
{
|
|
||||||
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||||
let indices = shape.index_from_flat(flat_index).indices;
|
let indices = shape.index_from_flat(flat_index).indices;
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
@ -370,9 +361,7 @@ impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Value, const R: usize> From<Tensor<T, R>>
|
impl<T: Value, const R: usize> From<Tensor<T, R>> for TensorIndex<R> {
|
||||||
for TensorIndex<R>
|
|
||||||
{
|
|
||||||
fn from(tensor: Tensor<T, R>) -> Self {
|
fn from(tensor: Tensor<T, R>) -> Self {
|
||||||
Self::zero(tensor.shape().clone())
|
Self::zero(tensor.shape().clone())
|
||||||
}
|
}
|
||||||
|
@ -9,7 +9,7 @@ pub mod shape;
|
|||||||
pub mod tensor;
|
pub mod tensor;
|
||||||
pub mod value;
|
pub mod value;
|
||||||
|
|
||||||
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
|
pub use {axis::*, error::*, index::*, shape::*, tensor::*, value::*};
|
||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! tensor {
|
macro_rules! tensor {
|
||||||
@ -27,9 +27,9 @@ macro_rules! shape {
|
|||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! index {
|
macro_rules! index {
|
||||||
($tensor:expr) => {
|
($tensor:expr) => {
|
||||||
TensorIndex::zero($tensor.shape().clone())
|
TensorIndex::zero($tensor.shape().clone())
|
||||||
};
|
};
|
||||||
($tensor:expr, $indices:expr) => {
|
($tensor:expr, $indices:expr) => {
|
||||||
TensorIndex::from(($tensor.shape().clone(), $indices))
|
TensorIndex::from(($tensor.shape().clone(), $indices))
|
||||||
};
|
};
|
||||||
|
@ -1,8 +1,8 @@
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
use core::result::Result as SerdeResult;
|
||||||
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||||
use std::fmt::{Result as FmtResult, Formatter};
|
use std::fmt::{Formatter, Result as FmtResult};
|
||||||
use core::result::Result as SerdeResult;
|
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
pub struct TensorShape<const R: usize>([usize; R]);
|
pub struct TensorShape<const R: usize>([usize; R]);
|
||||||
|
@ -632,7 +632,12 @@ where
|
|||||||
T: Display + Clone,
|
T: Display + Clone,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape().as_array().clone(), f, 1)
|
Tensor::<T, R>::fmt_helper(
|
||||||
|
&self.buffer,
|
||||||
|
&self.shape().as_array().clone(),
|
||||||
|
f,
|
||||||
|
1,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,9 +1,6 @@
|
|||||||
use num::{Num, One, Zero};
|
use num::{Num, One, Zero};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::{
|
use std::{fmt::Display, iter::Sum};
|
||||||
fmt::Display,
|
|
||||||
iter::Sum,
|
|
||||||
};
|
|
||||||
|
|
||||||
/// A trait for types that can be used as values in a tensor.
|
/// A trait for types that can be used as values in a tensor.
|
||||||
pub trait Value:
|
pub trait Value:
|
||||||
@ -22,4 +19,4 @@ impl<T> Value for T where
|
|||||||
+ Deserialize<'static>
|
+ Deserialize<'static>
|
||||||
+ Sum
|
+ Sum
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user