Contracting matrices of different shapes fails
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
6ec13ecb6e
commit
1c85daeb76
346
src/axis.rs
346
src/axis.rs
@ -176,26 +176,22 @@ where
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = contract_axes(
|
||||
li.try_into().unwrap(),
|
||||
ri.try_into().unwrap(),
|
||||
);
|
||||
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
|
||||
|
||||
println!("result: {:?}", result);
|
||||
println!("result: {:?}", result);
|
||||
|
||||
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
|
||||
|
||||
// // Iterate over the remaining axes results (skipping the first one which is already in the tensor)
|
||||
// for axes_result in result.iter().skip(1) {
|
||||
// // Ensure the current axes_result has the same length as the tensor's buffer
|
||||
// // assert_eq!(t.shape().size(), axes_result.len(), "Buffer size mismatch");
|
||||
// Iterate over the remaining axes results (skipping the first one already used)
|
||||
for axis_result in result.iter().skip(1) {
|
||||
// Iterate through each element of the axis_result and add it to the corresponding element in t's buffer
|
||||
for (i, val) in axis_result.iter().enumerate() {
|
||||
unsafe {
|
||||
*t.get_flat_unchecked_mut(i) = *t.get_flat(i) + *val;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// // Perform the arithmetic operation for each element
|
||||
// for (i, &value) in axes_result.iter().enumerate() {
|
||||
// // Modify this line to choose between summing or multiplying
|
||||
// t[i] = t[i] + value;
|
||||
// }
|
||||
// }
|
||||
t
|
||||
}
|
||||
|
||||
@ -217,19 +213,21 @@ where
|
||||
|
||||
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
|
||||
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
|
||||
let axes = laxes.into_iter().zip(raxes);
|
||||
let axes = laxes.zip(raxes);
|
||||
|
||||
for (laxis, raxis) in axes {
|
||||
let mut axes_result: Vec<T> = vec![];
|
||||
for rlevel in raxis {
|
||||
for llevel in laxis.clone() {
|
||||
let mut sum = T::zero();
|
||||
for (lv, rv) in llevel.clone().into_iter().zip(rlevel.clone().into_iter()) {
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
axes_result.push(sum);
|
||||
}
|
||||
}
|
||||
for rlevel in raxis {
|
||||
for llevel in laxis.clone() {
|
||||
let mut sum = T::zero();
|
||||
for (lv, rv) in
|
||||
llevel.clone().into_iter().zip(rlevel.clone().into_iter())
|
||||
{
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
axes_result.push(sum);
|
||||
}
|
||||
}
|
||||
result.push(axes_result);
|
||||
}
|
||||
|
||||
@ -254,144 +252,164 @@ mod tests {
|
||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
||||
|
||||
assert_eq!(contracted_tensor.shape(), Shape::new([2, 2]));
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[7, 10, 15, 22],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[7, 10, 15, 22],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_contraction_23x32() {
|
||||
// Define two 2D tensors (matrices)
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||
|
||||
// Tensor B is 1x3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let contracted_tensor: Tensor<i32, 2> =
|
||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
||||
|
||||
assert_eq!(contracted_tensor.shape(), Shape::new([3, 3]));
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator_disassemble() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_3d_tensor_axis_iteration() {
|
||||
// Create a 3D Tensor with specific values
|
||||
// Tensor shape is 2x2x2 for simplicity
|
||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// Axis 0 (Layer-wise):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
//
|
||||
// This order suggests that for each "layer" (first level of arrays),
|
||||
// the iterator goes through all rows and columns. It first completes
|
||||
// the entire first layer, then moves to the second.
|
||||
|
||||
let a0 = Axis::new(&t, 0);
|
||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// Axis 1 (Row-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first row across all layers, then the second row
|
||||
// across all layers.
|
||||
|
||||
let a1 = Axis::new(&t, 1);
|
||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// Axis 2 (Column-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][1][0] = 3
|
||||
// t[1][0][0] = 5
|
||||
// t[1][1][0] = 7
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][1] = 8
|
||||
// [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first column across all layers, then the second
|
||||
// column across all layers.
|
||||
|
||||
let a2 = Axis::new(&t, 2);
|
||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator_disassemble() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_3d_tensor_axis_iteration() {
|
||||
// Create a 3D Tensor with specific values
|
||||
// Tensor shape is 2x2x2 for simplicity
|
||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// Axis 0 (Layer-wise):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
//
|
||||
// This order suggests that for each "layer" (first level of arrays),
|
||||
// the iterator goes through all rows and columns. It first completes
|
||||
// the entire first layer, then moves to the second.
|
||||
|
||||
let a0 = Axis::new(&t, 0);
|
||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// Axis 1 (Row-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first row across all layers, then the second row
|
||||
// across all layers.
|
||||
|
||||
let a1 = Axis::new(&t, 1);
|
||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// Axis 2 (Column-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][1][0] = 3
|
||||
// t[1][0][0] = 5
|
||||
// t[1][1][0] = 7
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][1] = 8
|
||||
// [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first column across all layers, then the second
|
||||
// column across all layers.
|
||||
|
||||
let a2 = Axis::new(&t, 2);
|
||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
}
|
||||
// }
|
||||
|
@ -72,6 +72,22 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer.get_unchecked_mut(index.flat())
|
||||
}
|
||||
|
||||
pub fn get_flat(&self, index: usize) -> &T {
|
||||
&self.buffer[index]
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
|
||||
self.buffer.get_unchecked(index)
|
||||
}
|
||||
|
||||
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
|
||||
self.buffer.get_mut(index)
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
|
||||
self.buffer.get_unchecked_mut(index)
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user