Contraction works for simple 2x2 exmaple
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
4626504521
commit
6ec13ecb6e
34
src/axis.rs
34
src/axis.rs
@ -215,25 +215,21 @@ where
|
||||
{
|
||||
let mut result = vec![];
|
||||
|
||||
// Disassemble each axis iterator into iterators over each level
|
||||
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
|
||||
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
|
||||
let axes = laxes.into_iter().zip(raxes);
|
||||
|
||||
for (laxis, raxis) in axes {
|
||||
let mut axes_result: Vec<T> = vec![];
|
||||
let levels = laxis.into_iter().zip(raxis.into_iter());
|
||||
for (ll, rl) in levels {
|
||||
let mut sum = T::zero();
|
||||
let values = ll.into_iter().zip(rl.into_iter());
|
||||
|
||||
for (lv, rv) in values {
|
||||
println!("{} + {} * {} = {}", sum, lv, rv, sum + *lv * *rv);
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
|
||||
axes_result.push(sum);
|
||||
}
|
||||
for rlevel in raxis {
|
||||
for llevel in laxis.clone() {
|
||||
let mut sum = T::zero();
|
||||
for (lv, rv) in llevel.clone().into_iter().zip(rlevel.clone().into_iter()) {
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
axes_result.push(sum);
|
||||
}
|
||||
}
|
||||
result.push(axes_result);
|
||||
}
|
||||
|
||||
@ -257,12 +253,12 @@ mod tests {
|
||||
let contracted_tensor: Tensor<i32, 2> =
|
||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
||||
|
||||
// Expect: [22, 28, 49, 64]
|
||||
// Check if the contracted tensor is as expected
|
||||
println!("A: {}", a);
|
||||
println!("B: {}", b);
|
||||
println!("A * B = {}", contracted_tensor);
|
||||
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||
assert_eq!(contracted_tensor.shape(), Shape::new([2, 2]));
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[7, 10, 15, 22],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user