🚀 WIP implementation of N-rank Tensor

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-26 01:47:13 +02:00
commit b17264ba18
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
12 changed files with 1300 additions and 0 deletions

1
.gitignore vendored Normal file
View File

@ -0,0 +1 @@
/target

179
Cargo.lock generated Normal file
View File

@ -0,0 +1,179 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bytemuck"
version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "mltensor"
version = "0.1.0"
dependencies = [
"bytemuck",
"num",
"serde",
"serde_json",
]
[[package]]
name = "num"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
dependencies = [
"num-bigint",
"num-complex",
"num-integer",
"num-iter",
"num-rational",
"num-traits",
]
[[package]]
name = "num-bigint"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-complex"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
dependencies = [
"num-traits",
]
[[package]]
name = "num-integer"
version = "0.1.45"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
dependencies = [
"autocfg",
"num-traits",
]
[[package]]
name = "num-iter"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
dependencies = [
"autocfg",
"num-integer",
"num-traits",
]
[[package]]
name = "num-rational"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
dependencies = [
"autocfg",
"num-bigint",
"num-integer",
"num-traits",
]
[[package]]
name = "num-traits"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
]
[[package]]
name = "proc-macro2"
version = "1.0.71"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.33"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
dependencies = [
"proc-macro2",
]
[[package]]
name = "ryu"
version = "1.0.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
[[package]]
name = "serde"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.193"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "serde_json"
version = "1.0.108"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
dependencies = [
"itoa",
"ryu",
"serde",
]
[[package]]
name = "syn"
version = "2.0.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"

12
Cargo.toml Normal file
View File

@ -0,0 +1,12 @@
[package]
name = "mltensor"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bytemuck = "1.14.0"
num = "0.4.1"
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Julius Koskela
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

22
README.md Normal file
View File

@ -0,0 +1,22 @@
# Tensorc
```rust
// Create two tensors with different ranks and shapes
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
// Calculate tensor product
let product = tensor1.tensor_product(&tensor2);
println!("T1 * T2 = {}", product);
// Check shape of the resulting tensor
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]);
```

53
examples/operations.rs Normal file
View File

@ -0,0 +1,53 @@
use mltensor::*;
fn tensor_product() {
println!("Tensor Product\n");
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
println!("T1: {}", tensor1);
println!("T2: {}", tensor2);
let product = tensor1.tensor_product(&tensor2);
println!("T1 * T2 = {}", product);
// Check shape of the resulting tensor
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
assert_eq!(product.buffer(), &expected_buffer);
}
fn tensor_contraction() {
println!("Tensor Contraction\n");
// Create two tensors
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
// Specify axes for contraction
let axis_lhs = [1]; // Contract over the second dimension of tensor1
let axis_rhs = [0]; // Contract over the first dimension of tensor2
// Perform contraction
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
println!("T1: {}", tensor1);
println!("T2: {}", tensor2);
println!("T1 * T2 = {}", result);
// Expected result, for example, could be a single number or a new tensor,
// depending on how you defined the contraction operation.
// Assert the result is as expected
// assert_eq!(result, expected_result);
}
fn main() {
tensor_product();
tensor_contraction();
}

2
rust-toolchain.toml Normal file
View File

@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"

1
rustfmt.toml Normal file
View File

@ -0,0 +1 @@
max_width = 80

251
src/index.rs Normal file
View File

@ -0,0 +1,251 @@
use super::*;
use std::cmp::Ordering;
use std::ops::{Add, Sub};
#[derive(Clone, Copy, Debug)]
pub struct Idx<const R: usize> {
indices: [usize; R],
shape: Shape<R>,
}
impl<const R: usize> Idx<R> {
pub const fn zero(shape: Shape<R>) -> Self {
Self {
indices: [0; R],
shape,
}
}
pub fn max(shape: Shape<R>) -> Self {
let max_indices =
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
Self {
indices: max_indices,
shape,
}
}
pub fn new(shape: Shape<R>, indices: [usize; R]) -> Self {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape }
}
pub fn is_zero(&self) -> bool {
self.indices.iter().all(|&i| i == 0)
}
pub fn is_overflow(&self) -> bool {
// Check if the last index is equal to the size of the last dimension
self.indices[0] >= self.shape.get(R - 1)
}
pub fn reset(&mut self) {
self.indices = [0; R];
}
/// Increments the index and returns a boolean indicating whether the end has been reached.
///
/// # Returns
/// `true` if the increment does not overflow and is still within bounds;
/// `false` if it overflows, indicating the end of the tensor.
pub fn inc(&mut self) -> bool {
let mut carry = 1;
for (i, &dim_size) in
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
{
if carry == 1 {
*i += 1;
if *i >= dim_size {
*i = 0; // Reset index in this dimension and carry over
} else {
carry = 0; // Increment successful, no carry needed
}
}
}
// If carry is still 1 after the loop, it means we've incremented past the last dimension
if carry == 1 {
// Set the index to an invalid state (e.g., all indices to their max values)
self.indices[0] = self.shape.as_array()[0];
return true; // Indicate that the iteration is complete
}
false
}
pub fn dec(&mut self) {
// Check if already at the start
if self.indices.iter().all(|&i| i == 0) {
return;
}
let mut borrow = true;
for (i, &dim_size) in
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
{
if borrow {
if *i == 0 {
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
} else {
*i -= 1; // Decrement the index
borrow = false; // No more borrowing needed
}
}
}
}
/// Converts the multi-dimensional index to a flat index.
///
/// This method calculates the flat index corresponding to the multi-dimensional index
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
/// The calculation is based on the assumption that the tensor is stored in row-major order,
/// where the last dimension varies the fastest.
///
/// # Returns
/// The flat index corresponding to the multi-dimensional index.
///
/// # How It Works
/// - The method iterates over each pair of corresponding index and dimension size,
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
/// - In each iteration, it performs two main operations:
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
/// of dimension sizes (`product`). This calculates the contribution of the current index
/// to the overall flat index.
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
/// This updates `product` for the next iteration, as each dimension contributes to the
/// flat index based on the sizes of all preceding dimensions.
/// - The `fold` operation accumulates these results, starting with an initial state of
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
/// - The final flat index is obtained after the last iteration, which is the first element
/// of the tuple resulting from the `fold`.
///
/// # Example
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
/// - Starting with a flat index of 0 and a product of 1,
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
pub fn flat(&self) -> usize {
self.indices
.iter()
.zip(&self.shape.as_array())
.rev()
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
(flat_index + idx * product, product * dim_size)
})
.0
}
}
// --- blanket impls ---
impl<const R: usize> PartialEq for Idx<R> {
fn eq(&self, other: &Self) -> bool {
self.flat() == other.flat()
}
}
impl<const R: usize> Eq for Idx<R> {}
impl<const R: usize> PartialOrd for Idx<R> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.flat().partial_cmp(&other.flat())
}
}
impl<const R: usize> Ord for Idx<R> {
fn cmp(&self, other: &Self) -> Ordering {
self.flat().cmp(&other.flat())
}
}
impl<const R: usize> Index<usize> for Idx<R> {
type Output = usize;
fn index(&self, index: usize) -> &Self::Output {
&self.indices[index]
}
}
impl<const R: usize> IndexMut<usize> for Idx<R> {
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
&mut self.indices[index]
}
}
impl<const R: usize> From<(Shape<R>, [usize; R])> for Idx<R> {
fn from((shape, indices): (Shape<R>, [usize; R])) -> Self {
assert!(shape.check_indices(indices));
Self::new(shape, indices)
}
}
impl<const R: usize> From<(Shape<R>, usize)> for Idx<R> {
fn from((shape, flat_index): (Shape<R>, usize)) -> Self {
let indices = shape.index_from_flat(flat_index).indices;
Self::new(shape, indices)
}
}
impl<const R: usize> From<Shape<R>> for Idx<R> {
fn from(shape: Shape<R>) -> Self {
Self::zero(shape)
}
}
impl<const R: usize> std::fmt::Display for Idx<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
for (i, (&idx, &dim_size)) in self
.indices
.iter()
.zip(self.shape.as_array().iter())
.enumerate()
{
write!(f, "{}/{}", idx, dim_size - 1)?;
if i < self.indices.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")
}
}
// ---- Arithmetic Operations ----
impl<const R: usize> Add for Idx<R> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
result_indices[i] = self.indices[i] + rhs.indices[i];
}
Self {
indices: result_indices,
shape: self.shape,
}
}
}
impl<const R: usize> Sub for Idx<R> {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
let mut result_indices = [0; R];
for i in 0..R {
result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]);
}
Self {
indices: result_indices,
shape: self.shape,
}
}
}

212
src/lib.rs Normal file
View File

@ -0,0 +1,212 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
pub mod index;
pub mod shape;
pub mod tensor;
pub use index::Idx;
use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize};
pub use shape::Shape;
pub use std::fmt::{Display, Formatter, Result as FmtResult};
use std::ops::{Index, IndexMut};
pub use tensor::{Tensor, TensorIterator};
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
{
}
// ---- Tests ----
#[cfg(test)]
mod tests {
use super::*;
use serde_json;
#[test]
fn test_tensor_product() {
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
let product = tensor1.tensor_product(&tensor2);
// Check shape of the resulting tensor
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
assert_eq!(product.buffer(), &expected_buffer);
}
#[test]
fn serde_shape_serialization_test() {
// Create a shape instance
let shape: Shape<3> = [1, 2, 3].into();
// Serialize the shape to a JSON string
let serialized =
serde_json::to_string(&shape).expect("Failed to serialize");
// Deserialize the JSON string back into a shape
let deserialized: Shape<3> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized shape is equal to the original
assert_eq!(shape, deserialized);
}
#[test]
fn tensor_serde_serialization_test() {
// Create an instance of Tensor
let tensor: Tensor<i32, 2> = Tensor::new(Shape::new([2, 2]));
// Serialize the Tensor to a JSON string
let serialized =
serde_json::to_string(&tensor).expect("Failed to serialize");
// Deserialize the JSON string back into a Tensor
let deserialized: Tensor<i32, 2> =
serde_json::from_str(&serialized).expect("Failed to deserialize");
// Check that the deserialized Tensor is equal to the original
assert_eq!(tensor.buffer(), deserialized.buffer());
assert_eq!(tensor.shape(), deserialized.shape());
}
#[test]
fn iterate_3d_tensor() {
let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
tensor.buffer_mut()[i * 4 + j * 2 + k] = num;
num += 1;
}
}
}
println!("{}", tensor);
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
println!("{}", iter);
assert_eq!(iter.next(), Some(&0));
assert_eq!(iter.next(), Some(&1));
assert_eq!(iter.next(), Some(&2));
assert_eq!(iter.next(), Some(&3));
assert_eq!(iter.next(), Some(&4));
assert_eq!(iter.next(), Some(&5));
assert_eq!(iter.next(), Some(&6));
assert_eq!(iter.next(), Some(&7));
assert_eq!(iter.next(), None);
assert_eq!(iter.next(), None);
}
#[test]
fn iterate_rank_4_tensor() {
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
let shape = Shape::new([2, 2, 2, 2]);
let mut tensor = Tensor::new(shape);
let mut num = 0;
// Fill the tensor with sequential numbers
for i in 0..tensor.len() {
tensor.buffer_mut()[i] = num;
num += 1;
}
// Iterate over the tensor and check that the numbers are correct
let mut iter = TensorIterator::new(&tensor);
for expected_value in 0..tensor.len() {
assert_eq!(*iter.next().unwrap(), expected_value);
}
// Ensure the iterator is exhausted
assert!(iter.next().is_none());
}
#[test]
fn test_dec_method() {
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
let mut index = Idx::zero(shape);
// Increment the index to the maximum
for _ in 0..26 {
// 3 * 3 * 3 - 1 = 26 increments to reach the end
index.inc();
}
// Check if the index is at the maximum
assert_eq!(index, Idx::new(shape, [2, 2, 2]));
// Decrement step by step and check the index
let expected_indices = [
[2, 2, 2],
[2, 2, 1],
[2, 2, 0],
[2, 1, 2],
[2, 1, 1],
[2, 1, 0],
[2, 0, 2],
[2, 0, 1],
[2, 0, 0],
[1, 2, 2],
[1, 2, 1],
[1, 2, 0],
[1, 1, 2],
[1, 1, 1],
[1, 1, 0],
[1, 0, 2],
[1, 0, 1],
[1, 0, 0],
[0, 2, 2],
[0, 2, 1],
[0, 2, 0],
[0, 1, 2],
[0, 1, 1],
[0, 1, 0],
[0, 0, 2],
[0, 0, 1],
[0, 0, 0],
];
for (i, &expected) in expected_indices.iter().enumerate() {
assert_eq!(
index,
Idx::new(shape, expected),
"Failed at index {}",
i
);
index.dec();
}
// Finally, the index should reach [0, 0, 0]
index.dec();
assert_eq!(index, Idx::zero(shape));
}
}

178
src/shape.rs Normal file
View File

@ -0,0 +1,178 @@
use super::*;
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt;
#[derive(Clone, Copy, Debug)]
pub struct Shape<const R: usize>([usize; R]);
impl<const R: usize> Shape<R> {
pub const fn new(shape: [usize; R]) -> Self {
Self(shape)
}
pub const fn as_array(&self) -> [usize; R] {
self.0
}
pub const fn rank(&self) -> usize {
R
}
pub fn flat_max(&self) -> usize {
self.size() - 1
}
pub fn size(&self) -> usize {
self.0.iter().product()
}
pub fn iter(&self) -> impl Iterator<Item = &usize> {
self.0.iter()
}
pub const fn get(&self, index: usize) -> usize {
self.0[index]
}
pub fn check_indices(&self, indices: [usize; R]) -> bool {
indices
.iter()
.zip(self.0.iter())
.all(|(&idx, &dim_size)| idx < dim_size)
}
/// Converts a flat index to a multi-dimensional index.
///
/// # Arguments
/// * `flat_index` - The flat index to convert.
///
/// # Returns
/// An `Idx<R>` instance representing the multi-dimensional index corresponding to the flat index.
///
/// # How It Works
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
/// and integer division to reduce the flat index for the next higher dimension.
/// - This process is repeated for each dimension to build the multi-dimensional index.
pub fn index_from_flat(&self, flat_index: usize) -> Idx<R> {
let mut indices = [0; R];
let mut remaining = flat_index;
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
*idx = remaining % dim_size;
remaining /= dim_size;
}
indices.reverse(); // Reverse the indices to match the original dimension order
Idx::new(*self, indices)
}
pub const fn index_zero(&self) -> Idx<R> {
Idx::zero(*self)
}
pub fn index_max(&self) -> Idx<R> {
let max_indices =
self.0
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
Idx::new(*self, max_indices)
}
pub fn remove_dims<const NAX: usize>(
&self,
dims_to_remove: [usize; NAX],
) -> Shape<{ R - NAX }> {
// Create a new array to store the remaining dimensions
let mut new_shape = [0; R - NAX];
let mut new_index = 0;
// Iterate over the original dimensions
for (index, &dim) in self.0.iter().enumerate() {
// Skip dimensions that are in the dims_to_remove array
if dims_to_remove.contains(&index) {
continue;
}
// Add the dimension to the new shape array
new_shape[new_index] = dim;
new_index += 1;
}
Shape(new_shape)
}
}
// ---- Serialize and Deserialize ----
struct ShapeVisitor<const R: usize>;
impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
type Value = Shape<R>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str(concat!("an array of length ", "{R}"))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let mut arr = [0; R];
for i in 0..R {
arr[i] = seq
.next_element()?
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
}
Ok(Shape(arr))
}
}
impl<'de, const R: usize> Deserialize<'de> for Shape<R> {
fn deserialize<D>(deserializer: D) -> Result<Shape<R>, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_tuple(R, ShapeVisitor)
}
}
impl<const R: usize> Serialize for Shape<R> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut seq = serializer.serialize_tuple(R)?;
for elem in &self.0 {
seq.serialize_element(elem)?;
}
seq.end()
}
}
// ---- Blanket Implementations ----
impl<const R: usize> From<[usize; R]> for Shape<R> {
fn from(shape: [usize; R]) -> Self {
Self::new(shape)
}
}
impl<const R: usize> PartialEq for Shape<R> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<const R: usize> Eq for Shape<R> {}
// ---- From and Into Implementations ----
impl<T, const R: usize> From<Tensor<T, R>> for Shape<R>
where
T: Value,
{
fn from(tensor: Tensor<T, R>) -> Self {
tensor.shape()
}
}

368
src/tensor.rs Normal file
View File

@ -0,0 +1,368 @@
use super::*;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Tensor<T, const R: usize> {
buffer: Vec<T>,
shape: Shape<R>,
}
impl<T: Value> Tensor<T, 1> {
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
if self.shape != other.shape {
panic!("Shapes of tensors do not match");
}
let mut result = T::zero();
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
result = result + (*a * *b);
}
result
}
}
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn new(shape: Shape<R>) -> Self {
let total_size: usize = shape.iter().product();
let buffer = vec![T::zero(); total_size];
Self { buffer, shape }
}
pub fn shape(&self) -> Shape<R> {
self.shape
}
pub fn buffer(&self) -> &[T] {
&self.buffer
}
pub fn buffer_mut(&mut self) -> &mut [T] {
&mut self.buffer
}
pub fn get(&self, index: Idx<R>) -> &T {
&self.buffer[index.flat()]
}
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
self.buffer.get_unchecked(index.flat())
}
pub fn get_mut(&mut self, index: Idx<R>) -> Option<&mut T> {
self.buffer.get_mut(index.flat())
}
pub unsafe fn get_unchecked_mut(&mut self, index: Idx<R>) -> &mut T {
self.buffer.get_unchecked_mut(index.flat())
}
pub fn rank(&self) -> usize {
R
}
pub fn len(&self) -> usize {
self.buffer.len()
}
pub fn iter(&self) -> TensorIterator<T, R> {
TensorIterator::new(self)
}
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
if self.shape != other.shape {
panic!("Shapes of tensors do not match");
}
let mut result_buffer = Vec::with_capacity(self.buffer.len());
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
result_buffer.push(*a * *b);
}
Tensor {
buffer: result_buffer,
shape: self.shape,
}
}
pub fn tensor_product<const S: usize>(
&self,
other: &Tensor<T, S>,
) -> Tensor<T, { R + S }> {
let mut new_shape_vec = Vec::new();
new_shape_vec.extend_from_slice(&self.shape.as_array());
new_shape_vec.extend_from_slice(&other.shape.as_array());
let new_shape_array: [usize; R + S] = new_shape_vec
.try_into()
.expect("Failed to create shape array");
let mut new_buffer =
Vec::with_capacity(self.buffer.len() * other.buffer.len());
for &item_self in &self.buffer {
for &item_other in &other.buffer {
new_buffer.push(item_self * item_other);
}
}
Tensor {
buffer: new_buffer,
shape: Shape::new(new_shape_array),
}
}
// `self_dims` and `other_dims` specify the dimensions to contract over
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
&self,
rhs: &Tensor<T, S>,
axis_lhs: [usize; NAXL],
axis_rhs: [usize; NAXR],
) -> Tensor<T, { R + S - NAXL - NAXR }>
where
[(); R - NAXL]:,
[(); S - NAXR]:,
{
// Step 1: Validate the axes for both tensors to ensure they are within bounds
for &axis in &axis_lhs {
if axis >= R {
panic!(
"Axis {} is out of bounds for the left-hand tensor",
axis
);
}
}
for &axis in &axis_rhs {
if axis >= S {
panic!(
"Axis {} is out of bounds for the right-hand tensor",
axis
);
}
}
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
let mut result_buffer = Vec::new();
for i in 0..self.shape.size() {
for j in 0..rhs.shape.size() {
// Debug: Print indices being processed
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
let mut product_sum = T::zero();
// Debug: Print axes of contraction
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
// Debug: Print values being multiplied
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
product_sum = product_sum + value_lhs * value_rhs;
}
// Debug: Print the product sum for the current indices
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
result_buffer.push(product_sum);
}
}
}
// Step 3: Remove contracted dimensions to create new shapes for both tensors
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
let mut new_shape = Vec::new();
new_shape.extend_from_slice(&new_shape_lhs.as_array());
new_shape.extend_from_slice(&new_shape_rhs.as_array());
let new_shape_array: [usize; R + S - NAXL - NAXR] =
new_shape.try_into().expect("Failed to create shape array");
Tensor {
buffer: result_buffer,
shape: Shape::new(new_shape_array),
}
}
// Retrieve an element based on a specific axis and index
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
// Convert axis and index to a flat index
let flat_index = self.axis_to_flat_index(axis, index);
if flat_index >= self.buffer.len() {
return None;
}
Some(self.buffer[flat_index])
}
// Convert axis and index to a flat index in the buffer
fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize {
let mut flat_index = 0;
let mut stride = 1;
// Ensure the given axis is within the tensor's dimensions
if axis >= R {
panic!("Axis out of bounds");
}
// Calculate the stride for each dimension and accumulate the flat index
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
if i > axis {
stride *= dim_size;
} else if i == axis {
flat_index += index * stride;
break; // We've reached the target axis
}
}
flat_index
}
}
// ---- Indexing ----
impl<T: Value, const R: usize> Index<Idx<R>> for Tensor<T, R> {
type Output = T;
fn index(&self, index: Idx<R>) -> &Self::Output {
&self.buffer[index.flat()]
}
}
impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
&mut self.buffer[index.flat()]
}
}
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Print the shape of the tensor
write!(f, "Shape: [")?;
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
write!(f, "{}", dim_size)?;
if i < R - 1 {
write!(f, ", ")?;
}
}
write!(f, "], Elements: [")?;
// Print the elements in a flattened form
for (i, elem) in self.buffer.iter().enumerate() {
write!(f, "{}", elem)?;
if i < self.buffer.len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")
}
}
// ---- Iterator ----
pub struct TensorIterator<'a, T: Value, const R: usize> {
tensor: &'a Tensor<T, R>,
index: Idx<R>,
}
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
Self {
tensor,
index: tensor.shape.index_zero(),
}
}
}
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.index.is_overflow() {
return None;
}
let result = unsafe { self.tensor.get_unchecked(self.index) };
self.index.inc();
Some(result)
}
}
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
type Item = &'a T;
type IntoIter = TensorIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
TensorIterator::new(self)
}
}
// ---- Formatting ----
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
// Print the current index and flat index
write!(
f,
"Current Index: {}, Flat Index: {}",
self.index,
self.index.flat()
)?;
// Print the tensor elements, highlighting the current element
write!(f, ", Tensor Elements: [")?;
for (i, elem) in self.tensor.buffer().iter().enumerate() {
if i == self.index.flat() {
write!(f, "*{}*", elem)?; // Highlight the current element
} else {
write!(f, "{}", elem)?;
}
if i < self.tensor.buffer().len() - 1 {
write!(f, ", ")?;
}
}
write!(f, "]")
}
}
// ---- From ----
impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
fn from(shape: Shape<R>) -> Self {
Self::new(shape)
}
}
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
fn from(shape: [usize; R]) -> Self {
let shape = Shape::new(shape);
Self::new(shape.into())
}
}
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
for Tensor<T, 2>
{
fn from(array: [[T; X]; Y]) -> Self {
let shape = Shape::new([Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, row) in array.iter().enumerate() {
for (j, &elem) in row.iter().enumerate() {
buffer[i * X + j] = elem;
}
}
tensor
}
}