🚀 Implement Tensor-type and basic methods #15

Merged
julius merged 8 commits from core-types into master 2024-01-03 21:52:54 +00:00
12 changed files with 631 additions and 1183 deletions

View File

@ -1,22 +1,3 @@
# Mainfold
# Manifold
```rust
// Create two tensors with different ranks and shapes
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
// Calculate tensor product
let product = tensor1.tensor_product(&tensor2);
println!("T1 * T2 = {}", product);
// Check shape of the resulting tensor
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]);
```
A tensor implementation in Rust.

View File

@ -1,34 +0,0 @@
To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps:
1. **Tensor Shapes**:
- Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\]
- Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\]
2. **Tensor Contraction Operation**:
- The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results.
- The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3.
3. **Contraction Steps**:
- Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix.
- \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \)
- Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix.
- \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \)
- Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix.
- \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \)
- Continue this process for the remaining rows of `a` and columns of `b`:
- For the second row of `a`:
- \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \)
- \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \)
- \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \)
- For the third row of `a`:
- \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \)
- \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \)
- \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \)
4. **Resulting Tensor**:
- The resulting 3x3 tensor from the contraction of `a` and `b` will be:
\[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\]
These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`.

View File

@ -1,239 +0,0 @@
# Operations Index
## 1. Addition
Element-wize addition of two tensors.
\( C = A + B \) where \( C_{ijk...} = A_{ijk...} + B_{ijk...} \) for all indices \( i, j, k, ... \).
```rust
let t1 = tensor!([[1, 2], [3, 4]]);
let t2 = tensor!([[5, 6], [7, 8]]);
let sum = t1 + t2;
```
```sh
[[7, 8], [10, 12]]
```
## 2. Subtraction
Element-wize substraction of two tensors.
\( C = A - B \) where \( C_{ijk...} = A_{ijk...} - B_{ijk...} \).
```rust
let t1 = tensor!([[1, 2], [3, 4]]);
let t2 = tensor!([[5, 6], [7, 8]]);
let diff = i1 - t2;
```
```sh
[[-4, -4], [-4, -4]]
```
## 3. Multiplication
Element-wize multiplication of two tensors.
\( C = A \odot B \) where \( C_{ijk...} = A_{ijk...} \times B_{ijk...} \).
```rust
let t1 = tensor!([[1, 2], [3, 4]]);
let t2 = tensor!([[5, 6], [7, 8]]);
let prod = t1 * t2;
```
```sh
[[5, 12], [21, 32]]
```
## 4. Division
Element-wize division of two tensors.
\( C = A \div B \) where \( C_{ijk...} = A_{ijk...} \div B_{ijk...} \).
```rust
let t1 = tensor!([[1, 2], [3, 4]]);
let t2 = tensor!([[1, 2], [3, 4]]);
let quot = t1 / t2;
```
```sh
[[1, 1], [1, 1]]
```
## 5. Contraction
Contract two tensors over given axes.
For matrices \( A \) and \( B \), \( C = AB \) where \( C_{ij} = \sum_k A_{ik} B_{kj} \).
```rust
let t1 = tensor!([[1, 2], [3, 4], [5, 6]]);
let t2 = tensor!([[1, 2, 3], [4, 5, 6]]);
let cont = contract((t1, [1]), (t2, [0]));
```
```sh
TODO!
```
## 6. Reduction (e.g., Sum)
\( \text{sum}(A) \) where sum over all elements of A.
```rust
let t1 = tensor!([[1, 2], [3, 4]]);
let total = t1.sum();
```
```sh
10
```
## 7. Broadcasting
Adjusts tensors with different shapes to make them compatible for element-wise operations automatically
when using supported functions.
## 8. Reshape
Changing the shape of a tensor without altering its data.
```rust
let t1 = tensor!([1, 2, 3, 4, 5, 6]);
let tr = t1.reshape([2, 3]);
```
```sh
[[1, 2, 3], [4, 5, 6]]
```
## 9. Transpose
Transpose a tensor over given axes.
\( B = A^T \) where \( B_{ij} = A_{ji} \).
```rust
let t1 = tensor!([1, 2, 3, 4]);
let transposed = t1.transpose();
```
```sh
TODO!
```
## 10. Concatenation
Joining tensors along a specified dimension.
```rust
let t1 = tensor!([1, 2, 3]);
let t2 = tensor!([4, 5, 6]);
let cat = t1.concat(&t2, 0);
```
```sh
TODO!
```
## 11. Slicing and Indexing
Extracting parts of tensors based on indices.
```rust
let t1 = tensor!([1, 2, 3, 4, 5, 6]);
let slice = t1.slice(s![1, ..]);
```
```sh
TODO!
```
## 12. Element-wise Functions (e.g., Sigmoid)
**Mathematical Definition**:
Applying a function to each element of a tensor, like \( \sigma(x) = \frac{1}{1 + e^{-x}} \) for sigmoid.
**Rust Code Example**:
```rust
let tensor = Tensor::<f32, 2>::from([-1.0, 0.0, 1.0, 2.0]); // 2x2 tensor
let sigmoid_tensor = tensor.map(|x| 1.0 / (1.0 + (-x).exp())); // Apply sigmoid element-wise
```
## 13. Gradient Computation/Automatic Differentiation
**Description**:
Calculating the derivatives of tensors, crucial for training machine learning models.
**Rust Code Example**: Depends on if your tensor library supports automatic differentiation. This is typically more complex and may involve constructing computational graphs.
## 14. Normalization Operations (e.g., Batch Normalization)
**Description**: Standardizing the inputs of a model across the batch dimension.
**Rust Code Example**: This is specific to deep learning libraries and may not be directly supported in a general-purpose tensor library.
## 15. Convolution Operations
**Description**: Essential for image processing and CNNs.
**Rust Code Example**: If your library supports it, convolutions typically involve using a specialized function that takes the input tensor and a kernel tensor.
## 16. Pooling Operations (e.g., Max Pooling)
**Description**: Reducing the spatial dimensions of
a tensor, commonly used in CNNs.
**Rust Code Example**: Again, this depends on your library's support for such operations.
## 17. Tensor Slicing and Joining
**Description**: Operations to slice a tensor into sub-tensors or join multiple tensors into a larger tensor.
**Rust Code Example**: Similar to the slicing and concatenation examples provided above.
## 18. Dimension Permutation
**Description**: Rearranging the dimensions of a tensor.
**Rust Code Example**:
```rust
let tensor = Tensor::<i32, 3>::from([...]); // 3D tensor
let permuted_tensor = tensor.permute_dims([2, 0, 1]); // Permute dimensions
```
## 19. Expand and Squeeze Operations
**Description**: Increasing or decreasing the dimensions of a tensor (adding/removing singleton dimensions).
**Rust Code Example**: Depends on the specific functions provided by your library.
## 20. Data Type Conversions
**Description**: Converting tensors from one data type to another.
**Rust Code Example**:
```rust
let tensor = Tensor::<i32, 2>::from([1, 2, 3, 4]); // 2x2 tensor
let converted_tensor = tensor.to_type::<f32>(); // Convert to f32 tensor
```
These examples provide a general guide. The actual implementation details may vary depending on the specific features and capabilities of the Rust tensor library you're using.
## 21. Tensor Decompositions
**CANDECOMP/PARAFAC (CP) Decomposition**: This decomposes a tensor into a sum of component rank-one tensors. For a third-order tensor, it's like expressing it as a sum of outer products of vectors. This is useful in applications like signal processing, psychometrics, and chemometrics.
**Tucker Decomposition**: Similar to PCA for matrices, Tucker Decomposition decomposes a tensor into a core tensor multiplied by a matrix along each mode (dimension). It's more general than CP Decomposition and is useful in areas like data compression and tensor completion.
**Higher-Order Singular Value Decomposition (HOSVD)**: A generalization of SVD for higher-order tensors, HOSVD decomposes a tensor into a core tensor and a set of orthogonal matrices for each mode. It's used in image processing, computer vision, and multilinear subspace learning.

View File

@ -1,94 +0,0 @@
#![allow(mixed_script_confusables)]
#![allow(non_snake_case)]
use bytemuck::cast_slice;
use manifold::contract;
use manifold::*;
fn tensor_product() {
println!("Tensor Product\n");
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
println!("T1: {}", tensor1);
println!("T2: {}", tensor2);
let product = tensor1.tensor_product(&tensor2);
println!("T1 * T2 = {}", product);
// Check shape of the resulting tensor
assert_eq!(product.shape(), &Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expect: &[i32] =
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
assert_eq!(product.buffer(), expect);
}
fn test_tensor_contraction_23x32() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
println!("a: {:?}\n{}\n", a.shape(), a);
// Tensor B is 3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
println!("b: {:?}\n{}\n", b.shape(), b);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let ctr10 = contract((&a, [1]), (&b, [0]));
println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10);
let ctr01 = contract((&a, [0]), (&b, [1]));
println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01);
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
// assert_eq!(
// contracted_tensor.buffer(),
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
// "Contracted tensor buffer does not match expected"
// );
}
fn test_tensor_contraction_rank3() {
let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
let contracted_tensor = contract((&a, [2]), (&b, [0]));
println!("a: {}", a);
println!("b: {}", b);
println!("contracted_tensor: {}", contracted_tensor);
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
// Verify specific elements of contracted_tensor
// assert_eq!(contracted_tensor[0][0][0][0], 50);
// assert_eq!(contracted_tensor[0][0][0][1], 60);
// ... further checks for other elements ...
}
fn transpose() {
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = tensor!([[1, 2, 3], [4, 5, 6]]);
// let iter = a.idx().iter_transposed([1, 0]);
// for idx in iter {
// println!("{idx}");
// }
let b = a.clone().transpose([1, 0]).unwrap();
println!("a: {}", a);
println!("ta: {}", b);
}
fn main() {
// tensor_product();
// test_tensor_contraction_23x32();
// test_tensor_contraction_rank3();
transpose();
}

View File

@ -1 +1,3 @@
max_width = 80
wrap_comments = true
comment_width = 80

View File

@ -2,16 +2,16 @@ use super::*;
use getset::{Getters, MutGetters};
#[derive(Clone, Debug, Getters)]
pub struct Axis<'a, T: Value, const R: usize> {
pub struct TensorAxis<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
tensor: &'a Tensor<T, R>,
#[getset(get = "pub")]
dim: usize,
}
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
impl<'a, T: Value, const R: usize> TensorAxis<'a, T, R> {
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
assert!(dim < R, "Axis out of bounds");
assert!(dim < R, "TensorAxis out of bounds");
Self { tensor, dim }
}
@ -19,40 +19,42 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
self.tensor.shape().get(self.dim)
}
pub fn shape(&self) -> &Shape<R> {
pub fn shape(&self) -> &TensorShape<R> {
self.tensor.shape()
}
pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> {
pub fn iter_level(&'a self, level: usize) -> TensorAxisIterator<'a, T, R> {
assert!(level < self.len(), "Level out of bounds");
let mut index = Idx::new(self.shape(), [0; R]);
let mut index = TensorIndex::new(self.shape().clone(), [0; R]);
index.set_axis(self.dim, level);
AxisIterator::new(self).set_start(level).set_end(level + 1)
TensorAxisIterator::new(self)
.set_start(level)
.set_end(level + 1)
}
}
#[derive(Clone, Debug, Getters, MutGetters)]
pub struct AxisIterator<'a, T: Value, const R: usize> {
pub struct TensorAxisIterator<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
axis: &'a Axis<'a, T, R>,
axis: &'a TensorAxis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")]
index: Idx<'a, R>,
index: TensorIndex<R>,
#[getset(get = "pub")]
end: Option<usize>,
}
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
pub fn new(axis: &'a Axis<'a, T, R>) -> Self {
impl<'a, T: Value, const R: usize> TensorAxisIterator<'a, T, R> {
pub fn new(axis: &'a TensorAxis<'a, T, R>) -> Self {
Self {
axis,
index: Idx::new(axis.shape(), [0; R]),
index: TensorIndex::new(axis.shape().clone(), [0; R]),
end: None,
}
}
pub fn set_start(self, start: usize) -> Self {
assert!(start < self.axis().len(), "Start out of bounds");
let mut index = Idx::new(self.axis().shape(), [0; R]);
let mut index = TensorIndex::new(self.axis().shape().clone(), [0; R]);
index.set_axis(self.axis.dim, start);
Self {
axis: self.axis(),
@ -92,7 +94,7 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
}
}
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
impl<'a, T: Value, const R: usize> Iterator for TensorAxisIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
@ -106,284 +108,11 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
}
}
impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> {
impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
type Item = &'a T;
type IntoIter = AxisIterator<'a, T, R>;
type IntoIter = TensorAxisIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
AxisIterator::new(&self)
}
}
pub fn contract<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
lhs: (&'a Tensor<T, R>, [usize; N]),
rhs: (&'a Tensor<T, S>, [usize; N]),
) -> Tensor<T, { R + S - 2 * N }>
where
[(); R - N]:,
[(); S - N]:,
[(); R + S - 2 * N]:,
{
let (lhs, la) = lhs;
let (rhs, ra) = rhs;
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let shape = Shape::new(shape);
let result = contract_axes(&lnc, &rnc);
Tensor::new_with_buffer(shape, result)
}
pub fn contract_axes<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
laxes: &'a [Axis<'a, T, R>],
raxes: &'a [Axis<'a, T, S>],
) -> Vec<T>
where
[(); R - N]:,
[(); S - N]:,
{
let mut result = vec![];
let axes = laxes.into_iter().zip(raxes);
for (laxis, raxis) in axes {
let mut axes_result: Vec<T> = vec![];
for i in 0..raxis.len() {
for j in 0..laxis.len() {
let mut sum = T::zero();
let llevel = laxis.into_iter();
let rlevel = raxis.into_iter();
let zip = llevel.level(j).zip(rlevel.level(i));
for (lv, rv) in zip {
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
}
result.extend_from_slice(&axes_result);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_contraction_simple() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Tensor B is 1x3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2]));
assert_eq!(
contracted_tensor.buffer(),
&[7, 10, 15, 22],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_tensor_contraction_23x32() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
println!("b: {}", b);
// Tensor B is 3x2
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
println!("a: {}", a);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
println!("contracted_tensor: {}", contracted_tensor);
assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
assert_eq!(
contracted_tensor.buffer(),
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_tensor_contraction_rank3() {
let a: Tensor<i32, 3> =
Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
let b: Tensor<i32, 3> =
Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
println!("a: {}", a);
println!("b: {}", b);
println!("contracted_tensor: {}", contracted_tensor);
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
// Verify specific elements of contracted_tensor
// assert_eq!(contracted_tensor[0][0][0][0], 50);
// assert_eq!(contracted_tensor[0][0][0][1], 60);
// ... further checks for other elements ...
}
// #[test]
// fn test_axis_iterator_disassemble() {
// // Creating a 2x2 Tensor for testing
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// // Testing iteration over the first axis (axis = 0)
// let axis = Axis::new(&tensor, 0);
// let mut axis_iter = axis.into_iter().disassemble();
// assert_eq!(axis_iter[0].next(), Some(&1.0));
// assert_eq!(axis_iter[0].next(), Some(&2.0));
// assert_eq!(axis_iter[0].next(), None);
// assert_eq!(axis_iter[1].next(), Some(&3.0));
// assert_eq!(axis_iter[1].next(), Some(&4.0));
// assert_eq!(axis_iter[1].next(), None);
// // Resetting the iterator for the second axis (axis = 1)
// let axis = Axis::new(&tensor, 1);
// let mut axis_iter = axis.into_iter().disassemble();
// assert_eq!(axis_iter[0].next(), Some(&1.0));
// assert_eq!(axis_iter[0].next(), Some(&3.0));
// assert_eq!(axis_iter[0].next(), None);
// assert_eq!(axis_iter[1].next(), Some(&2.0));
// assert_eq!(axis_iter[1].next(), Some(&4.0));
// assert_eq!(axis_iter[1].next(), None);
// }
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
#[test]
fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// Axis 0 (Layer-wise):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8]
//
// This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second.
let a0 = Axis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// Axis 1 (Row-wise within each layer):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row
// across all layers.
let a1 = Axis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// Axis 2 (Column-wise within each layer):
//
// t[0][0][0] = 1
// t[0][1][0] = 3
// t[1][0][0] = 5
// t[1][1][0] = 7
// t[0][0][1] = 2
// t[0][1][1] = 4
// t[1][0][1] = 6
// t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second
// column across all layers.
let a2 = Axis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
TensorAxisIterator::new(&self)
}
}

View File

@ -1,9 +1,9 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
pub type Result<T> = std::result::Result<T, TensorError>;
#[derive(Error, Debug)]
pub enum Error {
pub enum TensorError {
#[error("Invalid argument: {0}")]
InvalidArgument(String),
}

View File

@ -1,43 +1,47 @@
use super::*;
use getset::{Getters, MutGetters};
use std::cmp::Ordering;
use std::ops::{Add, Sub};
use std::{
ops::{Index, IndexMut, Add, Sub},
cmp::Ordering,
};
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct Idx<'a, const R: usize> {
pub struct TensorIndex<const R: usize> {
#[getset(get = "pub", get_mut = "pub")]
indices: [usize; R],
#[getset(get = "pub")]
shape: &'a Shape<R>,
shape: TensorShape<R>,
}
impl<'a, const R: usize> Idx<'a, R> {
pub const fn zero(shape: &'a Shape<R>) -> Self {
// ---- Construction and Initialization ---------------------------------------
impl<const R: usize> TensorIndex<R> {
pub fn new(shape: TensorShape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
pub const fn zero(shape: TensorShape<R>) -> Self {
Self {
indices: [0; R],
shape,
}
}
pub fn last(shape: &'a Shape<R>) -> Self {
pub fn last(shape: TensorShape<R>) -> Self {
let max_indices =
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
Self {
indices: max_indices,
shape: shape,
}
}
pub fn new(shape: &'a Shape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self {
indices,
shape: shape,
shape,
}
}
}
impl<const R: usize> TensorIndex<R> {
pub fn is_zero(&self) -> bool {
self.indices.iter().all(|&i| i == 0)
}
@ -51,7 +55,8 @@ impl<'a, const R: usize> Idx<'a, R> {
self.indices = [0; R];
}
/// Increments the index and returns a boolean indicating whether the end has been reached.
/// Increments the index and returns a boolean indicating whether the end
/// has been reached.
///
/// # Returns
/// `true` if the increment does not overflow and is still within bounds;
@ -74,10 +79,12 @@ impl<'a, const R: usize> Idx<'a, R> {
}
}
// If carry is still 1 after the loop, it means we've incremented past the last dimension
// If carry is still 1 after the loop, it means we've incremented past
// the last dimension
if carry == 1 {
// Set the index to an invalid state to indicate the end of the iteration indicated
// by setting the first index to the size of the first dimension
// Set the index to an invalid state to indicate the end of the
// iteration indicated by setting the first index to the
// size of the first dimension
self.indices[0] = self.shape.as_array()[0];
return true; // Indicate that the iteration is complete
}
@ -87,7 +94,7 @@ impl<'a, const R: usize> Idx<'a, R> {
// fn inc_axis
pub fn inc_axis(&mut self, fixed_axis: usize) {
assert!(fixed_axis < R, "Axis out of bounds");
assert!(fixed_axis < R, "TensorAxis out of bounds");
assert!(
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
"Index out of bounds"
@ -156,7 +163,8 @@ impl<'a, const R: usize> Idx<'a, R> {
{
if borrow {
if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
*i = dim_size - 1; // Wrap around to the maximum index of
// this dimension
} else {
*i -= 1; // Decrement the index
borrow = false; // No more borrowing needed
@ -183,7 +191,8 @@ impl<'a, const R: usize> Idx<'a, R> {
}
}
// Decrement the fixed axis if possible and reset other axes to their max
// Decrement the fixed axis if possible and reset other axes to their
// max
if self.indices[fixed_axis] > 0 {
self.indices[fixed_axis] -= 1;
for i in 0..R {
@ -216,45 +225,53 @@ impl<'a, const R: usize> Idx<'a, R> {
}
}
// If no axis can be decremented, set the first axis in the order to indicate overflow
// If no axis can be decremented, set the first axis in the order to
// indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
/// Converts the multi-dimensional index to a flat index.
///
/// This method calculates the flat index corresponding to the multi-dimensional index
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
/// The calculation is based on the assumption that the tensor is stored in row-major order,
/// This method calculates the flat index corresponding to the
/// multi-dimensional index stored in `self.indices`, given the shape of
/// the tensor stored in `self.shape`. The calculation is based on the
/// assumption that the tensor is stored in row-major order,
/// where the last dimension varies the fastest.
///
/// # Returns
/// The flat index corresponding to the multi-dimensional index.
///
/// # How It Works
/// - The method iterates over each pair of corresponding index and dimension size,
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
/// - The method iterates over each pair of corresponding index and
/// dimension size, starting from the last dimension (hence `rev()` is
/// used for reverse iteration).
/// - In each iteration, it performs two main operations:
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
/// of dimension sizes (`product`). This calculates the contribution of the current index
/// to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
/// This updates `product` for the next iteration, as each dimension contributes to the
/// flat index based on the sizes of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an initial state of
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
/// - The final flat index is obtained after the last iteration, which is the first element
/// of the tuple resulting from the `fold`.
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a
/// running product of dimension sizes (`product`). This calculates the
/// contribution of the current index to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension
/// size (`dim_size`). This updates `product` for the next iteration,
/// as each dimension contributes to the flat index based on the sizes
/// of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an
/// initial state of `(0, 1)` where `0` is the initial flat index and `1`
/// is the initial product.
/// - The final flat index is obtained after the last iteration, which is
/// the first element of the tuple resulting from the `fold`.
///
/// # Example
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
/// - Starting with a flat index of 0 and a product of 1,
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update
/// the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update
/// the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The
/// final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize {
self.indices
self.indices()
.iter()
.zip(&self.shape.as_array())
.zip(&self.shape().as_array())
.rev()
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
(flat_index + idx * product, product * dim_size)
@ -263,13 +280,13 @@ impl<'a, const R: usize> Idx<'a, R> {
}
pub fn set_axis(&mut self, axis: usize, value: usize) {
assert!(axis < R, "Axis out of bounds");
assert!(axis < R, "TensorAxis out of bounds");
// assert!(value < self.shape.get(axis), "Value out of bounds");
self.indices[axis] = value;
}
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
assert!(axis < R, "Axis out of bounds");
assert!(axis < R, "TensorAxis out of bounds");
if value < self.shape.get(axis) {
self.indices[axis] = value;
true
@ -279,41 +296,41 @@ impl<'a, const R: usize> Idx<'a, R> {
}
pub fn get_axis(&self, axis: usize) -> usize {
assert!(axis < R, "Axis out of bounds");
assert!(axis < R, "TensorAxis out of bounds");
self.indices[axis]
}
pub fn iter_transposed(
&self,
order: [usize; R],
) -> IdxTransposedIterator<'a, R> {
IdxTransposedIterator::new(self.shape(), order)
) -> TensorIndexTransposedIterator<R> {
TensorIndexTransposedIterator::new(self.shape().clone(), order)
}
}
// --- blanket impls ---
impl<'a, const R: usize> PartialEq for Idx<'a, R> {
impl<const R: usize> PartialEq for TensorIndex<R> {
fn eq(&self, other: &Self) -> bool {
self.flat() == other.flat()
}
}
impl<'a, const R: usize> Eq for Idx<'a, R> {}
impl<const R: usize> Eq for TensorIndex<R> {}
impl<'a, const R: usize> PartialOrd for Idx<'a, R> {
impl<const R: usize> PartialOrd for TensorIndex<R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.flat().partial_cmp(&other.flat())
}
}
impl<'a, const R: usize> Ord for Idx<'a, R> {
impl<const R: usize> Ord for TensorIndex<R> {
fn cmp(&self, other: &Self) -> Ordering {
self.flat().cmp(&other.flat())
}
}
impl<'a, const R: usize> Index<usize> for Idx<'a, R> {
impl<const R: usize> Index<usize> for TensorIndex<R> {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
@ -321,39 +338,45 @@ impl<'a, const R: usize> Index<usize> for Idx<'a, R> {
}
}
impl<'a, const R: usize> IndexMut<usize> for Idx<'a, R> {
impl<const R: usize> IndexMut<usize> for TensorIndex<R> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.indices[index]
}
}
impl<'a, const R: usize> From<(&'a Shape<R>, [usize; R])> for Idx<'a, R> {
fn from((shape, indices): (&'a Shape<R>, [usize; R])) -> Self {
impl<const R: usize> From<(TensorShape<R>, [usize; R])>
for TensorIndex<R>
{
fn from((shape, indices): (TensorShape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<'a, const R: usize> From<(&'a Shape<R>, usize)> for Idx<'a, R> {
fn from((shape, flat_index): (&'a Shape<R>, usize)) -> Self {
impl<const R: usize> From<(TensorShape<R>, usize)>
for TensorIndex<R>
{
fn from((shape, flat_index): (TensorShape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
}
}
impl<'a, const R: usize> From<&'a Shape<R>> for Idx<'a, R> {
fn from(shape: &'a Shape<R>) -> Self {
impl<const R: usize> From<TensorShape<R>> for TensorIndex<R> {
fn from(shape: TensorShape<R>) -> Self {
Self::zero(shape)
}
}
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for Idx<'a, R> {
fn from(tensor: &'a Tensor<T, R>) -> Self {
Self::zero(tensor.shape())
impl<T: Value, const R: usize> From<Tensor<T, R>>
for TensorIndex<R>
{
fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape().clone())
}
}
impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> {
impl<const R: usize> std::fmt::Display for TensorIndex<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for (i, (&idx, &dim_size)) in self
@ -373,11 +396,11 @@ impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> {
// ---- Arithmetic Operations ----
impl<'a, const R: usize> Add for Idx<'a, R> {
impl<const R: usize> Add for TensorIndex<R> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
@ -391,11 +414,11 @@ impl<'a, const R: usize> Add for Idx<'a, R> {
}
}
impl<'a, const R: usize> Sub for Idx<'a, R> {
impl<const R: usize> Sub for TensorIndex<R> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
assert_eq!(self.shape, rhs.shape, "TensorShape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
@ -411,22 +434,22 @@ impl<'a, const R: usize> Sub for Idx<'a, R> {
// ---- Iterator ----
pub struct IdxIterator<'a, const R: usize> {
current: Idx<'a, R>,
pub struct TensorIndexIterator<const R: usize> {
current: TensorIndex<R>,
end: bool,
}
impl<'a, const R: usize> IdxIterator<'a, R> {
pub fn new(shape: &'a Shape<R>) -> Self {
impl<const R: usize> TensorIndexIterator<R> {
pub fn new(shape: TensorShape<R>) -> Self {
Self {
current: Idx::zero(shape),
current: TensorIndex::zero(shape),
end: false,
}
}
}
impl<'a, const R: usize> Iterator for IdxIterator<'a, R> {
type Item = Idx<'a, R>;
impl<const R: usize> Iterator for TensorIndexIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
@ -439,36 +462,36 @@ impl<'a, const R: usize> Iterator for IdxIterator<'a, R> {
}
}
impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
type Item = Idx<'a, R>;
type IntoIter = IdxIterator<'a, R>;
impl<const R: usize> IntoIterator for TensorIndex<R> {
type Item = TensorIndex<R>;
type IntoIter = TensorIndexIterator<R>;
fn into_iter(self) -> Self::IntoIter {
IdxIterator {
TensorIndexIterator {
current: self,
end: false,
}
}
}
pub struct IdxTransposedIterator<'a, const R: usize> {
current: Idx<'a, R>,
pub struct TensorIndexTransposedIterator<const R: usize> {
current: TensorIndex<R>,
order: [usize; R],
end: bool,
}
impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
pub fn new(shape: &'a Shape<R>, order: [usize; R]) -> Self {
impl<const R: usize> TensorIndexTransposedIterator<R> {
pub fn new(shape: TensorShape<R>, order: [usize; R]) -> Self {
Self {
current: Idx::zero(shape),
current: TensorIndex::zero(shape),
end: false,
order,
}
}
}
impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> {
type Item = Idx<'a, R>;
impl<const R: usize> Iterator for TensorIndexTransposedIterator<R> {
type Item = TensorIndex<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {

View File

@ -1,40 +1,15 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![warn(clippy::all)]
pub mod axis;
pub mod error;
pub mod index;
pub mod shape;
pub mod tensor;
pub mod value;
pub use axis::*;
pub use index::Idx;
pub use itertools::Itertools;
use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize};
pub use shape::Shape;
pub use static_assertions::const_assert;
pub use std::fmt::{Display, Formatter, Result as FmtResult};
use std::ops::{Index, IndexMut};
pub use std::sync::Arc;
pub use tensor::{Tensor, TensorIterator};
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ std::iter::Sum
{
}
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
#[macro_export]
macro_rules! tensor {
@ -43,184 +18,19 @@ macro_rules! tensor {
};
}
// ---- Tests ----
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_tensor_product() {
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
let product = tensor1.tensor_product(&tensor2);
// Check shape of the resulting tensor
assert_eq!(*product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
assert_eq!(product.buffer(), &expected_buffer);
#[macro_export]
macro_rules! shape {
($array:expr) => {
TensorShape::from($array)
};
}
#[test]
fn serde_shape_serialization_test() {
// Create a shape instance
let shape: Shape<3> = [1, 2, 3].into();
// Serialize the shape to a JSON string
let serialized =
serde_json::to_string(&shape).expect("Failed to serialize");
// Deserialize the JSON string back into a shape
let deserialized: Shape<3> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized shape is equal to the original
assert_eq!(shape, deserialized);
}
#[test]
fn tensor_serde_serialization_test() {
// Create an instance of Tensor
let tensor: Tensor<i32, 2> = Tensor::new(Shape::new([2, 2]));
// Serialize the Tensor to a JSON string
let serialized =
serde_json::to_string(&tensor).expect("Failed to serialize");
// Deserialize the JSON string back into a Tensor
let deserialized: Tensor<i32, 2> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized Tensor is equal to the original
assert_eq!(tensor.buffer(), deserialized.buffer());
assert_eq!(tensor.shape(), deserialized.shape());
}
#[test]
fn iterate_3d_tensor() {
let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
tensor.buffer_mut()[i * 4 + j * 2 + k] = num;
num += 1;
}
}
}
println!("{}", tensor);
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
println!("{}", iter);
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&4));
assert_eq!(iter.next(), Some(&5));
assert_eq!(iter.next(), Some(&6));
assert_eq!(iter.next(), Some(&7));
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
}
#[test]
fn iterate_rank_4_tensor() {
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
let shape = Shape::new([2, 2, 2, 2]);
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..tensor.len() {
tensor.buffer_mut()[i] = num;
num += 1;
}
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
for expected_value in 0..tensor.len() {
assert_eq!(*iter.next().unwrap(), expected_value);
}
// Ensure the iterator is exhausted
assert!(iter.next().is_none());
}
#[test]
fn test_dec_method() {
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
let mut index = Idx::zero(&shape);
// Increment the index to the maximum
for _ in 0..26 {
// 3 * 3 * 3 - 1 = 26 increments to reach the end
index.inc();
}
// Check if the index is at the maximum
assert_eq!(index, Idx::new(&shape, [2, 2, 2]));
// Decrement step by step and check the index
let expected_indices = [
[2, 2, 2],
[2, 2, 1],
[2, 2, 0],
[2, 1, 2],
[2, 1, 1],
[2, 1, 0],
[2, 0, 2],
[2, 0, 1],
[2, 0, 0],
[1, 2, 2],
[1, 2, 1],
[1, 2, 0],
[1, 1, 2],
[1, 1, 1],
[1, 1, 0],
[1, 0, 2],
[1, 0, 1],
[1, 0, 0],
[0, 2, 2],
[0, 2, 1],
[0, 2, 0],
[0, 1, 2],
[0, 1, 1],
[0, 1, 0],
[0, 0, 2],
[0, 0, 1],
[0, 0, 0],
];
for (i, &expected) in expected_indices.iter().enumerate() {
assert_eq!(
index,
Idx::new(&shape, expected),
"Failed at index {}",
i
);
index.dec();
}
// Finally, the index should reach [0, 0, 0]
index.dec();
assert_eq!(index, Idx::zero(&shape));
}
#[macro_export]
macro_rules! index {
($tensor:expr) => {
TensorIndex::zero($tensor.shape().clone())
};
($tensor:expr, $indices:expr) => {
TensorIndex::from(($tensor.shape().clone(), $indices))
};
}

View File

@ -1,12 +1,13 @@
use super::*;
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt;
use std::fmt::{Result as FmtResult, Formatter};
use core::result::Result as SerdeResult;
#[derive(Clone, Copy, Debug)]
pub struct Shape<const R: usize>([usize; R]);
pub struct TensorShape<const R: usize>([usize; R]);
impl<const R: usize> Shape<R> {
impl<const R: usize> TensorShape<R> {
pub const fn new(shape: [usize; R]) -> Self {
Self(shape)
}
@ -16,7 +17,7 @@ impl<const R: usize> Shape<R> {
}
pub fn reorder(&self, indices: [usize; R]) -> Self {
let mut new_shape = Shape::new([0; R]);
let mut new_shape = TensorShape::new([0; R]);
for (new_index, &index) in indices.iter().enumerate() {
new_shape.0[new_index] = self.0[index];
}
@ -60,14 +61,18 @@ impl<const R: usize> Shape<R> {
/// * `flat_index` - The flat index to convert.
///
/// # Returns
/// An `Idx<R>` instance representing the multi-dimensional index corresponding to the flat index.
/// An `TensorIndex<R>` instance representing the multi-dimensional index
/// corresponding to the flat index.
///
/// # How It Works
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
/// and integer division to reduce the flat index for the next higher dimension.
/// - This process is repeated for each dimension to build the multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> Idx<R> {
/// - The method iterates over the dimensions of the tensor in reverse order
/// (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in
/// the current dimension and integer division to reduce the flat index
/// for the next higher dimension.
/// - This process is repeated for each dimension to build the
/// multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> TensorIndex<R> {
let mut indices = [0; R];
let mut remaining = flat_index;
@ -77,24 +82,24 @@ impl<const R: usize> Shape<R> {
}
indices.reverse(); // Reverse the indices to match the original dimension order
Idx::new(self, indices)
TensorIndex::new(self.clone(), indices)
}
pub const fn index_zero(&self) -> Idx<R> {
Idx::zero(self)
pub fn index_zero(&self) -> TensorIndex<R> {
TensorIndex::zero(self.clone())
}
pub fn index_max(&self) -> Idx<R> {
pub fn index_max(&self) -> TensorIndex<R> {
let max_indices =
self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
Idx::new(self, max_indices)
TensorIndex::new(self.clone(), max_indices)
}
pub fn remove_dims<const NAX: usize>(
&self,
dims_to_remove: [usize; NAX],
) -> Shape<{ R - NAX }> {
) -> TensorShape<{ R - NAX }> {
// Create a new array to store the remaining dimensions
let mut new_shape = [0; R - NAX];
let mut new_index = 0;
@ -111,13 +116,13 @@ impl<const R: usize> Shape<R> {
new_index += 1;
}
Shape(new_shape)
TensorShape(new_shape)
}
pub fn remove_axes<'a, T: Value, const NAX: usize>(
&self,
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
) -> Shape<{ R - NAX }> {
axes_to_remove: &'a [TensorAxis<'a, T, R>; NAX],
) -> TensorShape<{ R - NAX }> {
// Create a new array to store the remaining dimensions
let mut new_shape = [0; R - NAX];
let mut new_index = 0;
@ -136,22 +141,22 @@ impl<const R: usize> Shape<R> {
new_index += 1;
}
Shape(new_shape)
TensorShape(new_shape)
}
}
// ---- Serialize and Deserialize ----
struct ShapeVisitor<const R: usize>;
struct TensorShapeVisitor<const R: usize>;
impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
type Value = Shape<R>;
impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
type Value = TensorShape<R>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
formatter.write_str(concat!("an array of length ", "{R}"))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
fn visit_seq<A>(self, mut seq: A) -> SerdeResult<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
@ -161,21 +166,21 @@ impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
.next_element()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
}
Ok(Shape(arr))
Ok(TensorShape(arr))
}
}
impl<'de, const R: usize> Deserialize<'de> for Shape<R> {
fn deserialize<D>(deserializer: D) -> Result<Shape<R>, D::Error>
impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
fn deserialize<D>(deserializer: D) -> SerdeResult<TensorShape<R>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_tuple(R, ShapeVisitor)
deserializer.deserialize_tuple(R, TensorShapeVisitor)
}
}
impl<const R: usize> Serialize for Shape<R> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
impl<const R: usize> Serialize for TensorShape<R> {
fn serialize<S>(&self, serializer: S) -> SerdeResult<S::Ok, S::Error>
where
S: Serializer,
{
@ -189,23 +194,23 @@ impl<const R: usize> Serialize for Shape<R> {
// ---- Blanket Implementations ----
impl<const R: usize> From<[usize; R]> for Shape<R> {
impl<const R: usize> From<[usize; R]> for TensorShape<R> {
fn from(shape: [usize; R]) -> Self {
Self::new(shape)
}
}
impl<const R: usize> PartialEq for Shape<R> {
impl<const R: usize> PartialEq for TensorShape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for Shape<R> {}
impl<const R: usize> Eq for TensorShape<R> {}
// ---- From and Into Implementations ----
impl<T, const R: usize> From<Tensor<T, R>> for Shape<R>
impl<T, const R: usize> From<Tensor<T, R>> for TensorShape<R>
where
T: Value,
{

View File

@ -1,24 +1,50 @@
use super::*;
use crate::error::*;
use getset::{Getters, MutGetters};
use std::fmt;
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter, Result as FmtResult},
ops::{Index, IndexMut},
};
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
/// a vector, a rank 2 tensor is a matrix, and so on.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.rank(), 2);
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct Tensor<T, const R: usize> {
#[getset(get = "pub", get_mut = "pub")]
buffer: Vec<T>,
#[getset(get = "pub")]
shape: Shape<R>,
shape: TensorShape<R>,
}
// ---- Construction and Initialization ---------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn new(shape: Shape<R>) -> Self {
/// Create a new tensor with the given shape. The rank of the tensor is
/// determined by the shape and all elements are initialized to zero.
///
/// ```
/// use manifold::Tensor;
///
/// let t = Tensor::<f64, 2>::new([3, 3].into());
/// assert_eq!(t.shape().as_array(), [3, 3]);
/// ```
pub fn new(shape: TensorShape<R>) -> Self {
// Handle rank 0 tensor (scalar) as a special case
let total_size = if R == 0 {
// A rank 0 tensor should still have a buffer with one element
1
} else {
// For tensors of rank 1 or higher, calculate the total size normally
// For tensors of rank 1 or higher, calculate the total size
// normally
shape.iter().product()
};
@ -26,27 +52,356 @@ impl<T: Value, const R: usize> Tensor<T, R> {
Self { buffer, shape }
}
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
/// Create a new tensor with the given shape and initialize it from the
/// given buffer. The rank of the tensor is determined by the shape.
///
/// ```
/// use manifold::Tensor;
///
/// let buffer = vec![1, 2, 3, 4, 5, 6];
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer);
/// assert_eq!(t.shape().as_array(), [2, 3]);
/// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]);
/// ```
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self {
Self { buffer, shape }
}
}
pub fn reshape(self, shape: Shape<R>) -> Result<Self> {
// ---- Trivial Getters -------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn rank(&self) -> usize {
R
}
pub fn len(&self) -> usize {
self.buffer().len()
}
}
// ---- Get Values ------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Get a reference to a value at the given index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape().clone(), [1, 1]).into();
/// assert_eq!(t.get(i), Some(&4));
/// ```
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
self.buffer().get(index.flat())
}
/// Get a reference to a value at the given index without bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape().clone(), [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
/// ```
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
self.buffer().get_unchecked(index.flat())
}
/// Get a mutable reference to a value at the given index.
///
/// ```
/// use manifold::*;
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.get_mut(index!(&t, [1, 1])), Some(&mut 4));
/// ```
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
self.buffer_mut().get_mut(index.flat())
}
/// Get a mutable reference to a value at the given index without bounds
/// checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone();
/// let i = (s, [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
/// ```
pub unsafe fn get_unchecked_mut(
&mut self,
index: TensorIndex<R>,
) -> &mut T {
self.buffer_mut().get_unchecked_mut(index.flat())
}
/// Get a reference to a value at the given flat index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.get_flat(3), Some(&4));
/// ```
pub fn get_flat(&self, index: usize) -> Option<&T> {
self.buffer().get(index)
}
/// Get a reference to a value at the given flat index without bounds
/// checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// unsafe { assert_eq!(t.get_flat_unchecked(3), &4); }
/// ```
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer().get_unchecked(index)
}
/// Get a mutable reference to a value at the given flat index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.get_flat_mut(3), Some(&mut 4));
/// ```
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer_mut().get_mut(index)
}
/// Get a mutable reference to a value at the given flat index without
/// bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// unsafe { assert_eq!(t.get_flat_unchecked_mut(3), &mut 4); }
/// ```
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer_mut().get_unchecked_mut(index)
}
}
// ---- Arithmetic ------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Elementwise operation on two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_for_each(&a, &b, &mut c, &|a, b| a * b).unwrap();
/// assert_eq!(c, tensor!([[5, 12], [21, 32]]));
/// ```
pub fn ew_for_each(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
f: &dyn Fn(T, T) -> T,
) -> Result<()> {
if self.shape() != other.shape() {
return Err(TensorError::InvalidArgument(format!(
"TensorShape mismatch: {:?} != {:?}",
self.shape(),
other.shape()
)));
} else if self.shape() != result.shape() {
return Err(TensorError::InvalidArgument(format!(
"TensorShape mismatch: {:?} != {:?}",
self.shape(),
result.shape()
)));
}
for (i, (a, b)) in
self.buffer().iter().zip(other.buffer().iter()).enumerate()
{
unsafe {
*result.get_flat_unchecked_mut(i) = f(*a, *b);
}
}
Ok(())
}
/// Elementwise multiplication of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_multiply(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[5, 12], [21, 32]]));
/// ```
pub fn ew_multiply(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a * b)
}
/// Elementwise addition of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_add(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[6, 8], [10, 12]]));
/// ```
pub fn ew_add(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a + b)
}
/// Elementwise subtraction of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_subtract(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[-4, -4], [-4, -4]]));
/// ```
pub fn ew_subtract(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a - b)
}
/// Elementwise division of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[2, 4], [8, 16]]);
/// let b = tensor!([[2, 2], [4, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_divide(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[1, 2], [2, 2]]));
/// ```
pub fn ew_divide(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a / b)
}
/// Elementwise modulo of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[2, 2], [3, 3]]);
/// let b = tensor!([[4, 4], [6, 9]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_modulo(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[2, 2], [3, 3]]));
/// ```
pub fn ew_modulo(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a % b)
}
// pub fn product<const S: usize>(
// &self,
// other: &Tensor<T, S>,
// ) -> Tensor<T, { R + S }> {
// let mut new_shape_vec = Vec::new();
// new_shape_vec.extend_from_slice(&self.shape().as_array());
// new_shape_vec.extend_from_slice(&other.shape().as_array());
// let new_shape_array: [usize; R + S] = new_shape_vec
// .try_into()
// .expect("Failed to create shape array");
// let mut new_buffer =
// Vec::with_capacity(self.buffer.len() * other.buffer.len());
// for &item_self in &self.buffer {
// for &item_other in &other.buffer {
// new_buffer.push(item_self * item_other);
// }
// }
// Tensor {
// buffer: new_buffer,
// shape: TensorShape::new(new_shape_array),
// }
// }
}
// ---- Reshape ---------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Reshape the tensor to the given shape. The total size of the new shape
/// must be the same as the total size of the old shape.
///
/// ```
/// use manifold::{tensor, shape, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let s = shape!([4]);
/// let t = t.reshape(s).unwrap();
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
/// ```
pub fn reshape<const S: usize>(
self,
shape: TensorShape<S>,
) -> Result<Tensor<T, S>> {
if self.shape().size() != shape.size() {
let (ls, rs) = (self.shape().as_array(), shape.as_array());
let (lsize, rsize) = (self.shape().size(), shape.size());
Err(Error::InvalidArgument(format!(
"Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
Err(TensorError::InvalidArgument(format!(
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
)))
} else {
Ok(Self {
buffer: self.buffer,
shape,
})
Ok(Tensor::new_with_buffer(shape, self.buffer))
}
}
}
// ---- Transpose -------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Transpose the tensor according to the given order. The order must be a
/// permutation of the tensor's axes.
///
/// ```
/// use manifold::{tensor, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let t = t.transpose([1, 0]).unwrap();
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = Idx::from(self.shape())
let buffer = TensorIndex::from(self.shape().clone())
.iter_transposed(order)
.map(|index| self.get(index).unwrap().clone())
.collect();
@ -56,150 +411,22 @@ impl<T: Value, const R: usize> Tensor<T, R> {
shape: self.shape().reorder(order),
})
}
pub fn idx(&self) -> Idx<R> {
Idx::from(self)
}
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
Axis::new(self, axis)
}
// ---- Indexing --------------------------------------------------------------
pub fn get(&self, index: Idx<R>) -> Option<&T> {
self.buffer.get(index.flat())
}
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
self.buffer.get_unchecked(index.flat())
}
pub fn get_mut(&mut self, index: Idx<R>) -> Option<&mut T> {
self.buffer.get_mut(index.flat())
}
pub unsafe fn get_unchecked_mut(&mut self, index: Idx<R>) -> &mut T {
self.buffer.get_unchecked_mut(index.flat())
}
pub fn get_flat(&self, index: usize) -> Option<&T> {
self.buffer.get(index)
}
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer.get_unchecked(index)
}
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer.get_mut(index)
}
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer.get_unchecked_mut(index)
}
pub fn rank(&self) -> usize {
R
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn iter(&self) -> TensorIterator<T, R> {
TensorIterator::new(self)
}
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
if self.shape != other.shape {
panic!("Shapes of tensors do not match");
}
let mut result_buffer = Vec::with_capacity(self.buffer.len());
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
result_buffer.push(*a * *b);
}
Tensor {
buffer: result_buffer,
shape: self.shape,
}
}
pub fn tensor_product<const S: usize>(
&self,
other: &Tensor<T, S>,
) -> Tensor<T, { R + S }> {
let mut new_shape_vec = Vec::new();
new_shape_vec.extend_from_slice(&self.shape.as_array());
new_shape_vec.extend_from_slice(&other.shape.as_array());
let new_shape_array: [usize; R + S] = new_shape_vec
.try_into()
.expect("Failed to create shape array");
let mut new_buffer =
Vec::with_capacity(self.buffer.len() * other.buffer.len());
for &item_self in &self.buffer {
for &item_other in &other.buffer {
new_buffer.push(item_self * item_other);
}
}
Tensor {
buffer: new_buffer,
shape: Shape::new(new_shape_array),
}
}
// Retrieve an element based on a specific axis and index
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
// Convert axis and index to a flat index
let flat_index = self.axis_to_flat_index(axis, index);
if flat_index >= self.buffer.len() {
return None;
}
Some(self.buffer[flat_index])
}
// Convert axis and index to a flat index in the buffer
fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize {
let mut flat_index = 0;
let mut stride = 1;
// Ensure the given axis is within the tensor's dimensions
if axis >= R {
panic!("Axis out of bounds");
}
// Calculate the stride for each dimension and accumulate the flat index
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
if i > axis {
stride *= dim_size;
} else if i == axis {
flat_index += index * stride;
break; // We've reached the target axis
}
}
flat_index
}
}
// ---- Indexing ----
impl<'a, T: Value, const R: usize> Index<Idx<'a, R>> for Tensor<T, R> {
impl<T: Value, const R: usize> Index<TensorIndex<R>> for Tensor<T, R> {
type Output = T;
fn index(&self, index: Idx<R>) -> &Self::Output {
fn index(&self, index: TensorIndex<R>) -> &Self::Output {
&self.buffer[index.flat()]
}
}
impl<'a, T: Value, const R: usize> IndexMut<Idx<'a, R>> for Tensor<T, R> {
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
impl<T: Value, const R: usize> IndexMut<TensorIndex<R>>
for Tensor<T, R>
{
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
&mut self.buffer[index.flat()]
}
}
@ -218,18 +445,18 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
}
}
// ---- Display ----
// ---- Display ---------------------------------------------------------------
impl<T, const R: usize> Tensor<T, R>
where
T: fmt::Display + Clone,
T: Display + Clone,
{
fn fmt_helper(
buffer: &[T],
shape: &[usize],
f: &mut fmt::Formatter<'_>,
f: &mut Formatter<'_>,
level: usize,
) -> fmt::Result {
) -> FmtResult {
if shape.is_empty() {
// Base case: print individual elements
write!(f, "{}", buffer[0])
@ -247,24 +474,37 @@ where
}
}
impl<T, const R: usize> fmt::Display for Tensor<T, R>
impl<T, const R: usize> Display for Tensor<T, R>
where
T: fmt::Display + Clone,
T: Display + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
}
}
// ---- Iterator ----
// ---- Equality --------------------------------------------------------------
impl<T, const R: usize> PartialEq for Tensor<T, R>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.shape == other.shape && self.buffer == other.buffer
}
}
impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
// ---- Iterator --------------------------------------------------------------
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
index: Idx<'a, R>,
index: TensorIndex<R>,
}
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
pub fn new(tensor: &'a Tensor<T, R>) -> Self {
Self {
tensor,
index: tensor.shape.index_zero(),
@ -294,7 +534,7 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
}
}
// ---- Formatting ----
// ---- Formatting ------------------------------------------------------------
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
@ -323,17 +563,17 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
}
}
// ---- From ----
// ---- From ------------------------------------------------------------------
impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
fn from(shape: Shape<R>) -> Self {
impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> {
fn from(shape: TensorShape<R>) -> Self {
Self::new(shape)
}
}
impl<T: Value> From<T> for Tensor<T, 0> {
fn from(value: T) -> Self {
let shape = Shape::new([]);
let shape = TensorShape::new([]);
let mut tensor = Tensor::new(shape);
tensor.buffer_mut()[0] = value;
tensor
@ -342,7 +582,7 @@ impl<T: Value> From<T> for Tensor<T, 0> {
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
fn from(array: [T; X]) -> Self {
let shape = Shape::new([X]);
let shape = TensorShape::new([X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
@ -358,7 +598,7 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
for Tensor<T, 2>
{
fn from(array: [[T; X]; Y]) -> Self {
let shape = Shape::new([Y, X]);
let shape = TensorShape::new([Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
@ -376,7 +616,7 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
{
fn from(array: [[[T; X]; Y]; Z]) -> Self {
let shape = Shape::new([Z, Y, X]);
let shape = TensorShape::new([Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
@ -401,7 +641,7 @@ impl<
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
{
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
let shape = Shape::new([W, Z, Y, X]);
let shape = TensorShape::new([W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
@ -429,7 +669,7 @@ impl<
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
{
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
let shape = Shape::new([V, W, Z, Y, X]);
let shape = TensorShape::new([V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
@ -464,7 +704,7 @@ impl<
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
{
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
let shape = Shape::new([U, V, W, Z, Y, X]);
let shape = TensorShape::new([U, V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();

25
src/value.rs Normal file
View File

@ -0,0 +1,25 @@
use num::{Num, One, Zero};
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
iter::Sum,
};
/// A trait for types that can be used as values in a tensor.
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ Sum
{
}