Working on contract_axes

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-29 05:22:35 +02:00
parent d54179cd5d
commit 4626504521
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
4 changed files with 210 additions and 427 deletions

View File

@ -28,7 +28,9 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
assert!(level < self.len, "Level out of bounds");
let mut index = Idx::new(self.shape.clone(), [0; R]);
index.set_axis(self.dim, level);
AxisIterator::new(self.clone()).set_start(level).set_end(level + 1)
AxisIterator::new(self.clone())
.set_start(level)
.set_end(level + 1)
}
}
@ -38,7 +40,7 @@ pub struct AxisIterator<'a, T: Value, const R: usize> {
axis: Axis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")]
index: Idx<R>,
#[getset(get = "pub")]
#[getset(get = "pub")]
end: Option<usize>,
}
@ -84,31 +86,30 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
result
}
pub fn axis_max_idx(&self) -> usize {
self.end().unwrap_or(*self.axis().len())
}
pub fn axis_max_idx(&self) -> usize {
self.end().unwrap_or(*self.axis().len())
}
pub fn axis_idx(&self) -> usize {
self.index().get_axis(*self.axis().dim())
}
pub fn axis_idx(&self) -> usize {
self.index().get_axis(*self.axis().dim())
}
pub fn axis_dim(&self) -> usize {
self.axis().dim().clone()
}
pub fn axis_dim(&self) -> usize {
self.axis().dim().clone()
}
}
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.axis_idx() == self.axis_max_idx() {
return None;
}
// Increment the index along the fixed axis and check if it's within bounds
if self.axis_idx() == self.axis_max_idx() {
return None;
}
let result = self.axis().tensor().get(self.index);
let axis_dim = self.axis_dim();
self.index_mut().inc_axis(axis_dim);
Some(result)
let axis_dim = self.axis_dim();
self.index_mut().inc_axis(axis_dim);
Some(result)
}
}
@ -138,19 +139,6 @@ where
[(); S - N]:,
[(); R + S - 2 * N]:,
{
let li = laxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let ri = raxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
laxes
.iter()
@ -168,7 +156,6 @@ where
.unwrap(),
);
// Step 5: Concatenate the shapes to form the shape of the resultant tensor
let mut shape = Vec::new();
shape.extend_from_slice(&lhs_shape_reduced.as_array());
shape.extend_from_slice(&rhs_shape_reduced.as_array());
@ -178,7 +165,38 @@ where
let shape = Shape::new(shape);
Tensor::new_with_buffer(shape, result)
let li = laxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let ri = raxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let result = contract_axes(
li.try_into().unwrap(),
ri.try_into().unwrap(),
);
println!("result: {:?}", result);
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
// // Iterate over the remaining axes results (skipping the first one which is already in the tensor)
// for axes_result in result.iter().skip(1) {
// // Ensure the current axes_result has the same length as the tensor's buffer
// // assert_eq!(t.shape().size(), axes_result.len(), "Buffer size mismatch");
// // Perform the arithmetic operation for each element
// for (i, &value) in axes_result.iter().enumerate() {
// // Modify this line to choose between summing or multiplying
// t[i] = t[i] + value;
// }
// }
t
}
pub fn contract_axes<
@ -190,242 +208,38 @@ pub fn contract_axes<
>(
laxes: [AxisIterator<'a, T, R>; N],
raxes: [AxisIterator<'a, T, S>; N],
) -> Vec<T>
) -> Vec<Vec<T>>
where
[(); R - N]:,
[(); S - N]:,
// [(); R + S - 2 * N]:,
{
let mut result = Vec::new();
let mut result = vec![];
// Disassemble each axis iterator into iterators over each level
let laxes: Vec<Vec<AxisIterator<'a, T, R>>> =
laxes.into_iter().map(|axis| axis.disassemble()).collect();
let raxes: Vec<Vec<AxisIterator<'a, T, S>>> =
raxes.into_iter().map(|axis| axis.disassemble()).collect();
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
let axes = laxes.into_iter().zip(raxes);
// Iterate over all combinations of levels for contracted axes
for (laxis, raxis) in axes {
let mut axes_result: Vec<T> = vec![];
let levels = laxis.into_iter().zip(raxis.into_iter());
for (ll, rl) in levels {
let mut sum = T::zero();
let values = ll.into_iter().zip(rl.into_iter());
let lcart = laxes.into_iter().multi_cartesian_product();
let rcart = raxes.into_iter().multi_cartesian_product();
let cartesian = lcart.zip(rcart);
for (i, (l, r)) in cartesian.enumerate() {
println!("{} cartesian:", i);
let mut sum = T::zero();
let level = l.into_iter().zip(r.into_iter());
for (i, (l, r)) in level.enumerate() {
// let mut sum = T::zero();
println!("{} level:", i);
let value = l.into_iter().zip(r.into_iter());
for (i, (l, r)) in value.enumerate() {
println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r);
sum = sum + *l * *r;
for (lv, rv) in values {
println!("{} + {} * {} = {}", sum, lv, rv, sum + *lv * *rv);
sum = sum + *lv * *rv;
}
// result.push(sum);
}
result.push(sum);
axes_result.push(sum);
}
result.push(axes_result);
}
result
}
// pub fn contract_axes<
// 'a,
// T: Value + std::fmt::Debug,
// const R: usize,
// const S: usize,
// const N: usize,
// >(
// // Axes are (contracted, non-contracted)
// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]),
// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]),
// ) -> Vec<T> //Tensor<T, { R + S - 2 * N }>
// where
// [(); R - N]:,
// [(); S - N]:,
// [(); R + S - 2 * N]:,
// {
// let (lc, lnc) = laxes;
// let (rc, rnc) = raxes;
// // Step 1: Prepare cartesian products of contracted axes
// let lc_prod = lc.into_iter().multi_cartesian_product();
// let rc_prod = rc.into_iter().multi_cartesian_product();
// // Step 2: Prepare iterators for non-contracted axes
// let lnc_prod = lnc.into_iter().multi_cartesian_product();
// let rnc_prod = rnc.into_iter().multi_cartesian_product();
// // Initialize buffer for the resulting tensor
// let mut result_buffer = Vec::new();
// // Step 3: Iterate over combinations and compute tensor product
// for lnc_vals in lnc_prod {
// for rnc_vals in rnc_prod.clone() {
// let mut sum = T::zero();
// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod
// .clone()
// .zip(rc_prod.clone())
// .zip(lnc_vals.iter().zip(rnc_vals.iter()))
// {
// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) {
// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val);
// }
// }
// result_buffer.push(sum);
// }
// }
// result_buffer
// }
// // Check that there are no duplicate axes
// for (i, laxis) in laxes.iter().enumerate() {
// for (j, laxisx) in laxes.iter().enumerate() {
// if i != j && laxis == laxisx {
// panic!("Duplicate axes");
// }
// }
// }
// for (i, raxis) in raxes.iter().enumerate() {
// for (j, raxisx) in raxes.iter().enumerate() {
// if i != j && raxis == raxisx {
// panic!("Duplicate axes");
// }
// }
// }
// // Check that the sizes of the axes to contract are the same
// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) {
// // // Check that axes are within bounds
// // if lhs.shape()[*laxis] != *raxis {
// // println!("laxis: {}, raxis: {}", laxis, raxis);
// // panic!("Axis out of bounds");
// // }
// // }
// // 1. Axis iterators
// //
// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors
// // and take the Cartesian product of the iterators.
// let lcartesian_contract = lhs
// .shape()
// .iter()
// .enumerate()
// .filter(|(i, _)| laxes.contains(i))
// .map(|(i, _)| {
// println!("LHS is adding axis: {}", i);
// let vals = Axis::new(lhs.clone(), i).into_iter().collect::<Vec<_>>();
// println!("vals: {:?}", vals);
// })
// .collect::<Vec<_>>();
// for l in lcartesian_contract.iter() {
// let l = l.clone();
// println!("{:?}", l);
// }
// // .multi_cartesian_product();
// // let cart_product = lcartesian_contract.multi_cartesian_product();
// // for l in cart_product {
// // println!("cartesian_contract l:");
// // for r in l {
// // println!("l: {:?}", r);
// // }
// // }
// let rcartesian_contract = rhs
// .shape()
// .iter()
// .enumerate()
// .filter(|(i, _)| raxes.contains(i))
// .map(|(i, _)| {
// println!("RHS is adding axis: {}", i);
// let vals = Axis::new(rhs.clone(), i).into_iter().collect::<Vec<_>>();
// println!("vals: {:?}", vals);
// })
// .collect::<Vec<_>>();
// for r in rcartesian_contract.iter() {
// let r = r.clone();
// println!("{:?}", r);
// }
// // .multi_cartesian_product();
// // println!("lcartesian_contract: {:?}", lcartesian_contract);
// // println!("rcartesian_contract: {:?}", rcartesian_contract);
// // Initialize buffer for the resulting tensor
// // let mut result_buffer = Vec::new();
// // let zip_contract = lcartesian_contract.zip(rcartesian_contract);
// // for (i, (j, k)) in zip_contract.enumerate() {
// // let mut sum = T::zero();
// // print!("{}. sum = ", i);
// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) {
// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
// // // sum = sum + lhsnc * rhsnc;
// // // }
// // for lhsnc in j.iter() {
// // for rhsnc in k.iter() {
// // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
// // sum = sum + *lhsnc * *rhsnc;
// // }
// // }
// // print!("= {}\n", sum);
// // result_buffer.push(sum);
// // }
// // Iterate over non-contracted axes combinations
// // for (lhsnci, rhsnci) in
// // lcartesian_contract.zip(rcartesian_contract)
// // {
// // let mut sum = T::zero();
// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) {
// // sum = sum + *lhsnc * *rhsnc;
// // }
// // // Append the result to the buffer
// // result_buffer.push(sum);
// // }
// // TODO! Implement the rest of the function
// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
// laxes
// .iter()
// .map(|axis| *axis)
// .collect::<Vec<_>>()
// .try_into()
// .unwrap(),
// );
// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
// raxes
// .iter()
// .map(|axis| *axis)
// .collect::<Vec<_>>()
// .try_into()
// .unwrap(),
// );
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
// let mut shape = Vec::new();
// shape.extend_from_slice(&lhs_shape_reduced.as_array());
// shape.extend_from_slice(&rhs_shape_reduced.as_array());
// let shape: [usize; R + S - 2 * N] =
// shape.try_into().expect("Failed to create shape array");
// let shape = Shape::new(shape);
// Tensor::new_with_buffer(shape, vec![])
// }
#[cfg(test)]
mod tests {
use super::*;
@ -452,130 +266,136 @@ mod tests {
}
}
// #[test]
// fn test_tensor_contraction() {
// // Define two 2D tensors (matrices)
// // Tensor A is 2x3
// let tensor_a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
// println!("A: {}", tensor_a);
#[test]
fn test_axis_iterator_disassemble() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// // Tensor B is 1x3x2
// let tensor_b: Tensor<i32, 3> = Tensor::from([[[1, 2], [3, 4], [5, 6]]]);
// println!("B: {}", tensor_b);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
// let contracted_tensor: Tensor<i32, 3> =
// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]);
let mut axis_iter = axis.into_iter().disassemble();
// // Expect: [22, 28, 49, 64]
// // Check if the contracted tensor is as expected
// println!("A: {}", tensor_a);
// println!("B: {}", tensor_b);
// println!("A * B = {}", contracted_tensor);
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
// }
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&2.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&3.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter().disassemble();
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter[0].next(), Some(&1.0));
assert_eq!(axis_iter[0].next(), Some(&3.0));
assert_eq!(axis_iter[0].next(), None);
assert_eq!(axis_iter[1].next(), Some(&2.0));
assert_eq!(axis_iter[1].next(), Some(&4.0));
assert_eq!(axis_iter[1].next(), None);
}
axis_iter.for_each(|x| println!("t1: {}", x));
// assert_eq!(axis_iter.next(), Some(&1.0));
// assert_eq!(axis_iter.next(), Some(&2.0));
// assert_eq!(axis_iter.next(), Some(&3.0));
// assert_eq!(axis_iter.next(), Some(&4.0));
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
axis_iter.for_each(|x| println!("t2: {}", x));
// assert_eq!(axis_iter.next(), Some(&1.0));
// assert_eq!(axis_iter.next(), Some(&3.0));
// assert_eq!(axis_iter.next(), Some(&2.0));
// assert_eq!(axis_iter.next(), Some(&4.0));
let mut axis_iter = axis.into_iter();
let shape = tensor.shape();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
while a <= b {
println!("a: {}", a);
a.inc();
}
let mut axis_iter = axis.into_iter();
assert_eq!(axis_iter.next(), Some(&1.0));
assert_eq!(axis_iter.next(), Some(&3.0));
assert_eq!(axis_iter.next(), Some(&2.0));
assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
// #[test]
// fn test_3d_tensor_axis_iteration() {
// // Create a 3D Tensor with specific values
// // Tensor shape is 2x2x2 for simplicity
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
#[test]
fn test_3d_tensor_axis_iteration() {
// Create a 3D Tensor with specific values
// Tensor shape is 2x2x2 for simplicity
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// // Axis 0 (Layer-wise):
// //
// // t[0][0][0] = 1
// // t[0][0][1] = 2
// // t[0][1][0] = 3
// // t[0][1][1] = 4
// // t[1][0][0] = 5
// // t[1][0][1] = 6
// // t[1][1][0] = 7
// // t[1][1][1] = 8
// // [1, 2, 3, 4, 5, 6, 7, 8]
// //
// // This order suggests that for each "layer" (first level of arrays),
// // the iterator goes through all rows and columns. It first completes
// // the entire first layer, then moves to the second.
// Axis 0 (Layer-wise):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 3, 4, 5, 6, 7, 8]
//
// This order suggests that for each "layer" (first level of arrays),
// the iterator goes through all rows and columns. It first completes
// the entire first layer, then moves to the second.
// let a0 = Axis::new(t.clone(), 0);
// let a0_order = a0.into_iter().collect::<Vec<_>>();
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
let a0 = Axis::new(&t, 0);
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// // Axis 1 (Row-wise within each layer):
// //
// // t[0][0][0] = 1
// // t[0][0][1] = 2
// // t[1][0][0] = 5
// // t[1][0][1] = 6
// // t[0][1][0] = 3
// // t[0][1][1] = 4
// // t[1][1][0] = 7
// // t[1][1][1] = 8
// // [1, 2, 5, 6, 3, 4, 7, 8]
// //
// // This indicates that within each "layer", the iterator first
// // completes the first row across all layers, then the second row
// // across all layers.
// Axis 1 (Row-wise within each layer):
//
// t[0][0][0] = 1
// t[0][0][1] = 2
// t[1][0][0] = 5
// t[1][0][1] = 6
// t[0][1][0] = 3
// t[0][1][1] = 4
// t[1][1][0] = 7
// t[1][1][1] = 8
// [1, 2, 5, 6, 3, 4, 7, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first row across all layers, then the second row
// across all layers.
// let a1 = Axis::new(t.clone(), 1);
// let a1_order = a1.into_iter().collect::<Vec<_>>();
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
let a1 = Axis::new(&t, 1);
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// // Axis 2 (Column-wise within each layer):
// //
// // t[0][0][0] = 1
// // t[0][1][0] = 3
// // t[1][0][0] = 5
// // t[1][1][0] = 7
// // t[0][0][1] = 2
// // t[0][1][1] = 4
// // t[1][0][1] = 6
// // t[1][1][1] = 8
// // [1, 3, 5, 7, 2, 4, 6, 8]
// //
// // This indicates that within each "layer", the iterator first
// // completes the first column across all layers, then the second
// // column across all layers.
// Axis 2 (Column-wise within each layer):
//
// t[0][0][0] = 1
// t[0][1][0] = 3
// t[1][0][0] = 5
// t[1][1][0] = 7
// t[0][0][1] = 2
// t[0][1][1] = 4
// t[1][0][1] = 6
// t[1][1][1] = 8
// [1, 3, 5, 7, 2, 4, 6, 8]
//
// This indicates that within each "layer", the iterator first
// completes the first column across all layers, then the second
// column across all layers.
// let a2 = Axis::new(t, 2);
// let a2_order = a2.into_iter().collect::<Vec<_>>();
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
// }
let a2 = Axis::new(&t, 2);
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
}
// }

View File

@ -82,26 +82,29 @@ impl<const R: usize> Idx<R> {
pub fn inc_axis(&mut self, fixed_axis: usize) {
assert!(fixed_axis < R, "Axis out of bounds");
assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds");
assert!(self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds");
// Try to increment non-fixed axes
for i in (0..R).rev() {
if i != fixed_axis {
if self.indices[i] < self.shape.get(i) {
if self.indices[i] + 1 < self.shape.get(i) {
self.indices[i] += 1;
return ;
} else {
self.indices[i] = 0;
}
}
}
// Increment the fixed axis and reset other axes to 0
self.indices[fixed_axis] += 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = 0;
}
}
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
self.indices[fixed_axis] += 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = 0;
}
}
return ;
}
}
pub fn dec(&mut self) {

View File

@ -31,6 +31,7 @@ impl<T> Value for T where
+ Display
+ Serialize
+ Deserialize<'static>
+ std::iter::Sum
{
}

View File

@ -127,64 +127,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
}
}
// // Step 1: Validate the axes for both tensors to ensure they are within bounds
// for &axis in &axis_lhs {
// if axis >= R {
// panic!(
// "Axis {} is out of bounds for the left-hand tensor",
// axis
// );
// }
// }
// for &axis in &axis_rhs {
// if axis >= S {
// panic!(
// "Axis {} is out of bounds for the right-hand tensor",
// axis
// );
// }
// }
// // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
// let mut result_buffer = Vec::new();
// for i in self.iter_over_non_contracted_axes(&axis_lhs) {
// for j in rhs.iter_over_non_contracted_axes(&axis_rhs) {
// let mut product_sum = T::zero();
// // Step 3: Contraction over specified axes
// for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
// let idx_l = i[axis_l];
// let idx_r = j[axis_r];
// let value_lhs = self.buffer[self.flat(&idx_l)];
// let value_rhs = rhs.buffer[rhs.flat(&idx_r)];
// product_sum = product_sum + value_lhs * value_rhs;
// }
// result_buffer.push(product_sum);
// }
// }
// // Step 4: Remove contracted dimensions to create new shapes for both tensors
// let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs);
// let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
// let mut new_shape = Vec::new();
// new_shape.extend_from_slice(&new_shape_lhs.as_array());
// new_shape.extend_from_slice(&new_shape_rhs.as_array());
// let new_shape_array: [usize; R + S - NAXL - NAXR] =
// new_shape.try_into().expect("Failed to create shape array");
// Tensor {
// buffer: result_buffer,
// shape: Shape::new(new_shape_array),
// }
// }
// Retrieve an element based on a specific axis and index
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
// Convert axis and index to a flat index
@ -237,6 +179,23 @@ impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
}
}
impl<T: Value, const R: usize> Index<usize> for Tensor<T, R> {
type Output = T;
fn index(&self, index: usize) -> &Self::Output {
&self.buffer[index]
}
}
impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.buffer[index]
}
}
// ---- Display ----
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Print the shape of the tensor