diff --git a/Cargo.lock b/Cargo.lock index 6871f48..0cd0294 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,22 +14,39 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "getset" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" -[[package]] -name = "mltensor" -version = "0.1.0" -dependencies = [ - "bytemuck", - "num", - "serde", - "serde_json", -] - [[package]] name = "num" version = "0.4.1" @@ -106,6 +123,30 @@ dependencies = [ "autocfg", ] +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.71" @@ -147,7 +188,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.43", ] [[package]] @@ -161,6 +202,23 @@ dependencies = [ "serde", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.43" @@ -172,8 +230,27 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "tensorc" +version = "0.1.0" +dependencies = [ + "bytemuck", + "getset", + "itertools", + "num", + "serde", + "serde_json", + "static_assertions", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" diff --git a/Cargo.toml b/Cargo.toml index 17c0c54..a6d91d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "mltensor" +name = "tensorc" version = "0.1.0" edition = "2021" @@ -7,6 +7,9 @@ edition = "2021" [dependencies] bytemuck = "1.14.0" +getset = "0.1.2" +itertools = "0.12.0" num = "0.4.1" serde = { version = "1.0.193", features = ["derive"] } -serde_json = "1.0.108" \ No newline at end of file +serde_json = "1.0.108" +static_assertions = "1.1.0" diff --git a/examples/operations.rs b/examples/operations.rs index 6f77676..b1c45ad 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -1,53 +1,66 @@ -use mltensor::*; +#![allow(mixed_script_confusables)] +#![allow(non_snake_case)] +use bytemuck::cast_slice; +use tensorc::*; fn tensor_product() { - println!("Tensor Product\n"); - let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor - let mut tensor2 = Tensor::::from([2]); // 2-element vector + println!("Tensor Product\n"); + let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor2 = Tensor::::from([2]); // 2-element vector - // Fill tensors with some values - tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); - tensor2.buffer_mut().copy_from_slice(&[5, 6]); + // Fill tensors with some values + tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); + tensor2.buffer_mut().copy_from_slice(&[5, 6]); - println!("T1: {}", tensor1); - println!("T2: {}", tensor2); + println!("T1: {}", tensor1); + println!("T2: {}", tensor2); - let product = tensor1.tensor_product(&tensor2); + let product = tensor1.tensor_product(&tensor2); - println!("T1 * T2 = {}", product); + println!("T1 * T2 = {}", product); - // Check shape of the resulting tensor - assert_eq!(product.shape(), Shape::new([2, 2, 2])); + // Check shape of the resulting tensor + assert_eq!(product.shape(), Shape::new([2, 2, 2])); - // Check buffer of the resulting tensor - let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; - assert_eq!(product.buffer(), &expected_buffer); + // Check buffer of the resulting tensor + let expect: &[i32] = + cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]); + assert_eq!(product.buffer(), expect); } fn tensor_contraction() { - println!("Tensor Contraction\n"); - // Create two tensors - let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor - let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor + println!("Tensor Contraction\n"); + // Create two tensors + let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor + let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor - // Specify axes for contraction - let axis_lhs = [1]; // Contract over the second dimension of tensor1 - let axis_rhs = [0]; // Contract over the first dimension of tensor2 + // Specify axes for contraction + let axis_lhs = [1]; // Contract over the second dimension of tensor1 + let axis_rhs = [0]; // Contract over the first dimension of tensor2 - // Perform contraction - let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs); + // Perform contraction + // let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs); - println!("T1: {}", tensor1); - println!("T2: {}", tensor2); - println!("T1 * T2 = {}", result); + // println!("T1: {}", tensor1); + // println!("T2: {}", tensor2); + // println!("T1 * T2 = {}", result); - // Expected result, for example, could be a single number or a new tensor, - // depending on how you defined the contraction operation. - // Assert the result is as expected - // assert_eq!(result, expected_result); + // Expected result, for example, could be a single number or a new tensor, + // depending on how you defined the contraction operation. + // Assert the result is as expected + // assert_eq!(result, expected_result); + + // let Λ = Tensor::::from([ + // [1.0, 0.0, 0.0, 0.0], + // [0.0, 1.0, 0.0 ,0.0], + // [0.0, 0.0, 1.0, 0.0], + // [0.0, 0.0, 0.0, 1.0] + // ]); + + // println!("Λ: {}", Λ); } fn main() { - tensor_product(); - tensor_contraction(); -} \ No newline at end of file + tensor_product(); + tensor_contraction(); +} diff --git a/src/axis.rs b/src/axis.rs new file mode 100644 index 0000000..ca68986 --- /dev/null +++ b/src/axis.rs @@ -0,0 +1,581 @@ +use super::*; +use getset::{Getters, MutGetters}; + +#[derive(Clone, Debug, Getters)] +pub struct Axis<'a, T: Value, const R: usize> { + #[getset(get = "pub")] + tensor: &'a Tensor, + #[getset(get = "pub")] + shape: Shape, + #[getset(get = "pub")] + dim: usize, + #[getset(get = "pub")] + len: usize, +} + +impl<'a, T: Value, const R: usize> Axis<'a, T, R> { + pub fn new(tensor: &'a Tensor, dim: usize) -> Self { + assert!(dim < R, "Axis out of bounds"); + Self { + tensor, + shape: tensor.shape(), + dim, + len: tensor.shape().get(dim), + } + } + + pub fn iter_level(&self, level: usize) -> AxisIterator<'a, T, R> { + assert!(level < self.len, "Level out of bounds"); + let mut index = Idx::new(self.shape.clone(), [0; R]); + index.set_axis(self.dim, level); + AxisIterator::new(self.clone()).set_start(level).set_end(level + 1) + } +} + +#[derive(Clone, Debug, Getters, MutGetters)] +pub struct AxisIterator<'a, T: Value, const R: usize> { + #[getset(get = "pub")] + axis: Axis<'a, T, R>, + #[getset(get = "pub", get_mut = "pub")] + index: Idx, + #[getset(get = "pub")] + end: Option, +} + +impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { + pub fn new(axis: Axis<'a, T, R>) -> Self { + Self { + axis: axis.clone(), + index: Idx::new(axis.shape().clone(), [0; R]), + end: None, + } + } + + pub fn set_start(self, start: usize) -> Self { + assert!(start < self.axis.len, "Start out of bounds"); + let mut index = Idx::new(self.axis.shape().clone(), [0; R]); + index.set_axis(self.axis.dim, start); + Self { + axis: self.axis.clone(), + index, + end: None, + } + } + + pub fn set_end(self, end: usize) -> Self { + assert!(end <= self.axis.len, "End out of bounds"); + Self { + axis: self.axis.clone(), + index: self.index.clone(), + end: Some(end), + } + } + + pub fn set_level(self, level: usize) -> Self { + assert!(level < self.axis.len, "Level out of bounds"); + self.set_start(level).set_end(level + 1) + } + + pub fn disassemble(self) -> Vec { + let mut result = Vec::new(); + for i in 0..self.axis().len { + result.push(Self::new(self.axis.clone()).set_level(i)); + } + result + } + + pub fn axis_max_idx(&self) -> usize { + self.end().unwrap_or(*self.axis().len()) + } + + pub fn axis_idx(&self) -> usize { + self.index().get_axis(*self.axis().dim()) + } + + pub fn axis_dim(&self) -> usize { + self.axis().dim().clone() + } +} + +impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.axis_idx() == self.axis_max_idx() { + return None; + } + // Increment the index along the fixed axis and check if it's within bounds + let result = self.axis().tensor().get(self.index); + let axis_dim = self.axis_dim(); + self.index_mut().inc_axis(axis_dim); + Some(result) + } +} + +impl<'a, T: Value, const R: usize> IntoIterator for Axis<'a, T, R> { + type Item = &'a T; + type IntoIter = AxisIterator<'a, T, R>; + + fn into_iter(self) -> Self::IntoIter { + AxisIterator::new(self) + } +} + +pub fn contract< + 'a, + T: Value + std::fmt::Debug, + const R: usize, + const S: usize, + const N: usize, +>( + lhs: &'a Tensor, + rhs: &'a Tensor, + laxes: [Axis<'a, T, R>; N], + raxes: [Axis<'a, T, S>; N], +) -> Tensor +where + [(); R - N]:, + [(); S - N]:, + [(); R + S - 2 * N]:, +{ + let li = laxes + .clone() + .into_iter() + .map(|axis| axis.into_iter()) + .collect::>(); + let ri = raxes + .clone() + .into_iter() + .map(|axis| axis.into_iter()) + .collect::>(); + + let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap()); + + let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>( + laxes + .iter() + .map(|axis| *axis.dim()) + .collect::>() + .try_into() + .unwrap(), + ); + let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>( + raxes + .iter() + .map(|axis| *axis.dim()) + .collect::>() + .try_into() + .unwrap(), + ); + + // Step 5: Concatenate the shapes to form the shape of the resultant tensor + let mut shape = Vec::new(); + shape.extend_from_slice(&lhs_shape_reduced.as_array()); + shape.extend_from_slice(&rhs_shape_reduced.as_array()); + + let shape: [usize; R + S - 2 * N] = + shape.try_into().expect("Failed to create shape array"); + + let shape = Shape::new(shape); + + Tensor::new_with_buffer(shape, result) +} + +pub fn contract_axes< + 'a, + T: Value + std::fmt::Debug, + const R: usize, + const S: usize, + const N: usize, +>( + laxes: [AxisIterator<'a, T, R>; N], + raxes: [AxisIterator<'a, T, S>; N], +) -> Vec +where + [(); R - N]:, + [(); S - N]:, + // [(); R + S - 2 * N]:, +{ + let mut result = Vec::new(); + + // Disassemble each axis iterator into iterators over each level + let laxes: Vec>> = + laxes.into_iter().map(|axis| axis.disassemble()).collect(); + let raxes: Vec>> = + raxes.into_iter().map(|axis| axis.disassemble()).collect(); + + // Iterate over all combinations of levels for contracted axes + + let lcart = laxes.into_iter().multi_cartesian_product(); + let rcart = raxes.into_iter().multi_cartesian_product(); + let cartesian = lcart.zip(rcart); + + for (i, (l, r)) in cartesian.enumerate() { + println!("{} cartesian:", i); + let mut sum = T::zero(); + + let level = l.into_iter().zip(r.into_iter()); + + for (i, (l, r)) in level.enumerate() { + // let mut sum = T::zero(); + println!("{} level:", i); + let value = l.into_iter().zip(r.into_iter()); + + for (i, (l, r)) in value.enumerate() { + println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r); + sum = sum + *l * *r; + } + // result.push(sum); + } + + result.push(sum); + } + + result +} + +// pub fn contract_axes< +// 'a, +// T: Value + std::fmt::Debug, +// const R: usize, +// const S: usize, +// const N: usize, +// >( +// // Axes are (contracted, non-contracted) +// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]), +// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]), +// ) -> Vec //Tensor +// where +// [(); R - N]:, +// [(); S - N]:, +// [(); R + S - 2 * N]:, +// { +// let (lc, lnc) = laxes; +// let (rc, rnc) = raxes; + +// // Step 1: Prepare cartesian products of contracted axes +// let lc_prod = lc.into_iter().multi_cartesian_product(); +// let rc_prod = rc.into_iter().multi_cartesian_product(); + +// // Step 2: Prepare iterators for non-contracted axes +// let lnc_prod = lnc.into_iter().multi_cartesian_product(); +// let rnc_prod = rnc.into_iter().multi_cartesian_product(); + +// // Initialize buffer for the resulting tensor +// let mut result_buffer = Vec::new(); + +// // Step 3: Iterate over combinations and compute tensor product +// for lnc_vals in lnc_prod { +// for rnc_vals in rnc_prod.clone() { +// let mut sum = T::zero(); +// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod +// .clone() +// .zip(rc_prod.clone()) +// .zip(lnc_vals.iter().zip(rnc_vals.iter())) +// { +// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) { +// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val); +// } +// } +// result_buffer.push(sum); +// } +// } + +// result_buffer +// } + +// // Check that there are no duplicate axes +// for (i, laxis) in laxes.iter().enumerate() { +// for (j, laxisx) in laxes.iter().enumerate() { +// if i != j && laxis == laxisx { +// panic!("Duplicate axes"); +// } +// } +// } + +// for (i, raxis) in raxes.iter().enumerate() { +// for (j, raxisx) in raxes.iter().enumerate() { +// if i != j && raxis == raxisx { +// panic!("Duplicate axes"); +// } +// } +// } +// // Check that the sizes of the axes to contract are the same +// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) { +// // // Check that axes are within bounds +// // if lhs.shape()[*laxis] != *raxis { +// // println!("laxis: {}, raxis: {}", laxis, raxis); +// // panic!("Axis out of bounds"); +// // } +// // } +// // 1. Axis iterators +// // +// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors +// // and take the Cartesian product of the iterators. + +// let lcartesian_contract = lhs +// .shape() +// .iter() +// .enumerate() +// .filter(|(i, _)| laxes.contains(i)) +// .map(|(i, _)| { +// println!("LHS is adding axis: {}", i); +// let vals = Axis::new(lhs.clone(), i).into_iter().collect::>(); +// println!("vals: {:?}", vals); +// }) +// .collect::>(); + +// for l in lcartesian_contract.iter() { +// let l = l.clone(); +// println!("{:?}", l); +// } +// // .multi_cartesian_product(); + +// // let cart_product = lcartesian_contract.multi_cartesian_product(); + +// // for l in cart_product { +// // println!("cartesian_contract l:"); +// // for r in l { +// // println!("l: {:?}", r); +// // } +// // } + +// let rcartesian_contract = rhs +// .shape() +// .iter() +// .enumerate() +// .filter(|(i, _)| raxes.contains(i)) +// .map(|(i, _)| { +// println!("RHS is adding axis: {}", i); +// let vals = Axis::new(rhs.clone(), i).into_iter().collect::>(); +// println!("vals: {:?}", vals); +// }) +// .collect::>(); + +// for r in rcartesian_contract.iter() { +// let r = r.clone(); +// println!("{:?}", r); +// } +// // .multi_cartesian_product(); + +// // println!("lcartesian_contract: {:?}", lcartesian_contract); +// // println!("rcartesian_contract: {:?}", rcartesian_contract); + +// // Initialize buffer for the resulting tensor +// // let mut result_buffer = Vec::new(); +// // let zip_contract = lcartesian_contract.zip(rcartesian_contract); + +// // for (i, (j, k)) in zip_contract.enumerate() { +// // let mut sum = T::zero(); +// // print!("{}. sum = ", i); +// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) { +// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum); +// // // sum = sum + lhsnc * rhsnc; +// // // } +// // for lhsnc in j.iter() { +// // for rhsnc in k.iter() { +// // print!("{} * {} + {} ", lhsnc, rhsnc, sum); +// // sum = sum + *lhsnc * *rhsnc; +// // } +// // } +// // print!("= {}\n", sum); +// // result_buffer.push(sum); +// // } + +// // Iterate over non-contracted axes combinations +// // for (lhsnci, rhsnci) in +// // lcartesian_contract.zip(rcartesian_contract) +// // { +// // let mut sum = T::zero(); +// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) { +// // sum = sum + *lhsnc * *rhsnc; +// // } +// // // Append the result to the buffer +// // result_buffer.push(sum); +// // } + +// // TODO! Implement the rest of the function + +// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>( +// laxes +// .iter() +// .map(|axis| *axis) +// .collect::>() +// .try_into() +// .unwrap(), +// ); +// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>( +// raxes +// .iter() +// .map(|axis| *axis) +// .collect::>() +// .try_into() +// .unwrap(), +// ); + +// // Step 5: Concatenate the shapes to form the shape of the resultant tensor +// let mut shape = Vec::new(); +// shape.extend_from_slice(&lhs_shape_reduced.as_array()); +// shape.extend_from_slice(&rhs_shape_reduced.as_array()); + +// let shape: [usize; R + S - 2 * N] = +// shape.try_into().expect("Failed to create shape array"); + +// let shape = Shape::new(shape); + +// Tensor::new_with_buffer(shape, vec![]) +// } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tensor_contraction_simple() { + // Define two 2D tensors (matrices) + // Tensor A is 2x3 + let a: Tensor = Tensor::from([[1, 2], [3, 4]]); + + // Tensor B is 1x3x2 + let b: Tensor = Tensor::from([[1, 2], [3, 4]]); + + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let contracted_tensor: Tensor = + contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]); + + // Expect: [22, 28, 49, 64] + // Check if the contracted tensor is as expected + println!("A: {}", a); + println!("B: {}", b); + println!("A * B = {}", contracted_tensor); + // println!("Expected: {}", Tensor::from([[22, 28], [49, 64]])); + } +} + +// #[test] +// fn test_tensor_contraction() { +// // Define two 2D tensors (matrices) +// // Tensor A is 2x3 +// let tensor_a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); +// println!("A: {}", tensor_a); + +// // Tensor B is 1x3x2 +// let tensor_b: Tensor = Tensor::from([[[1, 2], [3, 4], [5, 6]]]); +// println!("B: {}", tensor_b); + +// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) +// let contracted_tensor: Tensor = +// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]); + +// // Expect: [22, 28, 49, 64] +// // Check if the contracted tensor is as expected +// println!("A: {}", tensor_a); +// println!("B: {}", tensor_b); +// println!("A * B = {}", contracted_tensor); +// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]])); +// } + + #[test] + fn test_axis_iterator() { + // Creating a 2x2 Tensor for testing + let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); + + // Testing iteration over the first axis (axis = 0) + let axis = Axis::new(&tensor, 0); + + let mut axis_iter = axis.into_iter(); + + axis_iter.for_each(|x| println!("t1: {}", x)); + // assert_eq!(axis_iter.next(), Some(&1.0)); + // assert_eq!(axis_iter.next(), Some(&2.0)); + // assert_eq!(axis_iter.next(), Some(&3.0)); + // assert_eq!(axis_iter.next(), Some(&4.0)); + + // Resetting the iterator for the second axis (axis = 1) + let axis = Axis::new(&tensor, 1); + + let mut axis_iter = axis.into_iter(); + axis_iter.for_each(|x| println!("t2: {}", x)); + // assert_eq!(axis_iter.next(), Some(&1.0)); + // assert_eq!(axis_iter.next(), Some(&3.0)); + // assert_eq!(axis_iter.next(), Some(&2.0)); + // assert_eq!(axis_iter.next(), Some(&4.0)); + + let shape = tensor.shape(); + + let mut a: Idx<2> = (shape, [0, 0]).into(); + let b: Idx<2> = (shape, [1, 1]).into(); + + while a <= b { + println!("a: {}", a); + a.inc(); + } + } + +// #[test] +// fn test_3d_tensor_axis_iteration() { +// // Create a 3D Tensor with specific values +// // Tensor shape is 2x2x2 for simplicity +// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + +// // Axis 0 (Layer-wise): +// // +// // t[0][0][0] = 1 +// // t[0][0][1] = 2 +// // t[0][1][0] = 3 +// // t[0][1][1] = 4 +// // t[1][0][0] = 5 +// // t[1][0][1] = 6 +// // t[1][1][0] = 7 +// // t[1][1][1] = 8 +// // [1, 2, 3, 4, 5, 6, 7, 8] +// // +// // This order suggests that for each "layer" (first level of arrays), +// // the iterator goes through all rows and columns. It first completes +// // the entire first layer, then moves to the second. + +// let a0 = Axis::new(t.clone(), 0); +// let a0_order = a0.into_iter().collect::>(); +// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); + +// // Axis 1 (Row-wise within each layer): +// // +// // t[0][0][0] = 1 +// // t[0][0][1] = 2 +// // t[1][0][0] = 5 +// // t[1][0][1] = 6 +// // t[0][1][0] = 3 +// // t[0][1][1] = 4 +// // t[1][1][0] = 7 +// // t[1][1][1] = 8 +// // [1, 2, 5, 6, 3, 4, 7, 8] +// // +// // This indicates that within each "layer", the iterator first +// // completes the first row across all layers, then the second row +// // across all layers. + +// let a1 = Axis::new(t.clone(), 1); +// let a1_order = a1.into_iter().collect::>(); +// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); + +// // Axis 2 (Column-wise within each layer): +// // +// // t[0][0][0] = 1 +// // t[0][1][0] = 3 +// // t[1][0][0] = 5 +// // t[1][1][0] = 7 +// // t[0][0][1] = 2 +// // t[0][1][1] = 4 +// // t[1][0][1] = 6 +// // t[1][1][1] = 8 +// // [1, 3, 5, 7, 2, 4, 6, 8] +// // +// // This indicates that within each "layer", the iterator first +// // completes the first column across all layers, then the second +// // column across all layers. + +// let a2 = Axis::new(t, 2); +// let a2_order = a2.into_iter().collect::>(); +// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); +// } +// } diff --git a/src/index.rs b/src/index.rs index 7bc71e0..d938e21 100644 --- a/src/index.rs +++ b/src/index.rs @@ -51,6 +51,9 @@ impl Idx { /// `true` if the increment does not overflow and is still within bounds; /// `false` if it overflows, indicating the end of the tensor. pub fn inc(&mut self) -> bool { + if self.indices[0] >= self.shape.get(0) { + return false; + } let mut carry = 1; for (i, &dim_size) in self.indices.iter_mut().zip(&self.shape.as_array()).rev() @@ -67,13 +70,40 @@ impl Idx { // If carry is still 1 after the loop, it means we've incremented past the last dimension if carry == 1 { - // Set the index to an invalid state (e.g., all indices to their max values) + // Set the index to an invalid state to indicate the end of the iteration indicated + // by setting the first index to the size of the first dimension self.indices[0] = self.shape.as_array()[0]; return true; // Indicate that the iteration is complete } false } + // fn inc_axis + + pub fn inc_axis(&mut self, fixed_axis: usize) { + assert!(fixed_axis < R, "Axis out of bounds"); + assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds"); + + // Try to increment non-fixed axes + for i in (0..R).rev() { + if i != fixed_axis { + if self.indices[i] < self.shape.get(i) { + self.indices[i] += 1; + } else { + self.indices[i] = 0; + } + } + } + + // Increment the fixed axis and reset other axes to 0 + self.indices[fixed_axis] += 1; + for i in 0..R { + if i != fixed_axis { + self.indices[i] = 0; + } + } + } + pub fn dec(&mut self) { // Check if already at the start if self.indices.iter().all(|&i| i == 0) { @@ -95,6 +125,40 @@ impl Idx { } } + pub fn dec_axis(&mut self, fixed_axis: usize) -> bool { + // Check if the fixed axis index is already in an invalid state + if self.indices[fixed_axis] == self.shape.get(fixed_axis) { + return false; + } + + // Try to decrement non-fixed axes + for i in (0..R).rev() { + if i != fixed_axis { + if self.indices[i] > 0 { + self.indices[i] -= 1; + return true; + } else { + self.indices[i] = self.shape.get(i) - 1; + } + } + } + + // Decrement the fixed axis if possible and reset other axes to their max + if self.indices[fixed_axis] > 0 { + self.indices[fixed_axis] -= 1; + for i in 0..R { + if i != fixed_axis { + self.indices[i] = self.shape.get(i) - 1; + } + } + } else { + // Fixed axis already at minimum, set to invalid state + self.indices[fixed_axis] = self.shape.get(fixed_axis); + } + + true + } + /// Converts the multi-dimensional index to a flat index. /// /// This method calculates the flat index corresponding to the multi-dimensional index @@ -136,6 +200,56 @@ impl Idx { }) .0 } + + // pub fn inc_axis(&mut self, axis: usize) -> bool { + // if axis >= R { + // // Axis is out of bounds + // return false; + // } + + // let dim_size = self.shape.get(axis); + // if self.indices[axis] + 1 < dim_size { + // // Increment the index if it's not at its maximum + // self.indices[axis] += 1; + // true + // } else { + // // Index is at its maximum for this axis + // false + // } + // } + + pub fn set_axis(&mut self, axis: usize, value: usize) { + assert!(axis < R, "Axis out of bounds"); + // assert!(value < self.shape.get(axis), "Value out of bounds"); + self.indices[axis] = value; + } + + pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool { + assert!(axis < R, "Axis out of bounds"); + if value < self.shape.get(axis) { + self.indices[axis] = value; + true + } else { + false + } + } + + pub fn get_axis(&self, axis: usize) -> usize { + assert!(axis < R, "Axis out of bounds"); + self.indices[axis] + } + + pub fn indices(&self) -> &[usize; R] { + &self.indices + } + + pub fn shape(&self) -> Shape { + self.shape + } + + pub fn shape_mut(&mut self) -> &mut Shape { + &mut self.shape + } } // --- blanket impls --- @@ -194,6 +308,12 @@ impl From> for Idx { } } +impl From> for Idx { + fn from(tensor: Tensor) -> Self { + Self::zero(tensor.shape()) + } +} + impl std::fmt::Display for Idx { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; @@ -249,3 +369,45 @@ impl Sub for Idx { } } } + +// ---- Iterator ---- + +pub struct IdxIterator { + current: Idx, + end: bool, +} + +impl IdxIterator { + pub fn new(shape: Shape) -> Self { + Self { + current: Idx::zero(shape), + end: false, + } + } +} + +impl Iterator for IdxIterator { + type Item = Idx; + + fn next(&mut self) -> Option { + if self.end { + return None; + } + + let result = self.current; + self.end = self.current.inc(); + Some(result) + } +} + +impl IntoIterator for Idx { + type Item = Idx; + type IntoIter = IdxIterator; + + fn into_iter(self) -> Self::IntoIter { + IdxIterator { + current: self, + end: false, + } + } +} diff --git a/src/lib.rs b/src/lib.rs index be92164..0f363cc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -2,6 +2,7 @@ #![feature(generic_const_exprs)] pub mod index; pub mod shape; +pub mod axis; pub mod tensor; pub use index::Idx; @@ -11,6 +12,10 @@ pub use shape::Shape; pub use std::fmt::{Display, Formatter, Result as FmtResult}; use std::ops::{Index, IndexMut}; pub use tensor::{Tensor, TensorIterator}; +pub use static_assertions::const_assert; +pub use itertools::Itertools; +pub use std::sync::Arc; +pub use axis::*; pub trait Value: Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> diff --git a/src/tensor.rs b/src/tensor.rs index d8c78f6..0693d7c 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -22,12 +22,28 @@ impl Tensor { } impl Tensor { - pub fn new(shape: Shape) -> Self { - let total_size: usize = shape.iter().product(); + pub fn new(shape: Shape) -> Self { + // Handle rank 0 tensor (scalar) as a special case + let total_size = if R == 0 { + // A rank 0 tensor should still have a buffer with one element + 1 + } else { + // For tensors of rank 1 or higher, calculate the total size normally + shape.iter().product() + }; + let buffer = vec![T::zero(); total_size]; Self { buffer, shape } } + pub fn new_with_buffer(shape: Shape, buffer: Vec) -> Self { + Self { buffer, shape } + } + + pub fn idx(&self) -> Idx { + Idx::from(self.clone()) + } + pub fn shape(&self) -> Shape { self.shape } @@ -111,93 +127,71 @@ impl Tensor { } } - // `self_dims` and `other_dims` specify the dimensions to contract over - pub fn contract( - &self, - rhs: &Tensor, - axis_lhs: [usize; NAXL], - axis_rhs: [usize; NAXR], - ) -> Tensor - where - [(); R - NAXL]:, - [(); S - NAXR]:, - { - // Step 1: Validate the axes for both tensors to ensure they are within bounds - for &axis in &axis_lhs { - if axis >= R { - panic!( - "Axis {} is out of bounds for the left-hand tensor", - axis - ); - } - } + // // Step 1: Validate the axes for both tensors to ensure they are within bounds + // for &axis in &axis_lhs { + // if axis >= R { + // panic!( + // "Axis {} is out of bounds for the left-hand tensor", + // axis + // ); + // } + // } - for &axis in &axis_rhs { - if axis >= S { - panic!( - "Axis {} is out of bounds for the right-hand tensor", - axis - ); - } - } + // for &axis in &axis_rhs { + // if axis >= S { + // panic!( + // "Axis {} is out of bounds for the right-hand tensor", + // axis + // ); + // } + // } - // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions - let mut result_buffer = Vec::new(); + // // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions + // let mut result_buffer = Vec::new(); - for i in 0..self.shape.size() { - for j in 0..rhs.shape.size() { - // Debug: Print indices being processed - println!("Processing Indices: lhs = {}, rhs = {}", i, j); + // for i in self.iter_over_non_contracted_axes(&axis_lhs) { + // for j in rhs.iter_over_non_contracted_axes(&axis_rhs) { + // let mut product_sum = T::zero(); - if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) { - let mut product_sum = T::zero(); + // // Step 3: Contraction over specified axes + // for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) { + // let idx_l = i[axis_l]; + // let idx_r = j[axis_r]; - // Debug: Print axes of contraction - println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs); + // let value_lhs = self.buffer[self.flat(&idx_l)]; + // let value_rhs = rhs.buffer[rhs.flat(&idx_r)]; + // product_sum = product_sum + value_lhs * value_rhs; + // } - for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) { - // Debug: Print values being multiplied - let value_lhs = self.get_by_axis(axis_l, i).unwrap(); - let value_rhs = rhs.get_by_axis(axis_r, j).unwrap(); - println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs); + // result_buffer.push(product_sum); + // } + // } - product_sum = product_sum + value_lhs * value_rhs; - } + // // Step 4: Remove contracted dimensions to create new shapes for both tensors + // let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs); + // let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs); - // Debug: Print the product sum for the current indices - println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum); + // // Step 5: Concatenate the shapes to form the shape of the resultant tensor + // let mut new_shape = Vec::new(); + // new_shape.extend_from_slice(&new_shape_lhs.as_array()); + // new_shape.extend_from_slice(&new_shape_rhs.as_array()); - result_buffer.push(product_sum); - } - } - } + // let new_shape_array: [usize; R + S - NAXL - NAXR] = + // new_shape.try_into().expect("Failed to create shape array"); - // Step 3: Remove contracted dimensions to create new shapes for both tensors - let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs); - let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs); - - // Step 4: Concatenate the shapes to form the shape of the resultant tensor - let mut new_shape = Vec::new(); - - new_shape.extend_from_slice(&new_shape_lhs.as_array()); - new_shape.extend_from_slice(&new_shape_rhs.as_array()); - - let new_shape_array: [usize; R + S - NAXL - NAXR] = - new_shape.try_into().expect("Failed to create shape array"); - - Tensor { - buffer: result_buffer, - shape: Shape::new(new_shape_array), - } - } + // Tensor { + // buffer: result_buffer, + // shape: Shape::new(new_shape_array), + // } + // } // Retrieve an element based on a specific axis and index pub fn get_by_axis(&self, axis: usize, index: usize) -> Option { // Convert axis and index to a flat index let flat_index = self.axis_to_flat_index(axis, index); - if flat_index >= self.buffer.len() { - return None; - } + if flat_index >= self.buffer.len() { + return None; + } Some(self.buffer[flat_index]) } @@ -214,7 +208,7 @@ impl Tensor { // Calculate the stride for each dimension and accumulate the flat index for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() { - println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); + println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); if i > axis { stride *= dim_size; } else if i == axis { @@ -366,3 +360,23 @@ impl From<[[T; X]; Y]> tensor } } + +impl + From<[[[T; X]; Y]; Z]> for Tensor +{ + fn from(array: [[[T; X]; Y]; Z]) -> Self { + let shape = Shape::new([Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, plane) in array.iter().enumerate() { + for (j, row) in plane.iter().enumerate() { + for (k, &elem) in row.iter().enumerate() { + buffer[i * X * Y + j * X + k] = elem; + } + } + } + + tensor + } +}