diff --git a/docs/tensor-contraction.md b/docs/tensor-contraction.md new file mode 100644 index 0000000..0361388 --- /dev/null +++ b/docs/tensor-contraction.md @@ -0,0 +1,34 @@ +To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps: + +1. **Tensor Shapes**: + - Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\] + - Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\] + +2. **Tensor Contraction Operation**: + - The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results. + - The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3. + +3. **Contraction Steps**: + + - Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix. + - \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \) + - Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix. + - \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \) + - Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix. + - \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \) + + - Continue this process for the remaining rows of `a` and columns of `b`: + - For the second row of `a`: + - \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \) + - \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \) + - \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \) + - For the third row of `a`: + - \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \) + - \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \) + - \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \) + +4. **Resulting Tensor**: + - The resulting 3x3 tensor from the contraction of `a` and `b` will be: + \[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\] + +These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`. \ No newline at end of file diff --git a/examples/operations.rs b/examples/operations.rs index 3f1cf9c..63bd179 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -2,6 +2,7 @@ #![allow(non_snake_case)] use bytemuck::cast_slice; use manifold::*; +use manifold::contract; fn tensor_product() { println!("Tensor Product\n"); @@ -20,7 +21,7 @@ fn tensor_product() { println!("T1 * T2 = {}", product); // Check shape of the resulting tensor - assert_eq!(product.shape(), Shape::new([2, 2, 2])); + assert_eq!(product.shape(), &Shape::new([2, 2, 2])); // Check buffer of the resulting tensor let expect: &[i32] = @@ -28,39 +29,34 @@ fn tensor_product() { assert_eq!(product.buffer(), expect); } -fn tensor_contraction() { - println!("Tensor Contraction\n"); - // Create two tensors - let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor - let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor +fn test_tensor_contraction_23x32() { + // Define two 2D tensors (matrices) - // Specify axes for contraction - let axis_lhs = [1]; // Contract over the second dimension of tensor1 - let axis_rhs = [0]; // Contract over the first dimension of tensor2 + // Tensor A is 2x3 + let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + println!("a: {}", a); - // Perform contraction - // let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs); + // Tensor B is 3x2 + let b: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); + println!("b: {}", b); - // println!("T1: {}", tensor1); - // println!("T2: {}", tensor2); - // println!("T1 * T2 = {}", result); + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let ctr10 = contract((&a, [1]), (&b, [0])); - // Expected result, for example, could be a single number or a new tensor, - // depending on how you defined the contraction operation. - // Assert the result is as expected - // assert_eq!(result, expected_result); + println!("[1, 0]: {}", ctr10); - // let Λ = Tensor::::from([ - // [1.0, 0.0, 0.0, 0.0], - // [0.0, 1.0, 0.0 ,0.0], - // [0.0, 0.0, 1.0, 0.0], - // [0.0, 0.0, 0.0, 1.0] - // ]); + let ctr01 = contract((&a, [0]), (&b, [1])); - // println!("Λ: {}", Λ); + println!("[0, 1]: {}", ctr01); + // assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); + // assert_eq!( + // contracted_tensor.buffer(), + // &[9, 12, 15, 19, 26, 33, 29, 40, 51], + // "Contracted tensor buffer does not match expected" + // ); } fn main() { - tensor_product(); - tensor_contraction(); + // tensor_product(); + test_tensor_contraction_23x32(); } diff --git a/src/axis.rs b/src/axis.rs index c05c9a8..f98b755 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -134,37 +134,25 @@ pub fn contract< const S: usize, const N: usize, >( - lhs: &'a Tensor, - rhs: &'a Tensor, - laxes: [Axis<'a, T, R>; N], - raxes: [Axis<'a, T, S>; N], + lhs: (&'a Tensor, [usize; N]), + rhs: (&'a Tensor, [usize; N]), ) -> Tensor where [(); R - N]:, [(); S - N]:, [(); R + S - 2 * N]:, { - let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>( - laxes - .iter() - .map(|axis| *axis.dim()) - .collect::>() - .try_into() - .unwrap(), - ); - let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>( - raxes - .iter() - .map(|axis| *axis.dim()) - .collect::>() - .try_into() - .unwrap(), - ); - + let (lhs, la) = lhs; + let (rhs, ra) = rhs; + let raxes = ra.into_iter().map(|i| rhs.axis(i)).collect::>(); + let raxes: [Axis<'a, T, S>; N] = + raxes.try_into().expect("Failed to create raxes array"); + let laxes = la.into_iter().map(|i| lhs.axis(i)).collect::>(); + let laxes: [Axis<'a, T, R>; N] = + laxes.try_into().expect("Failed to create laxes array"); let mut shape = Vec::new(); - shape.extend_from_slice(&lhs_shape_reduced.as_array()); - shape.extend_from_slice(&rhs_shape_reduced.as_array()); - + shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); + shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); let shape: [usize; R + S - 2 * N] = shape.try_into().expect("Failed to create shape array"); @@ -205,27 +193,20 @@ where let axes = laxes.into_iter().zip(raxes); - for (laxis, raxis) in axes { + for (laxis, raxis) in axes { let mut axes_result: Vec = vec![]; for i in 0..raxis.len() { - println!("raxis: {}", i); - for j in 0..laxis.len() { - println!("laxis: {}", j); - let mut sum = T::zero(); - let llevel = laxis.into_iter(); - let llevel = llevel.level(j); - let rlevel = raxis.into_iter(); - let rlevel = rlevel.level(i); - let zip = llevel.zip(rlevel); - for (lv, rv) in zip { - println!("{} * {} = {}", lv, rv, *lv * *rv); - println!("{} + {} = {}", sum, *lv * *rv, sum + *lv * *rv); - sum = sum + *lv * *rv; - } - println!("sum: {}", sum); - axes_result.push(sum); - } - } + for j in 0..laxis.len() { + let mut sum = T::zero(); + let llevel = laxis.into_iter(); + let rlevel = raxis.into_iter(); + let zip = llevel.level(j).zip(rlevel.level(i)); + for (lv, rv) in zip { + sum = sum + *lv * *rv; + } + axes_result.push(sum); + } + } result.push(axes_result); } @@ -246,9 +227,7 @@ mod tests { let b: Tensor = Tensor::from([[1, 2], [3, 4]]); // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = - contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]); - + let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2])); assert_eq!( contracted_tensor.buffer(), @@ -260,16 +239,19 @@ mod tests { #[test] fn test_tensor_contraction_23x32() { // Define two 2D tensors (matrices) - // Tensor A is 2x3 - let a: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); - // Tensor B is 1x3x2 + // Tensor A is 2x3 let b: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + println!("b: {}", b); + + // Tensor B is 3x2 + let a: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); + println!("a: {}", a); // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) - let contracted_tensor: Tensor = - contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]); + let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); + println!("contracted_tensor: {}", contracted_tensor); assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); assert_eq!( contracted_tensor.buffer(), @@ -278,6 +260,24 @@ mod tests { ); } + #[test] + fn test_tensor_contraction_rank3() { + let a: Tensor = + Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 + let b: Tensor = + Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 + let contracted_tensor: Tensor = contract((&a, [2]), (&b, [0])); + + println!("a: {}", a); + println!("b: {}", b); + println!("contracted_tensor: {}", contracted_tensor); + // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); + // Verify specific elements of contracted_tensor + // assert_eq!(contracted_tensor[0][0][0][0], 50); + // assert_eq!(contracted_tensor[0][0][0][1], 60); + // ... further checks for other elements ... + } + // #[test] // fn test_axis_iterator_disassemble() { // // Creating a 2x2 Tensor for testing diff --git a/src/shape.rs b/src/shape.rs index c5488f7..1dd7c14 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -99,6 +99,31 @@ impl Shape { new_index += 1; } + Shape(new_shape) + } + + pub fn remove_axes<'a, T: Value, const NAX: usize>( + &self, + axes_to_remove: &'a [Axis<'a, T, R>; NAX], + ) -> Shape<{ R - NAX }> { + // Create a new array to store the remaining dimensions + let mut new_shape = [0; R - NAX]; + let mut new_index = 0; + + // Iterate over the original dimensions + for (index, &dim) in self.0.iter().enumerate() { + // Skip dimensions that are in the axes_to_remove array + for axis in axes_to_remove { + if *axis.dim() == index { + continue; + } + } + + // Add the dimension to the new shape array + new_shape[new_index] = dim; + new_index += 1; + } + Shape(new_shape) } } diff --git a/src/tensor.rs b/src/tensor.rs index f344450..d8f2f09 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -51,6 +51,10 @@ impl Tensor { &self.buffer[index.flat()] } + pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> { + Axis::new(self, axis) + } + pub unsafe fn get_unchecked(&self, index: Idx) -> &T { self.buffer.get_unchecked(index.flat()) }