Make index own Shape and refactoring

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-03 23:10:59 +02:00
parent 7935cf8a1d
commit 22551ffec4
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 76 additions and 70 deletions

View File

@ -25,7 +25,7 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> { pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
assert!(level < self.len(), "Level out of bounds"); assert!(level < self.len(), "Level out of bounds");
let mut index = TensorIndex::new(self.shape(), [0; R]); let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
index.set_axis(self.dim, level); index.set_axis(self.dim, level);
TensorAxisIterator::new(self) TensorAxisIterator::new(self)
.set_start(level) .set_start(level)
@ -38,7 +38,7 @@ pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
#[getset(get = "pub")] #[getset(get = "pub")]
axis: &'a TensorAxis<'a, T, R>, axis: &'a TensorAxis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")] #[getset(get = "pub", get_mut = "pub")]
index: TensorIndex<'a, R>, index: TensorIndex<R>,
#[getset(get = "pub")] #[getset(get = "pub")]
end: Option<usize>, end: Option<usize>,
} }
@ -47,14 +47,14 @@ impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self { pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
Self { Self {
axis, axis,
index: TensorIndex::new(axis.shape(), [0; R]), index: TensorIndex::new(axis.shape().clone(), [0; R]),
end: None, end: None,
} }
} }
pub fn set_start(self, start: usize) -> Self { pub fn set_start(self, start: usize) -> Self {
assert!(start < self.axis().len(), "Start out of bounds"); assert!(start < self.axis().len(), "Start out of bounds");
let mut index = TensorIndex::new(self.axis().shape(), [0; R]); let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
index.set_axis(self.axis.dim, start); index.set_axis(self.axis.dim, start);
Self { Self {
axis: self.axis(), axis: self.axis(),

View File

@ -6,22 +6,32 @@ use std::{
}; };
#[derive(Clone, Copy, Debug, Getters, MutGetters)] #[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct TensorIndex<'a, const R: usize> { pub struct TensorIndex<const R: usize> {
#[getset(get = "pub", get_mut = "pub")] #[getset(get = "pub", get_mut = "pub")]
indices: [usize; R], indices: [usize; R],
#[getset(get = "pub")] #[getset(get = "pub")]
shape: &'a TensorShape<R>, shape: TensorShape<R>,
} }
impl<'a, const R: usize> TensorIndex<'a, R> { // ---- Construction and Initialization ---------------------------------------
pub const fn zero(shape: &'a TensorShape<R>) -> Self {
impl<const R: usize> TensorIndex<R> {
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
pub const fn zero(shape: TensorShape<R>) -> Self {
Self { Self {
indices: [0; R], indices: [0; R],
shape, shape,
} }
} }
pub fn last(shape: &'a TensorShape<R>) -> Self { pub fn last(shape: TensorShape<R>) -> Self {
let max_indices = let max_indices =
shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
Self { Self {
@ -29,14 +39,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
shape, shape,
} }
} }
}
pub fn new(shape: &'a TensorShape<R>, indices: [usize; R]) -> Self { impl<const R: usize> TensorIndex<R> {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
pub fn is_zero(&self) -> bool { pub fn is_zero(&self) -> bool {
self.indices.iter().all(|&i| i == 0) self.indices.iter().all(|&i| i == 0)
} }
@ -298,34 +303,34 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
pub fn iter_transposed( pub fn iter_transposed(
&self, &self,
order: [usize; R], order: [usize; R],
) -> TensorIndexTransposedIterator<'a, R> { ) -> TensorIndexTransposedIterator<R> {
TensorIndexTransposedIterator::new(self.shape(), order) TensorIndexTransposedIterator::new(self.shape().clone(), order)
} }
} }
// --- blanket impls --- // --- blanket impls ---
impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> { impl<const R: usize> PartialEq for TensorIndex<R> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.flat() == other.flat() self.flat() == other.flat()
} }
} }
impl<'a, const R: usize> Eq for TensorIndex<'a, R> {} impl<const R: usize> Eq for TensorIndex<R> {}
impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> { impl<const R: usize> PartialOrd for TensorIndex<R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.flat().partial_cmp(&other.flat()) self.flat().partial_cmp(&other.flat())
} }
} }
impl<'a, const R: usize> Ord for TensorIndex<'a, R> { impl<const R: usize> Ord for TensorIndex<R> {
fn cmp(&self, other: &Self) -> Ordering { fn cmp(&self, other: &Self) -> Ordering {
self.flat().cmp(&other.flat()) self.flat().cmp(&other.flat())
} }
} }
impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> { impl<const R: usize> Index<usize> for TensorIndex<R> {
type Output = usize; type Output = usize;
fn index(&self, index: usize) -> &Self::Output { fn index(&self, index: usize) -> &Self::Output {
@ -333,45 +338,45 @@ impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
} }
} }
impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> { impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output { fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.indices[index] &mut self.indices[index]
} }
} }
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> impl<const R: usize> From<(TensorShape<R>, [usize; R])>
for TensorIndex<'a, R> for TensorIndex<R>
{ {
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self { fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices)); assert!(shape.check_indices(indices));
Self::new(shape, indices) Self::new(shape, indices)
} }
} }
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> impl<const R: usize> From<(TensorShape<R>, usize)>
for TensorIndex<'a, R> for TensorIndex<R>
{ {
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self { fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices; let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices) Self::new(shape, indices)
} }
} }
impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> { impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
fn from(shape: &'a TensorShape<R>) -> Self { fn from(shape: TensorShape<R>) -> Self {
Self::zero(shape) Self::zero(shape)
} }
} }
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> impl<T: Value, const R: usize> From<Tensor<T, R>>
for TensorIndex<'a, R> for TensorIndex<R>
{ {
fn from(tensor: &'a Tensor<T, R>) -> Self { fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape()) Self::zero(tensor.shape().clone())
} }
} }
impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> { impl<const R: usize> std::fmt::Display for TensorIndex<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?; write!(f, "[")?;
for (i, (&idx, &dim_size)) in self for (i, (&idx, &dim_size)) in self
@ -391,7 +396,7 @@ impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
// ---- Arithmetic Operations ---- // ---- Arithmetic Operations ----
impl<'a, const R: usize> Add for TensorIndex<'a, R> { impl<const R: usize> Add for TensorIndex<R> {
type Output = Self; type Output = Self;
fn add(self, rhs: Self) -> Self::Output { fn add(self, rhs: Self) -> Self::Output {
@ -409,7 +414,7 @@ impl<'a, const R: usize> Add for TensorIndex<'a, R> {
} }
} }
impl<'a, const R: usize> Sub for TensorIndex<'a, R> { impl<const R: usize> Sub for TensorIndex<R> {
type Output = Self; type Output = Self;
fn sub(self, rhs: Self) -> Self::Output { fn sub(self, rhs: Self) -> Self::Output {
@ -429,13 +434,13 @@ impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
// ---- Iterator ---- // ---- Iterator ----
pub struct TensorIndexIterator<'a, const R: usize> { pub struct TensorIndexIterator<const R: usize> {
current: TensorIndex<'a, R>, current: TensorIndex<R>,
end: bool, end: bool,
} }
impl<'a, const R: usize> TensorIndexIterator<'a, R> { impl<const R: usize> TensorIndexIterator<R> {
pub fn new(shape: &'a TensorShape<R>) -> Self { pub fn new(shape: TensorShape<R>) -> Self {
Self { Self {
current: TensorIndex::zero(shape), current: TensorIndex::zero(shape),
end: false, end: false,
@ -443,8 +448,8 @@ impl<'a, const R: usize> TensorIndexIterator<'a, R> {
} }
} }
impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> { impl<const R: usize> Iterator for TensorIndexIterator<R> {
type Item = TensorIndex<'a, R>; type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.end { if self.end {
@ -457,9 +462,9 @@ impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
} }
} }
impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> { impl<const R: usize> IntoIterator for TensorIndex<R> {
type Item = TensorIndex<'a, R>; type Item = TensorIndex<R>;
type IntoIter = TensorIndexIterator<'a, R>; type IntoIter = TensorIndexIterator<R>;
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
TensorIndexIterator { TensorIndexIterator {
@ -469,14 +474,14 @@ impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
} }
} }
pub struct TensorIndexTransposedIterator<'a, const R: usize> { pub struct TensorIndexTransposedIterator<const R: usize> {
current: TensorIndex<'a, R>, current: TensorIndex<R>,
order: [usize; R], order: [usize; R],
end: bool, end: bool,
} }
impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> { impl<const R: usize> TensorIndexTransposedIterator<R> {
pub fn new(shape: &'a TensorShape<R>, order: [usize; R]) -> Self { pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
Self { Self {
current: TensorIndex::zero(shape), current: TensorIndex::zero(shape),
end: false, end: false,
@ -485,8 +490,8 @@ impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
} }
} }
impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> { impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
type Item = TensorIndex<'a, R>; type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> { fn next(&mut self) -> Option<Self::Item> {
if self.end { if self.end {

View File

@ -27,6 +27,9 @@ macro_rules! shape {
#[macro_export] #[macro_export]
macro_rules! index { macro_rules! index {
($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone())
};
($tensor:expr, $indices:expr) => { ($tensor:expr, $indices:expr) => {
TensorIndex::from(($tensor.shape().clone(), $indices)) TensorIndex::from(($tensor.shape().clone(), $indices))
}; };

View File

@ -82,18 +82,18 @@ impl<const R: usize> TensorShape<R> {
} }
indices.reverse(); // Reverse the indices to match the original dimension order indices.reverse(); // Reverse the indices to match the original dimension order
TensorIndex::new(self, indices) TensorIndex::new(self.clone(), indices)
} }
pub const fn index_zero(&self) -> TensorIndex<R> { pub fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self) TensorIndex::zero(self.clone())
} }
pub fn index_max(&self) -> TensorIndex<R> { pub fn index_max(&self) -> TensorIndex<R> {
let max_indices = let max_indices =
self.0 self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
TensorIndex::new(self, max_indices) TensorIndex::new(self.clone(), max_indices)
} }
pub fn remove_dims<const NAX: usize>( pub fn remove_dims<const NAX: usize>(

View File

@ -89,7 +89,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
/// ///
/// let t = tensor!([[1, 2], [3, 4]]); /// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into(); /// let i = (t.shape().clone(), [1, 1]).into();
/// assert_eq!(t.get(i), Some(&4)); /// assert_eq!(t.get(i), Some(&4));
/// ``` /// ```
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> { pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
@ -102,7 +102,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// use manifold::{tensor, Tensor}; /// use manifold::{tensor, Tensor};
/// ///
/// let t = tensor!([[1, 2], [3, 4]]); /// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into(); /// let i = (t.shape().clone(), [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked(i), &4); } /// unsafe { assert_eq!(t.get_unchecked(i), &4); }
/// ``` /// ```
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T { pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
@ -112,12 +112,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// Get a mutable reference to a value at the given index. /// Get a mutable reference to a value at the given index.
/// ///
/// ``` /// ```
/// use manifold::{tensor, Tensor}; /// use manifold::*;
/// ///
/// let mut t = tensor!([[1, 2], [3, 4]]); /// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone(); /// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4));
/// let i = (&s, [1, 1]).into();
/// assert_eq!(t.get_mut(i), Some(&mut 4));
/// ``` /// ```
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> { pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
self.buffer_mut().get_mut(index.flat()) self.buffer_mut().get_mut(index.flat())
@ -131,7 +129,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// ///
/// let mut t = tensor!([[1, 2], [3, 4]]); /// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone(); /// let s = t.shape().clone();
/// let i = (&s, [1, 1]).into(); /// let i = (s, [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); } /// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
/// ``` /// ```
pub unsafe fn get_unchecked_mut( pub unsafe fn get_unchecked_mut(
@ -403,7 +401,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// assert_eq!(t, tensor!([[1, 3], [2, 4]])); /// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ``` /// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> { pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape()) let buffer = TensorIndex::from(self.shape().clone())
.iter_transposed(order) .iter_transposed(order)
.map(|index| self.get(index).unwrap().clone()) .map(|index| self.get(index).unwrap().clone())
.collect(); .collect();
@ -417,7 +415,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Indexing -------------------------------------------------------------- // ---- Indexing --------------------------------------------------------------
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> { impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
type Output = T; type Output = T;
fn index(&self, index: TensorIndex<R>) -> &Self::Output { fn index(&self, index: TensorIndex<R>) -> &Self::Output {
@ -425,7 +423,7 @@ impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
} }
} }
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>> impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
for Tensor<T, R> for Tensor<T, R>
{ {
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output { fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
@ -502,11 +500,11 @@ impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
pub struct TensorIterator<'a, T: Value, const R: usize> { pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>, tensor: &'a Tensor<T, R>,
index: TensorIndex<'a, R>, index: TensorIndex<R>,
} }
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub const fn new(tensor: &'a Tensor<T, R>) -> Self { pub fn new(tensor: &'a Tensor<T, R>) -> Self {
Self { Self {
tensor, tensor,
index: tensor.shape.index_zero(), index: tensor.shape.index_zero(),