Contraction works for simple 2x2 exmaple

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-29 05:33:56 +02:00
parent 4626504521
commit 6ec13ecb6e
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B

View File

@ -215,25 +215,21 @@ where
{
let mut result = vec![];
// Disassemble each axis iterator into iterators over each level
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
let axes = laxes.into_iter().zip(raxes);
for (laxis, raxis) in axes {
let mut axes_result: Vec<T> = vec![];
let levels = laxis.into_iter().zip(raxis.into_iter());
for (ll, rl) in levels {
let mut sum = T::zero();
let values = ll.into_iter().zip(rl.into_iter());
for (lv, rv) in values {
println!("{} + {} * {} = {}", sum, lv, rv, sum + *lv * *rv);
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
for rlevel in raxis {
for llevel in laxis.clone() {
let mut sum = T::zero();
for (lv, rv) in llevel.clone().into_iter().zip(rlevel.clone().into_iter()) {
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
}
result.push(axes_result);
}
@ -257,12 +253,12 @@ mod tests {
let contracted_tensor: Tensor<i32, 2> =
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
// Expect: [22, 28, 49, 64]
// Check if the contracted tensor is as expected
println!("A: {}", a);
println!("B: {}", b);
println!("A * B = {}", contracted_tensor);
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
assert_eq!(contracted_tensor.shape(), Shape::new([2, 2]));
assert_eq!(
contracted_tensor.buffer(),
&[7, 10, 15, 22],
"Contracted tensor buffer does not match expected"
);
}
}