🚀 WIP implementation of N-rank Tensor
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
commit
b17264ba18
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
179
Cargo.lock
generated
Normal file
179
Cargo.lock
generated
Normal file
@ -0,0 +1,179 @@
|
|||||||
|
# This file is automatically @generated by Cargo.
|
||||||
|
# It is not intended for manual editing.
|
||||||
|
version = 3
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "autocfg"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytemuck"
|
||||||
|
version = "1.14.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itoa"
|
||||||
|
version = "1.0.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mltensor"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"num",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||||
|
dependencies = [
|
||||||
|
"num-bigint",
|
||||||
|
"num-complex",
|
||||||
|
"num-integer",
|
||||||
|
"num-iter",
|
||||||
|
"num-rational",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-bigint"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-complex"
|
||||||
|
version = "0.4.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||||
|
dependencies = [
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-integer"
|
||||||
|
version = "0.1.45"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-iter"
|
||||||
|
version = "0.1.43"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-rational"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
"num-bigint",
|
||||||
|
"num-integer",
|
||||||
|
"num-traits",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num-traits"
|
||||||
|
version = "0.2.17"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
|
||||||
|
dependencies = [
|
||||||
|
"autocfg",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro2"
|
||||||
|
version = "1.0.71"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "quote"
|
||||||
|
version = "1.0.33"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "ryu"
|
||||||
|
version = "1.0.16"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde"
|
||||||
|
version = "1.0.193"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
|
||||||
|
dependencies = [
|
||||||
|
"serde_derive",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_derive"
|
||||||
|
version = "1.0.193"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "serde_json"
|
||||||
|
version = "1.0.108"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
|
||||||
|
dependencies = [
|
||||||
|
"itoa",
|
||||||
|
"ryu",
|
||||||
|
"serde",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "2.0.43"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-ident"
|
||||||
|
version = "1.0.12"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
12
Cargo.toml
Normal file
12
Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
|||||||
|
[package]
|
||||||
|
name = "mltensor"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
bytemuck = "1.14.0"
|
||||||
|
num = "0.4.1"
|
||||||
|
serde = { version = "1.0.193", features = ["derive"] }
|
||||||
|
serde_json = "1.0.108"
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2023 Julius Koskela
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
22
README.md
Normal file
22
README.md
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
# Tensorc
|
||||||
|
|
||||||
|
```rust
|
||||||
|
// Create two tensors with different ranks and shapes
|
||||||
|
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||||
|
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||||
|
|
||||||
|
// Fill tensors with some values
|
||||||
|
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||||
|
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||||
|
|
||||||
|
// Calculate tensor product
|
||||||
|
let product = tensor1.tensor_product(&tensor2);
|
||||||
|
|
||||||
|
println!("T1 * T2 = {}", product);
|
||||||
|
|
||||||
|
// Check shape of the resulting tensor
|
||||||
|
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||||
|
|
||||||
|
// Check buffer of the resulting tensor
|
||||||
|
assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]);
|
||||||
|
```
|
53
examples/operations.rs
Normal file
53
examples/operations.rs
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
use mltensor::*;
|
||||||
|
|
||||||
|
fn tensor_product() {
|
||||||
|
println!("Tensor Product\n");
|
||||||
|
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||||
|
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||||
|
|
||||||
|
// Fill tensors with some values
|
||||||
|
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||||
|
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||||
|
|
||||||
|
println!("T1: {}", tensor1);
|
||||||
|
println!("T2: {}", tensor2);
|
||||||
|
|
||||||
|
let product = tensor1.tensor_product(&tensor2);
|
||||||
|
|
||||||
|
println!("T1 * T2 = {}", product);
|
||||||
|
|
||||||
|
// Check shape of the resulting tensor
|
||||||
|
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||||
|
|
||||||
|
// Check buffer of the resulting tensor
|
||||||
|
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||||
|
assert_eq!(product.buffer(), &expected_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tensor_contraction() {
|
||||||
|
println!("Tensor Contraction\n");
|
||||||
|
// Create two tensors
|
||||||
|
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
||||||
|
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
||||||
|
|
||||||
|
// Specify axes for contraction
|
||||||
|
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
||||||
|
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
||||||
|
|
||||||
|
// Perform contraction
|
||||||
|
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||||
|
|
||||||
|
println!("T1: {}", tensor1);
|
||||||
|
println!("T2: {}", tensor2);
|
||||||
|
println!("T1 * T2 = {}", result);
|
||||||
|
|
||||||
|
// Expected result, for example, could be a single number or a new tensor,
|
||||||
|
// depending on how you defined the contraction operation.
|
||||||
|
// Assert the result is as expected
|
||||||
|
// assert_eq!(result, expected_result);
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
tensor_product();
|
||||||
|
tensor_contraction();
|
||||||
|
}
|
2
rust-toolchain.toml
Normal file
2
rust-toolchain.toml
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
[toolchain]
|
||||||
|
channel = "nightly"
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
@ -0,0 +1 @@
|
|||||||
|
max_width = 80
|
251
src/index.rs
Normal file
251
src/index.rs
Normal file
@ -0,0 +1,251 @@
|
|||||||
|
use super::*;
|
||||||
|
use std::cmp::Ordering;
|
||||||
|
use std::ops::{Add, Sub};
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct Idx<const R: usize> {
|
||||||
|
indices: [usize; R],
|
||||||
|
shape: Shape<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Idx<R> {
|
||||||
|
pub const fn zero(shape: Shape<R>) -> Self {
|
||||||
|
Self {
|
||||||
|
indices: [0; R],
|
||||||
|
shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn max(shape: Shape<R>) -> Self {
|
||||||
|
let max_indices =
|
||||||
|
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||||
|
Self {
|
||||||
|
indices: max_indices,
|
||||||
|
shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn new(shape: Shape<R>, indices: [usize; R]) -> Self {
|
||||||
|
if !shape.check_indices(indices) {
|
||||||
|
panic!("indices out of bounds");
|
||||||
|
}
|
||||||
|
Self { indices, shape }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_zero(&self) -> bool {
|
||||||
|
self.indices.iter().all(|&i| i == 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn is_overflow(&self) -> bool {
|
||||||
|
// Check if the last index is equal to the size of the last dimension
|
||||||
|
self.indices[0] >= self.shape.get(R - 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reset(&mut self) {
|
||||||
|
self.indices = [0; R];
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Increments the index and returns a boolean indicating whether the end has been reached.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// `true` if the increment does not overflow and is still within bounds;
|
||||||
|
/// `false` if it overflows, indicating the end of the tensor.
|
||||||
|
pub fn inc(&mut self) -> bool {
|
||||||
|
let mut carry = 1;
|
||||||
|
for (i, &dim_size) in
|
||||||
|
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||||
|
{
|
||||||
|
if carry == 1 {
|
||||||
|
*i += 1;
|
||||||
|
if *i >= dim_size {
|
||||||
|
*i = 0; // Reset index in this dimension and carry over
|
||||||
|
} else {
|
||||||
|
carry = 0; // Increment successful, no carry needed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||||
|
if carry == 1 {
|
||||||
|
// Set the index to an invalid state (e.g., all indices to their max values)
|
||||||
|
self.indices[0] = self.shape.as_array()[0];
|
||||||
|
return true; // Indicate that the iteration is complete
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn dec(&mut self) {
|
||||||
|
// Check if already at the start
|
||||||
|
if self.indices.iter().all(|&i| i == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut borrow = true;
|
||||||
|
for (i, &dim_size) in
|
||||||
|
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||||
|
{
|
||||||
|
if borrow {
|
||||||
|
if *i == 0 {
|
||||||
|
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
|
||||||
|
} else {
|
||||||
|
*i -= 1; // Decrement the index
|
||||||
|
borrow = false; // No more borrowing needed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts the multi-dimensional index to a flat index.
|
||||||
|
///
|
||||||
|
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||||
|
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
|
||||||
|
/// The calculation is based on the assumption that the tensor is stored in row-major order,
|
||||||
|
/// where the last dimension varies the fastest.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// The flat index corresponding to the multi-dimensional index.
|
||||||
|
///
|
||||||
|
/// # How It Works
|
||||||
|
/// - The method iterates over each pair of corresponding index and dimension size,
|
||||||
|
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
|
||||||
|
/// - In each iteration, it performs two main operations:
|
||||||
|
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
|
||||||
|
/// of dimension sizes (`product`). This calculates the contribution of the current index
|
||||||
|
/// to the overall flat index.
|
||||||
|
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
|
||||||
|
/// This updates `product` for the next iteration, as each dimension contributes to the
|
||||||
|
/// flat index based on the sizes of all preceding dimensions.
|
||||||
|
/// - The `fold` operation accumulates these results, starting with an initial state of
|
||||||
|
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
|
||||||
|
/// - The final flat index is obtained after the last iteration, which is the first element
|
||||||
|
/// of the tuple resulting from the `fold`.
|
||||||
|
///
|
||||||
|
/// # Example
|
||||||
|
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
||||||
|
/// - Starting with a flat index of 0 and a product of 1,
|
||||||
|
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
|
||||||
|
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
|
||||||
|
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
|
||||||
|
pub fn flat(&self) -> usize {
|
||||||
|
self.indices
|
||||||
|
.iter()
|
||||||
|
.zip(&self.shape.as_array())
|
||||||
|
.rev()
|
||||||
|
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
|
||||||
|
(flat_index + idx * product, product * dim_size)
|
||||||
|
})
|
||||||
|
.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// --- blanket impls ---
|
||||||
|
|
||||||
|
impl<const R: usize> PartialEq for Idx<R> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.flat() == other.flat()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Eq for Idx<R> {}
|
||||||
|
|
||||||
|
impl<const R: usize> PartialOrd for Idx<R> {
|
||||||
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
|
self.flat().partial_cmp(&other.flat())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Ord for Idx<R> {
|
||||||
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
|
self.flat().cmp(&other.flat())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Index<usize> for Idx<R> {
|
||||||
|
type Output = usize;
|
||||||
|
|
||||||
|
fn index(&self, index: usize) -> &Self::Output {
|
||||||
|
&self.indices[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> IndexMut<usize> for Idx<R> {
|
||||||
|
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||||
|
&mut self.indices[index]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> From<(Shape<R>, [usize; R])> for Idx<R> {
|
||||||
|
fn from((shape, indices): (Shape<R>, [usize; R])) -> Self {
|
||||||
|
assert!(shape.check_indices(indices));
|
||||||
|
Self::new(shape, indices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> From<(Shape<R>, usize)> for Idx<R> {
|
||||||
|
fn from((shape, flat_index): (Shape<R>, usize)) -> Self {
|
||||||
|
let indices = shape.index_from_flat(flat_index).indices;
|
||||||
|
Self::new(shape, indices)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> From<Shape<R>> for Idx<R> {
|
||||||
|
fn from(shape: Shape<R>) -> Self {
|
||||||
|
Self::zero(shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
write!(f, "[")?;
|
||||||
|
for (i, (&idx, &dim_size)) in self
|
||||||
|
.indices
|
||||||
|
.iter()
|
||||||
|
.zip(self.shape.as_array().iter())
|
||||||
|
.enumerate()
|
||||||
|
{
|
||||||
|
write!(f, "{}/{}", idx, dim_size - 1)?;
|
||||||
|
if i < self.indices.len() - 1 {
|
||||||
|
write!(f, ", ")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Arithmetic Operations ----
|
||||||
|
|
||||||
|
impl<const R: usize> Add for Idx<R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn add(self, rhs: Self) -> Self::Output {
|
||||||
|
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
||||||
|
|
||||||
|
let mut result_indices = [0; R];
|
||||||
|
for i in 0..R {
|
||||||
|
result_indices[i] = self.indices[i] + rhs.indices[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
indices: result_indices,
|
||||||
|
shape: self.shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Sub for Idx<R> {
|
||||||
|
type Output = Self;
|
||||||
|
|
||||||
|
fn sub(self, rhs: Self) -> Self::Output {
|
||||||
|
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
||||||
|
|
||||||
|
let mut result_indices = [0; R];
|
||||||
|
for i in 0..R {
|
||||||
|
result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
indices: result_indices,
|
||||||
|
shape: self.shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
212
src/lib.rs
Normal file
212
src/lib.rs
Normal file
@ -0,0 +1,212 @@
|
|||||||
|
#![allow(incomplete_features)]
|
||||||
|
#![feature(generic_const_exprs)]
|
||||||
|
pub mod index;
|
||||||
|
pub mod shape;
|
||||||
|
pub mod tensor;
|
||||||
|
|
||||||
|
pub use index::Idx;
|
||||||
|
use num::{Num, One, Zero};
|
||||||
|
pub use serde::{Deserialize, Serialize};
|
||||||
|
pub use shape::Shape;
|
||||||
|
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||||
|
use std::ops::{Index, IndexMut};
|
||||||
|
pub use tensor::{Tensor, TensorIterator};
|
||||||
|
|
||||||
|
pub trait Value:
|
||||||
|
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Value for T where
|
||||||
|
T: Num
|
||||||
|
+ Zero
|
||||||
|
+ One
|
||||||
|
+ Copy
|
||||||
|
+ Clone
|
||||||
|
+ Display
|
||||||
|
+ Serialize
|
||||||
|
+ Deserialize<'static>
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Tests ----
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use serde_json;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tensor_product() {
|
||||||
|
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||||
|
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||||
|
|
||||||
|
// Fill tensors with some values
|
||||||
|
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||||
|
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||||
|
|
||||||
|
let product = tensor1.tensor_product(&tensor2);
|
||||||
|
|
||||||
|
// Check shape of the resulting tensor
|
||||||
|
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||||
|
|
||||||
|
// Check buffer of the resulting tensor
|
||||||
|
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||||
|
assert_eq!(product.buffer(), &expected_buffer);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn serde_shape_serialization_test() {
|
||||||
|
// Create a shape instance
|
||||||
|
let shape: Shape<3> = [1, 2, 3].into();
|
||||||
|
|
||||||
|
// Serialize the shape to a JSON string
|
||||||
|
let serialized =
|
||||||
|
serde_json::to_string(&shape).expect("Failed to serialize");
|
||||||
|
|
||||||
|
// Deserialize the JSON string back into a shape
|
||||||
|
let deserialized: Shape<3> =
|
||||||
|
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
||||||
|
|
||||||
|
// Check that the deserialized shape is equal to the original
|
||||||
|
assert_eq!(shape, deserialized);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tensor_serde_serialization_test() {
|
||||||
|
// Create an instance of Tensor
|
||||||
|
let tensor: Tensor<i32, 2> = Tensor::new(Shape::new([2, 2]));
|
||||||
|
|
||||||
|
// Serialize the Tensor to a JSON string
|
||||||
|
let serialized =
|
||||||
|
serde_json::to_string(&tensor).expect("Failed to serialize");
|
||||||
|
|
||||||
|
// Deserialize the JSON string back into a Tensor
|
||||||
|
let deserialized: Tensor<i32, 2> =
|
||||||
|
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
||||||
|
|
||||||
|
// Check that the deserialized Tensor is equal to the original
|
||||||
|
assert_eq!(tensor.buffer(), deserialized.buffer());
|
||||||
|
assert_eq!(tensor.shape(), deserialized.shape());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iterate_3d_tensor() {
|
||||||
|
let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
|
||||||
|
let mut tensor = Tensor::new(shape);
|
||||||
|
let mut num = 0;
|
||||||
|
|
||||||
|
// Fill the tensor with sequential numbers
|
||||||
|
for i in 0..2 {
|
||||||
|
for j in 0..2 {
|
||||||
|
for k in 0..2 {
|
||||||
|
tensor.buffer_mut()[i * 4 + j * 2 + k] = num;
|
||||||
|
num += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("{}", tensor);
|
||||||
|
|
||||||
|
// Iterate over the tensor and check that the numbers are correct
|
||||||
|
|
||||||
|
let mut iter = TensorIterator::new(&tensor);
|
||||||
|
|
||||||
|
println!("{}", iter);
|
||||||
|
|
||||||
|
assert_eq!(iter.next(), Some(&0));
|
||||||
|
|
||||||
|
assert_eq!(iter.next(), Some(&1));
|
||||||
|
assert_eq!(iter.next(), Some(&2));
|
||||||
|
assert_eq!(iter.next(), Some(&3));
|
||||||
|
assert_eq!(iter.next(), Some(&4));
|
||||||
|
assert_eq!(iter.next(), Some(&5));
|
||||||
|
assert_eq!(iter.next(), Some(&6));
|
||||||
|
assert_eq!(iter.next(), Some(&7));
|
||||||
|
assert_eq!(iter.next(), None);
|
||||||
|
assert_eq!(iter.next(), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn iterate_rank_4_tensor() {
|
||||||
|
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
|
||||||
|
let shape = Shape::new([2, 2, 2, 2]);
|
||||||
|
let mut tensor = Tensor::new(shape);
|
||||||
|
let mut num = 0;
|
||||||
|
|
||||||
|
// Fill the tensor with sequential numbers
|
||||||
|
for i in 0..tensor.len() {
|
||||||
|
tensor.buffer_mut()[i] = num;
|
||||||
|
num += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Iterate over the tensor and check that the numbers are correct
|
||||||
|
let mut iter = TensorIterator::new(&tensor);
|
||||||
|
for expected_value in 0..tensor.len() {
|
||||||
|
assert_eq!(*iter.next().unwrap(), expected_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure the iterator is exhausted
|
||||||
|
assert!(iter.next().is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_dec_method() {
|
||||||
|
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
|
||||||
|
let mut index = Idx::zero(shape);
|
||||||
|
|
||||||
|
// Increment the index to the maximum
|
||||||
|
for _ in 0..26 {
|
||||||
|
// 3 * 3 * 3 - 1 = 26 increments to reach the end
|
||||||
|
index.inc();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the index is at the maximum
|
||||||
|
assert_eq!(index, Idx::new(shape, [2, 2, 2]));
|
||||||
|
|
||||||
|
// Decrement step by step and check the index
|
||||||
|
let expected_indices = [
|
||||||
|
[2, 2, 2],
|
||||||
|
[2, 2, 1],
|
||||||
|
[2, 2, 0],
|
||||||
|
[2, 1, 2],
|
||||||
|
[2, 1, 1],
|
||||||
|
[2, 1, 0],
|
||||||
|
[2, 0, 2],
|
||||||
|
[2, 0, 1],
|
||||||
|
[2, 0, 0],
|
||||||
|
[1, 2, 2],
|
||||||
|
[1, 2, 1],
|
||||||
|
[1, 2, 0],
|
||||||
|
[1, 1, 2],
|
||||||
|
[1, 1, 1],
|
||||||
|
[1, 1, 0],
|
||||||
|
[1, 0, 2],
|
||||||
|
[1, 0, 1],
|
||||||
|
[1, 0, 0],
|
||||||
|
[0, 2, 2],
|
||||||
|
[0, 2, 1],
|
||||||
|
[0, 2, 0],
|
||||||
|
[0, 1, 2],
|
||||||
|
[0, 1, 1],
|
||||||
|
[0, 1, 0],
|
||||||
|
[0, 0, 2],
|
||||||
|
[0, 0, 1],
|
||||||
|
[0, 0, 0],
|
||||||
|
];
|
||||||
|
|
||||||
|
for (i, &expected) in expected_indices.iter().enumerate() {
|
||||||
|
assert_eq!(
|
||||||
|
index,
|
||||||
|
Idx::new(shape, expected),
|
||||||
|
"Failed at index {}",
|
||||||
|
i
|
||||||
|
);
|
||||||
|
index.dec();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finally, the index should reach [0, 0, 0]
|
||||||
|
index.dec();
|
||||||
|
assert_eq!(index, Idx::zero(shape));
|
||||||
|
}
|
||||||
|
}
|
178
src/shape.rs
Normal file
178
src/shape.rs
Normal file
@ -0,0 +1,178 @@
|
|||||||
|
use super::*;
|
||||||
|
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||||
|
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||||
|
use std::fmt;
|
||||||
|
|
||||||
|
#[derive(Clone, Copy, Debug)]
|
||||||
|
pub struct Shape<const R: usize>([usize; R]);
|
||||||
|
|
||||||
|
impl<const R: usize> Shape<R> {
|
||||||
|
pub const fn new(shape: [usize; R]) -> Self {
|
||||||
|
Self(shape)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn as_array(&self) -> [usize; R] {
|
||||||
|
self.0
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn rank(&self) -> usize {
|
||||||
|
R
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn flat_max(&self) -> usize {
|
||||||
|
self.size() - 1
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn size(&self) -> usize {
|
||||||
|
self.0.iter().product()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||||
|
self.0.iter()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn get(&self, index: usize) -> usize {
|
||||||
|
self.0[index]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||||
|
indices
|
||||||
|
.iter()
|
||||||
|
.zip(self.0.iter())
|
||||||
|
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a flat index to a multi-dimensional index.
|
||||||
|
///
|
||||||
|
/// # Arguments
|
||||||
|
/// * `flat_index` - The flat index to convert.
|
||||||
|
///
|
||||||
|
/// # Returns
|
||||||
|
/// An `Idx<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
||||||
|
///
|
||||||
|
/// # How It Works
|
||||||
|
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
||||||
|
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
||||||
|
/// and integer division to reduce the flat index for the next higher dimension.
|
||||||
|
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
||||||
|
pub fn index_from_flat(&self, flat_index: usize) -> Idx<R> {
|
||||||
|
let mut indices = [0; R];
|
||||||
|
let mut remaining = flat_index;
|
||||||
|
|
||||||
|
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
|
||||||
|
*idx = remaining % dim_size;
|
||||||
|
remaining /= dim_size;
|
||||||
|
}
|
||||||
|
|
||||||
|
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||||
|
Idx::new(*self, indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub const fn index_zero(&self) -> Idx<R> {
|
||||||
|
Idx::zero(*self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn index_max(&self) -> Idx<R> {
|
||||||
|
let max_indices =
|
||||||
|
self.0
|
||||||
|
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||||
|
Idx::new(*self, max_indices)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_dims<const NAX: usize>(
|
||||||
|
&self,
|
||||||
|
dims_to_remove: [usize; NAX],
|
||||||
|
) -> Shape<{ R - NAX }> {
|
||||||
|
// Create a new array to store the remaining dimensions
|
||||||
|
let mut new_shape = [0; R - NAX];
|
||||||
|
let mut new_index = 0;
|
||||||
|
|
||||||
|
// Iterate over the original dimensions
|
||||||
|
for (index, &dim) in self.0.iter().enumerate() {
|
||||||
|
// Skip dimensions that are in the dims_to_remove array
|
||||||
|
if dims_to_remove.contains(&index) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add the dimension to the new shape array
|
||||||
|
new_shape[new_index] = dim;
|
||||||
|
new_index += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Shape(new_shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Serialize and Deserialize ----
|
||||||
|
|
||||||
|
struct ShapeVisitor<const R: usize>;
|
||||||
|
|
||||||
|
impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
|
||||||
|
type Value = Shape<R>;
|
||||||
|
|
||||||
|
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||||
|
formatter.write_str(concat!("an array of length ", "{R}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||||
|
where
|
||||||
|
A: SeqAccess<'de>,
|
||||||
|
{
|
||||||
|
let mut arr = [0; R];
|
||||||
|
for i in 0..R {
|
||||||
|
arr[i] = seq
|
||||||
|
.next_element()?
|
||||||
|
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
|
||||||
|
}
|
||||||
|
Ok(Shape(arr))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'de, const R: usize> Deserialize<'de> for Shape<R> {
|
||||||
|
fn deserialize<D>(deserializer: D) -> Result<Shape<R>, D::Error>
|
||||||
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
|
deserializer.deserialize_tuple(R, ShapeVisitor)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Serialize for Shape<R> {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
|
let mut seq = serializer.serialize_tuple(R)?;
|
||||||
|
for elem in &self.0 {
|
||||||
|
seq.serialize_element(elem)?;
|
||||||
|
}
|
||||||
|
seq.end()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Blanket Implementations ----
|
||||||
|
|
||||||
|
impl<const R: usize> From<[usize; R]> for Shape<R> {
|
||||||
|
fn from(shape: [usize; R]) -> Self {
|
||||||
|
Self::new(shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> PartialEq for Shape<R> {
|
||||||
|
fn eq(&self, other: &Self) -> bool {
|
||||||
|
self.0 == other.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Eq for Shape<R> {}
|
||||||
|
|
||||||
|
// ---- From and Into Implementations ----
|
||||||
|
|
||||||
|
impl<T, const R: usize> From<Tensor<T, R>> for Shape<R>
|
||||||
|
where
|
||||||
|
T: Value,
|
||||||
|
{
|
||||||
|
fn from(tensor: Tensor<T, R>) -> Self {
|
||||||
|
tensor.shape()
|
||||||
|
}
|
||||||
|
}
|
368
src/tensor.rs
Normal file
368
src/tensor.rs
Normal file
@ -0,0 +1,368 @@
|
|||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct Tensor<T, const R: usize> {
|
||||||
|
buffer: Vec<T>,
|
||||||
|
shape: Shape<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value> Tensor<T, 1> {
|
||||||
|
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
|
||||||
|
if self.shape != other.shape {
|
||||||
|
panic!("Shapes of tensors do not match");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result = T::zero();
|
||||||
|
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||||
|
result = result + (*a * *b);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
|
pub fn new(shape: Shape<R>) -> Self {
|
||||||
|
let total_size: usize = shape.iter().product();
|
||||||
|
let buffer = vec![T::zero(); total_size];
|
||||||
|
Self { buffer, shape }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> Shape<R> {
|
||||||
|
self.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buffer(&self) -> &[T] {
|
||||||
|
&self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn buffer_mut(&mut self) -> &mut [T] {
|
||||||
|
&mut self.buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get(&self, index: Idx<R>) -> &T {
|
||||||
|
&self.buffer[index.flat()]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
||||||
|
self.buffer.get_unchecked(index.flat())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_mut(&mut self, index: Idx<R>) -> Option<&mut T> {
|
||||||
|
self.buffer.get_mut(index.flat())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub unsafe fn get_unchecked_mut(&mut self, index: Idx<R>) -> &mut T {
|
||||||
|
self.buffer.get_unchecked_mut(index.flat())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn rank(&self) -> usize {
|
||||||
|
R
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn len(&self) -> usize {
|
||||||
|
self.buffer.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter(&self) -> TensorIterator<T, R> {
|
||||||
|
TensorIterator::new(self)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
|
||||||
|
if self.shape != other.shape {
|
||||||
|
panic!("Shapes of tensors do not match");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut result_buffer = Vec::with_capacity(self.buffer.len());
|
||||||
|
|
||||||
|
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||||
|
result_buffer.push(*a * *b);
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor {
|
||||||
|
buffer: result_buffer,
|
||||||
|
shape: self.shape,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn tensor_product<const S: usize>(
|
||||||
|
&self,
|
||||||
|
other: &Tensor<T, S>,
|
||||||
|
) -> Tensor<T, { R + S }> {
|
||||||
|
let mut new_shape_vec = Vec::new();
|
||||||
|
new_shape_vec.extend_from_slice(&self.shape.as_array());
|
||||||
|
new_shape_vec.extend_from_slice(&other.shape.as_array());
|
||||||
|
|
||||||
|
let new_shape_array: [usize; R + S] = new_shape_vec
|
||||||
|
.try_into()
|
||||||
|
.expect("Failed to create shape array");
|
||||||
|
|
||||||
|
let mut new_buffer =
|
||||||
|
Vec::with_capacity(self.buffer.len() * other.buffer.len());
|
||||||
|
for &item_self in &self.buffer {
|
||||||
|
for &item_other in &other.buffer {
|
||||||
|
new_buffer.push(item_self * item_other);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Tensor {
|
||||||
|
buffer: new_buffer,
|
||||||
|
shape: Shape::new(new_shape_array),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// `self_dims` and `other_dims` specify the dimensions to contract over
|
||||||
|
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
|
||||||
|
&self,
|
||||||
|
rhs: &Tensor<T, S>,
|
||||||
|
axis_lhs: [usize; NAXL],
|
||||||
|
axis_rhs: [usize; NAXR],
|
||||||
|
) -> Tensor<T, { R + S - NAXL - NAXR }>
|
||||||
|
where
|
||||||
|
[(); R - NAXL]:,
|
||||||
|
[(); S - NAXR]:,
|
||||||
|
{
|
||||||
|
// Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||||
|
for &axis in &axis_lhs {
|
||||||
|
if axis >= R {
|
||||||
|
panic!(
|
||||||
|
"Axis {} is out of bounds for the left-hand tensor",
|
||||||
|
axis
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for &axis in &axis_rhs {
|
||||||
|
if axis >= S {
|
||||||
|
panic!(
|
||||||
|
"Axis {} is out of bounds for the right-hand tensor",
|
||||||
|
axis
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||||
|
let mut result_buffer = Vec::new();
|
||||||
|
|
||||||
|
for i in 0..self.shape.size() {
|
||||||
|
for j in 0..rhs.shape.size() {
|
||||||
|
// Debug: Print indices being processed
|
||||||
|
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
|
||||||
|
|
||||||
|
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
|
||||||
|
let mut product_sum = T::zero();
|
||||||
|
|
||||||
|
// Debug: Print axes of contraction
|
||||||
|
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
|
||||||
|
|
||||||
|
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||||
|
// Debug: Print values being multiplied
|
||||||
|
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
|
||||||
|
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
|
||||||
|
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
|
||||||
|
|
||||||
|
product_sum = product_sum + value_lhs * value_rhs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug: Print the product sum for the current indices
|
||||||
|
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
|
||||||
|
|
||||||
|
result_buffer.push(product_sum);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 3: Remove contracted dimensions to create new shapes for both tensors
|
||||||
|
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
|
||||||
|
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||||
|
|
||||||
|
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
|
||||||
|
let mut new_shape = Vec::new();
|
||||||
|
|
||||||
|
new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||||
|
new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||||
|
|
||||||
|
let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||||
|
new_shape.try_into().expect("Failed to create shape array");
|
||||||
|
|
||||||
|
Tensor {
|
||||||
|
buffer: result_buffer,
|
||||||
|
shape: Shape::new(new_shape_array),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve an element based on a specific axis and index
|
||||||
|
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||||
|
// Convert axis and index to a flat index
|
||||||
|
let flat_index = self.axis_to_flat_index(axis, index);
|
||||||
|
if flat_index >= self.buffer.len() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(self.buffer[flat_index])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert axis and index to a flat index in the buffer
|
||||||
|
fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize {
|
||||||
|
let mut flat_index = 0;
|
||||||
|
let mut stride = 1;
|
||||||
|
|
||||||
|
// Ensure the given axis is within the tensor's dimensions
|
||||||
|
if axis >= R {
|
||||||
|
panic!("Axis out of bounds");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Calculate the stride for each dimension and accumulate the flat index
|
||||||
|
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
|
||||||
|
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
||||||
|
if i > axis {
|
||||||
|
stride *= dim_size;
|
||||||
|
} else if i == axis {
|
||||||
|
flat_index += index * stride;
|
||||||
|
break; // We've reached the target axis
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
flat_index
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Indexing ----
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> Index<Idx<R>> for Tensor<T, R> {
|
||||||
|
type Output = T;
|
||||||
|
|
||||||
|
fn index(&self, index: Idx<R>) -> &Self::Output {
|
||||||
|
&self.buffer[index.flat()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
|
||||||
|
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
|
||||||
|
&mut self.buffer[index.flat()]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||||
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
|
// Print the shape of the tensor
|
||||||
|
write!(f, "Shape: [")?;
|
||||||
|
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
|
||||||
|
write!(f, "{}", dim_size)?;
|
||||||
|
if i < R - 1 {
|
||||||
|
write!(f, ", ")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
write!(f, "], Elements: [")?;
|
||||||
|
|
||||||
|
// Print the elements in a flattened form
|
||||||
|
for (i, elem) in self.buffer.iter().enumerate() {
|
||||||
|
write!(f, "{}", elem)?;
|
||||||
|
if i < self.buffer.len() - 1 {
|
||||||
|
write!(f, ", ")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
write!(f, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Iterator ----
|
||||||
|
|
||||||
|
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||||
|
tensor: &'a Tensor<T, R>,
|
||||||
|
index: Idx<R>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||||
|
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||||
|
Self {
|
||||||
|
tensor,
|
||||||
|
index: tensor.shape.index_zero(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
|
||||||
|
type Item = &'a T;
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.index.is_overflow() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = unsafe { self.tensor.get_unchecked(self.index) };
|
||||||
|
self.index.inc();
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
|
||||||
|
type Item = &'a T;
|
||||||
|
type IntoIter = TensorIterator<'a, T, R>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
TensorIterator::new(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- Formatting ----
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
|
||||||
|
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||||
|
// Print the current index and flat index
|
||||||
|
write!(
|
||||||
|
f,
|
||||||
|
"Current Index: {}, Flat Index: {}",
|
||||||
|
self.index,
|
||||||
|
self.index.flat()
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Print the tensor elements, highlighting the current element
|
||||||
|
write!(f, ", Tensor Elements: [")?;
|
||||||
|
for (i, elem) in self.tensor.buffer().iter().enumerate() {
|
||||||
|
if i == self.index.flat() {
|
||||||
|
write!(f, "*{}*", elem)?; // Highlight the current element
|
||||||
|
} else {
|
||||||
|
write!(f, "{}", elem)?;
|
||||||
|
}
|
||||||
|
if i < self.tensor.buffer().len() - 1 {
|
||||||
|
write!(f, ", ")?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
write!(f, "]")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ---- From ----
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
||||||
|
fn from(shape: Shape<R>) -> Self {
|
||||||
|
Self::new(shape)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
|
||||||
|
fn from(shape: [usize; R]) -> Self {
|
||||||
|
let shape = Shape::new(shape);
|
||||||
|
Self::new(shape.into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
||||||
|
for Tensor<T, 2>
|
||||||
|
{
|
||||||
|
fn from(array: [[T; X]; Y]) -> Self {
|
||||||
|
let shape = Shape::new([Y, X]);
|
||||||
|
let mut tensor = Tensor::new(shape);
|
||||||
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
|
for (i, row) in array.iter().enumerate() {
|
||||||
|
for (j, &elem) in row.iter().enumerate() {
|
||||||
|
buffer[i * X + j] = elem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user