Make index own Shape and refactoring
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
7935cf8a1d
commit
22551ffec4
@ -25,7 +25,7 @@ impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
|||||||
|
|
||||||
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
|
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
|
||||||
assert!(level < self.len(), "Level out of bounds");
|
assert!(level < self.len(), "Level out of bounds");
|
||||||
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
|
||||||
index.set_axis(self.dim, level);
|
index.set_axis(self.dim, level);
|
||||||
TensorAxisIterator::new(self)
|
TensorAxisIterator::new(self)
|
||||||
.set_start(level)
|
.set_start(level)
|
||||||
@ -38,7 +38,7 @@ pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
|
|||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
axis: &'a TensorAxis<'a, T, R>,
|
axis: &'a TensorAxis<'a, T, R>,
|
||||||
#[getset(get = "pub", get_mut = "pub")]
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
index: TensorIndex<'a, R>,
|
index: TensorIndex<R>,
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
end: Option<usize>,
|
end: Option<usize>,
|
||||||
}
|
}
|
||||||
@ -47,14 +47,14 @@ impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
|
|||||||
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
|
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
axis,
|
axis,
|
||||||
index: TensorIndex::new(axis.shape(), [0; R]),
|
index: TensorIndex::new(axis.shape().clone(), [0; R]),
|
||||||
end: None,
|
end: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_start(self, start: usize) -> Self {
|
pub fn set_start(self, start: usize) -> Self {
|
||||||
assert!(start < self.axis().len(), "Start out of bounds");
|
assert!(start < self.axis().len(), "Start out of bounds");
|
||||||
let mut index = TensorIndex::new(self.axis().shape(), [0; R]);
|
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
|
||||||
index.set_axis(self.axis.dim, start);
|
index.set_axis(self.axis.dim, start);
|
||||||
Self {
|
Self {
|
||||||
axis: self.axis(),
|
axis: self.axis(),
|
||||||
|
105
src/index.rs
105
src/index.rs
@ -6,22 +6,32 @@ use std::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||||
pub struct TensorIndex<'a, const R: usize> {
|
pub struct TensorIndex<const R: usize> {
|
||||||
#[getset(get = "pub", get_mut = "pub")]
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
indices: [usize; R],
|
indices: [usize; R],
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
shape: &'a TensorShape<R>,
|
shape: TensorShape<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> TensorIndex<'a, R> {
|
// ---- Construction and Initialization ---------------------------------------
|
||||||
pub const fn zero(shape: &'a TensorShape<R>) -> Self {
|
|
||||||
|
impl<const R: usize> TensorIndex<R> {
|
||||||
|
|
||||||
|
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||||
|
if !shape.check_indices(indices) {
|
||||||
|
panic!("indices out of bounds");
|
||||||
|
}
|
||||||
|
Self { indices, shape }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn zero(shape: TensorShape<R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
indices: [0; R],
|
indices: [0; R],
|
||||||
shape,
|
shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn last(shape: &'a TensorShape<R>) -> Self {
|
pub fn last(shape: TensorShape<R>) -> Self {
|
||||||
let max_indices =
|
let max_indices =
|
||||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||||
Self {
|
Self {
|
||||||
@ -29,14 +39,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
shape,
|
shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn new(shape: &'a TensorShape<R>, indices: [usize; R]) -> Self {
|
impl<const R: usize> TensorIndex<R> {
|
||||||
if !shape.check_indices(indices) {
|
|
||||||
panic!("indices out of bounds");
|
|
||||||
}
|
|
||||||
Self { indices, shape }
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn is_zero(&self) -> bool {
|
pub fn is_zero(&self) -> bool {
|
||||||
self.indices.iter().all(|&i| i == 0)
|
self.indices.iter().all(|&i| i == 0)
|
||||||
}
|
}
|
||||||
@ -298,34 +303,34 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
|
|||||||
pub fn iter_transposed(
|
pub fn iter_transposed(
|
||||||
&self,
|
&self,
|
||||||
order: [usize; R],
|
order: [usize; R],
|
||||||
) -> TensorIndexTransposedIterator<'a, R> {
|
) -> TensorIndexTransposedIterator<R> {
|
||||||
TensorIndexTransposedIterator::new(self.shape(), order)
|
TensorIndexTransposedIterator::new(self.shape().clone(), order)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- blanket impls ---
|
// --- blanket impls ---
|
||||||
|
|
||||||
impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> {
|
impl<const R: usize> PartialEq for TensorIndex<R> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.flat() == other.flat()
|
self.flat() == other.flat()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Eq for TensorIndex<'a, R> {}
|
impl<const R: usize> Eq for TensorIndex<R> {}
|
||||||
|
|
||||||
impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> {
|
impl<const R: usize> PartialOrd for TensorIndex<R> {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
self.flat().partial_cmp(&other.flat())
|
self.flat().partial_cmp(&other.flat())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Ord for TensorIndex<'a, R> {
|
impl<const R: usize> Ord for TensorIndex<R> {
|
||||||
fn cmp(&self, other: &Self) -> Ordering {
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
self.flat().cmp(&other.flat())
|
self.flat().cmp(&other.flat())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
|
impl<const R: usize> Index<usize> for TensorIndex<R> {
|
||||||
type Output = usize;
|
type Output = usize;
|
||||||
|
|
||||||
fn index(&self, index: usize) -> &Self::Output {
|
fn index(&self, index: usize) -> &Self::Output {
|
||||||
@ -333,45 +338,45 @@ impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
|
impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
|
||||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||||
&mut self.indices[index]
|
&mut self.indices[index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])>
|
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
|
||||||
for TensorIndex<'a, R>
|
for TensorIndex<R>
|
||||||
{
|
{
|
||||||
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
|
||||||
assert!(shape.check_indices(indices));
|
assert!(shape.check_indices(indices));
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)>
|
impl<const R: usize> From<(TensorShape<R>, usize)>
|
||||||
for TensorIndex<'a, R>
|
for TensorIndex<R>
|
||||||
{
|
{
|
||||||
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
|
||||||
let indices = shape.index_from_flat(flat_index).indices;
|
let indices = shape.index_from_flat(flat_index).indices;
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
|
impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
|
||||||
fn from(shape: &'a TensorShape<R>) -> Self {
|
fn from(shape: TensorShape<R>) -> Self {
|
||||||
Self::zero(shape)
|
Self::zero(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>>
|
impl<T: Value, const R: usize> From<Tensor<T, R>>
|
||||||
for TensorIndex<'a, R>
|
for TensorIndex<R>
|
||||||
{
|
{
|
||||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
fn from(tensor: Tensor<T, R>) -> Self {
|
||||||
Self::zero(tensor.shape())
|
Self::zero(tensor.shape().clone())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
|
impl<const R: usize> std::fmt::Display for TensorIndex<R> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "[")?;
|
write!(f, "[")?;
|
||||||
for (i, (&idx, &dim_size)) in self
|
for (i, (&idx, &dim_size)) in self
|
||||||
@ -391,7 +396,7 @@ impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
|
|||||||
|
|
||||||
// ---- Arithmetic Operations ----
|
// ---- Arithmetic Operations ----
|
||||||
|
|
||||||
impl<'a, const R: usize> Add for TensorIndex<'a, R> {
|
impl<const R: usize> Add for TensorIndex<R> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn add(self, rhs: Self) -> Self::Output {
|
fn add(self, rhs: Self) -> Self::Output {
|
||||||
@ -409,7 +414,7 @@ impl<'a, const R: usize> Add for TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
|
impl<const R: usize> Sub for TensorIndex<R> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn sub(self, rhs: Self) -> Self::Output {
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
@ -429,13 +434,13 @@ impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
|
|||||||
|
|
||||||
// ---- Iterator ----
|
// ---- Iterator ----
|
||||||
|
|
||||||
pub struct TensorIndexIterator<'a, const R: usize> {
|
pub struct TensorIndexIterator<const R: usize> {
|
||||||
current: TensorIndex<'a, R>,
|
current: TensorIndex<R>,
|
||||||
end: bool,
|
end: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> TensorIndexIterator<'a, R> {
|
impl<const R: usize> TensorIndexIterator<R> {
|
||||||
pub fn new(shape: &'a TensorShape<R>) -> Self {
|
pub fn new(shape: TensorShape<R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
current: TensorIndex::zero(shape),
|
current: TensorIndex::zero(shape),
|
||||||
end: false,
|
end: false,
|
||||||
@ -443,8 +448,8 @@ impl<'a, const R: usize> TensorIndexIterator<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
|
impl<const R: usize> Iterator for TensorIndexIterator<R> {
|
||||||
type Item = TensorIndex<'a, R>;
|
type Item = TensorIndex<R>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.end {
|
if self.end {
|
||||||
@ -457,9 +462,9 @@ impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
|
impl<const R: usize> IntoIterator for TensorIndex<R> {
|
||||||
type Item = TensorIndex<'a, R>;
|
type Item = TensorIndex<R>;
|
||||||
type IntoIter = TensorIndexIterator<'a, R>;
|
type IntoIter = TensorIndexIterator<R>;
|
||||||
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
TensorIndexIterator {
|
TensorIndexIterator {
|
||||||
@ -469,14 +474,14 @@ impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct TensorIndexTransposedIterator<'a, const R: usize> {
|
pub struct TensorIndexTransposedIterator<const R: usize> {
|
||||||
current: TensorIndex<'a, R>,
|
current: TensorIndex<R>,
|
||||||
order: [usize; R],
|
order: [usize; R],
|
||||||
end: bool,
|
end: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
|
impl<const R: usize> TensorIndexTransposedIterator<R> {
|
||||||
pub fn new(shape: &'a TensorShape<R>, order: [usize; R]) -> Self {
|
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
|
||||||
Self {
|
Self {
|
||||||
current: TensorIndex::zero(shape),
|
current: TensorIndex::zero(shape),
|
||||||
end: false,
|
end: false,
|
||||||
@ -485,8 +490,8 @@ impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> {
|
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
|
||||||
type Item = TensorIndex<'a, R>;
|
type Item = TensorIndex<R>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.end {
|
if self.end {
|
||||||
|
@ -27,6 +27,9 @@ macro_rules! shape {
|
|||||||
|
|
||||||
#[macro_export]
|
#[macro_export]
|
||||||
macro_rules! index {
|
macro_rules! index {
|
||||||
|
($tensor:expr) => {
|
||||||
|
TensorIndex::zero($tensor.shape().clone())
|
||||||
|
};
|
||||||
($tensor:expr, $indices:expr) => {
|
($tensor:expr, $indices:expr) => {
|
||||||
TensorIndex::from(($tensor.shape().clone(), $indices))
|
TensorIndex::from(($tensor.shape().clone(), $indices))
|
||||||
};
|
};
|
||||||
|
@ -82,18 +82,18 @@ impl<const R: usize> TensorShape<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||||
TensorIndex::new(self, indices)
|
TensorIndex::new(self.clone(), indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn index_zero(&self) -> TensorIndex<R> {
|
pub fn index_zero(&self) -> TensorIndex<R> {
|
||||||
TensorIndex::zero(self)
|
TensorIndex::zero(self.clone())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn index_max(&self) -> TensorIndex<R> {
|
pub fn index_max(&self) -> TensorIndex<R> {
|
||||||
let max_indices =
|
let max_indices =
|
||||||
self.0
|
self.0
|
||||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||||
TensorIndex::new(self, max_indices)
|
TensorIndex::new(self.clone(), max_indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_dims<const NAX: usize>(
|
pub fn remove_dims<const NAX: usize>(
|
||||||
|
@ -89,7 +89,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
///
|
///
|
||||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||||
/// let i = (t.shape(), [1, 1]).into();
|
/// let i = (t.shape().clone(), [1, 1]).into();
|
||||||
/// assert_eq!(t.get(i), Some(&4));
|
/// assert_eq!(t.get(i), Some(&4));
|
||||||
/// ```
|
/// ```
|
||||||
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
|
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
|
||||||
@ -102,7 +102,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::{tensor, Tensor};
|
||||||
///
|
///
|
||||||
/// let t = tensor!([[1, 2], [3, 4]]);
|
/// let t = tensor!([[1, 2], [3, 4]]);
|
||||||
/// let i = (t.shape(), [1, 1]).into();
|
/// let i = (t.shape().clone(), [1, 1]).into();
|
||||||
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
|
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
|
||||||
/// ```
|
/// ```
|
||||||
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
|
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
|
||||||
@ -112,12 +112,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
/// Get a mutable reference to a value at the given index.
|
/// Get a mutable reference to a value at the given index.
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use manifold::{tensor, Tensor};
|
/// use manifold::*;
|
||||||
///
|
///
|
||||||
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
||||||
/// let s = t.shape().clone();
|
/// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4));
|
||||||
/// let i = (&s, [1, 1]).into();
|
|
||||||
/// assert_eq!(t.get_mut(i), Some(&mut 4));
|
|
||||||
/// ```
|
/// ```
|
||||||
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
|
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
|
||||||
self.buffer_mut().get_mut(index.flat())
|
self.buffer_mut().get_mut(index.flat())
|
||||||
@ -131,7 +129,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
///
|
///
|
||||||
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
/// let mut t = tensor!([[1, 2], [3, 4]]);
|
||||||
/// let s = t.shape().clone();
|
/// let s = t.shape().clone();
|
||||||
/// let i = (&s, [1, 1]).into();
|
/// let i = (s, [1, 1]).into();
|
||||||
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
|
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
|
||||||
/// ```
|
/// ```
|
||||||
pub unsafe fn get_unchecked_mut(
|
pub unsafe fn get_unchecked_mut(
|
||||||
@ -403,7 +401,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
|
||||||
/// ```
|
/// ```
|
||||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||||
let buffer = TensorIndex::from(self.shape())
|
let buffer = TensorIndex::from(self.shape().clone())
|
||||||
.iter_transposed(order)
|
.iter_transposed(order)
|
||||||
.map(|index| self.get(index).unwrap().clone())
|
.map(|index| self.get(index).unwrap().clone())
|
||||||
.collect();
|
.collect();
|
||||||
@ -417,7 +415,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
// ---- Indexing --------------------------------------------------------------
|
// ---- Indexing --------------------------------------------------------------
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
|
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
|
||||||
type Output = T;
|
type Output = T;
|
||||||
|
|
||||||
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
|
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
|
||||||
@ -425,7 +423,7 @@ impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>>
|
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
|
||||||
for Tensor<T, R>
|
for Tensor<T, R>
|
||||||
{
|
{
|
||||||
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
||||||
@ -502,11 +500,11 @@ impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
|
|||||||
|
|
||||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||||
tensor: &'a Tensor<T, R>,
|
tensor: &'a Tensor<T, R>,
|
||||||
index: TensorIndex<'a, R>,
|
index: TensorIndex<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||||
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
|
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
tensor,
|
tensor,
|
||||||
index: tensor.shape.index_zero(),
|
index: tensor.shape.index_zero(),
|
||||||
|
Loading…
Reference in New Issue
Block a user