Refactoring
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
f734fa6434
commit
d8a8551016
344
src/index.rs
344
src/index.rs
@ -25,30 +25,30 @@ pub struct TensorIndex<const R: usize> {
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Creates a new `TensorIndex` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
/// Creates a new `TensorIndex` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
|
||||
let index = Self { indices, shape };
|
||||
assert!(index.check_indices(indices));
|
||||
index
|
||||
assert!(index.check_indices(indices));
|
||||
index
|
||||
}
|
||||
|
||||
/// Creates a new `TensorIndex` instance with all indices set to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
/// Creates a new `TensorIndex` instance with all indices set to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
pub const fn zero(shape: TensorShape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
@ -56,184 +56,185 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn from_flat(shape: TensorShape<R>, flat_index: usize) -> Self {
|
||||
let mut indices = [0; R];
|
||||
pub fn from_flat(shape: TensorShape<R>, flat_index: usize) -> Self {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
// - The method iterates over the dimensions of the tensor in reverse order
|
||||
// (assuming row-major order).
|
||||
// - In each iteration, it uses the modulo operation to find the index in
|
||||
// the current dimension and integer division to reduce the flat index
|
||||
// for the next higher dimension.
|
||||
// - This process is repeated for each dimension to build the
|
||||
// multi-dimensional index.
|
||||
// - The method iterates over the dimensions of the tensor in reverse
|
||||
// order (assuming row-major order).
|
||||
// - In each iteration, it uses the modulo operation to find the index
|
||||
// in the current dimension and integer division to reduce the flat
|
||||
// index for the next higher dimension.
|
||||
// - This process is repeated for each dimension to build the
|
||||
// multi-dimensional index.
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(shape.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
// Reverse the indices to match the original dimension order
|
||||
// Reverse the indices to match the original dimension order
|
||||
indices.reverse();
|
||||
Self::new(shape.clone(), indices)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Trivial Functions -----------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Returns `true` if all indices are zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert!(index.is_zero());
|
||||
/// ```
|
||||
/// Returns `true` if all indices are zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::zero(shape);
|
||||
/// assert!(index.is_zero());
|
||||
/// ```
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.indices().iter().all(|&i| i == 0)
|
||||
}
|
||||
|
||||
/// Returns `true` if the last index is equal to the size of the last
|
||||
/// dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([1, 1]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert!(!index.is_max());
|
||||
/// index.inc();
|
||||
/// assert!(index.is_max());
|
||||
/// ```
|
||||
/// Returns `true` if the last index is equal to the size of the last
|
||||
/// dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([1, 1]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert!(!index.is_max());
|
||||
/// index.inc();
|
||||
/// assert!(index.is_max());
|
||||
/// ```
|
||||
pub fn is_max(&self) -> bool {
|
||||
self.indices()[0] == self.shape().get(0)
|
||||
self.indices().get(0) == self.shape().dim_size(0)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Utils -----------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Resets the index to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.reset();
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
pub fn reset(&mut self) {
|
||||
/// Resets the index to zero.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.reset();
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// ```
|
||||
pub fn reset(&mut self) {
|
||||
*self.indices_mut() = [0; R];
|
||||
}
|
||||
|
||||
fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.shape().as_array().iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.shape().as_array().iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Increment and Decrement -----------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Increments the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc();
|
||||
/// assert_eq!(index.flat(), 1);
|
||||
/// ```
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc();
|
||||
/// assert_eq!(index.flat(), 1);
|
||||
/// ```
|
||||
pub fn inc(&mut self) {
|
||||
if self.indices()[0] == self.shape().get(0) {
|
||||
return;
|
||||
}
|
||||
if self.indices().get(0) == self.shape().dim_size(0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let shape = self.shape().as_array().clone();
|
||||
let shape = self.shape().as_array().clone();
|
||||
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return;
|
||||
}
|
||||
*dim = 0;
|
||||
}
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return;
|
||||
}
|
||||
*dim = 0;
|
||||
}
|
||||
|
||||
self.indices_mut()[0] = self.shape().get(0);
|
||||
}
|
||||
self.indices_mut()[0] = *self.shape().dim_size(0).unwrap();
|
||||
}
|
||||
|
||||
/// Increments the index by one, with the specified axis fixed,
|
||||
/// i.e. the index of the fixed axis is incremented only if the
|
||||
/// index of the other axes reaches the maximum.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_fixed_axis(Axis(0));
|
||||
/// assert_eq!(index.flat(), 4);
|
||||
/// index.inc_fixed_axis(Axis(1));
|
||||
/// assert_eq!(index.flat(), 5);
|
||||
/// ```
|
||||
/// Increments the index by one, with the specified axis fixed,
|
||||
/// i.e. the index of the fixed axis is incremented only if the
|
||||
/// index of the other axes reaches the maximum.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_fixed_axis(Axis(0));
|
||||
/// assert_eq!(index.flat(), 4);
|
||||
/// index.inc_fixed_axis(Axis(1));
|
||||
/// assert_eq!(index.flat(), 5);
|
||||
/// ```
|
||||
pub fn inc_fixed_axis(&mut self, Axis(ax): Axis) {
|
||||
let shape = self.shape().as_array().clone();
|
||||
assert!(ax < R, "TensorAxis out of bounds");
|
||||
if self.indices()[ax] == self.shape().get(ax) {
|
||||
return;
|
||||
}
|
||||
let shape = self.shape().as_array().clone();
|
||||
assert!(ax < R, "TensorAxis out of bounds");
|
||||
if self.indices().get(ax) == self.shape().dim_size(ax) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Iterate over all axes, skipping the fixed axis 'ax'
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
if i != ax { // Skip the fixed axis
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return; // No carry over needed
|
||||
}
|
||||
*dim = 0; // Reset the current axis and carry over to the next
|
||||
}
|
||||
}
|
||||
// Iterate over all axes, skipping the fixed axis 'ax'
|
||||
for (i, dim) in self.indices_mut().iter_mut().rev().enumerate() {
|
||||
if i != ax {
|
||||
// Skip the fixed axis
|
||||
*dim += 1;
|
||||
if *dim < shape[i] {
|
||||
return; // No carry over needed
|
||||
}
|
||||
*dim = 0; // Reset the current axis and carry over to the next
|
||||
}
|
||||
}
|
||||
|
||||
// Handle the case where incrementing has reached the end
|
||||
if self.indices()[ax] < self.shape().get(ax) {
|
||||
self.indices_mut()[ax] += 1;
|
||||
} else {
|
||||
// Reset if the fixed axis also overflows
|
||||
self.indices_mut()[ax] = self.shape().get(ax);
|
||||
}
|
||||
}
|
||||
// Handle the case where incrementing has reached the end
|
||||
if self.indices().get(ax) < self.shape().dim_size(ax) {
|
||||
self.indices_mut()[ax] += 1;
|
||||
} else {
|
||||
// Reset if the fixed axis also overflows
|
||||
self.indices_mut()[ax] = *self.shape().dim_size(ax).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Increments the index by one, reordering the order in which the
|
||||
/// axes are incremented.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_transposed(&[2, 1, 0]);
|
||||
/// assert_eq!(index.flat(), 12);
|
||||
/// index.inc_transposed(&[0, 1, 2]);
|
||||
/// assert_eq!(index.flat(), 13);
|
||||
/// ```
|
||||
/// Increments the index by one, reordering the order in which the
|
||||
/// axes are incremented.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::zero(shape);
|
||||
/// assert_eq!(index.flat(), 0);
|
||||
/// index.inc_transposed(&[2, 1, 0]);
|
||||
/// assert_eq!(index.flat(), 12);
|
||||
/// index.inc_transposed(&[0, 1, 2]);
|
||||
/// assert_eq!(index.flat(), 13);
|
||||
/// ```
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) {
|
||||
if self.indices()[order[0]] >= self.shape().get(order[0]) {
|
||||
return ;
|
||||
if self.indices().get(order[0]) == self.shape().dim_size(order[0]) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut carry = 1;
|
||||
|
||||
for i in order.iter().rev() {
|
||||
let dim_size = self.shape().get(*i);
|
||||
let dim_size = self.shape().dim_size(*i).unwrap().clone();
|
||||
let i = self.index_mut(*i);
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
@ -246,21 +247,21 @@ impl<const R: usize> TensorIndex<R> {
|
||||
}
|
||||
|
||||
if carry == 1 {
|
||||
self.indices_mut()[order[0]] = self.shape().get(order[0]);
|
||||
self.indices_mut()[order[0]] = *self.shape().dim_size(order[0]).unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
/// Decrements the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.dec();
|
||||
/// assert_eq!(index.flat(), 22);
|
||||
/// ```
|
||||
/// Decrements the index by one.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let mut index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// index.dec();
|
||||
/// assert_eq!(index.flat(), 22);
|
||||
/// ```
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -285,19 +286,18 @@ impl<const R: usize> TensorIndex<R> {
|
||||
|
||||
// ---- Conversion to Flat Index ----------------------------------------------
|
||||
|
||||
impl <const R: usize> TensorIndex<R> {
|
||||
|
||||
impl<const R: usize> TensorIndex<R> {
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let index = TensorIndex::new(shape, [1, 2, 3]);
|
||||
/// assert_eq!(index.flat(), 23);
|
||||
/// ```
|
||||
pub fn flat(&self) -> usize {
|
||||
self.indices()
|
||||
.iter()
|
||||
|
@ -4,11 +4,11 @@
|
||||
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod iterators;
|
||||
pub mod shape;
|
||||
pub mod tensor;
|
||||
pub mod iterators;
|
||||
|
||||
pub use {iterators::*, error::*, index::*, shape::*, tensor::*};
|
||||
pub use {error::*, index::*, iterators::*, shape::*, tensor::*};
|
||||
|
||||
use num::{Num, One, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
134
src/shape.rs
134
src/shape.rs
@ -20,16 +20,16 @@ pub struct TensorShape<const R: usize>(pub(crate) [usize; R]);
|
||||
// ---- Construction and Initialization ---------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Creates a new `TensorShape` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
/// Creates a new `TensorShape` instance.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0), Some(&2));
|
||||
/// assert_eq!(shape.dim_size(1), Some(&3));
|
||||
/// assert_eq!(shape.dim_size(2), None);
|
||||
/// ```
|
||||
pub const fn new(shape: [usize; R]) -> Self {
|
||||
Self(shape)
|
||||
}
|
||||
@ -38,39 +38,39 @@ impl<const R: usize> TensorShape<R> {
|
||||
// ---- Getters ---------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Get the size of the specified dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.get(0), 2);
|
||||
/// assert_eq!(shape.get(1), 3);
|
||||
/// ```
|
||||
pub const fn get(&self, index: usize) -> usize {
|
||||
self.0[index]
|
||||
/// Get the size of the specified dimension.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.dim_size(0).unwrap(), &2);
|
||||
/// assert_eq!(shape.dim_size(1).unwrap(), &3);
|
||||
/// ```
|
||||
pub fn dim_size(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
/// Returns the shape as an array.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.as_array(), &[2, 3]);
|
||||
/// ```
|
||||
/// Returns the shape as an array.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.as_array(), &[2, 3]);
|
||||
/// ```
|
||||
pub const fn as_array(&self) -> &[usize; R] {
|
||||
&self.0
|
||||
}
|
||||
|
||||
/// Returns the size of the shape, meaning the product of all dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.size(), 6);
|
||||
/// ```
|
||||
/// Returns the size of the shape, meaning the product of all dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// assert_eq!(shape.size(), 6);
|
||||
/// ```
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
@ -79,15 +79,15 @@ impl<const R: usize> TensorShape<R> {
|
||||
// ---- Manipulation ----------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Reorders the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.reorder([2, 0, 1]);
|
||||
/// assert_eq!(new_shape, shape!([4, 2, 3]));
|
||||
/// ```
|
||||
/// Reorders the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.reorder([2, 0, 1]);
|
||||
/// assert_eq!(new_shape, shape!([4, 2, 3]));
|
||||
/// ```
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = TensorShape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
@ -96,16 +96,16 @@ impl<const R: usize> TensorShape<R> {
|
||||
new_shape
|
||||
}
|
||||
|
||||
/// Creates a new shape by removing the specified dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.remove_dims([0, 2]);
|
||||
/// assert_eq!(new_shape, shape!([3]));
|
||||
/// ```
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
/// Creates a new shape by removing the specified dimensions.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3, 4]);
|
||||
/// let new_shape = shape.remove_dims([0, 2]);
|
||||
/// assert_eq!(new_shape, shape!([3]));
|
||||
/// ```
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
&self,
|
||||
dims_to_remove: [usize; NAX],
|
||||
) -> TensorShape<{ R - NAX }> {
|
||||
@ -132,17 +132,17 @@ impl<const R: usize> TensorShape<R> {
|
||||
// ---- Iterators -------------------------------------------------------------
|
||||
|
||||
impl<const R: usize> TensorShape<R> {
|
||||
/// Returns an iterator over the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// let mut iter = shape.iter();
|
||||
/// assert_eq!(iter.next(), Some(&2));
|
||||
/// assert_eq!(iter.next(), Some(&3));
|
||||
/// assert_eq!(iter.next(), None);
|
||||
/// ```
|
||||
/// Returns an iterator over the dimensions of the shape.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::*;
|
||||
///
|
||||
/// let shape = TensorShape::new([2, 3]);
|
||||
/// let mut iter = shape.iter();
|
||||
/// assert_eq!(iter.next(), Some(&2));
|
||||
/// assert_eq!(iter.next(), Some(&3));
|
||||
/// assert_eq!(iter.next(), None);
|
||||
/// ```
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
@ -65,39 +65,49 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let buffer = vec![1, 2, 3, 4, 5, 6];
|
||||
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer);
|
||||
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer).unwrap();
|
||||
/// assert_eq!(t.shape().as_array(), &[2, 3]);
|
||||
/// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]);
|
||||
/// ```
|
||||
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
pub fn new_with_buffer(
|
||||
shape: TensorShape<R>,
|
||||
buffer: Vec<T>,
|
||||
) -> Result<Self> {
|
||||
if buffer.len() != shape.size() {
|
||||
Err(TensorError::InvalidArgument(format!(
|
||||
"Provided buffer has length {} but shape has size {}",
|
||||
buffer.len(),
|
||||
shape.size()
|
||||
)))
|
||||
} else {
|
||||
Ok(Self { buffer, shape })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Trivial Getters -------------------------------------------------------
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
/// Get the rank of the tensor.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.rank(), 2);
|
||||
/// ```
|
||||
pub fn rank(&self) -> usize {
|
||||
/// Get the rank of the tensor.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.rank(), 2);
|
||||
/// ```
|
||||
pub const fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
||||
/// Get the length of the tensor's buffer.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.len(), 9);
|
||||
/// ```
|
||||
/// Get the length of the tensor's buffer.
|
||||
///
|
||||
/// ```
|
||||
/// use manifold::Tensor;
|
||||
///
|
||||
/// let t = Tensor::<f64, 2>::new([3, 3].into());
|
||||
/// assert_eq!(t.len(), 9);
|
||||
/// ```
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer().len()
|
||||
}
|
||||
@ -352,32 +362,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
) -> Result<()> {
|
||||
self.ew_for_each(other, result, &|a, b| a % b)
|
||||
}
|
||||
|
||||
// pub fn product<const S: usize>(
|
||||
// &self,
|
||||
// other: &Tensor<T, S>,
|
||||
// ) -> Tensor<T, { R + S }> {
|
||||
// let mut new_shape_vec = Vec::new();
|
||||
// new_shape_vec.extend_from_slice(&self.shape().as_array());
|
||||
// new_shape_vec.extend_from_slice(&other.shape().as_array());
|
||||
|
||||
// let new_shape_array: [usize; R + S] = new_shape_vec
|
||||
// .try_into()
|
||||
// .expect("Failed to create shape array");
|
||||
|
||||
// let mut new_buffer =
|
||||
// Vec::with_capacity(self.buffer.len() * other.buffer.len());
|
||||
// for &item_self in &self.buffer {
|
||||
// for &item_other in &other.buffer {
|
||||
// new_buffer.push(item_self * item_other);
|
||||
// }
|
||||
// }
|
||||
|
||||
// Tensor {
|
||||
// buffer: new_buffer,
|
||||
// shape: TensorShape::new(new_shape_array),
|
||||
// }
|
||||
// }
|
||||
}
|
||||
|
||||
// ---- Reshape ---------------------------------------------------------------
|
||||
@ -405,7 +389,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
||||
)))
|
||||
} else {
|
||||
Ok(Tensor::new_with_buffer(shape, self.buffer))
|
||||
Ok(Tensor::new_with_buffer(shape, self.buffer).unwrap())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -53,14 +53,10 @@ fn test_iterating_3d_tensor() {
|
||||
}
|
||||
}
|
||||
|
||||
println!("{}", tensor);
|
||||
|
||||
// Iterate over the tensor and check that the numbers are correct
|
||||
|
||||
let mut iter = TensorIterator::new(&tensor);
|
||||
|
||||
println!("{}", iter);
|
||||
|
||||
assert_eq!(iter.next(), Some(&0));
|
||||
|
||||
assert_eq!(iter.next(), Some(&1));
|
||||
|
Loading…
Reference in New Issue
Block a user