More cases of contraction working, still some issues with misaligned symmetries
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
27384a3044
commit
33c056f2b9
34
docs/tensor-contraction.md
Normal file
34
docs/tensor-contraction.md
Normal file
@ -0,0 +1,34 @@
|
|||||||
|
To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps:
|
||||||
|
|
||||||
|
1. **Tensor Shapes**:
|
||||||
|
- Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\]
|
||||||
|
- Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\]
|
||||||
|
|
||||||
|
2. **Tensor Contraction Operation**:
|
||||||
|
- The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results.
|
||||||
|
- The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3.
|
||||||
|
|
||||||
|
3. **Contraction Steps**:
|
||||||
|
|
||||||
|
- Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix.
|
||||||
|
- \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \)
|
||||||
|
- Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix.
|
||||||
|
- \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \)
|
||||||
|
- Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix.
|
||||||
|
- \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \)
|
||||||
|
|
||||||
|
- Continue this process for the remaining rows of `a` and columns of `b`:
|
||||||
|
- For the second row of `a`:
|
||||||
|
- \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \)
|
||||||
|
- \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \)
|
||||||
|
- \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \)
|
||||||
|
- For the third row of `a`:
|
||||||
|
- \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \)
|
||||||
|
- \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \)
|
||||||
|
- \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \)
|
||||||
|
|
||||||
|
4. **Resulting Tensor**:
|
||||||
|
- The resulting 3x3 tensor from the contraction of `a` and `b` will be:
|
||||||
|
\[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\]
|
||||||
|
|
||||||
|
These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`.
|
@ -2,6 +2,7 @@
|
|||||||
#![allow(non_snake_case)]
|
#![allow(non_snake_case)]
|
||||||
use bytemuck::cast_slice;
|
use bytemuck::cast_slice;
|
||||||
use manifold::*;
|
use manifold::*;
|
||||||
|
use manifold::contract;
|
||||||
|
|
||||||
fn tensor_product() {
|
fn tensor_product() {
|
||||||
println!("Tensor Product\n");
|
println!("Tensor Product\n");
|
||||||
@ -20,7 +21,7 @@ fn tensor_product() {
|
|||||||
println!("T1 * T2 = {}", product);
|
println!("T1 * T2 = {}", product);
|
||||||
|
|
||||||
// Check shape of the resulting tensor
|
// Check shape of the resulting tensor
|
||||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
assert_eq!(product.shape(), &Shape::new([2, 2, 2]));
|
||||||
|
|
||||||
// Check buffer of the resulting tensor
|
// Check buffer of the resulting tensor
|
||||||
let expect: &[i32] =
|
let expect: &[i32] =
|
||||||
@ -28,39 +29,34 @@ fn tensor_product() {
|
|||||||
assert_eq!(product.buffer(), expect);
|
assert_eq!(product.buffer(), expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tensor_contraction() {
|
fn test_tensor_contraction_23x32() {
|
||||||
println!("Tensor Contraction\n");
|
// Define two 2D tensors (matrices)
|
||||||
// Create two tensors
|
|
||||||
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
|
||||||
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
|
||||||
|
|
||||||
// Specify axes for contraction
|
// Tensor A is 2x3
|
||||||
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||||
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
println!("a: {}", a);
|
||||||
|
|
||||||
// Perform contraction
|
// Tensor B is 3x2
|
||||||
// let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||||
|
println!("b: {}", b);
|
||||||
|
|
||||||
// println!("T1: {}", tensor1);
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
// println!("T2: {}", tensor2);
|
let ctr10 = contract((&a, [1]), (&b, [0]));
|
||||||
// println!("T1 * T2 = {}", result);
|
|
||||||
|
|
||||||
// Expected result, for example, could be a single number or a new tensor,
|
println!("[1, 0]: {}", ctr10);
|
||||||
// depending on how you defined the contraction operation.
|
|
||||||
// Assert the result is as expected
|
|
||||||
// assert_eq!(result, expected_result);
|
|
||||||
|
|
||||||
// let Λ = Tensor::<f64, 2>::from([
|
let ctr01 = contract((&a, [0]), (&b, [1]));
|
||||||
// [1.0, 0.0, 0.0, 0.0],
|
|
||||||
// [0.0, 1.0, 0.0 ,0.0],
|
|
||||||
// [0.0, 0.0, 1.0, 0.0],
|
|
||||||
// [0.0, 0.0, 0.0, 1.0]
|
|
||||||
// ]);
|
|
||||||
|
|
||||||
// println!("Λ: {}", Λ);
|
println!("[0, 1]: {}", ctr01);
|
||||||
|
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||||
|
// assert_eq!(
|
||||||
|
// contracted_tensor.buffer(),
|
||||||
|
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||||
|
// "Contracted tensor buffer does not match expected"
|
||||||
|
// );
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
tensor_product();
|
// tensor_product();
|
||||||
tensor_contraction();
|
test_tensor_contraction_23x32();
|
||||||
}
|
}
|
||||||
|
102
src/axis.rs
102
src/axis.rs
@ -134,37 +134,25 @@ pub fn contract<
|
|||||||
const S: usize,
|
const S: usize,
|
||||||
const N: usize,
|
const N: usize,
|
||||||
>(
|
>(
|
||||||
lhs: &'a Tensor<T, R>,
|
lhs: (&'a Tensor<T, R>, [usize; N]),
|
||||||
rhs: &'a Tensor<T, S>,
|
rhs: (&'a Tensor<T, S>, [usize; N]),
|
||||||
laxes: [Axis<'a, T, R>; N],
|
|
||||||
raxes: [Axis<'a, T, S>; N],
|
|
||||||
) -> Tensor<T, { R + S - 2 * N }>
|
) -> Tensor<T, { R + S - 2 * N }>
|
||||||
where
|
where
|
||||||
[(); R - N]:,
|
[(); R - N]:,
|
||||||
[(); S - N]:,
|
[(); S - N]:,
|
||||||
[(); R + S - 2 * N]:,
|
[(); R + S - 2 * N]:,
|
||||||
{
|
{
|
||||||
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
let (lhs, la) = lhs;
|
||||||
laxes
|
let (rhs, ra) = rhs;
|
||||||
.iter()
|
let raxes = ra.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||||
.map(|axis| *axis.dim())
|
let raxes: [Axis<'a, T, S>; N] =
|
||||||
.collect::<Vec<_>>()
|
raxes.try_into().expect("Failed to create raxes array");
|
||||||
.try_into()
|
let laxes = la.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||||
.unwrap(),
|
let laxes: [Axis<'a, T, R>; N] =
|
||||||
);
|
laxes.try_into().expect("Failed to create laxes array");
|
||||||
let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
|
||||||
raxes
|
|
||||||
.iter()
|
|
||||||
.map(|axis| *axis.dim())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.try_into()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut shape = Vec::new();
|
let mut shape = Vec::new();
|
||||||
shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||||
shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||||
|
|
||||||
let shape: [usize; R + S - 2 * N] =
|
let shape: [usize; R + S - 2 * N] =
|
||||||
shape.try_into().expect("Failed to create shape array");
|
shape.try_into().expect("Failed to create shape array");
|
||||||
|
|
||||||
@ -205,27 +193,20 @@ where
|
|||||||
|
|
||||||
let axes = laxes.into_iter().zip(raxes);
|
let axes = laxes.into_iter().zip(raxes);
|
||||||
|
|
||||||
for (laxis, raxis) in axes {
|
for (laxis, raxis) in axes {
|
||||||
let mut axes_result: Vec<T> = vec![];
|
let mut axes_result: Vec<T> = vec![];
|
||||||
for i in 0..raxis.len() {
|
for i in 0..raxis.len() {
|
||||||
println!("raxis: {}", i);
|
for j in 0..laxis.len() {
|
||||||
for j in 0..laxis.len() {
|
let mut sum = T::zero();
|
||||||
println!("laxis: {}", j);
|
let llevel = laxis.into_iter();
|
||||||
let mut sum = T::zero();
|
let rlevel = raxis.into_iter();
|
||||||
let llevel = laxis.into_iter();
|
let zip = llevel.level(j).zip(rlevel.level(i));
|
||||||
let llevel = llevel.level(j);
|
for (lv, rv) in zip {
|
||||||
let rlevel = raxis.into_iter();
|
sum = sum + *lv * *rv;
|
||||||
let rlevel = rlevel.level(i);
|
}
|
||||||
let zip = llevel.zip(rlevel);
|
axes_result.push(sum);
|
||||||
for (lv, rv) in zip {
|
}
|
||||||
println!("{} * {} = {}", lv, rv, *lv * *rv);
|
}
|
||||||
println!("{} + {} = {}", sum, *lv * *rv, sum + *lv * *rv);
|
|
||||||
sum = sum + *lv * *rv;
|
|
||||||
}
|
|
||||||
println!("sum: {}", sum);
|
|
||||||
axes_result.push(sum);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
result.push(axes_result);
|
result.push(axes_result);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -246,9 +227,7 @@ mod tests {
|
|||||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||||
|
|
||||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
let contracted_tensor: Tensor<i32, 2> =
|
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
|
||||||
|
|
||||||
assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2]));
|
assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2]));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
contracted_tensor.buffer(),
|
contracted_tensor.buffer(),
|
||||||
@ -260,16 +239,19 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_tensor_contraction_23x32() {
|
fn test_tensor_contraction_23x32() {
|
||||||
// Define two 2D tensors (matrices)
|
// Define two 2D tensors (matrices)
|
||||||
// Tensor A is 2x3
|
|
||||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
|
||||||
|
|
||||||
// Tensor B is 1x3x2
|
// Tensor A is 2x3
|
||||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||||
|
println!("b: {}", b);
|
||||||
|
|
||||||
|
// Tensor B is 3x2
|
||||||
|
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||||
|
println!("a: {}", a);
|
||||||
|
|
||||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
let contracted_tensor: Tensor<i32, 2> =
|
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
|
||||||
|
|
||||||
|
println!("contracted_tensor: {}", contracted_tensor);
|
||||||
assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
contracted_tensor.buffer(),
|
contracted_tensor.buffer(),
|
||||||
@ -278,6 +260,24 @@ mod tests {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tensor_contraction_rank3() {
|
||||||
|
let a: Tensor<i32, 3> =
|
||||||
|
Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
|
||||||
|
let b: Tensor<i32, 3> =
|
||||||
|
Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
|
||||||
|
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
|
||||||
|
|
||||||
|
println!("a: {}", a);
|
||||||
|
println!("b: {}", b);
|
||||||
|
println!("contracted_tensor: {}", contracted_tensor);
|
||||||
|
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
|
||||||
|
// Verify specific elements of contracted_tensor
|
||||||
|
// assert_eq!(contracted_tensor[0][0][0][0], 50);
|
||||||
|
// assert_eq!(contracted_tensor[0][0][0][1], 60);
|
||||||
|
// ... further checks for other elements ...
|
||||||
|
}
|
||||||
|
|
||||||
// #[test]
|
// #[test]
|
||||||
// fn test_axis_iterator_disassemble() {
|
// fn test_axis_iterator_disassemble() {
|
||||||
// // Creating a 2x2 Tensor for testing
|
// // Creating a 2x2 Tensor for testing
|
||||||
|
25
src/shape.rs
25
src/shape.rs
@ -99,6 +99,31 @@ impl<const R: usize> Shape<R> {
|
|||||||
new_index += 1;
|
new_index += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Shape(new_shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||||
|
&self,
|
||||||
|
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
|
||||||
|
) -> Shape<{ R - NAX }> {
|
||||||
|
// Create a new array to store the remaining dimensions
|
||||||
|
let mut new_shape = [0; R - NAX];
|
||||||
|
let mut new_index = 0;
|
||||||
|
|
||||||
|
// Iterate over the original dimensions
|
||||||
|
for (index, &dim) in self.0.iter().enumerate() {
|
||||||
|
// Skip dimensions that are in the axes_to_remove array
|
||||||
|
for axis in axes_to_remove {
|
||||||
|
if *axis.dim() == index {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the dimension to the new shape array
|
||||||
|
new_shape[new_index] = dim;
|
||||||
|
new_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
Shape(new_shape)
|
Shape(new_shape)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -51,6 +51,10 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
&self.buffer[index.flat()]
|
&self.buffer[index.flat()]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
|
||||||
|
Axis::new(self, axis)
|
||||||
|
}
|
||||||
|
|
||||||
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
||||||
self.buffer.get_unchecked(index.flat())
|
self.buffer.get_unchecked(index.flat())
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user