diff --git a/benches/manifold_benchmark.rs b/benches/manifold_benchmark.rs index c88f9d2..587c869 100644 --- a/benches/manifold_benchmark.rs +++ b/benches/manifold_benchmark.rs @@ -57,4 +57,5 @@ criterion_group!( benches, tensor_product ); + criterion_main!(benches); \ No newline at end of file diff --git a/src/tensor.rs b/src/tensor.rs index 3069822..70c003c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -38,7 +38,7 @@ impl Tensor { /// use manifold::Tensor; /// /// let t = Tensor::::new([3, 3].into()); - /// assert_eq!(t.shape().as_array(), [3, 3]); + /// assert_eq!(t.shape().as_array(), &[3, 3]); /// ``` pub fn new(shape: TensorShape) -> Self { // Handle rank 0 tensor (scalar) as a special case @@ -63,7 +63,7 @@ impl Tensor { /// /// let buffer = vec![1, 2, 3, 4, 5, 6]; /// let t = Tensor::::new_with_buffer([2, 3].into(), buffer); - /// assert_eq!(t.shape().as_array(), [2, 3]); + /// assert_eq!(t.shape().as_array(), &[2, 3]); /// assert_eq!(t.buffer(), &[1, 2, 3, 4, 5, 6]); /// ``` pub fn new_with_buffer(shape: TensorShape, buffer: Vec) -> Self { diff --git a/tests/basic_tests.rs b/tests/basic_tests.rs new file mode 100644 index 0000000..c96996f --- /dev/null +++ b/tests/basic_tests.rs @@ -0,0 +1,261 @@ +use manifold::*; + +use serde_json; + +#[test] +fn test_serde_shape_serialization() { + // Create a shape instance + let shape: TensorShape<3> = [1, 2, 3].into(); + + // Serialize the shape to a JSON string + let serialized = + serde_json::to_string(&shape).expect("Failed to serialize"); + + // Deserialize the JSON string back into a shape + let deserialized: TensorShape<3> = + serde_json::from_str(&serialized).expect("Failed to deserialize"); + + // Check that the deserialized shape is equal to the original + assert_eq!(shape, deserialized); +} + +#[test] +fn test_tensor_serde_serialization() { + // Create an instance of Tensor + let tensor: Tensor = Tensor::new(TensorShape::new([2, 2])); + + // Serialize the Tensor to a JSON string + let serialized = + serde_json::to_string(&tensor).expect("Failed to serialize"); + + // Deserialize the JSON string back into a Tensor + let deserialized: Tensor = + serde_json::from_str(&serialized).expect("Failed to deserialize"); + + // Check that the deserialized Tensor is equal to the original + assert_eq!(tensor.buffer(), deserialized.buffer()); + assert_eq!(tensor.shape(), deserialized.shape()); +} + +#[test] +fn test_iterating_3d_tensor() { + let shape = TensorShape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 + let mut tensor = Tensor::new(shape); + let mut num = 0; + + // Fill the tensor with sequential numbers + for i in 0..2 { + for j in 0..2 { + for k in 0..2 { + tensor.buffer_mut()[i * 4 + j * 2 + k] = num; + num += 1; + } + } + } + + println!("{}", tensor); + + // Iterate over the tensor and check that the numbers are correct + + let mut iter = TensorIterator::new(&tensor); + + println!("{}", iter); + + assert_eq!(iter.next(), Some(&0)); + + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.next(), Some(&4)); + assert_eq!(iter.next(), Some(&5)); + assert_eq!(iter.next(), Some(&6)); + assert_eq!(iter.next(), Some(&7)); + assert_eq!(iter.next(), None); + assert_eq!(iter.next(), None); +} + +#[test] +fn test_iterating_rank_4_tensor() { + // Define the shape of the rank-4 tensor (e.g., 2x2x2x2) + let shape = TensorShape::new([2, 2, 2, 2]); + let mut tensor = Tensor::new(shape); + let mut num = 0; + + // Fill the tensor with sequential numbers + for i in 0..tensor.len() { + tensor.buffer_mut()[i] = num; + num += 1; + } + + // Iterate over the tensor and check that the numbers are correct + let mut iter = TensorIterator::new(&tensor); + for expected_value in 0..tensor.len() { + assert_eq!(*iter.next().unwrap(), expected_value); + } + + // Ensure the iterator is exhausted + assert!(iter.next().is_none()); +} + +#[test] +fn test_index_dec_method() { + let shape = TensorShape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor + let mut index = TensorIndex::zero(shape); + + // Increment the index to the maximum + for _ in 0..26 { + // 3 * 3 * 3 - 1 = 26 increments to reach the end + index.inc(); + } + + // Check if the index is at the maximum + assert_eq!(index, TensorIndex::new(shape, [2, 2, 2])); + + // Decrement step by step and check the index + let expected_indices = [ + [2, 2, 2], + [2, 2, 1], + [2, 2, 0], + [2, 1, 2], + [2, 1, 1], + [2, 1, 0], + [2, 0, 2], + [2, 0, 1], + [2, 0, 0], + [1, 2, 2], + [1, 2, 1], + [1, 2, 0], + [1, 1, 2], + [1, 1, 1], + [1, 1, 0], + [1, 0, 2], + [1, 0, 1], + [1, 0, 0], + [0, 2, 2], + [0, 2, 1], + [0, 2, 0], + [0, 1, 2], + [0, 1, 1], + [0, 1, 0], + [0, 0, 2], + [0, 0, 1], + [0, 0, 0], + ]; + + for (i, &expected) in expected_indices.iter().enumerate() { + assert_eq!( + index, + TensorIndex::new(shape, expected), + "Failed at index {}", + i + ); + index.dec(); + } + + // Finally, the index should reach [0, 0, 0] + index.dec(); + assert_eq!(index, TensorIndex::zero(shape)); +} + +#[test] +fn test_axis_iterator() { + // Creating a 2x2 Tensor for testing + let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); + + // Testing iteration over the first axis (axis = 0) + let axis = TensorAxis::new(&tensor, 0); + + let mut axis_iter = axis.into_iter(); + + assert_eq!(axis_iter.next(), Some(&1.0)); + assert_eq!(axis_iter.next(), Some(&2.0)); + assert_eq!(axis_iter.next(), Some(&3.0)); + assert_eq!(axis_iter.next(), Some(&4.0)); + + // Resetting the iterator for the second axis (axis = 1) + let axis = TensorAxis::new(&tensor, 1); + + let mut axis_iter = axis.into_iter(); + + assert_eq!(axis_iter.next(), Some(&1.0)); + assert_eq!(axis_iter.next(), Some(&3.0)); + assert_eq!(axis_iter.next(), Some(&2.0)); + assert_eq!(axis_iter.next(), Some(&4.0)); + + let shape = tensor.shape(); + + let mut a: TensorIndex<2> = (shape.clone(), [0, 0]).into(); + let b: TensorIndex<2> = (shape.clone(), [1, 1]).into(); + + while a <= b { + println!("a: {}", a); + a.inc(); + } +} + +#[test] +fn test_3d_tensor_axis_iteration() { + // Create a 3D Tensor with specific values + // Tensor shape is 2x2x2 for simplicity + let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + // TensorAxis 0 (Layer-wise): + // + // t[0][0][0] = 1 + // t[0][0][1] = 2 + // t[0][1][0] = 3 + // t[0][1][1] = 4 + // t[1][0][0] = 5 + // t[1][0][1] = 6 + // t[1][1][0] = 7 + // t[1][1][1] = 8 + // [1, 2, 3, 4, 5, 6, 7, 8] + // + // This order suggests that for each "layer" (first level of arrays), + // the iterator goes through all rows and columns. It first completes + // the entire first layer, then moves to the second. + + let a0 = TensorAxis::new(&t, 0); + let a0_order = a0.into_iter().cloned().collect::>(); + assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); + + // TensorAxis 1 (Row-wise within each layer): + // + // t[0][0][0] = 1 + // t[0][0][1] = 2 + // t[1][0][0] = 5 + // t[1][0][1] = 6 + // t[0][1][0] = 3 + // t[0][1][1] = 4 + // t[1][1][0] = 7 + // t[1][1][1] = 8 + // [1, 2, 5, 6, 3, 4, 7, 8] + // + // This indicates that within each "layer", the iterator first + // completes the first row across all layers, then the second row + // across all layers. + + let a1 = TensorAxis::new(&t, 1); + let a1_order = a1.into_iter().cloned().collect::>(); + assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); + + // TensorAxis 2 (Column-wise within each layer): + // + // t[0][0][0] = 1 + // t[0][1][0] = 3 + // t[1][0][0] = 5 + // t[1][1][0] = 7 + // t[0][0][1] = 2 + // t[0][1][1] = 4 + // t[1][0][1] = 6 + // t[1][1][1] = 8 + // [1, 3, 5, 7, 2, 4, 6, 8] + // + // This indicates that within each "layer", the iterator first + // completes the first column across all layers, then the second + // column across all layers. + + let a2 = TensorAxis::new(&t, 2); + let a2_order = a2.into_iter().cloned().collect::>(); + assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); +} diff --git a/tests/mod.rs b/tests/mod.rs new file mode 100644 index 0000000..e23c050 --- /dev/null +++ b/tests/mod.rs @@ -0,0 +1 @@ +mod basic_tests; \ No newline at end of file