Change the names of the existing types to align with requirements
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
35da61c619
commit
6ab486fb15
74
src/axis.rs
74
src/axis.rs
@ -2,16 +2,16 @@ use super::*;
|
|||||||
use getset::{Getters, MutGetters};
|
use getset::{Getters, MutGetters};
|
||||||
|
|
||||||
#[derive(Clone, Debug, Getters)]
|
#[derive(Clone, Debug, Getters)]
|
||||||
pub struct Axis<'a, T: Value, const R: usize> {
|
pub struct TensorAxis<'a, T: Value, const R: usize> {
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
tensor: &'a Tensor<T, R>,
|
tensor: &'a Tensor<T, R>,
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
dim: usize,
|
dim: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
|
||||||
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
||||||
assert!(dim < R, "Axis out of bounds");
|
assert!(dim < R, "TensorAxis out of bounds");
|
||||||
Self { tensor, dim }
|
Self { tensor, dim }
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -19,40 +19,40 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
|||||||
self.tensor.shape().get(self.dim)
|
self.tensor.shape().get(self.dim)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> &Shape<R> {
|
pub fn shape(&self) -> &TensorShape<R> {
|
||||||
self.tensor.shape()
|
self.tensor.shape()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> {
|
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
|
||||||
assert!(level < self.len(), "Level out of bounds");
|
assert!(level < self.len(), "Level out of bounds");
|
||||||
let mut index = Idx::new(self.shape(), [0; R]);
|
let mut index = TensorIndex::new(self.shape(), [0; R]);
|
||||||
index.set_axis(self.dim, level);
|
index.set_axis(self.dim, level);
|
||||||
AxisIterator::new(self).set_start(level).set_end(level + 1)
|
TensorAxisIterator::new(self).set_start(level).set_end(level + 1)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone, Debug, Getters, MutGetters)]
|
#[derive(Clone, Debug, Getters, MutGetters)]
|
||||||
pub struct AxisIterator<'a, T: Value, const R: usize> {
|
pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
axis: &'a Axis<'a, T, R>,
|
axis: &'a TensorAxis<'a, T, R>,
|
||||||
#[getset(get = "pub", get_mut = "pub")]
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
index: Idx<'a, R>,
|
index: TensorIndex<'a, R>,
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
end: Option<usize>,
|
end: Option<usize>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
|
||||||
pub fn new(axis: &'a Axis<'a, T, R>) -> Self {
|
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
axis,
|
axis,
|
||||||
index: Idx::new(axis.shape(), [0; R]),
|
index: TensorIndex::new(axis.shape(), [0; R]),
|
||||||
end: None,
|
end: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_start(self, start: usize) -> Self {
|
pub fn set_start(self, start: usize) -> Self {
|
||||||
assert!(start < self.axis().len(), "Start out of bounds");
|
assert!(start < self.axis().len(), "Start out of bounds");
|
||||||
let mut index = Idx::new(self.axis().shape(), [0; R]);
|
let mut index = TensorIndex::new(self.axis().shape(), [0; R]);
|
||||||
index.set_axis(self.axis.dim, start);
|
index.set_axis(self.axis.dim, start);
|
||||||
Self {
|
Self {
|
||||||
axis: self.axis(),
|
axis: self.axis(),
|
||||||
@ -92,7 +92,7 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> {
|
||||||
type Item = &'a T;
|
type Item = &'a T;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
@ -106,12 +106,12 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> {
|
impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
|
||||||
type Item = &'a T;
|
type Item = &'a T;
|
||||||
type IntoIter = AxisIterator<'a, T, R>;
|
type IntoIter = TensorAxisIterator<'a, T, R>;
|
||||||
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
AxisIterator::new(&self)
|
TensorAxisIterator::new(&self)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ where
|
|||||||
let shape: [usize; R + S - 2 * N] =
|
let shape: [usize; R + S - 2 * N] =
|
||||||
shape.try_into().expect("Failed to create shape array");
|
shape.try_into().expect("Failed to create shape array");
|
||||||
|
|
||||||
let shape = Shape::new(shape);
|
let shape = TensorShape::new(shape);
|
||||||
|
|
||||||
let result = contract_axes(&lnc, &rnc);
|
let result = contract_axes(&lnc, &rnc);
|
||||||
|
|
||||||
@ -158,8 +158,8 @@ pub fn contract_axes<
|
|||||||
const S: usize,
|
const S: usize,
|
||||||
const N: usize,
|
const N: usize,
|
||||||
>(
|
>(
|
||||||
laxes: &'a [Axis<'a, T, R>],
|
laxes: &'a [TensorAxis<'a, T, R>],
|
||||||
raxes: &'a [Axis<'a, T, S>],
|
raxes: &'a [TensorAxis<'a, T, S>],
|
||||||
) -> Vec<T>
|
) -> Vec<T>
|
||||||
where
|
where
|
||||||
[(); R - N]:,
|
[(); R - N]:,
|
||||||
@ -204,7 +204,7 @@ mod tests {
|
|||||||
|
|
||||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||||
assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2]));
|
assert_eq!(contracted_tensor.shape(), &TensorShape::new([2, 2]));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
contracted_tensor.buffer(),
|
contracted_tensor.buffer(),
|
||||||
&[7, 10, 15, 22],
|
&[7, 10, 15, 22],
|
||||||
@ -228,7 +228,7 @@ mod tests {
|
|||||||
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||||
|
|
||||||
println!("contracted_tensor: {}", contracted_tensor);
|
println!("contracted_tensor: {}", contracted_tensor);
|
||||||
assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
assert_eq!(contracted_tensor.shape(), &TensorShape::new([3, 3]));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
contracted_tensor.buffer(),
|
contracted_tensor.buffer(),
|
||||||
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||||
@ -239,9 +239,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_tensor_contraction_rank3() {
|
fn test_tensor_contraction_rank3() {
|
||||||
let a: Tensor<i32, 3> =
|
let a: Tensor<i32, 3> =
|
||||||
Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
|
Tensor::new_with_buffer(TensorShape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
|
||||||
let b: Tensor<i32, 3> =
|
let b: Tensor<i32, 3> =
|
||||||
Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
|
Tensor::new_with_buffer(TensorShape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
|
||||||
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
|
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
|
||||||
|
|
||||||
println!("a: {}", a);
|
println!("a: {}", a);
|
||||||
@ -260,7 +260,7 @@ mod tests {
|
|||||||
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||||
|
|
||||||
// // Testing iteration over the first axis (axis = 0)
|
// // Testing iteration over the first axis (axis = 0)
|
||||||
// let axis = Axis::new(&tensor, 0);
|
// let axis = TensorAxis::new(&tensor, 0);
|
||||||
|
|
||||||
// let mut axis_iter = axis.into_iter().disassemble();
|
// let mut axis_iter = axis.into_iter().disassemble();
|
||||||
|
|
||||||
@ -272,7 +272,7 @@ mod tests {
|
|||||||
// assert_eq!(axis_iter[1].next(), None);
|
// assert_eq!(axis_iter[1].next(), None);
|
||||||
|
|
||||||
// // Resetting the iterator for the second axis (axis = 1)
|
// // Resetting the iterator for the second axis (axis = 1)
|
||||||
// let axis = Axis::new(&tensor, 1);
|
// let axis = TensorAxis::new(&tensor, 1);
|
||||||
|
|
||||||
// let mut axis_iter = axis.into_iter().disassemble();
|
// let mut axis_iter = axis.into_iter().disassemble();
|
||||||
|
|
||||||
@ -290,7 +290,7 @@ mod tests {
|
|||||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||||
|
|
||||||
// Testing iteration over the first axis (axis = 0)
|
// Testing iteration over the first axis (axis = 0)
|
||||||
let axis = Axis::new(&tensor, 0);
|
let axis = TensorAxis::new(&tensor, 0);
|
||||||
|
|
||||||
let mut axis_iter = axis.into_iter();
|
let mut axis_iter = axis.into_iter();
|
||||||
|
|
||||||
@ -300,7 +300,7 @@ mod tests {
|
|||||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||||
|
|
||||||
// Resetting the iterator for the second axis (axis = 1)
|
// Resetting the iterator for the second axis (axis = 1)
|
||||||
let axis = Axis::new(&tensor, 1);
|
let axis = TensorAxis::new(&tensor, 1);
|
||||||
|
|
||||||
let mut axis_iter = axis.into_iter();
|
let mut axis_iter = axis.into_iter();
|
||||||
|
|
||||||
@ -311,8 +311,8 @@ mod tests {
|
|||||||
|
|
||||||
let shape = tensor.shape();
|
let shape = tensor.shape();
|
||||||
|
|
||||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
let mut a: TensorIndex<2> = (shape, [0, 0]).into();
|
||||||
let b: Idx<2> = (shape, [1, 1]).into();
|
let b: TensorIndex<2> = (shape, [1, 1]).into();
|
||||||
|
|
||||||
while a <= b {
|
while a <= b {
|
||||||
println!("a: {}", a);
|
println!("a: {}", a);
|
||||||
@ -326,7 +326,7 @@ mod tests {
|
|||||||
// Tensor shape is 2x2x2 for simplicity
|
// Tensor shape is 2x2x2 for simplicity
|
||||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||||
|
|
||||||
// Axis 0 (Layer-wise):
|
// TensorAxis 0 (Layer-wise):
|
||||||
//
|
//
|
||||||
// t[0][0][0] = 1
|
// t[0][0][0] = 1
|
||||||
// t[0][0][1] = 2
|
// t[0][0][1] = 2
|
||||||
@ -342,11 +342,11 @@ mod tests {
|
|||||||
// the iterator goes through all rows and columns. It first completes
|
// the iterator goes through all rows and columns. It first completes
|
||||||
// the entire first layer, then moves to the second.
|
// the entire first layer, then moves to the second.
|
||||||
|
|
||||||
let a0 = Axis::new(&t, 0);
|
let a0 = TensorAxis::new(&t, 0);
|
||||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
|
||||||
// Axis 1 (Row-wise within each layer):
|
// TensorAxis 1 (Row-wise within each layer):
|
||||||
//
|
//
|
||||||
// t[0][0][0] = 1
|
// t[0][0][0] = 1
|
||||||
// t[0][0][1] = 2
|
// t[0][0][1] = 2
|
||||||
@ -362,11 +362,11 @@ mod tests {
|
|||||||
// completes the first row across all layers, then the second row
|
// completes the first row across all layers, then the second row
|
||||||
// across all layers.
|
// across all layers.
|
||||||
|
|
||||||
let a1 = Axis::new(&t, 1);
|
let a1 = TensorAxis::new(&t, 1);
|
||||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||||
|
|
||||||
// Axis 2 (Column-wise within each layer):
|
// TensorAxis 2 (Column-wise within each layer):
|
||||||
//
|
//
|
||||||
// t[0][0][0] = 1
|
// t[0][0][0] = 1
|
||||||
// t[0][1][0] = 3
|
// t[0][1][0] = 3
|
||||||
@ -382,7 +382,7 @@ mod tests {
|
|||||||
// completes the first column across all layers, then the second
|
// completes the first column across all layers, then the second
|
||||||
// column across all layers.
|
// column across all layers.
|
||||||
|
|
||||||
let a2 = Axis::new(&t, 2);
|
let a2 = TensorAxis::new(&t, 2);
|
||||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||||
}
|
}
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, TensorError>;
|
||||||
|
|
||||||
#[derive(Error, Debug)]
|
#[derive(Error, Debug)]
|
||||||
pub enum Error {
|
pub enum TensorError {
|
||||||
#[error("Invalid argument: {0}")]
|
#[error("Invalid argument: {0}")]
|
||||||
InvalidArgument(String),
|
InvalidArgument(String),
|
||||||
}
|
}
|
||||||
|
96
src/index.rs
96
src/index.rs
@ -4,22 +4,22 @@ use std::cmp::Ordering;
|
|||||||
use std::ops::{Add, Sub};
|
use std::ops::{Add, Sub};
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||||
pub struct Idx<'a, const R: usize> {
|
pub struct TensorIndex<'a, const R: usize> {
|
||||||
#[getset(get = "pub", get_mut = "pub")]
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
indices: [usize; R],
|
indices: [usize; R],
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
shape: &'a Shape<R>,
|
shape: &'a TensorShape<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Idx<'a, R> {
|
impl<'a, const R: usize> TensorIndex<'a, R> {
|
||||||
pub const fn zero(shape: &'a Shape<R>) -> Self {
|
pub const fn zero(shape: &'a TensorShape<R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
indices: [0; R],
|
indices: [0; R],
|
||||||
shape,
|
shape,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn last(shape: &'a Shape<R>) -> Self {
|
pub fn last(shape: &'a TensorShape<R>) -> Self {
|
||||||
let max_indices =
|
let max_indices =
|
||||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||||
Self {
|
Self {
|
||||||
@ -28,7 +28,7 @@ impl<'a, const R: usize> Idx<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new(shape: &'a Shape<R>, indices: [usize; R]) -> Self {
|
pub fn new(shape: &'a TensorShape<R>, indices: [usize; R]) -> Self {
|
||||||
if !shape.check_indices(indices) {
|
if !shape.check_indices(indices) {
|
||||||
panic!("indices out of bounds");
|
panic!("indices out of bounds");
|
||||||
}
|
}
|
||||||
@ -87,7 +87,7 @@ impl<'a, const R: usize> Idx<'a, R> {
|
|||||||
// fn inc_axis
|
// fn inc_axis
|
||||||
|
|
||||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||||
assert!(fixed_axis < R, "Axis out of bounds");
|
assert!(fixed_axis < R, "TensorAxis out of bounds");
|
||||||
assert!(
|
assert!(
|
||||||
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
|
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
|
||||||
"Index out of bounds"
|
"Index out of bounds"
|
||||||
@ -263,13 +263,13 @@ impl<'a, const R: usize> Idx<'a, R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
||||||
assert!(axis < R, "Axis out of bounds");
|
assert!(axis < R, "TensorAxis out of bounds");
|
||||||
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
||||||
self.indices[axis] = value;
|
self.indices[axis] = value;
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
||||||
assert!(axis < R, "Axis out of bounds");
|
assert!(axis < R, "TensorAxis out of bounds");
|
||||||
if value < self.shape.get(axis) {
|
if value < self.shape.get(axis) {
|
||||||
self.indices[axis] = value;
|
self.indices[axis] = value;
|
||||||
true
|
true
|
||||||
@ -279,41 +279,41 @@ impl<'a, const R: usize> Idx<'a, R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_axis(&self, axis: usize) -> usize {
|
pub fn get_axis(&self, axis: usize) -> usize {
|
||||||
assert!(axis < R, "Axis out of bounds");
|
assert!(axis < R, "TensorAxis out of bounds");
|
||||||
self.indices[axis]
|
self.indices[axis]
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn iter_transposed(
|
pub fn iter_transposed(
|
||||||
&self,
|
&self,
|
||||||
order: [usize; R],
|
order: [usize; R],
|
||||||
) -> IdxTransposedIterator<'a, R> {
|
) -> TensorIndexTransposedIterator<'a, R> {
|
||||||
IdxTransposedIterator::new(self.shape(), order)
|
TensorIndexTransposedIterator::new(self.shape(), order)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- blanket impls ---
|
// --- blanket impls ---
|
||||||
|
|
||||||
impl<'a, const R: usize> PartialEq for Idx<'a, R> {
|
impl<'a, const R: usize> PartialEq for TensorIndex<'a, R> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.flat() == other.flat()
|
self.flat() == other.flat()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Eq for Idx<'a, R> {}
|
impl<'a, const R: usize> Eq for TensorIndex<'a, R> {}
|
||||||
|
|
||||||
impl<'a, const R: usize> PartialOrd for Idx<'a, R> {
|
impl<'a, const R: usize> PartialOrd for TensorIndex<'a, R> {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
self.flat().partial_cmp(&other.flat())
|
self.flat().partial_cmp(&other.flat())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Ord for Idx<'a, R> {
|
impl<'a, const R: usize> Ord for TensorIndex<'a, R> {
|
||||||
fn cmp(&self, other: &Self) -> Ordering {
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
self.flat().cmp(&other.flat())
|
self.flat().cmp(&other.flat())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Index<usize> for Idx<'a, R> {
|
impl<'a, const R: usize> Index<usize> for TensorIndex<'a, R> {
|
||||||
type Output = usize;
|
type Output = usize;
|
||||||
|
|
||||||
fn index(&self, index: usize) -> &Self::Output {
|
fn index(&self, index: usize) -> &Self::Output {
|
||||||
@ -321,39 +321,39 @@ impl<'a, const R: usize> Index<usize> for Idx<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IndexMut<usize> for Idx<'a, R> {
|
impl<'a, const R: usize> IndexMut<usize> for TensorIndex<'a, R> {
|
||||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||||
&mut self.indices[index]
|
&mut self.indices[index]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a Shape<R>, [usize; R])> for Idx<'a, R> {
|
impl<'a, const R: usize> From<(&'a TensorShape<R>, [usize; R])> for TensorIndex<'a, R> {
|
||||||
fn from((shape, indices): (&'a Shape<R>, [usize; R])) -> Self {
|
fn from((shape, indices): (&'a TensorShape<R>, [usize; R])) -> Self {
|
||||||
assert!(shape.check_indices(indices));
|
assert!(shape.check_indices(indices));
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<(&'a Shape<R>, usize)> for Idx<'a, R> {
|
impl<'a, const R: usize> From<(&'a TensorShape<R>, usize)> for TensorIndex<'a, R> {
|
||||||
fn from((shape, flat_index): (&'a Shape<R>, usize)) -> Self {
|
fn from((shape, flat_index): (&'a TensorShape<R>, usize)) -> Self {
|
||||||
let indices = shape.index_from_flat(flat_index).indices;
|
let indices = shape.index_from_flat(flat_index).indices;
|
||||||
Self::new(shape, indices)
|
Self::new(shape, indices)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> From<&'a Shape<R>> for Idx<'a, R> {
|
impl<'a, const R: usize> From<&'a TensorShape<R>> for TensorIndex<'a, R> {
|
||||||
fn from(shape: &'a Shape<R>) -> Self {
|
fn from(shape: &'a TensorShape<R>) -> Self {
|
||||||
Self::zero(shape)
|
Self::zero(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for Idx<'a, R> {
|
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for TensorIndex<'a, R> {
|
||||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
||||||
Self::zero(tensor.shape())
|
Self::zero(tensor.shape())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> {
|
impl<'a, const R: usize> std::fmt::Display for TensorIndex<'a, R> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "[")?;
|
write!(f, "[")?;
|
||||||
for (i, (&idx, &dim_size)) in self
|
for (i, (&idx, &dim_size)) in self
|
||||||
@ -373,11 +373,11 @@ impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> {
|
|||||||
|
|
||||||
// ---- Arithmetic Operations ----
|
// ---- Arithmetic Operations ----
|
||||||
|
|
||||||
impl<'a, const R: usize> Add for Idx<'a, R> {
|
impl<'a, const R: usize> Add for TensorIndex<'a, R> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn add(self, rhs: Self) -> Self::Output {
|
fn add(self, rhs: Self) -> Self::Output {
|
||||||
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
|
||||||
|
|
||||||
let mut result_indices = [0; R];
|
let mut result_indices = [0; R];
|
||||||
for i in 0..R {
|
for i in 0..R {
|
||||||
@ -391,11 +391,11 @@ impl<'a, const R: usize> Add for Idx<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Sub for Idx<'a, R> {
|
impl<'a, const R: usize> Sub for TensorIndex<'a, R> {
|
||||||
type Output = Self;
|
type Output = Self;
|
||||||
|
|
||||||
fn sub(self, rhs: Self) -> Self::Output {
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
|
||||||
|
|
||||||
let mut result_indices = [0; R];
|
let mut result_indices = [0; R];
|
||||||
for i in 0..R {
|
for i in 0..R {
|
||||||
@ -411,22 +411,22 @@ impl<'a, const R: usize> Sub for Idx<'a, R> {
|
|||||||
|
|
||||||
// ---- Iterator ----
|
// ---- Iterator ----
|
||||||
|
|
||||||
pub struct IdxIterator<'a, const R: usize> {
|
pub struct TensorIndexIterator<'a, const R: usize> {
|
||||||
current: Idx<'a, R>,
|
current: TensorIndex<'a, R>,
|
||||||
end: bool,
|
end: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IdxIterator<'a, R> {
|
impl<'a, const R: usize> TensorIndexIterator<'a, R> {
|
||||||
pub fn new(shape: &'a Shape<R>) -> Self {
|
pub fn new(shape: &'a TensorShape<R>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
current: Idx::zero(shape),
|
current: TensorIndex::zero(shape),
|
||||||
end: false,
|
end: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Iterator for IdxIterator<'a, R> {
|
impl<'a, const R: usize> Iterator for TensorIndexIterator<'a, R> {
|
||||||
type Item = Idx<'a, R>;
|
type Item = TensorIndex<'a, R>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.end {
|
if self.end {
|
||||||
@ -439,36 +439,36 @@ impl<'a, const R: usize> Iterator for IdxIterator<'a, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
|
impl<'a, const R: usize> IntoIterator for TensorIndex<'a, R> {
|
||||||
type Item = Idx<'a, R>;
|
type Item = TensorIndex<'a, R>;
|
||||||
type IntoIter = IdxIterator<'a, R>;
|
type IntoIter = TensorIndexIterator<'a, R>;
|
||||||
|
|
||||||
fn into_iter(self) -> Self::IntoIter {
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
IdxIterator {
|
TensorIndexIterator {
|
||||||
current: self,
|
current: self,
|
||||||
end: false,
|
end: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct IdxTransposedIterator<'a, const R: usize> {
|
pub struct TensorIndexTransposedIterator<'a, const R: usize> {
|
||||||
current: Idx<'a, R>,
|
current: TensorIndex<'a, R>,
|
||||||
order: [usize; R],
|
order: [usize; R],
|
||||||
end: bool,
|
end: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
|
impl<'a, const R: usize> TensorIndexTransposedIterator<'a, R> {
|
||||||
pub fn new(shape: &'a Shape<R>, order: [usize; R]) -> Self {
|
pub fn new(shape: &'a TensorShape<R>, order: [usize; R]) -> Self {
|
||||||
Self {
|
Self {
|
||||||
current: Idx::zero(shape),
|
current: TensorIndex::zero(shape),
|
||||||
end: false,
|
end: false,
|
||||||
order,
|
order,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> {
|
impl<'a, const R: usize> Iterator for TensorIndexTransposedIterator<'a, R> {
|
||||||
type Item = Idx<'a, R>;
|
type Item = TensorIndex<'a, R>;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
if self.end {
|
if self.end {
|
||||||
|
26
src/lib.rs
26
src/lib.rs
@ -7,11 +7,11 @@ pub mod shape;
|
|||||||
pub mod tensor;
|
pub mod tensor;
|
||||||
|
|
||||||
pub use axis::*;
|
pub use axis::*;
|
||||||
pub use index::Idx;
|
pub use index::TensorIndex;
|
||||||
pub use itertools::Itertools;
|
pub use itertools::Itertools;
|
||||||
use num::{Num, One, Zero};
|
use num::{Num, One, Zero};
|
||||||
pub use serde::{Deserialize, Serialize};
|
pub use serde::{Deserialize, Serialize};
|
||||||
pub use shape::Shape;
|
pub use shape::TensorShape;
|
||||||
pub use static_assertions::const_assert;
|
pub use static_assertions::const_assert;
|
||||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||||
use std::ops::{Index, IndexMut};
|
use std::ops::{Index, IndexMut};
|
||||||
@ -62,7 +62,7 @@ mod tests {
|
|||||||
let product = tensor1.tensor_product(&tensor2);
|
let product = tensor1.tensor_product(&tensor2);
|
||||||
|
|
||||||
// Check shape of the resulting tensor
|
// Check shape of the resulting tensor
|
||||||
assert_eq!(*product.shape(), Shape::new([2, 2, 2]));
|
assert_eq!(*product.shape(), TensorShape::new([2, 2, 2]));
|
||||||
|
|
||||||
// Check buffer of the resulting tensor
|
// Check buffer of the resulting tensor
|
||||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||||
@ -72,14 +72,14 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn serde_shape_serialization_test() {
|
fn serde_shape_serialization_test() {
|
||||||
// Create a shape instance
|
// Create a shape instance
|
||||||
let shape: Shape<3> = [1, 2, 3].into();
|
let shape: TensorShape<3> = [1, 2, 3].into();
|
||||||
|
|
||||||
// Serialize the shape to a JSON string
|
// Serialize the shape to a JSON string
|
||||||
let serialized =
|
let serialized =
|
||||||
serde_json::to_string(&shape).expect("Failed to serialize");
|
serde_json::to_string(&shape).expect("Failed to serialize");
|
||||||
|
|
||||||
// Deserialize the JSON string back into a shape
|
// Deserialize the JSON string back into a shape
|
||||||
let deserialized: Shape<3> =
|
let deserialized: TensorShape<3> =
|
||||||
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
||||||
|
|
||||||
// Check that the deserialized shape is equal to the original
|
// Check that the deserialized shape is equal to the original
|
||||||
@ -89,7 +89,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn tensor_serde_serialization_test() {
|
fn tensor_serde_serialization_test() {
|
||||||
// Create an instance of Tensor
|
// Create an instance of Tensor
|
||||||
let tensor: Tensor<i32, 2> = Tensor::new(Shape::new([2, 2]));
|
let tensor: Tensor<i32, 2> = Tensor::new(TensorShape::new([2, 2]));
|
||||||
|
|
||||||
// Serialize the Tensor to a JSON string
|
// Serialize the Tensor to a JSON string
|
||||||
let serialized =
|
let serialized =
|
||||||
@ -106,7 +106,7 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn iterate_3d_tensor() {
|
fn iterate_3d_tensor() {
|
||||||
let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
|
let shape = TensorShape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let mut num = 0;
|
let mut num = 0;
|
||||||
|
|
||||||
@ -144,7 +144,7 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn iterate_rank_4_tensor() {
|
fn iterate_rank_4_tensor() {
|
||||||
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
|
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
|
||||||
let shape = Shape::new([2, 2, 2, 2]);
|
let shape = TensorShape::new([2, 2, 2, 2]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let mut num = 0;
|
let mut num = 0;
|
||||||
|
|
||||||
@ -166,8 +166,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_dec_method() {
|
fn test_dec_method() {
|
||||||
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
|
let shape = TensorShape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
|
||||||
let mut index = Idx::zero(&shape);
|
let mut index = TensorIndex::zero(&shape);
|
||||||
|
|
||||||
// Increment the index to the maximum
|
// Increment the index to the maximum
|
||||||
for _ in 0..26 {
|
for _ in 0..26 {
|
||||||
@ -176,7 +176,7 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Check if the index is at the maximum
|
// Check if the index is at the maximum
|
||||||
assert_eq!(index, Idx::new(&shape, [2, 2, 2]));
|
assert_eq!(index, TensorIndex::new(&shape, [2, 2, 2]));
|
||||||
|
|
||||||
// Decrement step by step and check the index
|
// Decrement step by step and check the index
|
||||||
let expected_indices = [
|
let expected_indices = [
|
||||||
@ -212,7 +212,7 @@ mod tests {
|
|||||||
for (i, &expected) in expected_indices.iter().enumerate() {
|
for (i, &expected) in expected_indices.iter().enumerate() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
index,
|
index,
|
||||||
Idx::new(&shape, expected),
|
TensorIndex::new(&shape, expected),
|
||||||
"Failed at index {}",
|
"Failed at index {}",
|
||||||
i
|
i
|
||||||
);
|
);
|
||||||
@ -221,6 +221,6 @@ mod tests {
|
|||||||
|
|
||||||
// Finally, the index should reach [0, 0, 0]
|
// Finally, the index should reach [0, 0, 0]
|
||||||
index.dec();
|
index.dec();
|
||||||
assert_eq!(index, Idx::zero(&shape));
|
assert_eq!(index, TensorIndex::zero(&shape));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
54
src/shape.rs
54
src/shape.rs
@ -4,9 +4,9 @@ use serde::ser::{Serialize, SerializeTuple, Serializer};
|
|||||||
use std::fmt;
|
use std::fmt;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
pub struct Shape<const R: usize>([usize; R]);
|
pub struct TensorShape<const R: usize>([usize; R]);
|
||||||
|
|
||||||
impl<const R: usize> Shape<R> {
|
impl<const R: usize> TensorShape<R> {
|
||||||
pub const fn new(shape: [usize; R]) -> Self {
|
pub const fn new(shape: [usize; R]) -> Self {
|
||||||
Self(shape)
|
Self(shape)
|
||||||
}
|
}
|
||||||
@ -16,7 +16,7 @@ impl<const R: usize> Shape<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||||
let mut new_shape = Shape::new([0; R]);
|
let mut new_shape = TensorShape::new([0; R]);
|
||||||
for (new_index, &index) in indices.iter().enumerate() {
|
for (new_index, &index) in indices.iter().enumerate() {
|
||||||
new_shape.0[new_index] = self.0[index];
|
new_shape.0[new_index] = self.0[index];
|
||||||
}
|
}
|
||||||
@ -60,14 +60,14 @@ impl<const R: usize> Shape<R> {
|
|||||||
/// * `flat_index` - The flat index to convert.
|
/// * `flat_index` - The flat index to convert.
|
||||||
///
|
///
|
||||||
/// # Returns
|
/// # Returns
|
||||||
/// An `Idx<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
/// An `TensorIndex<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
||||||
///
|
///
|
||||||
/// # How It Works
|
/// # How It Works
|
||||||
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
||||||
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
||||||
/// and integer division to reduce the flat index for the next higher dimension.
|
/// and integer division to reduce the flat index for the next higher dimension.
|
||||||
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
||||||
pub fn index_from_flat(&self, flat_index: usize) -> Idx<R> {
|
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
|
||||||
let mut indices = [0; R];
|
let mut indices = [0; R];
|
||||||
let mut remaining = flat_index;
|
let mut remaining = flat_index;
|
||||||
|
|
||||||
@ -77,24 +77,24 @@ impl<const R: usize> Shape<R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||||
Idx::new(self, indices)
|
TensorIndex::new(self, indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub const fn index_zero(&self) -> Idx<R> {
|
pub const fn index_zero(&self) -> TensorIndex<R> {
|
||||||
Idx::zero(self)
|
TensorIndex::zero(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn index_max(&self) -> Idx<R> {
|
pub fn index_max(&self) -> TensorIndex<R> {
|
||||||
let max_indices =
|
let max_indices =
|
||||||
self.0
|
self.0
|
||||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||||
Idx::new(self, max_indices)
|
TensorIndex::new(self, max_indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_dims<const NAX: usize>(
|
pub fn remove_dims<const NAX: usize>(
|
||||||
&self,
|
&self,
|
||||||
dims_to_remove: [usize; NAX],
|
dims_to_remove: [usize; NAX],
|
||||||
) -> Shape<{ R - NAX }> {
|
) -> TensorShape<{ R - NAX }> {
|
||||||
// Create a new array to store the remaining dimensions
|
// Create a new array to store the remaining dimensions
|
||||||
let mut new_shape = [0; R - NAX];
|
let mut new_shape = [0; R - NAX];
|
||||||
let mut new_index = 0;
|
let mut new_index = 0;
|
||||||
@ -111,13 +111,13 @@ impl<const R: usize> Shape<R> {
|
|||||||
new_index += 1;
|
new_index += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape(new_shape)
|
TensorShape(new_shape)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||||
&self,
|
&self,
|
||||||
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
|
axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
|
||||||
) -> Shape<{ R - NAX }> {
|
) -> TensorShape<{ R - NAX }> {
|
||||||
// Create a new array to store the remaining dimensions
|
// Create a new array to store the remaining dimensions
|
||||||
let mut new_shape = [0; R - NAX];
|
let mut new_shape = [0; R - NAX];
|
||||||
let mut new_index = 0;
|
let mut new_index = 0;
|
||||||
@ -136,16 +136,16 @@ impl<const R: usize> Shape<R> {
|
|||||||
new_index += 1;
|
new_index += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
Shape(new_shape)
|
TensorShape(new_shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---- Serialize and Deserialize ----
|
// ---- Serialize and Deserialize ----
|
||||||
|
|
||||||
struct ShapeVisitor<const R: usize>;
|
struct TensorShapeVisitor<const R: usize>;
|
||||||
|
|
||||||
impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
|
impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
|
||||||
type Value = Shape<R>;
|
type Value = TensorShape<R>;
|
||||||
|
|
||||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
formatter.write_str(concat!("an array of length ", "{R}"))
|
formatter.write_str(concat!("an array of length ", "{R}"))
|
||||||
@ -161,20 +161,20 @@ impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
|
|||||||
.next_element()?
|
.next_element()?
|
||||||
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
|
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
|
||||||
}
|
}
|
||||||
Ok(Shape(arr))
|
Ok(TensorShape(arr))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'de, const R: usize> Deserialize<'de> for Shape<R> {
|
impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
|
||||||
fn deserialize<D>(deserializer: D) -> Result<Shape<R>, D::Error>
|
fn deserialize<D>(deserializer: D) -> Result<TensorShape<R>, D::Error>
|
||||||
where
|
where
|
||||||
D: Deserializer<'de>,
|
D: Deserializer<'de>,
|
||||||
{
|
{
|
||||||
deserializer.deserialize_tuple(R, ShapeVisitor)
|
deserializer.deserialize_tuple(R, TensorShapeVisitor)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const R: usize> Serialize for Shape<R> {
|
impl<const R: usize> Serialize for TensorShape<R> {
|
||||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
where
|
where
|
||||||
S: Serializer,
|
S: Serializer,
|
||||||
@ -189,23 +189,23 @@ impl<const R: usize> Serialize for Shape<R> {
|
|||||||
|
|
||||||
// ---- Blanket Implementations ----
|
// ---- Blanket Implementations ----
|
||||||
|
|
||||||
impl<const R: usize> From<[usize; R]> for Shape<R> {
|
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
|
||||||
fn from(shape: [usize; R]) -> Self {
|
fn from(shape: [usize; R]) -> Self {
|
||||||
Self::new(shape)
|
Self::new(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const R: usize> PartialEq for Shape<R> {
|
impl<const R: usize> PartialEq for TensorShape<R> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.0 == other.0
|
self.0 == other.0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<const R: usize> Eq for Shape<R> {}
|
impl<const R: usize> Eq for TensorShape<R> {}
|
||||||
|
|
||||||
// ---- From and Into Implementations ----
|
// ---- From and Into Implementations ----
|
||||||
|
|
||||||
impl<T, const R: usize> From<Tensor<T, R>> for Shape<R>
|
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
|
||||||
where
|
where
|
||||||
T: Value,
|
T: Value,
|
||||||
{
|
{
|
||||||
|
@ -8,11 +8,11 @@ pub struct Tensor<T, const R: usize> {
|
|||||||
#[getset(get = "pub", get_mut = "pub")]
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
buffer: Vec<T>,
|
buffer: Vec<T>,
|
||||||
#[getset(get = "pub")]
|
#[getset(get = "pub")]
|
||||||
shape: Shape<R>,
|
shape: TensorShape<R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
pub fn new(shape: Shape<R>) -> Self {
|
pub fn new(shape: TensorShape<R>) -> Self {
|
||||||
// Handle rank 0 tensor (scalar) as a special case
|
// Handle rank 0 tensor (scalar) as a special case
|
||||||
let total_size = if R == 0 {
|
let total_size = if R == 0 {
|
||||||
// A rank 0 tensor should still have a buffer with one element
|
// A rank 0 tensor should still have a buffer with one element
|
||||||
@ -26,16 +26,16 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
Self { buffer, shape }
|
Self { buffer, shape }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self {
|
||||||
Self { buffer, shape }
|
Self { buffer, shape }
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn reshape(self, shape: Shape<R>) -> Result<Self> {
|
pub fn reshape(self, shape: TensorShape<R>) -> Result<Self> {
|
||||||
if self.shape().size() != shape.size() {
|
if self.shape().size() != shape.size() {
|
||||||
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
||||||
let (lsize, rsize) = (self.shape().size(), shape.size());
|
let (lsize, rsize) = (self.shape().size(), shape.size());
|
||||||
Err(Error::InvalidArgument(format!(
|
Err(TensorError::InvalidArgument(format!(
|
||||||
"Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
||||||
)))
|
)))
|
||||||
} else {
|
} else {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
@ -46,7 +46,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||||
let buffer = Idx::from(self.shape())
|
let buffer = TensorIndex::from(self.shape())
|
||||||
.iter_transposed(order)
|
.iter_transposed(order)
|
||||||
.map(|index| self.get(index).unwrap().clone())
|
.map(|index| self.get(index).unwrap().clone())
|
||||||
.collect();
|
.collect();
|
||||||
@ -57,27 +57,27 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn idx(&self) -> Idx<R> {
|
pub fn idx(&self) -> TensorIndex<R> {
|
||||||
Idx::from(self)
|
TensorIndex::from(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
|
pub fn axis<'a>(&'a self, axis: usize) -> TensorAxis<'a, T, R> {
|
||||||
Axis::new(self, axis)
|
TensorAxis::new(self, axis)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, index: Idx<R>) -> Option<&T> {
|
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
|
||||||
self.buffer.get(index.flat())
|
self.buffer.get(index.flat())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
|
||||||
self.buffer.get_unchecked(index.flat())
|
self.buffer.get_unchecked(index.flat())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_mut(&mut self, index: Idx<R>) -> Option<&mut T> {
|
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
|
||||||
self.buffer.get_mut(index.flat())
|
self.buffer.get_mut(index.flat())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub unsafe fn get_unchecked_mut(&mut self, index: Idx<R>) -> &mut T {
|
pub unsafe fn get_unchecked_mut(&mut self, index: TensorIndex<R>) -> &mut T {
|
||||||
self.buffer.get_unchecked_mut(index.flat())
|
self.buffer.get_unchecked_mut(index.flat())
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -111,7 +111,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
|
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
|
||||||
if self.shape != other.shape {
|
if self.shape != other.shape {
|
||||||
panic!("Shapes of tensors do not match");
|
panic!("TensorShapes of tensors do not match");
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut result_buffer = Vec::with_capacity(self.buffer.len());
|
let mut result_buffer = Vec::with_capacity(self.buffer.len());
|
||||||
@ -148,7 +148,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
Tensor {
|
Tensor {
|
||||||
buffer: new_buffer,
|
buffer: new_buffer,
|
||||||
shape: Shape::new(new_shape_array),
|
shape: TensorShape::new(new_shape_array),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -170,7 +170,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
// Ensure the given axis is within the tensor's dimensions
|
// Ensure the given axis is within the tensor's dimensions
|
||||||
if axis >= R {
|
if axis >= R {
|
||||||
panic!("Axis out of bounds");
|
panic!("TensorAxis out of bounds");
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate the stride for each dimension and accumulate the flat index
|
// Calculate the stride for each dimension and accumulate the flat index
|
||||||
@ -190,16 +190,16 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
// ---- Indexing ----
|
// ---- Indexing ----
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> Index<Idx<'a, R>> for Tensor<T, R> {
|
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
|
||||||
type Output = T;
|
type Output = T;
|
||||||
|
|
||||||
fn index(&self, index: Idx<R>) -> &Self::Output {
|
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
|
||||||
&self.buffer[index.flat()]
|
&self.buffer[index.flat()]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> IndexMut<Idx<'a, R>> for Tensor<T, R> {
|
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>> for Tensor<T, R> {
|
||||||
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
|
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
|
||||||
&mut self.buffer[index.flat()]
|
&mut self.buffer[index.flat()]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -260,7 +260,7 @@ where
|
|||||||
|
|
||||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||||
tensor: &'a Tensor<T, R>,
|
tensor: &'a Tensor<T, R>,
|
||||||
index: Idx<'a, R>,
|
index: TensorIndex<'a, R>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||||
@ -325,15 +325,15 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
|
|||||||
|
|
||||||
// ---- From ----
|
// ---- From ----
|
||||||
|
|
||||||
impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> {
|
||||||
fn from(shape: Shape<R>) -> Self {
|
fn from(shape: TensorShape<R>) -> Self {
|
||||||
Self::new(shape)
|
Self::new(shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Value> From<T> for Tensor<T, 0> {
|
impl<T: Value> From<T> for Tensor<T, 0> {
|
||||||
fn from(value: T) -> Self {
|
fn from(value: T) -> Self {
|
||||||
let shape = Shape::new([]);
|
let shape = TensorShape::new([]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
tensor.buffer_mut()[0] = value;
|
tensor.buffer_mut()[0] = value;
|
||||||
tensor
|
tensor
|
||||||
@ -342,7 +342,7 @@ impl<T: Value> From<T> for Tensor<T, 0> {
|
|||||||
|
|
||||||
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
|
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
|
||||||
fn from(array: [T; X]) -> Self {
|
fn from(array: [T; X]) -> Self {
|
||||||
let shape = Shape::new([X]);
|
let shape = TensorShape::new([X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
@ -358,7 +358,7 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
|||||||
for Tensor<T, 2>
|
for Tensor<T, 2>
|
||||||
{
|
{
|
||||||
fn from(array: [[T; X]; Y]) -> Self {
|
fn from(array: [[T; X]; Y]) -> Self {
|
||||||
let shape = Shape::new([Y, X]);
|
let shape = TensorShape::new([Y, X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
@ -376,7 +376,7 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
|||||||
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
|
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
|
||||||
{
|
{
|
||||||
fn from(array: [[[T; X]; Y]; Z]) -> Self {
|
fn from(array: [[[T; X]; Y]; Z]) -> Self {
|
||||||
let shape = Shape::new([Z, Y, X]);
|
let shape = TensorShape::new([Z, Y, X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
@ -401,7 +401,7 @@ impl<
|
|||||||
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
||||||
{
|
{
|
||||||
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
||||||
let shape = Shape::new([W, Z, Y, X]);
|
let shape = TensorShape::new([W, Z, Y, X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
@ -429,7 +429,7 @@ impl<
|
|||||||
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
||||||
{
|
{
|
||||||
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
||||||
let shape = Shape::new([V, W, Z, Y, X]);
|
let shape = TensorShape::new([V, W, Z, Y, X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
@ -464,7 +464,7 @@ impl<
|
|||||||
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
||||||
{
|
{
|
||||||
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
||||||
let shape = Shape::new([U, V, W, Z, Y, X]);
|
let shape = TensorShape::new([U, V, W, Z, Y, X]);
|
||||||
let mut tensor = Tensor::new(shape);
|
let mut tensor = Tensor::new(shape);
|
||||||
let buffer = tensor.buffer_mut();
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user