diff --git a/examples/operations.rs b/examples/operations.rs index 2b50e1b..5ebb0a9 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -6,7 +6,7 @@ use manifold::*; fn tensor_product() { println!("Tensor Product\n"); - let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor let mut tensor2 = Tensor::::from([2]); // 2-element vector // Fill tensors with some values @@ -57,8 +57,8 @@ fn test_tensor_contraction_23x32() { } fn test_tensor_contraction_rank3() { - let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); - let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]); + let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]); let contracted_tensor = contract((&a, [2]), (&b, [0])); println!("a: {}", a); @@ -73,6 +73,10 @@ fn test_tensor_contraction_rank3() { fn transpose() { let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let b = tensor!( + [[1, 2, 3], + [4, 5, 6]] + ); // let iter = a.idx().iter_transposed([1, 0]); diff --git a/src/axis.rs b/src/axis.rs index 8fbca71..db63940 100644 --- a/src/axis.rs +++ b/src/axis.rs @@ -75,18 +75,6 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { self.set_start(level).set_end(level + 1) } - // pub fn disassemble(self) -> Vec { - // let mut result = Vec::new(); - // for i in 0..self.axis().len() { - // result.push(Self::new(self.axis()).set_level(i)); - // } - // result - // } - - // pub fn disassemble(&'a self) -> impl Iterator + 'a { - // (0..self.axis().len()).map(move |i| Self::new(self.axis()).set_level(i)) - // } - pub fn level(&'a self, level: usize) -> impl Iterator + 'a { Self::new(self.axis()).set_level(level) } diff --git a/src/index.rs b/src/index.rs index 1026d3f..cde85cd 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,7 +1,7 @@ use super::*; +use getset::{Getters, MutGetters}; use std::cmp::Ordering; use std::ops::{Add, Sub}; -use getset::{Getters, MutGetters}; #[derive(Clone, Copy, Debug, Getters, MutGetters)] pub struct Idx<'a, const R: usize> { @@ -32,7 +32,10 @@ impl<'a, const R: usize> Idx<'a, R> { if !shape.check_indices(indices) { panic!("indices out of bounds"); } - Self { indices, shape: shape } + Self { + indices, + shape: shape, + } } pub fn is_zero(&self) -> bool { @@ -54,7 +57,7 @@ impl<'a, const R: usize> Idx<'a, R> { /// `true` if the increment does not overflow and is still within bounds; /// `false` if it overflows, indicating the end of the tensor. pub fn inc(&mut self) -> bool { - if self.indices[0] >= self.shape.get(0) { + if self.indices()[0] >= self.shape().get(0) { return false; } let mut carry = 1; @@ -84,55 +87,45 @@ impl<'a, const R: usize> Idx<'a, R> { // fn inc_axis pub fn inc_axis(&mut self, fixed_axis: usize) { - assert!(fixed_axis < R, "Axis out of bounds"); - assert!(self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds"); + assert!(fixed_axis < R, "Axis out of bounds"); + assert!( + self.indices()[fixed_axis] < self.shape().get(fixed_axis), + "Index out of bounds" + ); - // Try to increment non-fixed axes + // Try to increment non-fixed axes for i in (0..R).rev() { if i != fixed_axis { - if self.indices[i] + 1 < self.shape.get(i) { + if self.indices[i] + 1 < self.shape.get(i) { self.indices[i] += 1; - return ; + return; } else { self.indices[i] = 0; } } } - if self.indices[fixed_axis] < self.shape.get(fixed_axis) { - self.indices[fixed_axis] += 1; - for i in 0..R { - if i != fixed_axis { - self.indices[i] = 0; - } - } - return ; - } + if self.indices[fixed_axis] < self.shape.get(fixed_axis) { + self.indices[fixed_axis] += 1; + for i in 0..R { + if i != fixed_axis { + self.indices[i] = 0; + } + } + return; + } } - // pub fn inc_transposed(&mut self, order: [usize; R]) -> bool { - // // Iterate over axes in the specified order - // for &axis in order.iter().rev() { - // if self.indices[axis] + 1 < self.shape.get(axis) { - // self.indices[axis] += 1; - // return true; - // } else { - // self.indices[axis] = 0; - // } - // } - // false - // } - - pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool { - if self.indices[order[0]] >= self.shape.get(order[0]) { + pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool { + if self.indices()[order[0]] >= self.shape().get(order[0]) { return false; } + let mut carry = 1; - for i in - order.iter().rev() - { - let dim_size = self.shape().get(*i); - let i = self.index_mut(*i); + + for i in order.iter().rev() { + let dim_size = self.shape().get(*i); + let i = self.index_mut(*i); if carry == 1 { *i += 1; if *i >= dim_size { @@ -144,9 +137,10 @@ impl<'a, const R: usize> Idx<'a, R> { } if carry == 1 { - self.indices[order[0]] = self.shape.as_array()[order[0]]; + self.indices_mut()[order[0]] = self.shape().get(order[0]); return true; } + false } @@ -171,7 +165,7 @@ impl<'a, const R: usize> Idx<'a, R> { } } - pub fn dec_axis(&mut self, fixed_axis: usize) -> bool { + pub fn dec_axis(&mut self, fixed_axis: usize) -> bool { // Check if the fixed axis index is already in an invalid state if self.indices[fixed_axis] == self.shape.get(fixed_axis) { return false; @@ -202,29 +196,29 @@ impl<'a, const R: usize> Idx<'a, R> { self.indices[fixed_axis] = self.shape.get(fixed_axis); } - true + true } - pub fn dec_transposed(&mut self, order: [usize; R]) { - // Iterate over the axes in the specified order - for &axis in &order { - // Try to decrement the current axis - if self.indices[axis] > 0 { - self.indices[axis] -= 1; - // Reset all preceding axes in the order to their maximum - for &prev_axis in &order { - if prev_axis == axis { - break; - } - self.indices[prev_axis] = self.shape.get(prev_axis) - 1; - } - return; - } - } + pub fn dec_transposed(&mut self, order: [usize; R]) { + // Iterate over the axes in the specified order + for &axis in &order { + // Try to decrement the current axis + if self.indices[axis] > 0 { + self.indices[axis] -= 1; + // Reset all preceding axes in the order to their maximum + for &prev_axis in &order { + if prev_axis == axis { + break; + } + self.indices[prev_axis] = self.shape.get(prev_axis) - 1; + } + return; + } + } - // If no axis can be decremented, set the first axis in the order to indicate overflow - self.indices[order[0]] = self.shape.get(order[0]); - } + // If no axis can be decremented, set the first axis in the order to indicate overflow + self.indices[order[0]] = self.shape.get(order[0]); + } /// Converts the multi-dimensional index to a flat index. /// @@ -289,9 +283,12 @@ impl<'a, const R: usize> Idx<'a, R> { self.indices[axis] } - pub fn iter_transposed(&self, order: [usize; R]) -> IdxTransposedIterator<'a, R> { - IdxTransposedIterator::new(self.shape(), order) - } + pub fn iter_transposed( + &self, + order: [usize; R], + ) -> IdxTransposedIterator<'a, R> { + IdxTransposedIterator::new(self.shape(), order) + } } // --- blanket impls --- @@ -456,7 +453,7 @@ impl<'a, const R: usize> IntoIterator for Idx<'a, R> { pub struct IdxTransposedIterator<'a, const R: usize> { current: Idx<'a, R>, - order: [usize; R], + order: [usize; R], end: bool, } @@ -465,7 +462,7 @@ impl<'a, const R: usize> IdxTransposedIterator<'a, R> { Self { current: Idx::zero(shape), end: false, - order, + order, } } } diff --git a/src/lib.rs b/src/lib.rs index e3745e1..87bdcd5 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -36,6 +36,13 @@ impl Value for T where { } +#[macro_export] +macro_rules! tensor { + ($array:expr) => { + Tensor::from($array) + }; +} + // ---- Tests ---- #[cfg(test)] @@ -45,7 +52,7 @@ mod tests { #[test] fn test_tensor_product() { - let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor let mut tensor2 = Tensor::::from([2]); // 2-element vector // Fill tensors with some values diff --git a/src/tensor.rs b/src/tensor.rs index c9f8d69..b470826 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,6 +1,7 @@ use super::*; use crate::error::*; use getset::{Getters, MutGetters}; +use std::fmt; #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)] pub struct Tensor { @@ -45,8 +46,6 @@ impl Tensor { } pub fn transpose(self, order: [usize; R]) -> Result { - // let shape = self.shape().reorder(order); - let buffer = Idx::from(self.shape()).iter_transposed(order) .map(|index| { println!("index: {}", index); @@ -220,8 +219,6 @@ impl IndexMut for Tensor { // ---- Display ---- -use std::fmt; - impl Tensor where T: fmt::Display + Clone, @@ -333,11 +330,27 @@ impl From> for Tensor { } } -impl From<[usize; R]> for Tensor { - fn from(shape: [usize; R]) -> Self { - let shape = Shape::new(shape); - Self::new(shape.into()) - } +impl From for Tensor { + fn from(value: T) -> Self { + let shape = Shape::new([]); + let mut tensor = Tensor::new(shape); + tensor.buffer_mut()[0] = value; + tensor + } +} + +impl From<[T; X]> for Tensor { + fn from(array: [T; X]) -> Self { + let shape = Shape::new([X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, &elem) in array.iter().enumerate() { + buffer[i] = elem; + } + + tensor + } } impl From<[[T; X]; Y]> @@ -377,3 +390,75 @@ impl tensor } } + +impl + From<[[[[T; X]; Y]; Z]; W]> for Tensor +{ + fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { + let shape = Shape::new([W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperplane) in array.iter().enumerate() { + for (j, plane) in hyperplane.iter().enumerate() { + for (k, row) in plane.iter().enumerate() { + for (l, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem; + } + } + } + } + + tensor + } +} + +impl + From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor +{ + fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { + let shape = Shape::new([V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperhyperplane) in array.iter().enumerate() { + for (j, hyperplane) in hyperhyperplane.iter().enumerate() { + for (k, plane) in hyperplane.iter().enumerate() { + for (l, row) in plane.iter().enumerate() { + for (m, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem; + } + } + } + } + } + + tensor + } +} + +impl + From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor +{ + fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { + let shape = Shape::new([U, V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperhyperhyperplane) in array.iter().enumerate() { + for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() { + for (k, hyperplane) in hyperhyperplane.iter().enumerate() { + for (l, plane) in hyperplane.iter().enumerate() { + for (m, row) in plane.iter().enumerate() { + for (n, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem; + } + } + } + } + } + } + + tensor + } +}