commit b17264ba188cdd0b023d072e7343d880b281552b Author: Julius Koskela Date: Tue Dec 26 01:47:13 2023 +0200 🚀 WIP implementation of N-rank Tensor Signed-off-by: Julius Koskela diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ea8c4bf --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +/target diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..6871f48 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,179 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bytemuck" +version = "1.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" + +[[package]] +name = "itoa" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" + +[[package]] +name = "mltensor" +version = "0.1.0" +dependencies = [ + "bytemuck", + "num", + "serde", + "serde_json", +] + +[[package]] +name = "num" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + +[[package]] +name = "num-bigint" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" +dependencies = [ + "autocfg", + "num-traits", +] + +[[package]] +name = "num-iter" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0" +dependencies = [ + "autocfg", + "num-bigint", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "proc-macro2" +version = "1.0.71" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.33" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "ryu" +version = "1.0.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" + +[[package]] +name = "serde" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.193" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.108" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "syn" +version = "2.0.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..17c0c54 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "mltensor" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bytemuck = "1.14.0" +num = "0.4.1" +serde = { version = "1.0.193", features = ["derive"] } +serde_json = "1.0.108" \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..b3473c2 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2023 Julius Koskela + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/README.md b/README.md new file mode 100644 index 0000000..22e0174 --- /dev/null +++ b/README.md @@ -0,0 +1,22 @@ +# Tensorc + +```rust +// Create two tensors with different ranks and shapes +let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor +let mut tensor2 = Tensor::::from([2]); // 2-element vector + +// Fill tensors with some values +tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); +tensor2.buffer_mut().copy_from_slice(&[5, 6]); + +// Calculate tensor product +let product = tensor1.tensor_product(&tensor2); + +println!("T1 * T2 = {}", product); + +// Check shape of the resulting tensor +assert_eq!(product.shape(), Shape::new([2, 2, 2])); + +// Check buffer of the resulting tensor +assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]); +``` diff --git a/examples/operations.rs b/examples/operations.rs new file mode 100644 index 0000000..6f77676 --- /dev/null +++ b/examples/operations.rs @@ -0,0 +1,53 @@ +use mltensor::*; + +fn tensor_product() { + println!("Tensor Product\n"); + let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor2 = Tensor::::from([2]); // 2-element vector + + // Fill tensors with some values + tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); + tensor2.buffer_mut().copy_from_slice(&[5, 6]); + + println!("T1: {}", tensor1); + println!("T2: {}", tensor2); + + let product = tensor1.tensor_product(&tensor2); + + println!("T1 * T2 = {}", product); + + // Check shape of the resulting tensor + assert_eq!(product.shape(), Shape::new([2, 2, 2])); + + // Check buffer of the resulting tensor + let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; + assert_eq!(product.buffer(), &expected_buffer); +} + +fn tensor_contraction() { + println!("Tensor Contraction\n"); + // Create two tensors + let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor + let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor + + // Specify axes for contraction + let axis_lhs = [1]; // Contract over the second dimension of tensor1 + let axis_rhs = [0]; // Contract over the first dimension of tensor2 + + // Perform contraction + let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs); + + println!("T1: {}", tensor1); + println!("T2: {}", tensor2); + println!("T1 * T2 = {}", result); + + // Expected result, for example, could be a single number or a new tensor, + // depending on how you defined the contraction operation. + // Assert the result is as expected + // assert_eq!(result, expected_result); +} + +fn main() { + tensor_product(); + tensor_contraction(); +} \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..5d56faf --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..df99c69 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +max_width = 80 diff --git a/src/index.rs b/src/index.rs new file mode 100644 index 0000000..7bc71e0 --- /dev/null +++ b/src/index.rs @@ -0,0 +1,251 @@ +use super::*; +use std::cmp::Ordering; +use std::ops::{Add, Sub}; + +#[derive(Clone, Copy, Debug)] +pub struct Idx { + indices: [usize; R], + shape: Shape, +} + +impl Idx { + pub const fn zero(shape: Shape) -> Self { + Self { + indices: [0; R], + shape, + } + } + + pub fn max(shape: Shape) -> Self { + let max_indices = + shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); + Self { + indices: max_indices, + shape, + } + } + + pub fn new(shape: Shape, indices: [usize; R]) -> Self { + if !shape.check_indices(indices) { + panic!("indices out of bounds"); + } + Self { indices, shape } + } + + pub fn is_zero(&self) -> bool { + self.indices.iter().all(|&i| i == 0) + } + + pub fn is_overflow(&self) -> bool { + // Check if the last index is equal to the size of the last dimension + self.indices[0] >= self.shape.get(R - 1) + } + + pub fn reset(&mut self) { + self.indices = [0; R]; + } + + /// Increments the index and returns a boolean indicating whether the end has been reached. + /// + /// # Returns + /// `true` if the increment does not overflow and is still within bounds; + /// `false` if it overflows, indicating the end of the tensor. + pub fn inc(&mut self) -> bool { + let mut carry = 1; + for (i, &dim_size) in + self.indices.iter_mut().zip(&self.shape.as_array()).rev() + { + if carry == 1 { + *i += 1; + if *i >= dim_size { + *i = 0; // Reset index in this dimension and carry over + } else { + carry = 0; // Increment successful, no carry needed + } + } + } + + // If carry is still 1 after the loop, it means we've incremented past the last dimension + if carry == 1 { + // Set the index to an invalid state (e.g., all indices to their max values) + self.indices[0] = self.shape.as_array()[0]; + return true; // Indicate that the iteration is complete + } + false + } + + pub fn dec(&mut self) { + // Check if already at the start + if self.indices.iter().all(|&i| i == 0) { + return; + } + + let mut borrow = true; + for (i, &dim_size) in + self.indices.iter_mut().zip(&self.shape.as_array()).rev() + { + if borrow { + if *i == 0 { + *i = dim_size - 1; // Wrap around to the maximum index of this dimension + } else { + *i -= 1; // Decrement the index + borrow = false; // No more borrowing needed + } + } + } + } + + /// Converts the multi-dimensional index to a flat index. + /// + /// This method calculates the flat index corresponding to the multi-dimensional index + /// stored in `self.indices`, given the shape of the tensor stored in `self.shape`. + /// The calculation is based on the assumption that the tensor is stored in row-major order, + /// where the last dimension varies the fastest. + /// + /// # Returns + /// The flat index corresponding to the multi-dimensional index. + /// + /// # How It Works + /// - The method iterates over each pair of corresponding index and dimension size, + /// starting from the last dimension (hence `rev()` is used for reverse iteration). + /// - In each iteration, it performs two main operations: + /// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product + /// of dimension sizes (`product`). This calculates the contribution of the current index + /// to the overall flat index. + /// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`). + /// This updates `product` for the next iteration, as each dimension contributes to the + /// flat index based on the sizes of all preceding dimensions. + /// - The `fold` operation accumulates these results, starting with an initial state of + /// `(0, 1)` where `0` is the initial flat index and `1` is the initial product. + /// - The final flat index is obtained after the last iteration, which is the first element + /// of the tuple resulting from the `fold`. + /// + /// # Example + /// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`. + /// - Starting with a flat index of 0 and a product of 1, + /// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5. + /// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20. + /// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33. + pub fn flat(&self) -> usize { + self.indices + .iter() + .zip(&self.shape.as_array()) + .rev() + .fold((0, 1), |(flat_index, product), (&idx, &dim_size)| { + (flat_index + idx * product, product * dim_size) + }) + .0 + } +} + +// --- blanket impls --- + +impl PartialEq for Idx { + fn eq(&self, other: &Self) -> bool { + self.flat() == other.flat() + } +} + +impl Eq for Idx {} + +impl PartialOrd for Idx { + fn partial_cmp(&self, other: &Self) -> Option { + self.flat().partial_cmp(&other.flat()) + } +} + +impl Ord for Idx { + fn cmp(&self, other: &Self) -> Ordering { + self.flat().cmp(&other.flat()) + } +} + +impl Index for Idx { + type Output = usize; + + fn index(&self, index: usize) -> &Self::Output { + &self.indices[index] + } +} + +impl IndexMut for Idx { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.indices[index] + } +} + +impl From<(Shape, [usize; R])> for Idx { + fn from((shape, indices): (Shape, [usize; R])) -> Self { + assert!(shape.check_indices(indices)); + Self::new(shape, indices) + } +} + +impl From<(Shape, usize)> for Idx { + fn from((shape, flat_index): (Shape, usize)) -> Self { + let indices = shape.index_from_flat(flat_index).indices; + Self::new(shape, indices) + } +} + +impl From> for Idx { + fn from(shape: Shape) -> Self { + Self::zero(shape) + } +} + +impl std::fmt::Display for Idx { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "[")?; + for (i, (&idx, &dim_size)) in self + .indices + .iter() + .zip(self.shape.as_array().iter()) + .enumerate() + { + write!(f, "{}/{}", idx, dim_size - 1)?; + if i < self.indices.len() - 1 { + write!(f, ", ")?; + } + } + write!(f, "]") + } +} + +// ---- Arithmetic Operations ---- + +impl Add for Idx { + type Output = Self; + + fn add(self, rhs: Self) -> Self::Output { + assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + + let mut result_indices = [0; R]; + for i in 0..R { + result_indices[i] = self.indices[i] + rhs.indices[i]; + } + + Self { + indices: result_indices, + shape: self.shape, + } + } +} + +impl Sub for Idx { + type Output = Self; + + fn sub(self, rhs: Self) -> Self::Output { + assert_eq!(self.shape, rhs.shape, "Shape mismatch"); + + let mut result_indices = [0; R]; + for i in 0..R { + result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]); + } + + Self { + indices: result_indices, + shape: self.shape, + } + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..be92164 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,212 @@ +#![allow(incomplete_features)] +#![feature(generic_const_exprs)] +pub mod index; +pub mod shape; +pub mod tensor; + +pub use index::Idx; +use num::{Num, One, Zero}; +pub use serde::{Deserialize, Serialize}; +pub use shape::Shape; +pub use std::fmt::{Display, Formatter, Result as FmtResult}; +use std::ops::{Index, IndexMut}; +pub use tensor::{Tensor, TensorIterator}; + +pub trait Value: + Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static> +{ +} + +impl Value for T where + T: Num + + Zero + + One + + Copy + + Clone + + Display + + Serialize + + Deserialize<'static> +{ +} + +// ---- Tests ---- + +#[cfg(test)] +mod tests { + use super::*; + use serde_json; + + #[test] + fn test_tensor_product() { + let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor2 = Tensor::::from([2]); // 2-element vector + + // Fill tensors with some values + tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); + tensor2.buffer_mut().copy_from_slice(&[5, 6]); + + let product = tensor1.tensor_product(&tensor2); + + // Check shape of the resulting tensor + assert_eq!(product.shape(), Shape::new([2, 2, 2])); + + // Check buffer of the resulting tensor + let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; + assert_eq!(product.buffer(), &expected_buffer); + } + + #[test] + fn serde_shape_serialization_test() { + // Create a shape instance + let shape: Shape<3> = [1, 2, 3].into(); + + // Serialize the shape to a JSON string + let serialized = + serde_json::to_string(&shape).expect("Failed to serialize"); + + // Deserialize the JSON string back into a shape + let deserialized: Shape<3> = + serde_json::from_str(&serialized).expect("Failed to deserialize"); + + // Check that the deserialized shape is equal to the original + assert_eq!(shape, deserialized); + } + + #[test] + fn tensor_serde_serialization_test() { + // Create an instance of Tensor + let tensor: Tensor = Tensor::new(Shape::new([2, 2])); + + // Serialize the Tensor to a JSON string + let serialized = + serde_json::to_string(&tensor).expect("Failed to serialize"); + + // Deserialize the JSON string back into a Tensor + let deserialized: Tensor = + serde_json::from_str(&serialized).expect("Failed to deserialize"); + + // Check that the deserialized Tensor is equal to the original + assert_eq!(tensor.buffer(), deserialized.buffer()); + assert_eq!(tensor.shape(), deserialized.shape()); + } + + #[test] + fn iterate_3d_tensor() { + let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2 + let mut tensor = Tensor::new(shape); + let mut num = 0; + + // Fill the tensor with sequential numbers + for i in 0..2 { + for j in 0..2 { + for k in 0..2 { + tensor.buffer_mut()[i * 4 + j * 2 + k] = num; + num += 1; + } + } + } + + println!("{}", tensor); + + // Iterate over the tensor and check that the numbers are correct + + let mut iter = TensorIterator::new(&tensor); + + println!("{}", iter); + + assert_eq!(iter.next(), Some(&0)); + + assert_eq!(iter.next(), Some(&1)); + assert_eq!(iter.next(), Some(&2)); + assert_eq!(iter.next(), Some(&3)); + assert_eq!(iter.next(), Some(&4)); + assert_eq!(iter.next(), Some(&5)); + assert_eq!(iter.next(), Some(&6)); + assert_eq!(iter.next(), Some(&7)); + assert_eq!(iter.next(), None); + assert_eq!(iter.next(), None); + } + + #[test] + fn iterate_rank_4_tensor() { + // Define the shape of the rank-4 tensor (e.g., 2x2x2x2) + let shape = Shape::new([2, 2, 2, 2]); + let mut tensor = Tensor::new(shape); + let mut num = 0; + + // Fill the tensor with sequential numbers + for i in 0..tensor.len() { + tensor.buffer_mut()[i] = num; + num += 1; + } + + // Iterate over the tensor and check that the numbers are correct + let mut iter = TensorIterator::new(&tensor); + for expected_value in 0..tensor.len() { + assert_eq!(*iter.next().unwrap(), expected_value); + } + + // Ensure the iterator is exhausted + assert!(iter.next().is_none()); + } + + #[test] + fn test_dec_method() { + let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor + let mut index = Idx::zero(shape); + + // Increment the index to the maximum + for _ in 0..26 { + // 3 * 3 * 3 - 1 = 26 increments to reach the end + index.inc(); + } + + // Check if the index is at the maximum + assert_eq!(index, Idx::new(shape, [2, 2, 2])); + + // Decrement step by step and check the index + let expected_indices = [ + [2, 2, 2], + [2, 2, 1], + [2, 2, 0], + [2, 1, 2], + [2, 1, 1], + [2, 1, 0], + [2, 0, 2], + [2, 0, 1], + [2, 0, 0], + [1, 2, 2], + [1, 2, 1], + [1, 2, 0], + [1, 1, 2], + [1, 1, 1], + [1, 1, 0], + [1, 0, 2], + [1, 0, 1], + [1, 0, 0], + [0, 2, 2], + [0, 2, 1], + [0, 2, 0], + [0, 1, 2], + [0, 1, 1], + [0, 1, 0], + [0, 0, 2], + [0, 0, 1], + [0, 0, 0], + ]; + + for (i, &expected) in expected_indices.iter().enumerate() { + assert_eq!( + index, + Idx::new(shape, expected), + "Failed at index {}", + i + ); + index.dec(); + } + + // Finally, the index should reach [0, 0, 0] + index.dec(); + assert_eq!(index, Idx::zero(shape)); + } +} diff --git a/src/shape.rs b/src/shape.rs new file mode 100644 index 0000000..e6409ea --- /dev/null +++ b/src/shape.rs @@ -0,0 +1,178 @@ +use super::*; +use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor}; +use serde::ser::{Serialize, SerializeTuple, Serializer}; +use std::fmt; + +#[derive(Clone, Copy, Debug)] +pub struct Shape([usize; R]); + +impl Shape { + pub const fn new(shape: [usize; R]) -> Self { + Self(shape) + } + + pub const fn as_array(&self) -> [usize; R] { + self.0 + } + + pub const fn rank(&self) -> usize { + R + } + + pub fn flat_max(&self) -> usize { + self.size() - 1 + } + + pub fn size(&self) -> usize { + self.0.iter().product() + } + + pub fn iter(&self) -> impl Iterator { + self.0.iter() + } + + pub const fn get(&self, index: usize) -> usize { + self.0[index] + } + + pub fn check_indices(&self, indices: [usize; R]) -> bool { + indices + .iter() + .zip(self.0.iter()) + .all(|(&idx, &dim_size)| idx < dim_size) + } + + /// Converts a flat index to a multi-dimensional index. + /// + /// # Arguments + /// * `flat_index` - The flat index to convert. + /// + /// # Returns + /// An `Idx` instance representing the multi-dimensional index corresponding to the flat index. + /// + /// # How It Works + /// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order). + /// - In each iteration, it uses the modulo operation to find the index in the current dimension + /// and integer division to reduce the flat index for the next higher dimension. + /// - This process is repeated for each dimension to build the multi-dimensional index. + pub fn index_from_flat(&self, flat_index: usize) -> Idx { + let mut indices = [0; R]; + let mut remaining = flat_index; + + for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() { + *idx = remaining % dim_size; + remaining /= dim_size; + } + + indices.reverse(); // Reverse the indices to match the original dimension order + Idx::new(*self, indices) + } + + pub const fn index_zero(&self) -> Idx { + Idx::zero(*self) + } + + pub fn index_max(&self) -> Idx { + let max_indices = + self.0 + .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); + Idx::new(*self, max_indices) + } + + pub fn remove_dims( + &self, + dims_to_remove: [usize; NAX], + ) -> Shape<{ R - NAX }> { + // Create a new array to store the remaining dimensions + let mut new_shape = [0; R - NAX]; + let mut new_index = 0; + + // Iterate over the original dimensions + for (index, &dim) in self.0.iter().enumerate() { + // Skip dimensions that are in the dims_to_remove array + if dims_to_remove.contains(&index) { + continue; + } + + // Add the dimension to the new shape array + new_shape[new_index] = dim; + new_index += 1; + } + + Shape(new_shape) + } +} + +// ---- Serialize and Deserialize ---- + +struct ShapeVisitor; + +impl<'de, const R: usize> Visitor<'de> for ShapeVisitor { + type Value = Shape; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str(concat!("an array of length ", "{R}")) + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: SeqAccess<'de>, + { + let mut arr = [0; R]; + for i in 0..R { + arr[i] = seq + .next_element()? + .ok_or_else(|| de::Error::invalid_length(i, &self))?; + } + Ok(Shape(arr)) + } +} + +impl<'de, const R: usize> Deserialize<'de> for Shape { + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: Deserializer<'de>, + { + deserializer.deserialize_tuple(R, ShapeVisitor) + } +} + +impl Serialize for Shape { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + let mut seq = serializer.serialize_tuple(R)?; + for elem in &self.0 { + seq.serialize_element(elem)?; + } + seq.end() + } +} + +// ---- Blanket Implementations ---- + +impl From<[usize; R]> for Shape { + fn from(shape: [usize; R]) -> Self { + Self::new(shape) + } +} + +impl PartialEq for Shape { + fn eq(&self, other: &Self) -> bool { + self.0 == other.0 + } +} + +impl Eq for Shape {} + +// ---- From and Into Implementations ---- + +impl From> for Shape +where + T: Value, +{ + fn from(tensor: Tensor) -> Self { + tensor.shape() + } +} diff --git a/src/tensor.rs b/src/tensor.rs new file mode 100644 index 0000000..d8c78f6 --- /dev/null +++ b/src/tensor.rs @@ -0,0 +1,368 @@ +use super::*; + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Tensor { + buffer: Vec, + shape: Shape, +} + +impl Tensor { + pub fn dot(&self, other: &Tensor) -> T { + if self.shape != other.shape { + panic!("Shapes of tensors do not match"); + } + + let mut result = T::zero(); + for (a, b) in self.buffer.iter().zip(other.buffer.iter()) { + result = result + (*a * *b); + } + + result + } +} + +impl Tensor { + pub fn new(shape: Shape) -> Self { + let total_size: usize = shape.iter().product(); + let buffer = vec![T::zero(); total_size]; + Self { buffer, shape } + } + + pub fn shape(&self) -> Shape { + self.shape + } + + pub fn buffer(&self) -> &[T] { + &self.buffer + } + + pub fn buffer_mut(&mut self) -> &mut [T] { + &mut self.buffer + } + + pub fn get(&self, index: Idx) -> &T { + &self.buffer[index.flat()] + } + + pub unsafe fn get_unchecked(&self, index: Idx) -> &T { + self.buffer.get_unchecked(index.flat()) + } + + pub fn get_mut(&mut self, index: Idx) -> Option<&mut T> { + self.buffer.get_mut(index.flat()) + } + + pub unsafe fn get_unchecked_mut(&mut self, index: Idx) -> &mut T { + self.buffer.get_unchecked_mut(index.flat()) + } + + pub fn rank(&self) -> usize { + R + } + + pub fn len(&self) -> usize { + self.buffer.len() + } + + pub fn iter(&self) -> TensorIterator { + TensorIterator::new(self) + } + + pub fn elementwise_multiply(&self, other: &Tensor) -> Tensor { + if self.shape != other.shape { + panic!("Shapes of tensors do not match"); + } + + let mut result_buffer = Vec::with_capacity(self.buffer.len()); + + for (a, b) in self.buffer.iter().zip(other.buffer.iter()) { + result_buffer.push(*a * *b); + } + + Tensor { + buffer: result_buffer, + shape: self.shape, + } + } + + pub fn tensor_product( + &self, + other: &Tensor, + ) -> Tensor { + let mut new_shape_vec = Vec::new(); + new_shape_vec.extend_from_slice(&self.shape.as_array()); + new_shape_vec.extend_from_slice(&other.shape.as_array()); + + let new_shape_array: [usize; R + S] = new_shape_vec + .try_into() + .expect("Failed to create shape array"); + + let mut new_buffer = + Vec::with_capacity(self.buffer.len() * other.buffer.len()); + for &item_self in &self.buffer { + for &item_other in &other.buffer { + new_buffer.push(item_self * item_other); + } + } + + Tensor { + buffer: new_buffer, + shape: Shape::new(new_shape_array), + } + } + + // `self_dims` and `other_dims` specify the dimensions to contract over + pub fn contract( + &self, + rhs: &Tensor, + axis_lhs: [usize; NAXL], + axis_rhs: [usize; NAXR], + ) -> Tensor + where + [(); R - NAXL]:, + [(); S - NAXR]:, + { + // Step 1: Validate the axes for both tensors to ensure they are within bounds + for &axis in &axis_lhs { + if axis >= R { + panic!( + "Axis {} is out of bounds for the left-hand tensor", + axis + ); + } + } + + for &axis in &axis_rhs { + if axis >= S { + panic!( + "Axis {} is out of bounds for the right-hand tensor", + axis + ); + } + } + + // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions + let mut result_buffer = Vec::new(); + + for i in 0..self.shape.size() { + for j in 0..rhs.shape.size() { + // Debug: Print indices being processed + println!("Processing Indices: lhs = {}, rhs = {}", i, j); + + if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) { + let mut product_sum = T::zero(); + + // Debug: Print axes of contraction + println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs); + + for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) { + // Debug: Print values being multiplied + let value_lhs = self.get_by_axis(axis_l, i).unwrap(); + let value_rhs = rhs.get_by_axis(axis_r, j).unwrap(); + println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs); + + product_sum = product_sum + value_lhs * value_rhs; + } + + // Debug: Print the product sum for the current indices + println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum); + + result_buffer.push(product_sum); + } + } + } + + // Step 3: Remove contracted dimensions to create new shapes for both tensors + let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs); + let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs); + + // Step 4: Concatenate the shapes to form the shape of the resultant tensor + let mut new_shape = Vec::new(); + + new_shape.extend_from_slice(&new_shape_lhs.as_array()); + new_shape.extend_from_slice(&new_shape_rhs.as_array()); + + let new_shape_array: [usize; R + S - NAXL - NAXR] = + new_shape.try_into().expect("Failed to create shape array"); + + Tensor { + buffer: result_buffer, + shape: Shape::new(new_shape_array), + } + } + + // Retrieve an element based on a specific axis and index + pub fn get_by_axis(&self, axis: usize, index: usize) -> Option { + // Convert axis and index to a flat index + let flat_index = self.axis_to_flat_index(axis, index); + if flat_index >= self.buffer.len() { + return None; + } + + Some(self.buffer[flat_index]) + } + + // Convert axis and index to a flat index in the buffer + fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize { + let mut flat_index = 0; + let mut stride = 1; + + // Ensure the given axis is within the tensor's dimensions + if axis >= R { + panic!("Axis out of bounds"); + } + + // Calculate the stride for each dimension and accumulate the flat index + for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() { + println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); + if i > axis { + stride *= dim_size; + } else if i == axis { + flat_index += index * stride; + break; // We've reached the target axis + } + } + + flat_index + } +} + +// ---- Indexing ---- + +impl Index> for Tensor { + type Output = T; + + fn index(&self, index: Idx) -> &Self::Output { + &self.buffer[index.flat()] + } +} + +impl IndexMut> for Tensor { + fn index_mut(&mut self, index: Idx) -> &mut Self::Output { + &mut self.buffer[index.flat()] + } +} + +impl std::fmt::Display for Tensor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + // Print the shape of the tensor + write!(f, "Shape: [")?; + for (i, dim_size) in self.shape.as_array().iter().enumerate() { + write!(f, "{}", dim_size)?; + if i < R - 1 { + write!(f, ", ")?; + } + } + write!(f, "], Elements: [")?; + + // Print the elements in a flattened form + for (i, elem) in self.buffer.iter().enumerate() { + write!(f, "{}", elem)?; + if i < self.buffer.len() - 1 { + write!(f, ", ")?; + } + } + + write!(f, "]") + } +} + +// ---- Iterator ---- + +pub struct TensorIterator<'a, T: Value, const R: usize> { + tensor: &'a Tensor, + index: Idx, +} + +impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { + pub const fn new(tensor: &'a Tensor) -> Self { + Self { + tensor, + index: tensor.shape.index_zero(), + } + } +} + +impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> { + type Item = &'a T; + fn next(&mut self) -> Option { + if self.index.is_overflow() { + return None; + } + + let result = unsafe { self.tensor.get_unchecked(self.index) }; + self.index.inc(); + Some(result) + } +} + +impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor { + type Item = &'a T; + type IntoIter = TensorIterator<'a, T, R>; + + fn into_iter(self) -> Self::IntoIter { + TensorIterator::new(self) + } +} + +// ---- Formatting ---- + +impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> { + fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult { + // Print the current index and flat index + write!( + f, + "Current Index: {}, Flat Index: {}", + self.index, + self.index.flat() + )?; + + // Print the tensor elements, highlighting the current element + write!(f, ", Tensor Elements: [")?; + for (i, elem) in self.tensor.buffer().iter().enumerate() { + if i == self.index.flat() { + write!(f, "*{}*", elem)?; // Highlight the current element + } else { + write!(f, "{}", elem)?; + } + if i < self.tensor.buffer().len() - 1 { + write!(f, ", ")?; + } + } + + write!(f, "]") + } +} + +// ---- From ---- + +impl From> for Tensor { + fn from(shape: Shape) -> Self { + Self::new(shape) + } +} + +impl From<[usize; R]> for Tensor { + fn from(shape: [usize; R]) -> Self { + let shape = Shape::new(shape); + Self::new(shape.into()) + } +} + +impl From<[[T; X]; Y]> + for Tensor +{ + fn from(array: [[T; X]; Y]) -> Self { + let shape = Shape::new([Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, row) in array.iter().enumerate() { + for (j, &elem) in row.iter().enumerate() { + buffer[i * X + j] = elem; + } + } + + tensor + } +}