🚀 WIP implementation of N-rank Tensor
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
commit
b17264ba18
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
||||
/target
|
179
Cargo.lock
generated
Normal file
179
Cargo.lock
generated
Normal file
@ -0,0 +1,179 @@
|
||||
# This file is automatically @generated by Cargo.
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||
|
||||
[[package]]
|
||||
name = "mltensor"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"num",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b05180d69e3da0e530ba2a1dae5110317e49e3b7f3d41be227dc5f92e49ee7af"
|
||||
dependencies = [
|
||||
"num-bigint",
|
||||
"num-complex",
|
||||
"num-integer",
|
||||
"num-iter",
|
||||
"num-rational",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-bigint"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-complex"
|
||||
version = "0.4.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-integer"
|
||||
version = "0.1.45"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-iter"
|
||||
version = "0.1.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7d03e6c028c5dc5cac6e2dec0efda81fc887605bb3d884578bb6d6bf7514e252"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-rational"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0638a1c9d0a3c0914158145bc76cff373a75a627e6ecbfb71cbe6f453a5a19b0"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"num-bigint",
|
||||
"num-integer",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num-traits"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.71"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "75cb1540fadbd5b8fbccc4dddad2734eba435053f725621c070711a14bb5f4b8"
|
||||
dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "quote"
|
||||
version = "1.0.33"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5267fca4496028628a95160fc423a33e8b2e6af8a5302579e322e4b520293cae"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.193"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25dd9975e68d0cb5aa1120c288333fc98731bd1dd12f561e468ea4728c042b89"
|
||||
dependencies = [
|
||||
"serde_derive",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_derive"
|
||||
version = "1.0.193"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde_json"
|
||||
version = "1.0.108"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b"
|
||||
dependencies = [
|
||||
"itoa",
|
||||
"ryu",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.43"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ee659fb5f3d355364e1f3e5bc10fb82068efbf824a1e9d1c9504244a6469ad53"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
12
Cargo.toml
Normal file
12
Cargo.toml
Normal file
@ -0,0 +1,12 @@
|
||||
[package]
|
||||
name = "mltensor"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
|
||||
[dependencies]
|
||||
bytemuck = "1.14.0"
|
||||
num = "0.4.1"
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Julius Koskela
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
22
README.md
Normal file
22
README.md
Normal file
@ -0,0 +1,22 @@
|
||||
# Tensorc
|
||||
|
||||
```rust
|
||||
// Create two tensors with different ranks and shapes
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||
|
||||
// Calculate tensor product
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
|
||||
println!("T1 * T2 = {}", product);
|
||||
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
assert_eq!(product.buffer(), &[5, 6, 10, 12, 15, 18, 20, 24]);
|
||||
```
|
53
examples/operations.rs
Normal file
53
examples/operations.rs
Normal file
@ -0,0 +1,53 @@
|
||||
use mltensor::*;
|
||||
|
||||
fn tensor_product() {
|
||||
println!("Tensor Product\n");
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
|
||||
println!("T1 * T2 = {}", product);
|
||||
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||
assert_eq!(product.buffer(), &expected_buffer);
|
||||
}
|
||||
|
||||
fn tensor_contraction() {
|
||||
println!("Tensor Contraction\n");
|
||||
// Create two tensors
|
||||
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
||||
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
||||
|
||||
// Specify axes for contraction
|
||||
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
||||
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
||||
|
||||
// Perform contraction
|
||||
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
println!("T1 * T2 = {}", result);
|
||||
|
||||
// Expected result, for example, could be a single number or a new tensor,
|
||||
// depending on how you defined the contraction operation.
|
||||
// Assert the result is as expected
|
||||
// assert_eq!(result, expected_result);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
tensor_product();
|
||||
tensor_contraction();
|
||||
}
|
2
rust-toolchain.toml
Normal file
2
rust-toolchain.toml
Normal file
@ -0,0 +1,2 @@
|
||||
[toolchain]
|
||||
channel = "nightly"
|
1
rustfmt.toml
Normal file
1
rustfmt.toml
Normal file
@ -0,0 +1 @@
|
||||
max_width = 80
|
251
src/index.rs
Normal file
251
src/index.rs
Normal file
@ -0,0 +1,251 @@
|
||||
use super::*;
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::{Add, Sub};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Idx<const R: usize> {
|
||||
indices: [usize; R],
|
||||
shape: Shape<R>,
|
||||
}
|
||||
|
||||
impl<const R: usize> Idx<R> {
|
||||
pub const fn zero(shape: Shape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max(shape: Shape<R>) -> Self {
|
||||
let max_indices =
|
||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||
Self {
|
||||
indices: max_indices,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(shape: Shape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
self.indices.iter().all(|&i| i == 0)
|
||||
}
|
||||
|
||||
pub fn is_overflow(&self) -> bool {
|
||||
// Check if the last index is equal to the size of the last dimension
|
||||
self.indices[0] >= self.shape.get(R - 1)
|
||||
}
|
||||
|
||||
pub fn reset(&mut self) {
|
||||
self.indices = [0; R];
|
||||
}
|
||||
|
||||
/// Increments the index and returns a boolean indicating whether the end has been reached.
|
||||
///
|
||||
/// # Returns
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
/// `false` if it overflows, indicating the end of the tensor.
|
||||
pub fn inc(&mut self) -> bool {
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
{
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
if *i >= dim_size {
|
||||
*i = 0; // Reset index in this dimension and carry over
|
||||
} else {
|
||||
carry = 0; // Increment successful, no carry needed
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||
if carry == 1 {
|
||||
// Set the index to an invalid state (e.g., all indices to their max values)
|
||||
self.indices[0] = self.shape.as_array()[0];
|
||||
return true; // Indicate that the iteration is complete
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut borrow = true;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
{
|
||||
if borrow {
|
||||
if *i == 0 {
|
||||
*i = dim_size - 1; // Wrap around to the maximum index of this dimension
|
||||
} else {
|
||||
*i -= 1; // Decrement the index
|
||||
borrow = false; // No more borrowing needed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||
/// stored in `self.indices`, given the shape of the tensor stored in `self.shape`.
|
||||
/// The calculation is based on the assumption that the tensor is stored in row-major order,
|
||||
/// where the last dimension varies the fastest.
|
||||
///
|
||||
/// # Returns
|
||||
/// The flat index corresponding to the multi-dimensional index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over each pair of corresponding index and dimension size,
|
||||
/// starting from the last dimension (hence `rev()` is used for reverse iteration).
|
||||
/// - In each iteration, it performs two main operations:
|
||||
/// 1. **Index Contribution**: Multiplies the current index (`idx`) by a running product
|
||||
/// of dimension sizes (`product`). This calculates the contribution of the current index
|
||||
/// to the overall flat index.
|
||||
/// 2. **Product Update**: Multiplies `product` by the current dimension size (`dim_size`).
|
||||
/// This updates `product` for the next iteration, as each dimension contributes to the
|
||||
/// flat index based on the sizes of all preceding dimensions.
|
||||
/// - The `fold` operation accumulates these results, starting with an initial state of
|
||||
/// `(0, 1)` where `0` is the initial flat index and `1` is the initial product.
|
||||
/// - The final flat index is obtained after the last iteration, which is the first element
|
||||
/// of the tuple resulting from the `fold`.
|
||||
///
|
||||
/// # Example
|
||||
/// Consider a tensor with shape `[3, 4, 5]` and an index `[1, 2, 3]`.
|
||||
/// - Starting with a flat index of 0 and a product of 1,
|
||||
/// - For the last dimension (size 5), add 3 * 1 to the flat index. Update the product to 1 * 5 = 5.
|
||||
/// - For the second dimension (size 4), add 2 * 5 to the flat index. Update the product to 5 * 4 = 20.
|
||||
/// - For the first dimension (size 3), add 1 * 20 to the flat index. The final flat index is 3 + 10 + 20 = 33.
|
||||
pub fn flat(&self) -> usize {
|
||||
self.indices
|
||||
.iter()
|
||||
.zip(&self.shape.as_array())
|
||||
.rev()
|
||||
.fold((0, 1), |(flat_index, product), (&idx, &dim_size)| {
|
||||
(flat_index + idx * product, product * dim_size)
|
||||
})
|
||||
.0
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
|
||||
impl<const R: usize> PartialEq for Idx<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.flat() == other.flat()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for Idx<R> {}
|
||||
|
||||
impl<const R: usize> PartialOrd for Idx<R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.flat().partial_cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Ord for Idx<R> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.flat().cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Index<usize> for Idx<R> {
|
||||
type Output = usize;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.indices[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> IndexMut<usize> for Idx<R> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.indices[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(Shape<R>, [usize; R])> for Idx<R> {
|
||||
fn from((shape, indices): (Shape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(Shape<R>, usize)> for Idx<R> {
|
||||
fn from((shape, flat_index): (Shape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<Shape<R>> for Idx<R> {
|
||||
fn from(shape: Shape<R>) -> Self {
|
||||
Self::zero(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
for (i, (&idx, &dim_size)) in self
|
||||
.indices
|
||||
.iter()
|
||||
.zip(self.shape.as_array().iter())
|
||||
.enumerate()
|
||||
{
|
||||
write!(f, "{}/{}", idx, dim_size - 1)?;
|
||||
if i < self.indices.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Arithmetic Operations ----
|
||||
|
||||
impl<const R: usize> Add for Idx<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
||||
|
||||
let mut result_indices = [0; R];
|
||||
for i in 0..R {
|
||||
result_indices[i] = self.indices[i] + rhs.indices[i];
|
||||
}
|
||||
|
||||
Self {
|
||||
indices: result_indices,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Sub for Idx<R> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
assert_eq!(self.shape, rhs.shape, "Shape mismatch");
|
||||
|
||||
let mut result_indices = [0; R];
|
||||
for i in 0..R {
|
||||
result_indices[i] = self.indices[i].saturating_sub(rhs.indices[i]);
|
||||
}
|
||||
|
||||
Self {
|
||||
indices: result_indices,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
}
|
212
src/lib.rs
Normal file
212
src/lib.rs
Normal file
@ -0,0 +1,212 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
pub mod index;
|
||||
pub mod shape;
|
||||
pub mod tensor;
|
||||
|
||||
pub use index::Idx;
|
||||
use num::{Num, One, Zero};
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use shape::Shape;
|
||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use std::ops::{Index, IndexMut};
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> Value for T where
|
||||
T: Num
|
||||
+ Zero
|
||||
+ One
|
||||
+ Copy
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use serde_json;
|
||||
|
||||
#[test]
|
||||
fn test_tensor_product() {
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||
assert_eq!(product.buffer(), &expected_buffer);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serde_shape_serialization_test() {
|
||||
// Create a shape instance
|
||||
let shape: Shape<3> = [1, 2, 3].into();
|
||||
|
||||
// Serialize the shape to a JSON string
|
||||
let serialized =
|
||||
serde_json::to_string(&shape).expect("Failed to serialize");
|
||||
|
||||
// Deserialize the JSON string back into a shape
|
||||
let deserialized: Shape<3> =
|
||||
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
||||
|
||||
// Check that the deserialized shape is equal to the original
|
||||
assert_eq!(shape, deserialized);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tensor_serde_serialization_test() {
|
||||
// Create an instance of Tensor
|
||||
let tensor: Tensor<i32, 2> = Tensor::new(Shape::new([2, 2]));
|
||||
|
||||
// Serialize the Tensor to a JSON string
|
||||
let serialized =
|
||||
serde_json::to_string(&tensor).expect("Failed to serialize");
|
||||
|
||||
// Deserialize the JSON string back into a Tensor
|
||||
let deserialized: Tensor<i32, 2> =
|
||||
serde_json::from_str(&serialized).expect("Failed to deserialize");
|
||||
|
||||
// Check that the deserialized Tensor is equal to the original
|
||||
assert_eq!(tensor.buffer(), deserialized.buffer());
|
||||
assert_eq!(tensor.shape(), deserialized.shape());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iterate_3d_tensor() {
|
||||
let shape = Shape::new([2, 2, 2]); // 3D tensor with shape 2x2x2
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let mut num = 0;
|
||||
|
||||
// Fill the tensor with sequential numbers
|
||||
for i in 0..2 {
|
||||
for j in 0..2 {
|
||||
for k in 0..2 {
|
||||
tensor.buffer_mut()[i * 4 + j * 2 + k] = num;
|
||||
num += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
println!("{}", tensor);
|
||||
|
||||
// Iterate over the tensor and check that the numbers are correct
|
||||
|
||||
let mut iter = TensorIterator::new(&tensor);
|
||||
|
||||
println!("{}", iter);
|
||||
|
||||
assert_eq!(iter.next(), Some(&0));
|
||||
|
||||
assert_eq!(iter.next(), Some(&1));
|
||||
assert_eq!(iter.next(), Some(&2));
|
||||
assert_eq!(iter.next(), Some(&3));
|
||||
assert_eq!(iter.next(), Some(&4));
|
||||
assert_eq!(iter.next(), Some(&5));
|
||||
assert_eq!(iter.next(), Some(&6));
|
||||
assert_eq!(iter.next(), Some(&7));
|
||||
assert_eq!(iter.next(), None);
|
||||
assert_eq!(iter.next(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn iterate_rank_4_tensor() {
|
||||
// Define the shape of the rank-4 tensor (e.g., 2x2x2x2)
|
||||
let shape = Shape::new([2, 2, 2, 2]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let mut num = 0;
|
||||
|
||||
// Fill the tensor with sequential numbers
|
||||
for i in 0..tensor.len() {
|
||||
tensor.buffer_mut()[i] = num;
|
||||
num += 1;
|
||||
}
|
||||
|
||||
// Iterate over the tensor and check that the numbers are correct
|
||||
let mut iter = TensorIterator::new(&tensor);
|
||||
for expected_value in 0..tensor.len() {
|
||||
assert_eq!(*iter.next().unwrap(), expected_value);
|
||||
}
|
||||
|
||||
// Ensure the iterator is exhausted
|
||||
assert!(iter.next().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_dec_method() {
|
||||
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
|
||||
let mut index = Idx::zero(shape);
|
||||
|
||||
// Increment the index to the maximum
|
||||
for _ in 0..26 {
|
||||
// 3 * 3 * 3 - 1 = 26 increments to reach the end
|
||||
index.inc();
|
||||
}
|
||||
|
||||
// Check if the index is at the maximum
|
||||
assert_eq!(index, Idx::new(shape, [2, 2, 2]));
|
||||
|
||||
// Decrement step by step and check the index
|
||||
let expected_indices = [
|
||||
[2, 2, 2],
|
||||
[2, 2, 1],
|
||||
[2, 2, 0],
|
||||
[2, 1, 2],
|
||||
[2, 1, 1],
|
||||
[2, 1, 0],
|
||||
[2, 0, 2],
|
||||
[2, 0, 1],
|
||||
[2, 0, 0],
|
||||
[1, 2, 2],
|
||||
[1, 2, 1],
|
||||
[1, 2, 0],
|
||||
[1, 1, 2],
|
||||
[1, 1, 1],
|
||||
[1, 1, 0],
|
||||
[1, 0, 2],
|
||||
[1, 0, 1],
|
||||
[1, 0, 0],
|
||||
[0, 2, 2],
|
||||
[0, 2, 1],
|
||||
[0, 2, 0],
|
||||
[0, 1, 2],
|
||||
[0, 1, 1],
|
||||
[0, 1, 0],
|
||||
[0, 0, 2],
|
||||
[0, 0, 1],
|
||||
[0, 0, 0],
|
||||
];
|
||||
|
||||
for (i, &expected) in expected_indices.iter().enumerate() {
|
||||
assert_eq!(
|
||||
index,
|
||||
Idx::new(shape, expected),
|
||||
"Failed at index {}",
|
||||
i
|
||||
);
|
||||
index.dec();
|
||||
}
|
||||
|
||||
// Finally, the index should reach [0, 0, 0]
|
||||
index.dec();
|
||||
assert_eq!(index, Idx::zero(shape));
|
||||
}
|
||||
}
|
178
src/shape.rs
Normal file
178
src/shape.rs
Normal file
@ -0,0 +1,178 @@
|
||||
use super::*;
|
||||
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Shape<const R: usize>([usize; R]);
|
||||
|
||||
impl<const R: usize> Shape<R> {
|
||||
pub const fn new(shape: [usize; R]) -> Self {
|
||||
Self(shape)
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> [usize; R] {
|
||||
self.0
|
||||
}
|
||||
|
||||
pub const fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
||||
pub fn flat_max(&self) -> usize {
|
||||
self.size() - 1
|
||||
}
|
||||
|
||||
pub fn size(&self) -> usize {
|
||||
self.0.iter().product()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> impl Iterator<Item = &usize> {
|
||||
self.0.iter()
|
||||
}
|
||||
|
||||
pub const fn get(&self, index: usize) -> usize {
|
||||
self.0[index]
|
||||
}
|
||||
|
||||
pub fn check_indices(&self, indices: [usize; R]) -> bool {
|
||||
indices
|
||||
.iter()
|
||||
.zip(self.0.iter())
|
||||
.all(|(&idx, &dim_size)| idx < dim_size)
|
||||
}
|
||||
|
||||
/// Converts a flat index to a multi-dimensional index.
|
||||
///
|
||||
/// # Arguments
|
||||
/// * `flat_index` - The flat index to convert.
|
||||
///
|
||||
/// # Returns
|
||||
/// An `Idx<R>` instance representing the multi-dimensional index corresponding to the flat index.
|
||||
///
|
||||
/// # How It Works
|
||||
/// - The method iterates over the dimensions of the tensor in reverse order (assuming row-major order).
|
||||
/// - In each iteration, it uses the modulo operation to find the index in the current dimension
|
||||
/// and integer division to reduce the flat index for the next higher dimension.
|
||||
/// - This process is repeated for each dimension to build the multi-dimensional index.
|
||||
pub fn index_from_flat(&self, flat_index: usize) -> Idx<R> {
|
||||
let mut indices = [0; R];
|
||||
let mut remaining = flat_index;
|
||||
|
||||
for (idx, &dim_size) in indices.iter_mut().zip(self.0.iter()).rev() {
|
||||
*idx = remaining % dim_size;
|
||||
remaining /= dim_size;
|
||||
}
|
||||
|
||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||
Idx::new(*self, indices)
|
||||
}
|
||||
|
||||
pub const fn index_zero(&self) -> Idx<R> {
|
||||
Idx::zero(*self)
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> Idx<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
Idx::new(*self, max_indices)
|
||||
}
|
||||
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
&self,
|
||||
dims_to_remove: [usize; NAX],
|
||||
) -> Shape<{ R - NAX }> {
|
||||
// Create a new array to store the remaining dimensions
|
||||
let mut new_shape = [0; R - NAX];
|
||||
let mut new_index = 0;
|
||||
|
||||
// Iterate over the original dimensions
|
||||
for (index, &dim) in self.0.iter().enumerate() {
|
||||
// Skip dimensions that are in the dims_to_remove array
|
||||
if dims_to_remove.contains(&index) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Add the dimension to the new shape array
|
||||
new_shape[new_index] = dim;
|
||||
new_index += 1;
|
||||
}
|
||||
|
||||
Shape(new_shape)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Serialize and Deserialize ----
|
||||
|
||||
struct ShapeVisitor<const R: usize>;
|
||||
|
||||
impl<'de, const R: usize> Visitor<'de> for ShapeVisitor<R> {
|
||||
type Value = Shape<R>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
formatter.write_str(concat!("an array of length ", "{R}"))
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
let mut arr = [0; R];
|
||||
for i in 0..R {
|
||||
arr[i] = seq
|
||||
.next_element()?
|
||||
.ok_or_else(|| de::Error::invalid_length(i, &self))?;
|
||||
}
|
||||
Ok(Shape(arr))
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, const R: usize> Deserialize<'de> for Shape<R> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Shape<R>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_tuple(R, ShapeVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Serialize for Shape<R> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
let mut seq = serializer.serialize_tuple(R)?;
|
||||
for elem in &self.0 {
|
||||
seq.serialize_element(elem)?;
|
||||
}
|
||||
seq.end()
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Blanket Implementations ----
|
||||
|
||||
impl<const R: usize> From<[usize; R]> for Shape<R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
Self::new(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> PartialEq for Shape<R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.0 == other.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for Shape<R> {}
|
||||
|
||||
// ---- From and Into Implementations ----
|
||||
|
||||
impl<T, const R: usize> From<Tensor<T, R>> for Shape<R>
|
||||
where
|
||||
T: Value,
|
||||
{
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
tensor.shape()
|
||||
}
|
||||
}
|
368
src/tensor.rs
Normal file
368
src/tensor.rs
Normal file
@ -0,0 +1,368 @@
|
||||
use super::*;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct Tensor<T, const R: usize> {
|
||||
buffer: Vec<T>,
|
||||
shape: Shape<R>,
|
||||
}
|
||||
|
||||
impl<T: Value> Tensor<T, 1> {
|
||||
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
|
||||
if self.shape != other.shape {
|
||||
panic!("Shapes of tensors do not match");
|
||||
}
|
||||
|
||||
let mut result = T::zero();
|
||||
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||
result = result + (*a * *b);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
let total_size: usize = shape.iter().product();
|
||||
let buffer = vec![T::zero(); total_size];
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Shape<R> {
|
||||
self.shape
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &[T] {
|
||||
&self.buffer
|
||||
}
|
||||
|
||||
pub fn buffer_mut(&mut self) -> &mut [T] {
|
||||
&mut self.buffer
|
||||
}
|
||||
|
||||
pub fn get(&self, index: Idx<R>) -> &T {
|
||||
&self.buffer[index.flat()]
|
||||
}
|
||||
|
||||
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
||||
self.buffer.get_unchecked(index.flat())
|
||||
}
|
||||
|
||||
pub fn get_mut(&mut self, index: Idx<R>) -> Option<&mut T> {
|
||||
self.buffer.get_mut(index.flat())
|
||||
}
|
||||
|
||||
pub unsafe fn get_unchecked_mut(&mut self, index: Idx<R>) -> &mut T {
|
||||
self.buffer.get_unchecked_mut(index.flat())
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.buffer.len()
|
||||
}
|
||||
|
||||
pub fn iter(&self) -> TensorIterator<T, R> {
|
||||
TensorIterator::new(self)
|
||||
}
|
||||
|
||||
pub fn elementwise_multiply(&self, other: &Tensor<T, R>) -> Tensor<T, R> {
|
||||
if self.shape != other.shape {
|
||||
panic!("Shapes of tensors do not match");
|
||||
}
|
||||
|
||||
let mut result_buffer = Vec::with_capacity(self.buffer.len());
|
||||
|
||||
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||
result_buffer.push(*a * *b);
|
||||
}
|
||||
|
||||
Tensor {
|
||||
buffer: result_buffer,
|
||||
shape: self.shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tensor_product<const S: usize>(
|
||||
&self,
|
||||
other: &Tensor<T, S>,
|
||||
) -> Tensor<T, { R + S }> {
|
||||
let mut new_shape_vec = Vec::new();
|
||||
new_shape_vec.extend_from_slice(&self.shape.as_array());
|
||||
new_shape_vec.extend_from_slice(&other.shape.as_array());
|
||||
|
||||
let new_shape_array: [usize; R + S] = new_shape_vec
|
||||
.try_into()
|
||||
.expect("Failed to create shape array");
|
||||
|
||||
let mut new_buffer =
|
||||
Vec::with_capacity(self.buffer.len() * other.buffer.len());
|
||||
for &item_self in &self.buffer {
|
||||
for &item_other in &other.buffer {
|
||||
new_buffer.push(item_self * item_other);
|
||||
}
|
||||
}
|
||||
|
||||
Tensor {
|
||||
buffer: new_buffer,
|
||||
shape: Shape::new(new_shape_array),
|
||||
}
|
||||
}
|
||||
|
||||
// `self_dims` and `other_dims` specify the dimensions to contract over
|
||||
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
|
||||
&self,
|
||||
rhs: &Tensor<T, S>,
|
||||
axis_lhs: [usize; NAXL],
|
||||
axis_rhs: [usize; NAXR],
|
||||
) -> Tensor<T, { R + S - NAXL - NAXR }>
|
||||
where
|
||||
[(); R - NAXL]:,
|
||||
[(); S - NAXR]:,
|
||||
{
|
||||
// Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||
for &axis in &axis_lhs {
|
||||
if axis >= R {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the left-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for &axis in &axis_rhs {
|
||||
if axis >= S {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the right-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||
let mut result_buffer = Vec::new();
|
||||
|
||||
for i in 0..self.shape.size() {
|
||||
for j in 0..rhs.shape.size() {
|
||||
// Debug: Print indices being processed
|
||||
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
|
||||
|
||||
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
|
||||
let mut product_sum = T::zero();
|
||||
|
||||
// Debug: Print axes of contraction
|
||||
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
|
||||
|
||||
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||
// Debug: Print values being multiplied
|
||||
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
|
||||
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
|
||||
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
|
||||
|
||||
product_sum = product_sum + value_lhs * value_rhs;
|
||||
}
|
||||
|
||||
// Debug: Print the product sum for the current indices
|
||||
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
|
||||
|
||||
result_buffer.push(product_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Remove contracted dimensions to create new shapes for both tensors
|
||||
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
|
||||
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||
|
||||
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
|
||||
let mut new_shape = Vec::new();
|
||||
|
||||
new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||
new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||
|
||||
let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||
new_shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
Tensor {
|
||||
buffer: result_buffer,
|
||||
shape: Shape::new(new_shape_array),
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve an element based on a specific axis and index
|
||||
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||
// Convert axis and index to a flat index
|
||||
let flat_index = self.axis_to_flat_index(axis, index);
|
||||
if flat_index >= self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.buffer[flat_index])
|
||||
}
|
||||
|
||||
// Convert axis and index to a flat index in the buffer
|
||||
fn axis_to_flat_index(&self, axis: usize, index: usize) -> usize {
|
||||
let mut flat_index = 0;
|
||||
let mut stride = 1;
|
||||
|
||||
// Ensure the given axis is within the tensor's dimensions
|
||||
if axis >= R {
|
||||
panic!("Axis out of bounds");
|
||||
}
|
||||
|
||||
// Calculate the stride for each dimension and accumulate the flat index
|
||||
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
|
||||
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
||||
if i > axis {
|
||||
stride *= dim_size;
|
||||
} else if i == axis {
|
||||
flat_index += index * stride;
|
||||
break; // We've reached the target axis
|
||||
}
|
||||
}
|
||||
|
||||
flat_index
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Indexing ----
|
||||
|
||||
impl<T: Value, const R: usize> Index<Idx<R>> for Tensor<T, R> {
|
||||
type Output = T;
|
||||
|
||||
fn index(&self, index: Idx<R>) -> &Self::Output {
|
||||
&self.buffer[index.flat()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
|
||||
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
|
||||
&mut self.buffer[index.flat()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Print the shape of the tensor
|
||||
write!(f, "Shape: [")?;
|
||||
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
|
||||
write!(f, "{}", dim_size)?;
|
||||
if i < R - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "], Elements: [")?;
|
||||
|
||||
// Print the elements in a flattened form
|
||||
for (i, elem) in self.buffer.iter().enumerate() {
|
||||
write!(f, "{}", elem)?;
|
||||
if i < self.buffer.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Iterator ----
|
||||
|
||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
index: Idx<R>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||
pub const fn new(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self {
|
||||
tensor,
|
||||
index: tensor.shape.index_zero(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for TensorIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.index.is_overflow() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = unsafe { self.tensor.get_unchecked(self.index) };
|
||||
self.index.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a Tensor<T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = TensorIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
TensorIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Formatting ----
|
||||
|
||||
impl<'a, T: Value, const R: usize> Display for TensorIterator<'a, T, R> {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
// Print the current index and flat index
|
||||
write!(
|
||||
f,
|
||||
"Current Index: {}, Flat Index: {}",
|
||||
self.index,
|
||||
self.index.flat()
|
||||
)?;
|
||||
|
||||
// Print the tensor elements, highlighting the current element
|
||||
write!(f, ", Tensor Elements: [")?;
|
||||
for (i, elem) in self.tensor.buffer().iter().enumerate() {
|
||||
if i == self.index.flat() {
|
||||
write!(f, "*{}*", elem)?; // Highlight the current element
|
||||
} else {
|
||||
write!(f, "{}", elem)?;
|
||||
}
|
||||
if i < self.tensor.buffer().len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
|
||||
// ---- From ----
|
||||
|
||||
impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
||||
fn from(shape: Shape<R>) -> Self {
|
||||
Self::new(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
let shape = Shape::new(shape);
|
||||
Self::new(shape.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
||||
for Tensor<T, 2>
|
||||
{
|
||||
fn from(array: [[T; X]; Y]) -> Self {
|
||||
let shape = Shape::new([Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, row) in array.iter().enumerate() {
|
||||
for (j, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X + j] = elem;
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user