Contracting matrices of different shapes fails

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-29 05:52:18 +02:00
parent 6ec13ecb6e
commit 1c85daeb76
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
2 changed files with 198 additions and 164 deletions

View File

@ -176,26 +176,22 @@ where
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let result = contract_axes(
li.try_into().unwrap(),
ri.try_into().unwrap(),
);
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
println!("result: {:?}", result);
println!("result: {:?}", result);
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
// // Iterate over the remaining axes results (skipping the first one which is already in the tensor)
// for axes_result in result.iter().skip(1) {
// // Ensure the current axes_result has the same length as the tensor's buffer
// // assert_eq!(t.shape().size(), axes_result.len(), "Buffer size mismatch");
// Iterate over the remaining axes results (skipping the first one already used)
for axis_result in result.iter().skip(1) {
// Iterate through each element of the axis_result and add it to the corresponding element in t's buffer
for (i, val) in axis_result.iter().enumerate() {
unsafe {
*t.get_flat_unchecked_mut(i) = *t.get_flat(i) + *val;
}
}
}
// // Perform the arithmetic operation for each element
// for (i, &value) in axes_result.iter().enumerate() {
// // Modify this line to choose between summing or multiplying
// t[i] = t[i] + value;
// }
// }
t
}
@ -217,19 +213,21 @@ where
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
let axes = laxes.into_iter().zip(raxes);
let axes = laxes.zip(raxes);
for (laxis, raxis) in axes {
let mut axes_result: Vec<T> = vec![];
for rlevel in raxis {
for llevel in laxis.clone() {
let mut sum = T::zero();
for (lv, rv) in llevel.clone().into_iter().zip(rlevel.clone().into_iter()) {
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
}
for rlevel in raxis {
for llevel in laxis.clone() {
let mut sum = T::zero();
for (lv, rv) in
llevel.clone().into_iter().zip(rlevel.clone().into_iter())
{
sum = sum + *lv * *rv;
}
axes_result.push(sum);
}
}
result.push(axes_result);
}
@ -254,144 +252,164 @@ mod tests {
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
assert_eq!(contracted_tensor.shape(), Shape::new([2, 2]));
assert_eq!(
contracted_tensor.buffer(),
&[7, 10, 15, 22],
"Contracted tensor buffer does not match expected"
);
assert_eq!(
contracted_tensor.buffer(),
&[7, 10, 15, 22],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_tensor_contraction_23x32() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
// Tensor B is 1x3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> =
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
assert_eq!(contracted_tensor.shape(), Shape::new([3, 3]));
assert_eq!(
contracted_tensor.buffer(),
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
"Contracted tensor buffer does not match expected"
);
}
#[test]
fn test_axis_iterator_disassemble() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter().disassemble();
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&2.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&3.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter().disassemble();
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&3.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&2.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
}
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
#[test]
fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// Axis 0 (Layer-wise):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8]
//
// This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second.
let a0 = Axis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// Axis 1 (Row-wise within each layer):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row
// across all layers.
let a1 = Axis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// Axis 2 (Column-wise within each layer):
//
// t[0][0][0] = 1
// t[0][1][0] = 3
// t[1][0][0] = 5
// t[1][1][0] = 7
// t[0][0][1] = 2
// t[0][1][1] = 4
// t[1][0][1] = 6
// t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second
// column across all layers.
let a2 = Axis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
}
}
#[test]
fn test_axis_iterator_disassemble() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter().disassemble();
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&2.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&3.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter().disassemble();
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&3.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&2.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
}
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
#[test]
fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// Axis 0 (Layer-wise):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8]
//
// This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second.
let a0 = Axis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// Axis 1 (Row-wise within each layer):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row
// across all layers.
let a1 = Axis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// Axis 2 (Column-wise within each layer):
//
// t[0][0][0] = 1
// t[0][1][0] = 3
// t[1][0][0] = 5
// t[1][1][0] = 7
// t[0][0][1] = 2
// t[0][1][1] = 4
// t[1][0][1] = 6
// t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second
// column across all layers.
let a2 = Axis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
}
// }

View File

@ -72,6 +72,22 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer.get_unchecked_mut(index.flat())
}
pub fn get_flat(&self, index: usize) -> &T {
&self.buffer[index]
}
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer.get_unchecked(index)
}
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer.get_mut(index)
}
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer.get_unchecked_mut(index)
}
pub fn rank(&self) -> usize {
R
}