Make index own Shape and refactoring

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-03 23:10:59 +02:00
parent 7935cf8a1d
commit 22551ffec4
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 76 additions and 70 deletions

View File

@ -25,7 +25,7 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
assert!(level < self.len(), "Level out of bounds");
let mut index = TensorIndex::new(self.shape(), [0; R]);
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
index.set_axis(self.dim, level);
TensorAxisIterator::new(self)
.set_start(level)
@ -38,7 +38,7 @@ pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
axis: &'a TensorAxis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")]
index: TensorIndex<'a, R>,
index: TensorIndex<R>,
#[getset(get = "pub")]
end: Option<usize>,
}
@ -47,14 +47,14 @@ impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
Self {
axis,
index: TensorIndex::new(axis.shape(), [0; R]),
index: TensorIndex::new(axis.shape().clone(), [0; R]),
end: None,
}
}
pub fn set_start(self, start: usize) -> Self {
assert!(start < self.axis().len(), "Start out of bounds");
let mut index = TensorIndex::new(self.axis().shape(), [0; R]);
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
index.set_axis(self.axis.dim, start);
Self {
axis: self.axis(),

View File

@ -6,22 +6,32 @@ use std::{
};
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct TensorIndex<'a, const R: usize> {
pub struct TensorIndex<const R: usize> {
#[getset(get = "pub", get_mut = "pub")]
indices: [usize; R],
#[getset(get = "pub")]
shape: &'a TensorShape<R>,
shape: TensorShape<R>,
}
impl<'a, const R: usize> TensorIndex<'a, R> {
pub const fn zero(shape: &'a TensorShape<R>) -> Self {
// ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorIndex<R> {
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
pub const fn zero(shape: TensorShape<R>) -> Self {
Self {
indices: [0; R],
shape,
}
}
pub fn last(shape: &'a TensorShape<R>) -> Self {
pub fn last(shape: TensorShape<R>) -> Self {
let max_indices =
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
Self {
@ -29,14 +39,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
shape,
}
}
}
pub fn new(shape: &'a TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
impl<const R: usize> TensorIndex<R> {
pub fn is_zero(&self) -> bool {
self.indices.iter().all(|&i| i == 0)
}
@ -298,34 +303,34 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
pub fn iter_transposed(
&self,
order: [usize; R],
) -> TensorIndexTransposedIterator<'a, R> {
TensorIndexTransposedIterator::new(self.shape(), order)
) -> TensorIndexTransposedIterator<R> {
TensorIndexTransposedIterator::new(self.shape().clone(), order)
}
}
// --- blanket impls ---
impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> {
impl<const R: usize> PartialEq for TensorIndex<R> {
fn eq(&self, other: &Self) -> bool {
self.flat() == other.flat()
}
}
impl<'a, const R: usize> Eq for TensorIndex<'a, R> {}
impl<const R: usize> Eq for TensorIndex<R> {}
impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> {
impl<const R: usize> PartialOrd for TensorIndex<R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.flat().partial_cmp(&other.flat())
}
}
impl<'a, const R: usize> Ord for TensorIndex<'a, R> {
impl<const R: usize> Ord for TensorIndex<R> {
fn cmp(&self, other: &Self) -> Ordering {
self.flat().cmp(&other.flat())
}
}
impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
impl<const R: usize> Index<usize> for TensorIndex<R> {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
@ -333,45 +338,45 @@ impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
}
}
impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.indices[index]
}
}
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
for TensorIndex<'a, R>
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
for TensorIndex<R>
{
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
for TensorIndex<'a, R>
impl<const R: usize> From<(TensorShape<R>, usize)>
for TensorIndex<R>
{
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
}
}
impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
fn from(shape: &'a TensorShape<R>) -> Self {
impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
fn from(shape: TensorShape<R>) -> Self {
Self::zero(shape)
}
}
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
for TensorIndex<'a, R>
impl<T: Value, const R: usize> From<Tensor<T, R>>
for TensorIndex<R>
{
fn from(tensor: &'a Tensor<T, R>) -> Self {
Self::zero(tensor.shape())
fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape().clone())
}
}
impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
impl<const R: usize> std::fmt::Display for TensorIndex<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for (i, (&idx, &dim_size)) in self
@ -391,7 +396,7 @@ impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
// ---- Arithmetic Operations ----
impl<'a, const R: usize> Add for TensorIndex<'a, R> {
impl<const R: usize> Add for TensorIndex<R> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
@ -409,7 +414,7 @@ impl<'a, const R: usize> Add for TensorIndex<'a, R> {
}
}
impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
impl<const R: usize> Sub for TensorIndex<R> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
@ -429,13 +434,13 @@ impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
// ---- Iterator ----
pub struct TensorIndexIterator<'a, const R: usize> {
current: TensorIndex<'a, R>,
pub struct TensorIndexIterator<const R: usize> {
current: TensorIndex<R>,
end: bool,
}
impl<'a, const R: usize> TensorIndexIterator<'a, R> {
pub fn new(shape: &'a TensorShape<R>) -> Self {
impl<const R: usize> TensorIndexIterator<R> {
pub fn new(shape: TensorShape<R>) -> Self {
Self {
current: TensorIndex::zero(shape),
end: false,
@ -443,8 +448,8 @@ impl<'a, const R: usize> TensorIndexIterator<'a, R> {
}
}
impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
type Item = TensorIndex<'a, R>;
impl<const R: usize> Iterator for TensorIndexIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
@ -457,9 +462,9 @@ impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
}
}
impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
type Item = TensorIndex<'a, R>;
type IntoIter = TensorIndexIterator<'a, R>;
impl<const R: usize> IntoIterator for TensorIndex<R> {
type Item = TensorIndex<R>;
type IntoIter = TensorIndexIterator<R>;
fn into_iter(self) -> Self::IntoIter {
TensorIndexIterator {
@ -469,14 +474,14 @@ impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
}
}
pub struct TensorIndexTransposedIterator<'a, const R: usize> {
current: TensorIndex<'a, R>,
pub struct TensorIndexTransposedIterator<const R: usize> {
current: TensorIndex<R>,
order: [usize; R],
end: bool,
}
impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
pub fn new(shape: &'a TensorShape<R>, order: [usize; R]) -> Self {
impl<const R: usize> TensorIndexTransposedIterator<R> {
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
Self {
current: TensorIndex::zero(shape),
end: false,
@ -485,8 +490,8 @@ impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
}
}
impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> {
type Item = TensorIndex<'a, R>;
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {

View File

@ -27,6 +27,9 @@ macro_rules! shape {
#[macro_export]
macro_rules! index {
($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone())
};
($tensor:expr, $indices:expr) => {
TensorIndex::from(($tensor.shape().clone(), $indices))
};

View File

@ -82,18 +82,18 @@ impl<const R: usize> TensorShape<R> {
}
indices.reverse(); // Reverse the indices to match the original dimension order
TensorIndex::new(self, indices)
TensorIndex::new(self.clone(), indices)
}
pub const fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self)
pub fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self.clone())
}
pub fn index_max(&self) -> TensorIndex<R> {
let max_indices =
self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
TensorIndex::new(self, max_indices)
TensorIndex::new(self.clone(), max_indices)
}
pub fn remove_dims<const NAX: usize>(

View File

@ -89,7 +89,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into();
/// let i = (t.shape().clone(), [1, 1]).into();
/// assert_eq!(t.get(i), Some(&4));
/// ```
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
@ -102,7 +102,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into();
/// let i = (t.shape().clone(), [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
/// ```
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
@ -112,12 +112,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// Get a mutable reference to a value at the given index.
///
/// ```
/// use manifold::{tensor, Tensor};
/// use manifold::*;
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone();
/// let i = (&s, [1, 1]).into();
/// assert_eq!(t.get_mut(i), Some(&mut 4));
/// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4));
/// ```
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
self.buffer_mut().get_mut(index.flat())
@ -131,7 +129,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone();
/// let i = (&s, [1, 1]).into();
/// let i = (s, [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
/// ```
pub unsafe fn get_unchecked_mut(
@ -403,7 +401,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape())
let buffer = TensorIndex::from(self.shape().clone())
.iter_transposed(order)
.map(|index| self.get(index).unwrap().clone())
.collect();
@ -417,7 +415,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
// ---- Indexing --------------------------------------------------------------
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
type Output = T;
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
@ -425,7 +423,7 @@ impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
}
}
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>>
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
for Tensor<T, R>
{
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
@ -502,11 +500,11 @@ impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
index: TensorIndex<'a, R>,
index: TensorIndex<R>,
}
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
Self {
tensor,
index: tensor.shape.index_zero(),