Refactor and document Tensor-type

- Add documentation to all methods exposed by the Tensor type.
- Remove some tests and methods to simplify structure, some might be
  introduced back later.
- Add elementwise operations.
- Add doctests to Tensor.

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-03 21:28:01 +02:00
parent f14892f0ef
commit d19ce40494
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
4 changed files with 391 additions and 599 deletions

View File

@ -113,277 +113,4 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a TensorAxis<'a, T, R> {
fn into_iter(self) -> Self::IntoIter {
TensorAxisIterator::new(&self)
}
}
pub fn contract<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
lhs: (&'a Tensor<T, R>, [usize; N]),
rhs: (&'a Tensor<T, S>, [usize; N]),
) -> Tensor<T, { R + S - 2 * N }>
where
[(); R - N]:,
[(); S - N]:,
[(); R + S - 2 * N]:,
{
let (lhs, la) = lhs;
let (rhs, ra) = rhs;
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let shape = TensorShape::new(shape);
let result = contract_axes(&lnc, &rnc);
Tensor::new_with_buffer(shape, result)
}
pub fn contract_axes<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
laxes: &'a [TensorAxis<'a, T, R>],
raxes: &'a [TensorAxis<'a, T, S>],
) -> Vec<T>
where
[(); R - N]:,
[(); S - N]:,
{
let mut result = vec![];
let axes = laxes.into_iter().zip(raxes);
for (laxis, raxis) in axes {
let mut axes_result: Vec<T> = vec![];
for i in 0..raxis.len() {
for j in 0..laxis.len() {
let mut sum = T::zero();
let llevel = laxis.into_iter();
let rlevel = raxis.into_iter();
let zip = llevel.level(j).zip(rlevel.level(i));
for (lv, rv) in zip {
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
}
result.extend_from_slice(&axes_result);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_contraction_simple() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Tensor B is 1x3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
assert_eq!(contracted_tensor.shape(), &TensorShape::new([2, 2]));
assert_eq!(
contracted_tensor.buffer(),
&[7, 10, 15, 22],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_tensor_contraction_23x32() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
println!("b: {}", b);
// Tensor B is 3x2
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
println!("a: {}", a);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
println!("contracted_tensor: {}", contracted_tensor);
assert_eq!(contracted_tensor.shape(), &TensorShape::new([3, 3]));
assert_eq!(
contracted_tensor.buffer(),
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_tensor_contraction_rank3() {
let a: Tensor<i32, 3> =
Tensor::new_with_buffer(TensorShape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
let b: Tensor<i32, 3> =
Tensor::new_with_buffer(TensorShape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
println!("a: {}", a);
println!("b: {}", b);
println!("contracted_tensor: {}", contracted_tensor);
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
// Verify specific elements of contracted_tensor
// assert_eq!(contracted_tensor[0][0][0][0], 50);
// assert_eq!(contracted_tensor[0][0][0][1], 60);
// ... further checks for other elements ...
}
// #[test]
// fn test_axis_iterator_disassemble() {
// // Creating a 2x2 Tensor for testing
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// // Testing iteration over the first axis (axis = 0)
// let axis = TensorAxis::new(&tensor, 0);
// let mut axis_iter = axis.into_iter().disassemble();
// assert_eq!(axis_iter[0].next(), Some(&1.0));
// assert_eq!(axis_iter[0].next(), Some(&2.0));
// assert_eq!(axis_iter[0].next(), None);
// assert_eq!(axis_iter[1].next(), Some(&3.0));
// assert_eq!(axis_iter[1].next(), Some(&4.0));
// assert_eq!(axis_iter[1].next(), None);
// // Resetting the iterator for the second axis (axis = 1)
// let axis = TensorAxis::new(&tensor, 1);
// let mut axis_iter = axis.into_iter().disassemble();
// assert_eq!(axis_iter[0].next(), Some(&1.0));
// assert_eq!(axis_iter[0].next(), Some(&3.0));
// assert_eq!(axis_iter[0].next(), None);
// assert_eq!(axis_iter[1].next(), Some(&2.0));
// assert_eq!(axis_iter[1].next(), Some(&4.0));
// assert_eq!(axis_iter[1].next(), None);
// }
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = TensorAxis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1)
let axis = TensorAxis::new(&tensor, 1);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: TensorIndex<2> = (shape, [0, 0]).into();
let b: TensorIndex<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
#[test]
fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// TensorAxis 0 (Layer-wise):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8]
//
// This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second.
let a0 = TensorAxis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// TensorAxis 1 (Row-wise within each layer):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row
// across all layers.
let a1 = TensorAxis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// TensorAxis 2 (Column-wise within each layer):
//
// t[0][0][0] = 1
// t[0][1][0] = 3
// t[1][0][0] = 5
// t[1][1][0] = 7
// t[0][0][1] = 2
// t[0][1][1] = 4
// t[1][0][1] = 6
// t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second
// column across all layers.
let a2 = TensorAxis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
}
}
}

View File

@ -24,7 +24,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
Self {
indices: max_indices,
shape: shape,
shape,
}
}
@ -34,7 +34,7 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
}
Self {
indices,
shape: shape,
shape,
}
}
@ -252,9 +252,9 @@ impl<'a, const R: usize> TensorIndex<'a, R> {
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize {
self.indices
self.indices()
.iter()
.zip(&self.shape.as_array())
.zip(&self.shape().as_array())
.rev()
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
(flat_index + idx * product, product * dim_size)

View File

@ -43,184 +43,16 @@ macro_rules! tensor {
};
}
// ---- Tests ----
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_tensor_product() {
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
let product = tensor1.tensor_product(&tensor2);
// Check shape of the resulting tensor
assert_eq!(*product.shape(), TensorShape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
assert_eq!(product.buffer(), &expected_buffer);
}
#[test]
fn serde_shape_serialization_test() {
// Create a shape instance
let shape: TensorShape<3> = [1, 2, 3].into();
// Serialize the shape to a JSON string
let serialized =
serde_json::to_string(&shape).expect("Failed to serialize");
// Deserialize the JSON string back into a shape
let deserialized: TensorShape<3> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized shape is equal to the original
assert_eq!(shape, deserialized);
}
#[test]
fn tensor_serde_serialization_test() {
// Create an instance of Tensor
let tensor: Tensor<i32, 2> = Tensor::new(TensorShape::new([2, 2]));
// Serialize the Tensor to a JSON string
let serialized =
serde_json::to_string(&tensor).expect("Failed to serialize");
// Deserialize the JSON string back into a Tensor
let deserialized: Tensor<i32, 2> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized Tensor is equal to the original
assert_eq!(tensor.buffer(), deserialized.buffer());
assert_eq!(tensor.shape(), deserialized.shape());
}
#[test]
fn iterate_3d_tensor() {
let shape = TensorShape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
tensor.buffer_mut()[i * 4 + j * 2 + k] = num;
num += 1;
}
}
}
println!("{}", tensor);
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
println!("{}", iter);
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&4));
assert_eq!(iter.next(), Some(&5));
assert_eq!(iter.next(), Some(&6));
assert_eq!(iter.next(), Some(&7));
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
}
#[test]
fn iterate_rank_4_tensor() {
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
let shape = TensorShape::new([2, 2, 2, 2]);
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..tensor.len() {
tensor.buffer_mut()[i] = num;
num += 1;
}
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
for expected_value in 0..tensor.len() {
assert_eq!(*iter.next().unwrap(), expected_value);
}
// Ensure the iterator is exhausted
assert!(iter.next().is_none());
}
#[test]
fn test_dec_method() {
let shape = TensorShape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
let mut index = TensorIndex::zero(&shape);
// Increment the index to the maximum
for _ in 0..26 {
// 3 * 3 * 3 - 1 = 26 increments to reach the end
index.inc();
}
// Check if the index is at the maximum
assert_eq!(index, TensorIndex::new(&shape, [2, 2, 2]));
// Decrement step by step and check the index
let expected_indices = [
[2, 2, 2],
[2, 2, 1],
[2, 2, 0],
[2, 1, 2],
[2, 1, 1],
[2, 1, 0],
[2, 0, 2],
[2, 0, 1],
[2, 0, 0],
[1, 2, 2],
[1, 2, 1],
[1, 2, 0],
[1, 1, 2],
[1, 1, 1],
[1, 1, 0],
[1, 0, 2],
[1, 0, 1],
[1, 0, 0],
[0, 2, 2],
[0, 2, 1],
[0, 2, 0],
[0, 1, 2],
[0, 1, 1],
[0, 1, 0],
[0, 0, 2],
[0, 0, 1],
[0, 0, 0],
];
for (i, &expected) in expected_indices.iter().enumerate() {
assert_eq!(
index,
TensorIndex::new(&shape, expected),
"Failed at index {}",
i
);
index.dec();
}
// Finally, the index should reach [0, 0, 0]
index.dec();
assert_eq!(index, TensorIndex::zero(&shape));
}
#[macro_export]
macro_rules! shape {
($array:expr) => {
TensorShape::from($array)
};
}
#[macro_export]
macro_rules! index {
($array:expr) => {
TensorIndex::from($array)
};
}

View File

@ -3,6 +3,16 @@ use crate::error::*;
use getset::{Getters, MutGetters};
use std::fmt;
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the number of
/// dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is a vector, a rank 2 tensor is
/// a matrix, and so on.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.rank(), 2);
/// ```
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct Tensor<T, const R: usize> {
#[getset(get = "pub", get_mut = "pub")]
@ -11,7 +21,18 @@ pub struct Tensor<T, const R: usize> {
shape: TensorShape<R>,
}
// ---- Construction and Initialization ---------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Create a new tensor with the given shape. The rank of the tensor is determined by the shape
/// and all elements are initialized to zero.
///
/// ```
/// use manifold::Tensor;
///
/// let t = Tensor::<f64, 2>::new([3, 3].into());
/// assert_eq!(t.shape().as_array(), [3, 3]);
/// ```
pub fn new(shape: TensorShape<R>) -> Self {
// Handle rank 0 tensor (scalar) as a special case
let total_size = if R == 0 {
@ -26,11 +47,326 @@ impl<T: Value, const R: usize> Tensor<T, R> {
Self { buffer, shape }
}
/// Create a new tensor with the given shape and initialize it from the given buffer. The rank
/// of the tensor is determined by the shape.
///
/// ```
/// use manifold::Tensor;
///
/// let buffer = vec![1, 2, 3, 4, 5, 6];
/// let t = Tensor::<i32, 2>::new_with_buffer([2, 3].into(), buffer);
/// assert_eq!(t.shape().as_array(), [2, 3]);
/// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]);
/// ```
pub fn new_with_buffer(shape: TensorShape<R>, buffer: Vec<T>) -> Self {
Self { buffer, shape }
}
}
pub fn reshape(self, shape: TensorShape<R>) -> Result<Self> {
// ---- Trivial Getters -------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn rank(&self) -> usize {
R
}
pub fn len(&self) -> usize {
self.buffer().len()
}
}
// ---- Get Values ------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Get a reference to a value at the given index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into();
/// assert_eq!(t.get(i), Some(&4));
/// ```
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
self.buffer().get(index.flat())
}
/// Get a reference to a value at the given index without bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let i = (t.shape(), [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked(i), &4); }
/// ```
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
self.buffer().get_unchecked(index.flat())
}
/// Get a mutable reference to a value at the given index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone();
/// let i = (&s, [1, 1]).into();
/// assert_eq!(t.get_mut(i), Some(&mut 4));
/// ```
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
self.buffer_mut().get_mut(index.flat())
}
/// Get a mutable reference to a value at the given index without bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// let s = t.shape().clone();
/// let i = (&s, [1, 1]).into();
/// unsafe { assert_eq!(t.get_unchecked_mut(i), &mut 4); }
/// ```
pub unsafe fn get_unchecked_mut(
&mut self,
index: TensorIndex<R>,
) -> &mut T {
self.buffer_mut().get_unchecked_mut(index.flat())
}
/// Get a reference to a value at the given flat index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.get_flat(3), Some(&4));
/// ```
pub fn get_flat(&self, index: usize) -> Option<&T> {
self.buffer().get(index)
}
/// Get a reference to a value at the given flat index without bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// unsafe { assert_eq!(t.get_flat_unchecked(3), &4); }
/// ```
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer().get_unchecked(index)
}
/// Get a mutable reference to a value at the given flat index.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// assert_eq!(t.get_flat_mut(3), Some(&mut 4));
/// ```
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer_mut().get_mut(index)
}
/// Get a mutable reference to a value at the given flat index without bounds checking.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let mut t = tensor!([[1, 2], [3, 4]]);
/// unsafe { assert_eq!(t.get_flat_unchecked_mut(3), &mut 4); }
/// ```
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer_mut().get_unchecked_mut(index)
}
}
// ---- Arithmetic ------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Elementwise operation on two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_for_each(&a, &b, &mut c, &|a, b| a * b).unwrap();
/// assert_eq!(c, tensor!([[5, 12], [21, 32]]));
/// ```
pub fn ew_for_each(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
f: &dyn Fn(T, T) -> T,
) -> Result<()> {
if self.shape() != other.shape() {
return Err(TensorError::InvalidArgument(format!(
"TensorShape mismatch: {:?} != {:?}",
self.shape(),
other.shape()
)));
} else if self.shape() != result.shape() {
return Err(TensorError::InvalidArgument(format!(
"TensorShape mismatch: {:?} != {:?}",
self.shape(),
result.shape()
)));
}
for (i, (a, b)) in
self.buffer().iter().zip(other.buffer().iter()).enumerate()
{
unsafe {
*result.get_flat_unchecked_mut(i) = f(*a, *b);
}
}
Ok(())
}
/// Elementwise multiplication of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_multiply(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[5, 12], [21, 32]]));
/// ```
pub fn ew_multiply(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a * b)
}
/// Elementwise addition of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_add(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[6, 8], [10, 12]]));
/// ```
pub fn ew_add(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a + b)
}
/// Elementwise subtraction of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[1, 2], [3, 4]]);
/// let b = tensor!([[5, 6], [7, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_subtract(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[-4, -4], [-4, -4]]));
/// ```
pub fn ew_subtract(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a - b)
}
/// Elementwise division of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[2, 4], [8, 16]]);
/// let b = tensor!([[2, 2], [4, 8]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_divide(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[1, 2], [2, 2]]));
/// ```
pub fn ew_divide(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a / b)
}
/// Elementwise modulo of two tensors.
///
/// ```
/// use manifold::{tensor, Tensor};
///
/// let a = tensor!([[2, 2], [3, 3]]);
/// let b = tensor!([[4, 4], [6, 9]]);
/// let mut c = Tensor::<i32, 2>::new([2, 2].into());
/// Tensor::ew_modulo(&a, &b, &mut c).unwrap();
/// assert_eq!(c, tensor!([[2, 2], [3, 3]]));
/// ```
pub fn ew_modulo(
&self,
other: &Tensor<T, R>,
result: &mut Tensor<T, R>,
) -> Result<()> {
self.ew_for_each(other, result, &|a, b| a % b)
}
// pub fn product<const S: usize>(
// &self,
// other: &Tensor<T, S>,
// ) -> Tensor<T, { R + S }> {
// let mut new_shape_vec = Vec::new();
// new_shape_vec.extend_from_slice(&self.shape().as_array());
// new_shape_vec.extend_from_slice(&other.shape().as_array());
// let new_shape_array: [usize; R + S] = new_shape_vec
// .try_into()
// .expect("Failed to create shape array");
// let mut new_buffer =
// Vec::with_capacity(self.buffer.len() * other.buffer.len());
// for &item_self in &self.buffer {
// for &item_other in &other.buffer {
// new_buffer.push(item_self * item_other);
// }
// }
// Tensor {
// buffer: new_buffer,
// shape: TensorShape::new(new_shape_array),
// }
// }
}
// ---- Reshape ---------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Reshape the tensor to the given shape. The total size of the new shape must be the same as
/// the total size of the old shape.
///
/// ```
/// use manifold::{tensor, shape, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let s = shape!([4]);
/// let t = t.reshape(s).unwrap();
/// assert_eq!(t, tensor!([1, 2, 3, 4]));
/// ```
pub fn reshape<const S: usize>(self, shape: TensorShape<S>) -> Result<Tensor<T, S>> {
if self.shape().size() != shape.size() {
let (ls, rs) = (self.shape().as_array(), shape.as_array());
let (lsize, rsize) = (self.shape().size(), shape.size());
@ -38,13 +374,25 @@ impl<T: Value, const R: usize> Tensor<T, R> {
"TensorShape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
)))
} else {
Ok(Self {
buffer: self.buffer,
shape,
})
Ok(Tensor::new_with_buffer(shape, self.buffer))
}
}
}
// ---- Transpose -------------------------------------------------------------
impl<T: Value, const R: usize> Tensor<T, R> {
/// Transpose the tensor according to the given order. The order must be a permutation of the
/// tensor's axes.
///
/// ```
/// use manifold::{tensor, Tensor, TensorShape};
///
/// let t = tensor!([[1, 2], [3, 4]]);
/// let t = t.transpose([1, 0]).unwrap();
/// assert_eq!(t, tensor!([[1, 3], [2, 4]]));
/// ```
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = TensorIndex::from(self.shape())
.iter_transposed(order)
@ -56,139 +404,9 @@ impl<T: Value, const R: usize> Tensor<T, R> {
shape: self.shape().reorder(order),
})
}
pub fn idx(&self) -> TensorIndex<R> {
TensorIndex::from(self)
}
pub fn axis<'a>(&'a self, axis: usize) -> TensorAxis<'a, T, R> {
TensorAxis::new(self, axis)
}
pub fn get(&self, index: TensorIndex<R>) -> Option<&T> {
self.buffer.get(index.flat())
}
pub unsafe fn get_unchecked(&self, index: TensorIndex<R>) -> &T {
self.buffer.get_unchecked(index.flat())
}
pub fn get_mut(&mut self, index: TensorIndex<R>) -> Option<&mut T> {
self.buffer.get_mut(index.flat())
}
pub unsafe fn get_unchecked_mut(&mut self, index: TensorIndex<R>) -> &mut T {
self.buffer.get_unchecked_mut(index.flat())
}
pub fn get_flat(&self, index: usize) -> Option<&T> {
self.buffer.get(index)
}
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer.get_unchecked(index)
}
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer.get_mut(index)
}
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer.get_unchecked_mut(index)
}
pub fn rank(&self) -> usize {
R
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn iter(&self) -> TensorIterator<T, R> {
TensorIterator::new(self)
}
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
if self.shape != other.shape {
panic!("TensorShapes of tensors do not match");
}
let mut result_buffer = Vec::with_capacity(self.buffer.len());
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
result_buffer.push(*a * *b);
}
Tensor {
buffer: result_buffer,
shape: self.shape,
}
}
pub fn tensor_product<const S: usize>(
&self,
other: &Tensor<T, S>,
) -> Tensor<T, { R + S }> {
let mut new_shape_vec = Vec::new();
new_shape_vec.extend_from_slice(&self.shape.as_array());
new_shape_vec.extend_from_slice(&other.shape.as_array());
let new_shape_array: [usize; R + S] = new_shape_vec
.try_into()
.expect("Failed to create shape array");
let mut new_buffer =
Vec::with_capacity(self.buffer.len() * other.buffer.len());
for &item_self in &self.buffer {
for &item_other in &other.buffer {
new_buffer.push(item_self * item_other);
}
}
Tensor {
buffer: new_buffer,
shape: TensorShape::new(new_shape_array),
}
}
// Retrieve an element based on a specific axis and index
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
// Convert axis and index to a flat index
let flat_index = self.axis_to_flat_index(axis, index);
if flat_index >= self.buffer.len() {
return None;
}
Some(self.buffer[flat_index])
}
// Convert axis and index to a flat index in the buffer
fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize {
let mut flat_index = 0;
let mut stride = 1;
// Ensure the given axis is within the tensor's dimensions
if axis >= R {
panic!("TensorAxis out of bounds");
}
// Calculate the stride for each dimension and accumulate the flat index
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
if i > axis {
stride *= dim_size;
} else if i == axis {
flat_index += index * stride;
break; // We've reached the target axis
}
}
flat_index
}
}
// ---- Indexing ----
// ---- Indexing --------------------------------------------------------------
impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
type Output = T;
@ -198,7 +416,9 @@ impl<'a, T: Value, const R: usize> Index<TensorIndex<'a, R>> for Tensor<T, R> {
}
}
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>> for Tensor<T, R> {
impl<'a, T: Value, const R: usize> IndexMut<TensorIndex<'a, R>>
for Tensor<T, R>
{
fn index_mut(&mut self, index: TensorIndex<R>) -> &mut Self::Output {
&mut self.buffer[index.flat()]
}
@ -218,7 +438,7 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
}
}
// ---- Display ----
// ---- Display ---------------------------------------------------------------
impl<T, const R: usize> Tensor<T, R>
where
@ -256,7 +476,20 @@ where
}
}
// ---- Iterator ----
// ---- Equality --------------------------------------------------------------
impl<T, const R: usize> PartialEq for Tensor<T, R>
where
T: PartialEq,
{
fn eq(&self, other: &Self) -> bool {
self.shape == other.shape && self.buffer == other.buffer
}
}
impl<T, const R: usize> Eq for Tensor<T, R> where T: Eq {}
// ---- Iterator --------------------------------------------------------------
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
@ -294,7 +527,7 @@ impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
}
}
// ---- Formatting ----
// ---- Formatting ------------------------------------------------------------
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
@ -323,7 +556,7 @@ impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
}
}
// ---- From ----
// ---- From ------------------------------------------------------------------
impl<T: Value, const R: usize> From<TensorShape<R>> for Tensor<T, R> {
fn from(shape: TensorShape<R>) -> Self {