From f9c29aefd51e0cd2dc6912106170228734e67ac4 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Sat, 30 Dec 2023 01:39:57 +0200 Subject: [PATCH] Rank 3 contraction not working yet Signed-off-by: Julius Koskela --- examples/operations.rs | 54 +++++++++++++++++++++++++++--------------- src/axis.rs | 44 +++++++++++++--------------------- src/lib.rs | 1 + src/tensor.rs | 43 ++++++++++++++++++++------------- src/tensor_view.rs | 8 +++++++ 5 files changed, 86 insertions(+), 64 deletions(-) create mode 100644 src/tensor_view.rs diff --git a/examples/operations.rs b/examples/operations.rs index 63bd179..290b61f 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -1,8 +1,8 @@ #![allow(mixed_script_confusables)] #![allow(non_snake_case)] use bytemuck::cast_slice; -use manifold::*; use manifold::contract; +use manifold::*; fn tensor_product() { println!("Tensor Product\n"); @@ -30,33 +30,49 @@ fn tensor_product() { } fn test_tensor_contraction_23x32() { - // Define two 2D tensors (matrices) + // Define two 2D tensors (matrices) - // Tensor A is 2x3 - let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); - println!("a: {}", a); + // Tensor A is 2x3 + let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + println!("a: {:?}\n{}\n", a.shape(), a); - // Tensor B is 3x2 - let b: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); - println!("b: {}", b); + // Tensor B is 3x2 + let b: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); + println!("b: {:?}\n{}\n", b.shape(), b); - // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let ctr10 = contract((&a, [1]), (&b, [0])); + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let ctr10 = contract((&a, [1]), (&b, [0])); - println!("[1, 0]: {}", ctr10); + println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10); - let ctr01 = contract((&a, [0]), (&b, [1])); + let ctr01 = contract((&a, [0]), (&b, [1])); - println!("[0, 1]: {}", ctr01); - // assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); - // assert_eq!( - // contracted_tensor.buffer(), - // &[9, 12, 15, 19, 26, 33, 29, 40, 51], - // "Contracted tensor buffer does not match expected" - // ); + println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01); + // assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); + // assert_eq!( + // contracted_tensor.buffer(), + // &[9, 12, 15, 19, 26, 33, 29, 40, 51], + // "Contracted tensor buffer does not match expected" + // ); +} + +fn test_tensor_contraction_rank3() { + let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]); + let contracted_tensor = contract((&a, [2]), (&b, [0])); + + println!("a: {}", a); + println!("b: {}", b); + println!("contracted_tensor: {}", contracted_tensor); + // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); + // Verify specific elements of contracted_tensor + // assert_eq!(contracted_tensor[0][0][0][0], 50); + // assert_eq!(contracted_tensor[0][0][0][1], 60); + // ... further checks for other elements ... } fn main() { // tensor_product(); test_tensor_contraction_23x32(); + test_tensor_contraction_rank3(); } diff --git a/src/axis.rs b/src/axis.rs index f98b755..82eedb6 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -144,35 +144,23 @@ where { let (lhs, la) = lhs; let (rhs, ra) = rhs; - let raxes = ra.into_iter().map(|i| rhs.axis(i)).collect::>(); - let raxes: [Axis<'a, T, S>; N] = - raxes.try_into().expect("Failed to create raxes array"); - let laxes = la.into_iter().map(|i| lhs.axis(i)).collect::>(); - let laxes: [Axis<'a, T, R>; N] = - laxes.try_into().expect("Failed to create laxes array"); - let mut shape = Vec::new(); - shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); - shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); - let shape: [usize; R + S - 2 * N] = - shape.try_into().expect("Failed to create shape array"); + let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); + let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); - let shape = Shape::new(shape); + let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); + let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); - let result = contract_axes(laxes, raxes); + let mut shape = Vec::new(); + shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); + shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); + let shape: [usize; R + S - 2 * N] = + shape.try_into().expect("Failed to create shape array"); - let mut t = Tensor::new_with_buffer(shape, result[0].clone()); + let shape = Shape::new(shape); - // Iterate over the remaining axes results (skipping the first one already used) - for axis_result in result.iter().skip(1) { - // Iterate through each element of the axis_result and add it to the corresponding element in t's buffer - for (i, val) in axis_result.iter().enumerate() { - unsafe { - *t.get_flat_unchecked_mut(i) = *t.get_flat(i) + *val; - } - } - } + let result = contract_axes(&lnc, &rnc); - t + Tensor::new_with_buffer(shape, result) } pub fn contract_axes< @@ -182,9 +170,9 @@ pub fn contract_axes< const S: usize, const N: usize, >( - laxes: [Axis<'a, T, R>; N], - raxes: [Axis<'a, T, S>; N], -) -> Vec> + laxes: &'a [Axis<'a, T, R>], + raxes: &'a [Axis<'a, T, S>], +) -> Vec where [(); R - N]:, [(); S - N]:, @@ -207,7 +195,7 @@ where axes_result.push(sum); } } - result.push(axes_result); + result.extend_from_slice(&axes_result); } result diff --git a/src/lib.rs b/src/lib.rs index 1faaf74..b4350c4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -4,6 +4,7 @@ pub mod index; pub mod shape; pub mod axis; pub mod tensor; +pub mod tensor_view; pub use index::Idx; use num::{Num, One, Zero}; diff --git a/src/tensor.rs b/src/tensor.rs index d8f2f09..43234f4 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -207,27 +207,36 @@ impl IndexMut for Tensor { // ---- Display ---- -impl std::fmt::Display for Tensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Print the shape of the tensor - write!(f, "Shape: [")?; - for (i, dim_size) in self.shape.as_array().iter().enumerate() { - write!(f, "{}", dim_size)?; - if i < R - 1 { - write!(f, ", ")?; - } - } - write!(f, "], Elements: [")?; +use std::fmt; - // Print the elements in a flattened form - for (i, elem) in self.buffer.iter().enumerate() { - write!(f, "{}", elem)?; - if i < self.buffer.len() - 1 { - write!(f, ", ")?; +impl Tensor +where + T: fmt::Display + Clone, +{ + fn fmt_helper(buffer: &[T], shape: &[usize], f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result { + if shape.is_empty() { + // Base case: print individual elements + write!(f, "{}", buffer[0]) + } else { + let sub_len = shape[1..].iter().product::(); + write!(f, "[")?; + for (i, chunk) in buffer.chunks(sub_len).enumerate() { + if i > 0 { + write!(f, ",")?; + } + Tensor::::fmt_helper(chunk, &shape[1..], f, level + 1)?; } + write!(f, "]") } + } +} - write!(f, "]") +impl fmt::Display for Tensor +where + T: fmt::Display + Clone, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Tensor::::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1) } } diff --git a/src/tensor_view.rs b/src/tensor_view.rs new file mode 100644 index 0000000..5a8b18e --- /dev/null +++ b/src/tensor_view.rs @@ -0,0 +1,8 @@ +use super::*; +use getset::{Getters, MutGetters}; + +// #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)] +pub struct TensorView<'a, T, const R: usize> { + tensor: &'a Tensor, + shape: Shape, +} \ No newline at end of file