Working on contract_axes
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
d54179cd5d
commit
4626504521
540
src/axis.rs
540
src/axis.rs
@ -28,7 +28,9 @@ impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
||||
assert!(level < self.len, "Level out of bounds");
|
||||
let mut index = Idx::new(self.shape.clone(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
AxisIterator::new(self.clone()).set_start(level).set_end(level + 1)
|
||||
AxisIterator::new(self.clone())
|
||||
.set_start(level)
|
||||
.set_end(level + 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -38,7 +40,7 @@ pub struct AxisIterator<'a, T: Value, const R: usize> {
|
||||
axis: Axis<'a, T, R>,
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
index: Idx<R>,
|
||||
#[getset(get = "pub")]
|
||||
#[getset(get = "pub")]
|
||||
end: Option<usize>,
|
||||
}
|
||||
|
||||
@ -84,31 +86,30 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
||||
result
|
||||
}
|
||||
|
||||
pub fn axis_max_idx(&self) -> usize {
|
||||
self.end().unwrap_or(*self.axis().len())
|
||||
}
|
||||
pub fn axis_max_idx(&self) -> usize {
|
||||
self.end().unwrap_or(*self.axis().len())
|
||||
}
|
||||
|
||||
pub fn axis_idx(&self) -> usize {
|
||||
self.index().get_axis(*self.axis().dim())
|
||||
}
|
||||
pub fn axis_idx(&self) -> usize {
|
||||
self.index().get_axis(*self.axis().dim())
|
||||
}
|
||||
|
||||
pub fn axis_dim(&self) -> usize {
|
||||
self.axis().dim().clone()
|
||||
}
|
||||
pub fn axis_dim(&self) -> usize {
|
||||
self.axis().dim().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
// Increment the index along the fixed axis and check if it's within bounds
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
let result = self.axis().tensor().get(self.index);
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
@ -138,19 +139,6 @@ where
|
||||
[(); S - N]:,
|
||||
[(); R + S - 2 * N]:,
|
||||
{
|
||||
let li = laxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
let ri = raxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
|
||||
|
||||
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||
laxes
|
||||
.iter()
|
||||
@ -168,7 +156,6 @@ where
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||
shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||
@ -178,7 +165,38 @@ where
|
||||
|
||||
let shape = Shape::new(shape);
|
||||
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
let li = laxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
let ri = raxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = contract_axes(
|
||||
li.try_into().unwrap(),
|
||||
ri.try_into().unwrap(),
|
||||
);
|
||||
|
||||
println!("result: {:?}", result);
|
||||
|
||||
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
|
||||
|
||||
// // Iterate over the remaining axes results (skipping the first one which is already in the tensor)
|
||||
// for axes_result in result.iter().skip(1) {
|
||||
// // Ensure the current axes_result has the same length as the tensor's buffer
|
||||
// // assert_eq!(t.shape().size(), axes_result.len(), "Buffer size mismatch");
|
||||
|
||||
// // Perform the arithmetic operation for each element
|
||||
// for (i, &value) in axes_result.iter().enumerate() {
|
||||
// // Modify this line to choose between summing or multiplying
|
||||
// t[i] = t[i] + value;
|
||||
// }
|
||||
// }
|
||||
t
|
||||
}
|
||||
|
||||
pub fn contract_axes<
|
||||
@ -190,242 +208,38 @@ pub fn contract_axes<
|
||||
>(
|
||||
laxes: [AxisIterator<'a, T, R>; N],
|
||||
raxes: [AxisIterator<'a, T, S>; N],
|
||||
) -> Vec<T>
|
||||
) -> Vec<Vec<T>>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
// [(); R + S - 2 * N]:,
|
||||
{
|
||||
let mut result = Vec::new();
|
||||
let mut result = vec![];
|
||||
|
||||
// Disassemble each axis iterator into iterators over each level
|
||||
let laxes: Vec<Vec<AxisIterator<'a, T, R>>> =
|
||||
laxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||
let raxes: Vec<Vec<AxisIterator<'a, T, S>>> =
|
||||
raxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||
let laxes = laxes.into_iter().map(|axis| axis.disassemble());
|
||||
let raxes = raxes.into_iter().map(|axis| axis.disassemble());
|
||||
let axes = laxes.into_iter().zip(raxes);
|
||||
|
||||
// Iterate over all combinations of levels for contracted axes
|
||||
for (laxis, raxis) in axes {
|
||||
let mut axes_result: Vec<T> = vec![];
|
||||
let levels = laxis.into_iter().zip(raxis.into_iter());
|
||||
for (ll, rl) in levels {
|
||||
let mut sum = T::zero();
|
||||
let values = ll.into_iter().zip(rl.into_iter());
|
||||
|
||||
let lcart = laxes.into_iter().multi_cartesian_product();
|
||||
let rcart = raxes.into_iter().multi_cartesian_product();
|
||||
let cartesian = lcart.zip(rcart);
|
||||
|
||||
for (i, (l, r)) in cartesian.enumerate() {
|
||||
println!("{} cartesian:", i);
|
||||
let mut sum = T::zero();
|
||||
|
||||
let level = l.into_iter().zip(r.into_iter());
|
||||
|
||||
for (i, (l, r)) in level.enumerate() {
|
||||
// let mut sum = T::zero();
|
||||
println!("{} level:", i);
|
||||
let value = l.into_iter().zip(r.into_iter());
|
||||
|
||||
for (i, (l, r)) in value.enumerate() {
|
||||
println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r);
|
||||
sum = sum + *l * *r;
|
||||
for (lv, rv) in values {
|
||||
println!("{} + {} * {} = {}", sum, lv, rv, sum + *lv * *rv);
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
// result.push(sum);
|
||||
}
|
||||
|
||||
result.push(sum);
|
||||
axes_result.push(sum);
|
||||
}
|
||||
result.push(axes_result);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// pub fn contract_axes<
|
||||
// 'a,
|
||||
// T: Value + std::fmt::Debug,
|
||||
// const R: usize,
|
||||
// const S: usize,
|
||||
// const N: usize,
|
||||
// >(
|
||||
// // Axes are (contracted, non-contracted)
|
||||
// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]),
|
||||
// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]),
|
||||
// ) -> Vec<T> //Tensor<T, { R + S - 2 * N }>
|
||||
// where
|
||||
// [(); R - N]:,
|
||||
// [(); S - N]:,
|
||||
// [(); R + S - 2 * N]:,
|
||||
// {
|
||||
// let (lc, lnc) = laxes;
|
||||
// let (rc, rnc) = raxes;
|
||||
|
||||
// // Step 1: Prepare cartesian products of contracted axes
|
||||
// let lc_prod = lc.into_iter().multi_cartesian_product();
|
||||
// let rc_prod = rc.into_iter().multi_cartesian_product();
|
||||
|
||||
// // Step 2: Prepare iterators for non-contracted axes
|
||||
// let lnc_prod = lnc.into_iter().multi_cartesian_product();
|
||||
// let rnc_prod = rnc.into_iter().multi_cartesian_product();
|
||||
|
||||
// // Initialize buffer for the resulting tensor
|
||||
// let mut result_buffer = Vec::new();
|
||||
|
||||
// // Step 3: Iterate over combinations and compute tensor product
|
||||
// for lnc_vals in lnc_prod {
|
||||
// for rnc_vals in rnc_prod.clone() {
|
||||
// let mut sum = T::zero();
|
||||
// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod
|
||||
// .clone()
|
||||
// .zip(rc_prod.clone())
|
||||
// .zip(lnc_vals.iter().zip(rnc_vals.iter()))
|
||||
// {
|
||||
// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) {
|
||||
// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val);
|
||||
// }
|
||||
// }
|
||||
// result_buffer.push(sum);
|
||||
// }
|
||||
// }
|
||||
|
||||
// result_buffer
|
||||
// }
|
||||
|
||||
// // Check that there are no duplicate axes
|
||||
// for (i, laxis) in laxes.iter().enumerate() {
|
||||
// for (j, laxisx) in laxes.iter().enumerate() {
|
||||
// if i != j && laxis == laxisx {
|
||||
// panic!("Duplicate axes");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// for (i, raxis) in raxes.iter().enumerate() {
|
||||
// for (j, raxisx) in raxes.iter().enumerate() {
|
||||
// if i != j && raxis == raxisx {
|
||||
// panic!("Duplicate axes");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// // Check that the sizes of the axes to contract are the same
|
||||
// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) {
|
||||
// // // Check that axes are within bounds
|
||||
// // if lhs.shape()[*laxis] != *raxis {
|
||||
// // println!("laxis: {}, raxis: {}", laxis, raxis);
|
||||
// // panic!("Axis out of bounds");
|
||||
// // }
|
||||
// // }
|
||||
// // 1. Axis iterators
|
||||
// //
|
||||
// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors
|
||||
// // and take the Cartesian product of the iterators.
|
||||
|
||||
// let lcartesian_contract = lhs
|
||||
// .shape()
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(i, _)| laxes.contains(i))
|
||||
// .map(|(i, _)| {
|
||||
// println!("LHS is adding axis: {}", i);
|
||||
// let vals = Axis::new(lhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||
// println!("vals: {:?}", vals);
|
||||
// })
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
// for l in lcartesian_contract.iter() {
|
||||
// let l = l.clone();
|
||||
// println!("{:?}", l);
|
||||
// }
|
||||
// // .multi_cartesian_product();
|
||||
|
||||
// // let cart_product = lcartesian_contract.multi_cartesian_product();
|
||||
|
||||
// // for l in cart_product {
|
||||
// // println!("cartesian_contract l:");
|
||||
// // for r in l {
|
||||
// // println!("l: {:?}", r);
|
||||
// // }
|
||||
// // }
|
||||
|
||||
// let rcartesian_contract = rhs
|
||||
// .shape()
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(i, _)| raxes.contains(i))
|
||||
// .map(|(i, _)| {
|
||||
// println!("RHS is adding axis: {}", i);
|
||||
// let vals = Axis::new(rhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||
// println!("vals: {:?}", vals);
|
||||
// })
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
// for r in rcartesian_contract.iter() {
|
||||
// let r = r.clone();
|
||||
// println!("{:?}", r);
|
||||
// }
|
||||
// // .multi_cartesian_product();
|
||||
|
||||
// // println!("lcartesian_contract: {:?}", lcartesian_contract);
|
||||
// // println!("rcartesian_contract: {:?}", rcartesian_contract);
|
||||
|
||||
// // Initialize buffer for the resulting tensor
|
||||
// // let mut result_buffer = Vec::new();
|
||||
// // let zip_contract = lcartesian_contract.zip(rcartesian_contract);
|
||||
|
||||
// // for (i, (j, k)) in zip_contract.enumerate() {
|
||||
// // let mut sum = T::zero();
|
||||
// // print!("{}. sum = ", i);
|
||||
// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) {
|
||||
// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||
// // // sum = sum + lhsnc * rhsnc;
|
||||
// // // }
|
||||
// // for lhsnc in j.iter() {
|
||||
// // for rhsnc in k.iter() {
|
||||
// // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||
// // sum = sum + *lhsnc * *rhsnc;
|
||||
// // }
|
||||
// // }
|
||||
// // print!("= {}\n", sum);
|
||||
// // result_buffer.push(sum);
|
||||
// // }
|
||||
|
||||
// // Iterate over non-contracted axes combinations
|
||||
// // for (lhsnci, rhsnci) in
|
||||
// // lcartesian_contract.zip(rcartesian_contract)
|
||||
// // {
|
||||
// // let mut sum = T::zero();
|
||||
// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) {
|
||||
// // sum = sum + *lhsnc * *rhsnc;
|
||||
// // }
|
||||
// // // Append the result to the buffer
|
||||
// // result_buffer.push(sum);
|
||||
// // }
|
||||
|
||||
// // TODO! Implement the rest of the function
|
||||
|
||||
// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||
// laxes
|
||||
// .iter()
|
||||
// .map(|axis| *axis)
|
||||
// .collect::<Vec<_>>()
|
||||
// .try_into()
|
||||
// .unwrap(),
|
||||
// );
|
||||
// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
||||
// raxes
|
||||
// .iter()
|
||||
// .map(|axis| *axis)
|
||||
// .collect::<Vec<_>>()
|
||||
// .try_into()
|
||||
// .unwrap(),
|
||||
// );
|
||||
|
||||
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
// let mut shape = Vec::new();
|
||||
// shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||
// shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||
|
||||
// let shape: [usize; R + S - 2 * N] =
|
||||
// shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
// let shape = Shape::new(shape);
|
||||
|
||||
// Tensor::new_with_buffer(shape, vec![])
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
@ -452,130 +266,136 @@ mod tests {
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_tensor_contraction() {
|
||||
// // Define two 2D tensors (matrices)
|
||||
// // Tensor A is 2x3
|
||||
// let tensor_a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
// println!("A: {}", tensor_a);
|
||||
#[test]
|
||||
fn test_axis_iterator_disassemble() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// // Tensor B is 1x3x2
|
||||
// let tensor_b: Tensor<i32, 3> = Tensor::from([[[1, 2], [3, 4], [5, 6]]]);
|
||||
// println!("B: {}", tensor_b);
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
// let contracted_tensor: Tensor<i32, 3> =
|
||||
// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]);
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
// // Expect: [22, 28, 49, 64]
|
||||
// // Check if the contracted tensor is as expected
|
||||
// println!("A: {}", tensor_a);
|
||||
// println!("B: {}", tensor_b);
|
||||
// println!("A * B = {}", contracted_tensor);
|
||||
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||
// }
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
assert_eq!(axis_iter[0].next(), Some(&3.0));
|
||||
assert_eq!(axis_iter[0].next(), None);
|
||||
assert_eq!(axis_iter[1].next(), Some(&2.0));
|
||||
assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
assert_eq!(axis_iter[1].next(), None);
|
||||
}
|
||||
|
||||
axis_iter.for_each(|x| println!("t1: {}", x));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
axis_iter.for_each(|x| println!("t2: {}", x));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
let shape = tensor.shape();
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_3d_tensor_axis_iteration() {
|
||||
// // Create a 3D Tensor with specific values
|
||||
// // Tensor shape is 2x2x2 for simplicity
|
||||
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
#[test]
|
||||
fn test_3d_tensor_axis_iteration() {
|
||||
// Create a 3D Tensor with specific values
|
||||
// Tensor shape is 2x2x2 for simplicity
|
||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// // Axis 0 (Layer-wise):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
// //
|
||||
// // This order suggests that for each "layer" (first level of arrays),
|
||||
// // the iterator goes through all rows and columns. It first completes
|
||||
// // the entire first layer, then moves to the second.
|
||||
// Axis 0 (Layer-wise):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
//
|
||||
// This order suggests that for each "layer" (first level of arrays),
|
||||
// the iterator goes through all rows and columns. It first completes
|
||||
// the entire first layer, then moves to the second.
|
||||
|
||||
// let a0 = Axis::new(t.clone(), 0);
|
||||
// let a0_order = a0.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
let a0 = Axis::new(&t, 0);
|
||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// // Axis 1 (Row-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first row across all layers, then the second row
|
||||
// // across all layers.
|
||||
// Axis 1 (Row-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first row across all layers, then the second row
|
||||
// across all layers.
|
||||
|
||||
// let a1 = Axis::new(t.clone(), 1);
|
||||
// let a1_order = a1.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
let a1 = Axis::new(&t, 1);
|
||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// // Axis 2 (Column-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][1][0] = 3
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][1][0] = 7
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first column across all layers, then the second
|
||||
// // column across all layers.
|
||||
// Axis 2 (Column-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][1][0] = 3
|
||||
// t[1][0][0] = 5
|
||||
// t[1][1][0] = 7
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][1] = 8
|
||||
// [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first column across all layers, then the second
|
||||
// column across all layers.
|
||||
|
||||
// let a2 = Axis::new(t, 2);
|
||||
// let a2_order = a2.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
// }
|
||||
let a2 = Axis::new(&t, 2);
|
||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
}
|
||||
// }
|
||||
|
21
src/index.rs
21
src/index.rs
@ -82,26 +82,29 @@ impl<const R: usize> Idx<R> {
|
||||
|
||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||
assert!(fixed_axis < R, "Axis out of bounds");
|
||||
assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds");
|
||||
assert!(self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds");
|
||||
|
||||
// Try to increment non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] < self.shape.get(i) {
|
||||
if self.indices[i] + 1 < self.shape.get(i) {
|
||||
self.indices[i] += 1;
|
||||
return ;
|
||||
} else {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the fixed axis and reset other axes to 0
|
||||
self.indices[fixed_axis] += 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
|
||||
self.indices[fixed_axis] += 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
return ;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec(&mut self) {
|
||||
|
@ -31,6 +31,7 @@ impl<T> Value for T where
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ std::iter::Sum
|
||||
{
|
||||
}
|
||||
|
||||
|
@ -127,64 +127,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// // Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||
// for &axis in &axis_lhs {
|
||||
// if axis >= R {
|
||||
// panic!(
|
||||
// "Axis {} is out of bounds for the left-hand tensor",
|
||||
// axis
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
// for &axis in &axis_rhs {
|
||||
// if axis >= S {
|
||||
// panic!(
|
||||
// "Axis {} is out of bounds for the right-hand tensor",
|
||||
// axis
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||
// let mut result_buffer = Vec::new();
|
||||
|
||||
// for i in self.iter_over_non_contracted_axes(&axis_lhs) {
|
||||
// for j in rhs.iter_over_non_contracted_axes(&axis_rhs) {
|
||||
// let mut product_sum = T::zero();
|
||||
|
||||
// // Step 3: Contraction over specified axes
|
||||
// for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||
// let idx_l = i[axis_l];
|
||||
// let idx_r = j[axis_r];
|
||||
|
||||
// let value_lhs = self.buffer[self.flat(&idx_l)];
|
||||
// let value_rhs = rhs.buffer[rhs.flat(&idx_r)];
|
||||
// product_sum = product_sum + value_lhs * value_rhs;
|
||||
// }
|
||||
|
||||
// result_buffer.push(product_sum);
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Step 4: Remove contracted dimensions to create new shapes for both tensors
|
||||
// let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs);
|
||||
// let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||
|
||||
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
// let mut new_shape = Vec::new();
|
||||
// new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||
// new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||
|
||||
// let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||
// new_shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
// Tensor {
|
||||
// buffer: result_buffer,
|
||||
// shape: Shape::new(new_shape_array),
|
||||
// }
|
||||
// }
|
||||
|
||||
// Retrieve an element based on a specific axis and index
|
||||
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||
// Convert axis and index to a flat index
|
||||
@ -237,6 +179,23 @@ impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Index<usize> for Tensor<T, R> {
|
||||
type Output = T;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.buffer[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.buffer[index]
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ---- Display ----
|
||||
|
||||
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Print the shape of the tensor
|
||||
|
Loading…
Reference in New Issue
Block a user