Make index own Shape and refactoring
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
7935cf8a1d
commit
22551ffec4
@ -25,7 +25,7 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
||||
|
||||
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
|
||||
assert!(level < self.len(), "Level out of bounds");
|
||||
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
||||
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
TensorAxisIterator::new(self)
|
||||
.set_start(level)
|
||||
@ -38,7 +38,7 @@ pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
axis: &'a TensorAxis<'a, T, R>,
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
index: TensorIndex<'a, R>,
|
||||
index: TensorIndex<R>,
|
||||
#[getset(get = "pub")]
|
||||
end: Option<usize>,
|
||||
}
|
||||
@ -47,14 +47,14 @@ impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
|
||||
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
|
||||
Self {
|
||||
axis,
|
||||
index: TensorIndex::new(axis.shape(), [0; R]),
|
||||
index: TensorIndex::new(axis.shape().clone(), [0; R]),
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_start(self, start: usize) -> Self {
|
||||
assert!(start < self.axis().len(), "Start out of bounds");
|
||||
let mut index = TensorIndex::new(self.axis().shape(), [0; R]);
|
||||
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
|
||||
index.set_axis(self.axis.dim, start);
|
||||
Self {
|
||||
axis: self.axis(),
|
||||
|
105
src/index.rs
105
src/index.rs
@ -6,22 +6,32 @@ use std::{
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
pub struct TensorIndex<'a, const R: usize> {
|
||||
pub struct TensorIndex<const R: usize> {
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
indices: [usize; R],
|
||||
#[getset(get = "pub")]
|
||||
shape: &'a TensorShape<R>,
|
||||
shape: TensorShape<R>,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
pub const fn zero(shape: &'a TensorShape<R>) -> Self {
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
|
||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
}
|
||||
|
||||
pub const fn zero(shape: TensorShape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn last(shape: &'a TensorShape<R>) -> Self {
|
||||
pub fn last(shape: TensorShape<R>) -> Self {
|
||||
let max_indices =
|
||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||
Self {
|
||||
@ -29,14 +39,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(shape: &'a TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
}
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.indices.iter().all(|&i| i == 0)
|
||||
}
|
||||
@ -298,34 +303,34 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||
pub fn iter_transposed(
|
||||
&self,
|
||||
order: [usize; R],
|
||||
) -> TensorIndexTransposedIterator<'a, R> {
|
||||
TensorIndexTransposedIterator::new(self.shape(), order)
|
||||
) -> TensorIndexTransposedIterator<R> {
|
||||
TensorIndexTransposedIterator::new(self.shape().clone(), order)
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
|
||||
impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> {
|
||||
impl<const R: usize> PartialEq for TensorIndex<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.flat() == other.flat()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Eq for TensorIndex<'a, R> {}
|
||||
impl<const R: usize> Eq for TensorIndex<R> {}
|
||||
|
||||
impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> {
|
||||
impl<const R: usize> PartialOrd for TensorIndex<R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.flat().partial_cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Ord for TensorIndex<'a, R> {
|
||||
impl<const R: usize> Ord for TensorIndex<R> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.flat().cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
|
||||
impl<const R: usize> Index<usize> for TensorIndex<R> {
|
||||
type Output = usize;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
@ -333,45 +338,45 @@ impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
|
||||
impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.indices[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
|
||||
for TensorIndex<'a, R>
|
||||
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
||||
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
|
||||
for TensorIndex<'a, R>
|
||||
impl<const R: usize> From<(TensorShape<R>, usize)>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
||||
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
|
||||
fn from(shape: &'a TensorShape<R>) -> Self {
|
||||
impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
|
||||
fn from(shape: TensorShape<R>) -> Self {
|
||||
Self::zero(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
|
||||
for TensorIndex<'a, R>
|
||||
impl<T: Value, const R: usize> From<Tensor<T, R>>
|
||||
for TensorIndex<R>
|
||||
{
|
||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape())
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape().clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
|
||||
impl<const R: usize> std::fmt::Display for TensorIndex<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
for (i, (&idx, &dim_size)) in self
|
||||
@ -391,7 +396,7 @@ impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
|
||||
|
||||
// ---- Arithmetic Operations ----
|
||||
|
||||
impl<'a, const R: usize> Add for TensorIndex<'a, R> {
|
||||
impl<const R: usize> Add for TensorIndex<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
@ -409,7 +414,7 @@ impl<'a, const R: usize> Add for TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
|
||||
impl<const R: usize> Sub for TensorIndex<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
@ -429,13 +434,13 @@ impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
|
||||
|
||||
// ---- Iterator ----
|
||||
|
||||
pub struct TensorIndexIterator<'a, const R: usize> {
|
||||
current: TensorIndex<'a, R>,
|
||||
pub struct TensorIndexIterator<const R: usize> {
|
||||
current: TensorIndex<R>,
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> TensorIndexIterator<'a, R> {
|
||||
pub fn new(shape: &'a TensorShape<R>) -> Self {
|
||||
impl<const R: usize> TensorIndexIterator<R> {
|
||||
pub fn new(shape: TensorShape<R>) -> Self {
|
||||
Self {
|
||||
current: TensorIndex::zero(shape),
|
||||
end: false,
|
||||
@ -443,8 +448,8 @@ impl<'a, const R: usize> TensorIndexIterator<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
|
||||
type Item = TensorIndex<'a, R>;
|
||||
impl<const R: usize> Iterator for TensorIndexIterator<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
@ -457,9 +462,9 @@ impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
|
||||
type Item = TensorIndex<'a, R>;
|
||||
type IntoIter = TensorIndexIterator<'a, R>;
|
||||
impl<const R: usize> IntoIterator for TensorIndex<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
type IntoIter = TensorIndexIterator<R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorIndexIterator {
|
||||
@ -469,14 +474,14 @@ impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub struct TensorIndexTransposedIterator<'a, const R: usize> {
|
||||
current: TensorIndex<'a, R>,
|
||||
pub struct TensorIndexTransposedIterator<const R: usize> {
|
||||
current: TensorIndex<R>,
|
||||
order: [usize; R],
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
|
||||
pub fn new(shape: &'a TensorShape<R>, order: [usize; R]) -> Self {
|
||||
impl<const R: usize> TensorIndexTransposedIterator<R> {
|
||||
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
|
||||
Self {
|
||||
current: TensorIndex::zero(shape),
|
||||
end: false,
|
||||
@ -485,8 +490,8 @@ impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> {
|
||||
type Item = TensorIndex<'a, R>;
|
||||
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
|
||||
type Item = TensorIndex<R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
|
@ -27,6 +27,9 @@ macro_rules! shape {
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! index {
|
||||
($tensor:expr) => {
|
||||
TensorIndex::zero($tensor.shape().clone())
|
||||
};
|
||||
($tensor:expr, $indices:expr) => {
|
||||
TensorIndex::from(($tensor.shape().clone(), $indices))
|
||||
};
|
||||
|
@ -82,18 +82,18 @@ impl<const R: usize> TensorShape<R> {
|
||||
}
|
||||
|
||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||
TensorIndex::new(self, indices)
|
||||
TensorIndex::new(self.clone(), indices)
|
||||
}
|
||||
|
||||
pub const fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self)
|
||||
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||
TensorIndex::zero(self.clone())
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> TensorIndex<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
TensorIndex::new(self, max_indices)
|
||||
TensorIndex::new(self.clone(), max_indices)
|
||||
}
|
||||
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
|
@ -89,7 +89,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// use manifold::{tensor, Tensor};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let i = (t.shape(), [1, 1]).into();
|
||||
/// let i = (t.shape().clone(), [1, 1]).into();
|
||||
/// assert_eq!(t.get(i), Some(&4));
|
||||
/// ```
|
||||
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
|
||||
@ -102,7 +102,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// use manifold::{tensor, Tensor};
|
||||
///
|
||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let i = (t.shape(), [1, 1]).into();
|
||||
/// let i = (t.shape().clone(), [1, 1]).into();
|
||||
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
|
||||
/// ```
|
||||
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
|
||||
@ -112,12 +112,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// Get a mutable reference to a value at the given index.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::{tensor, Tensor};
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let s = t.shape().clone();
|
||||
/// let i = (&s, [1, 1]).into();
|
||||
/// assert_eq!(t.get_mut(i), Some(&mut 4));
|
||||
/// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4));
|
||||
/// ```
|
||||
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
|
||||
self.buffer_mut().get_mut(index.flat())
|
||||
@ -131,7 +129,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
///
|
||||
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
||||
/// let s = t.shape().clone();
|
||||
/// let i = (&s, [1, 1]).into();
|
||||
/// let i = (s, [1, 1]).into();
|
||||
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
|
||||
/// ```
|
||||
pub unsafe fn get_unchecked_mut(
|
||||
@ -403,7 +401,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
||||
/// ```
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
let buffer = TensorIndex::from(self.shape())
|
||||
let buffer = TensorIndex::from(self.shape().clone())
|
||||
.iter_transposed(order)
|
||||
.map(|index| self.get(index).unwrap().clone())
|
||||
.collect();
|
||||
@ -417,7 +415,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
// ---- Indexing --------------------------------------------------------------
|
||||
|
||||
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
|
||||
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
||||
type Output = T;
|
||||
|
||||
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
|
||||
@ -425,7 +423,7 @@ impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>>
|
||||
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
|
||||
for Tensor<T, R>
|
||||
{
|
||||
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
||||
@ -502,11 +500,11 @@ impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
|
||||
|
||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
index: TensorIndex<'a, R>,
|
||||
index: TensorIndex<R>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self {
|
||||
tensor,
|
||||
index: tensor.shape.index_zero(),
|
||||
|
Loading…
Reference in New Issue
Block a user