Work on tensor contraction and axis iteration
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
b17264ba18
commit
d54179cd5d
99
Cargo.lock
generated
99
Cargo.lock
generated
@ -14,22 +14,39 @@ version = "1.14.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "either"
|
||||||
|
version = "1.9.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "getset"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.12.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.10"
|
version = "1.0.10"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "mltensor"
|
|
||||||
version = "0.1.0"
|
|
||||||
dependencies = [
|
|
||||||
"bytemuck",
|
|
||||||
"num",
|
|
||||||
"serde",
|
|
||||||
"serde_json",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "num"
|
name = "num"
|
||||||
version = "0.4.1"
|
version = "0.4.1"
|
||||||
@ -106,6 +123,30 @@ dependencies = [
|
|||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro-error-attr",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 1.0.109",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "proc-macro-error-attr"
|
||||||
|
version = "1.0.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"version_check",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "proc-macro2"
|
name = "proc-macro2"
|
||||||
version = "1.0.71"
|
version = "1.0.71"
|
||||||
@ -147,7 +188,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"syn",
|
"syn 2.0.43",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
@ -161,6 +202,23 @@ dependencies = [
|
|||||||
"serde",
|
"serde",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "static_assertions"
|
||||||
|
version = "1.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "syn"
|
||||||
|
version = "1.0.109"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"unicode-ident",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "syn"
|
name = "syn"
|
||||||
version = "2.0.43"
|
version = "2.0.43"
|
||||||
@ -172,8 +230,27 @@ dependencies = [
|
|||||||
"unicode-ident",
|
"unicode-ident",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tensorc"
|
||||||
|
version = "0.1.0"
|
||||||
|
dependencies = [
|
||||||
|
"bytemuck",
|
||||||
|
"getset",
|
||||||
|
"itertools",
|
||||||
|
"num",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
"static_assertions",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.12"
|
version = "1.0.12"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "version_check"
|
||||||
|
version = "0.9.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
[package]
|
[package]
|
||||||
name = "mltensor"
|
name = "tensorc"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
@ -7,6 +7,9 @@ edition = "2021"
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
bytemuck = "1.14.0"
|
bytemuck = "1.14.0"
|
||||||
|
getset = "0.1.2"
|
||||||
|
itertools = "0.12.0"
|
||||||
num = "0.4.1"
|
num = "0.4.1"
|
||||||
serde = { version = "1.0.193", features = ["derive"] }
|
serde = { version = "1.0.193", features = ["derive"] }
|
||||||
serde_json = "1.0.108"
|
serde_json = "1.0.108"
|
||||||
|
static_assertions = "1.1.0"
|
||||||
|
@ -1,53 +1,66 @@
|
|||||||
use mltensor::*;
|
#![allow(mixed_script_confusables)]
|
||||||
|
#![allow(non_snake_case)]
|
||||||
|
use bytemuck::cast_slice;
|
||||||
|
use tensorc::*;
|
||||||
|
|
||||||
fn tensor_product() {
|
fn tensor_product() {
|
||||||
println!("Tensor Product\n");
|
println!("Tensor Product\n");
|
||||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||||
|
|
||||||
// Fill tensors with some values
|
// Fill tensors with some values
|
||||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||||
|
|
||||||
println!("T1: {}", tensor1);
|
println!("T1: {}", tensor1);
|
||||||
println!("T2: {}", tensor2);
|
println!("T2: {}", tensor2);
|
||||||
|
|
||||||
let product = tensor1.tensor_product(&tensor2);
|
let product = tensor1.tensor_product(&tensor2);
|
||||||
|
|
||||||
println!("T1 * T2 = {}", product);
|
println!("T1 * T2 = {}", product);
|
||||||
|
|
||||||
// Check shape of the resulting tensor
|
// Check shape of the resulting tensor
|
||||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||||
|
|
||||||
// Check buffer of the resulting tensor
|
// Check buffer of the resulting tensor
|
||||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
let expect: &[i32] =
|
||||||
assert_eq!(product.buffer(), &expected_buffer);
|
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
|
||||||
|
assert_eq!(product.buffer(), expect);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn tensor_contraction() {
|
fn tensor_contraction() {
|
||||||
println!("Tensor Contraction\n");
|
println!("Tensor Contraction\n");
|
||||||
// Create two tensors
|
// Create two tensors
|
||||||
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
||||||
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
||||||
|
|
||||||
// Specify axes for contraction
|
// Specify axes for contraction
|
||||||
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
||||||
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
||||||
|
|
||||||
// Perform contraction
|
// Perform contraction
|
||||||
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
// let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||||
|
|
||||||
println!("T1: {}", tensor1);
|
// println!("T1: {}", tensor1);
|
||||||
println!("T2: {}", tensor2);
|
// println!("T2: {}", tensor2);
|
||||||
println!("T1 * T2 = {}", result);
|
// println!("T1 * T2 = {}", result);
|
||||||
|
|
||||||
// Expected result, for example, could be a single number or a new tensor,
|
// Expected result, for example, could be a single number or a new tensor,
|
||||||
// depending on how you defined the contraction operation.
|
// depending on how you defined the contraction operation.
|
||||||
// Assert the result is as expected
|
// Assert the result is as expected
|
||||||
// assert_eq!(result, expected_result);
|
// assert_eq!(result, expected_result);
|
||||||
|
|
||||||
|
// let Λ = Tensor::<f64, 2>::from([
|
||||||
|
// [1.0, 0.0, 0.0, 0.0],
|
||||||
|
// [0.0, 1.0, 0.0 ,0.0],
|
||||||
|
// [0.0, 0.0, 1.0, 0.0],
|
||||||
|
// [0.0, 0.0, 0.0, 1.0]
|
||||||
|
// ]);
|
||||||
|
|
||||||
|
// println!("Λ: {}", Λ);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
fn main() {
|
||||||
tensor_product();
|
tensor_product();
|
||||||
tensor_contraction();
|
tensor_contraction();
|
||||||
}
|
}
|
581
src/axis.rs
Normal file
581
src/axis.rs
Normal file
@ -0,0 +1,581 @@
|
|||||||
|
use super::*;
|
||||||
|
use getset::{Getters, MutGetters};
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Getters)]
|
||||||
|
pub struct Axis<'a, T: Value, const R: usize> {
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
tensor: &'a Tensor<T, R>,
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
shape: Shape<R>,
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
dim: usize,
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
len: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
||||||
|
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
||||||
|
assert!(dim < R, "Axis out of bounds");
|
||||||
|
Self {
|
||||||
|
tensor,
|
||||||
|
shape: tensor.shape(),
|
||||||
|
dim,
|
||||||
|
len: tensor.shape().get(dim),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn iter_level(&self, level: usize) -> AxisIterator<'a, T, R> {
|
||||||
|
assert!(level < self.len, "Level out of bounds");
|
||||||
|
let mut index = Idx::new(self.shape.clone(), [0; R]);
|
||||||
|
index.set_axis(self.dim, level);
|
||||||
|
AxisIterator::new(self.clone()).set_start(level).set_end(level + 1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug, Getters, MutGetters)]
|
||||||
|
pub struct AxisIterator<'a, T: Value, const R: usize> {
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
axis: Axis<'a, T, R>,
|
||||||
|
#[getset(get = "pub", get_mut = "pub")]
|
||||||
|
index: Idx<R>,
|
||||||
|
#[getset(get = "pub")]
|
||||||
|
end: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
||||||
|
pub fn new(axis: Axis<'a, T, R>) -> Self {
|
||||||
|
Self {
|
||||||
|
axis: axis.clone(),
|
||||||
|
index: Idx::new(axis.shape().clone(), [0; R]),
|
||||||
|
end: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_start(self, start: usize) -> Self {
|
||||||
|
assert!(start < self.axis.len, "Start out of bounds");
|
||||||
|
let mut index = Idx::new(self.axis.shape().clone(), [0; R]);
|
||||||
|
index.set_axis(self.axis.dim, start);
|
||||||
|
Self {
|
||||||
|
axis: self.axis.clone(),
|
||||||
|
index,
|
||||||
|
end: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_end(self, end: usize) -> Self {
|
||||||
|
assert!(end <= self.axis.len, "End out of bounds");
|
||||||
|
Self {
|
||||||
|
axis: self.axis.clone(),
|
||||||
|
index: self.index.clone(),
|
||||||
|
end: Some(end),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn set_level(self, level: usize) -> Self {
|
||||||
|
assert!(level < self.axis.len, "Level out of bounds");
|
||||||
|
self.set_start(level).set_end(level + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn disassemble(self) -> Vec<Self> {
|
||||||
|
let mut result = Vec::new();
|
||||||
|
for i in 0..self.axis().len {
|
||||||
|
result.push(Self::new(self.axis.clone()).set_level(i));
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn axis_max_idx(&self) -> usize {
|
||||||
|
self.end().unwrap_or(*self.axis().len())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn axis_idx(&self) -> usize {
|
||||||
|
self.index().get_axis(*self.axis().dim())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn axis_dim(&self) -> usize {
|
||||||
|
self.axis().dim().clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
||||||
|
type Item = &'a T;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.axis_idx() == self.axis_max_idx() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
// Increment the index along the fixed axis and check if it's within bounds
|
||||||
|
let result = self.axis().tensor().get(self.index);
|
||||||
|
let axis_dim = self.axis_dim();
|
||||||
|
self.index_mut().inc_axis(axis_dim);
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a, T: Value, const R: usize> IntoIterator for Axis<'a, T, R> {
|
||||||
|
type Item = &'a T;
|
||||||
|
type IntoIter = AxisIterator<'a, T, R>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
AxisIterator::new(self)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contract<
|
||||||
|
'a,
|
||||||
|
T: Value + std::fmt::Debug,
|
||||||
|
const R: usize,
|
||||||
|
const S: usize,
|
||||||
|
const N: usize,
|
||||||
|
>(
|
||||||
|
lhs: &'a Tensor<T, R>,
|
||||||
|
rhs: &'a Tensor<T, S>,
|
||||||
|
laxes: [Axis<'a, T, R>; N],
|
||||||
|
raxes: [Axis<'a, T, S>; N],
|
||||||
|
) -> Tensor<T, { R + S - 2 * N }>
|
||||||
|
where
|
||||||
|
[(); R - N]:,
|
||||||
|
[(); S - N]:,
|
||||||
|
[(); R + S - 2 * N]:,
|
||||||
|
{
|
||||||
|
let li = laxes
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|axis| axis.into_iter())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let ri = raxes
|
||||||
|
.clone()
|
||||||
|
.into_iter()
|
||||||
|
.map(|axis| axis.into_iter())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
|
||||||
|
|
||||||
|
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||||
|
laxes
|
||||||
|
.iter()
|
||||||
|
.map(|axis| *axis.dim())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.try_into()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
||||||
|
raxes
|
||||||
|
.iter()
|
||||||
|
.map(|axis| *axis.dim())
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.try_into()
|
||||||
|
.unwrap(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||||
|
let mut shape = Vec::new();
|
||||||
|
shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||||
|
shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||||
|
|
||||||
|
let shape: [usize; R + S - 2 * N] =
|
||||||
|
shape.try_into().expect("Failed to create shape array");
|
||||||
|
|
||||||
|
let shape = Shape::new(shape);
|
||||||
|
|
||||||
|
Tensor::new_with_buffer(shape, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn contract_axes<
|
||||||
|
'a,
|
||||||
|
T: Value + std::fmt::Debug,
|
||||||
|
const R: usize,
|
||||||
|
const S: usize,
|
||||||
|
const N: usize,
|
||||||
|
>(
|
||||||
|
laxes: [AxisIterator<'a, T, R>; N],
|
||||||
|
raxes: [AxisIterator<'a, T, S>; N],
|
||||||
|
) -> Vec<T>
|
||||||
|
where
|
||||||
|
[(); R - N]:,
|
||||||
|
[(); S - N]:,
|
||||||
|
// [(); R + S - 2 * N]:,
|
||||||
|
{
|
||||||
|
let mut result = Vec::new();
|
||||||
|
|
||||||
|
// Disassemble each axis iterator into iterators over each level
|
||||||
|
let laxes: Vec<Vec<AxisIterator<'a, T, R>>> =
|
||||||
|
laxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||||
|
let raxes: Vec<Vec<AxisIterator<'a, T, S>>> =
|
||||||
|
raxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||||
|
|
||||||
|
// Iterate over all combinations of levels for contracted axes
|
||||||
|
|
||||||
|
let lcart = laxes.into_iter().multi_cartesian_product();
|
||||||
|
let rcart = raxes.into_iter().multi_cartesian_product();
|
||||||
|
let cartesian = lcart.zip(rcart);
|
||||||
|
|
||||||
|
for (i, (l, r)) in cartesian.enumerate() {
|
||||||
|
println!("{} cartesian:", i);
|
||||||
|
let mut sum = T::zero();
|
||||||
|
|
||||||
|
let level = l.into_iter().zip(r.into_iter());
|
||||||
|
|
||||||
|
for (i, (l, r)) in level.enumerate() {
|
||||||
|
// let mut sum = T::zero();
|
||||||
|
println!("{} level:", i);
|
||||||
|
let value = l.into_iter().zip(r.into_iter());
|
||||||
|
|
||||||
|
for (i, (l, r)) in value.enumerate() {
|
||||||
|
println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r);
|
||||||
|
sum = sum + *l * *r;
|
||||||
|
}
|
||||||
|
// result.push(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.push(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
// pub fn contract_axes<
|
||||||
|
// 'a,
|
||||||
|
// T: Value + std::fmt::Debug,
|
||||||
|
// const R: usize,
|
||||||
|
// const S: usize,
|
||||||
|
// const N: usize,
|
||||||
|
// >(
|
||||||
|
// // Axes are (contracted, non-contracted)
|
||||||
|
// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]),
|
||||||
|
// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]),
|
||||||
|
// ) -> Vec<T> //Tensor<T, { R + S - 2 * N }>
|
||||||
|
// where
|
||||||
|
// [(); R - N]:,
|
||||||
|
// [(); S - N]:,
|
||||||
|
// [(); R + S - 2 * N]:,
|
||||||
|
// {
|
||||||
|
// let (lc, lnc) = laxes;
|
||||||
|
// let (rc, rnc) = raxes;
|
||||||
|
|
||||||
|
// // Step 1: Prepare cartesian products of contracted axes
|
||||||
|
// let lc_prod = lc.into_iter().multi_cartesian_product();
|
||||||
|
// let rc_prod = rc.into_iter().multi_cartesian_product();
|
||||||
|
|
||||||
|
// // Step 2: Prepare iterators for non-contracted axes
|
||||||
|
// let lnc_prod = lnc.into_iter().multi_cartesian_product();
|
||||||
|
// let rnc_prod = rnc.into_iter().multi_cartesian_product();
|
||||||
|
|
||||||
|
// // Initialize buffer for the resulting tensor
|
||||||
|
// let mut result_buffer = Vec::new();
|
||||||
|
|
||||||
|
// // Step 3: Iterate over combinations and compute tensor product
|
||||||
|
// for lnc_vals in lnc_prod {
|
||||||
|
// for rnc_vals in rnc_prod.clone() {
|
||||||
|
// let mut sum = T::zero();
|
||||||
|
// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod
|
||||||
|
// .clone()
|
||||||
|
// .zip(rc_prod.clone())
|
||||||
|
// .zip(lnc_vals.iter().zip(rnc_vals.iter()))
|
||||||
|
// {
|
||||||
|
// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) {
|
||||||
|
// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// result_buffer.push(sum);
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// result_buffer
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // Check that there are no duplicate axes
|
||||||
|
// for (i, laxis) in laxes.iter().enumerate() {
|
||||||
|
// for (j, laxisx) in laxes.iter().enumerate() {
|
||||||
|
// if i != j && laxis == laxisx {
|
||||||
|
// panic!("Duplicate axes");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// for (i, raxis) in raxes.iter().enumerate() {
|
||||||
|
// for (j, raxisx) in raxes.iter().enumerate() {
|
||||||
|
// if i != j && raxis == raxisx {
|
||||||
|
// panic!("Duplicate axes");
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
// // Check that the sizes of the axes to contract are the same
|
||||||
|
// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) {
|
||||||
|
// // // Check that axes are within bounds
|
||||||
|
// // if lhs.shape()[*laxis] != *raxis {
|
||||||
|
// // println!("laxis: {}, raxis: {}", laxis, raxis);
|
||||||
|
// // panic!("Axis out of bounds");
|
||||||
|
// // }
|
||||||
|
// // }
|
||||||
|
// // 1. Axis iterators
|
||||||
|
// //
|
||||||
|
// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors
|
||||||
|
// // and take the Cartesian product of the iterators.
|
||||||
|
|
||||||
|
// let lcartesian_contract = lhs
|
||||||
|
// .shape()
|
||||||
|
// .iter()
|
||||||
|
// .enumerate()
|
||||||
|
// .filter(|(i, _)| laxes.contains(i))
|
||||||
|
// .map(|(i, _)| {
|
||||||
|
// println!("LHS is adding axis: {}", i);
|
||||||
|
// let vals = Axis::new(lhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||||
|
// println!("vals: {:?}", vals);
|
||||||
|
// })
|
||||||
|
// .collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// for l in lcartesian_contract.iter() {
|
||||||
|
// let l = l.clone();
|
||||||
|
// println!("{:?}", l);
|
||||||
|
// }
|
||||||
|
// // .multi_cartesian_product();
|
||||||
|
|
||||||
|
// // let cart_product = lcartesian_contract.multi_cartesian_product();
|
||||||
|
|
||||||
|
// // for l in cart_product {
|
||||||
|
// // println!("cartesian_contract l:");
|
||||||
|
// // for r in l {
|
||||||
|
// // println!("l: {:?}", r);
|
||||||
|
// // }
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// let rcartesian_contract = rhs
|
||||||
|
// .shape()
|
||||||
|
// .iter()
|
||||||
|
// .enumerate()
|
||||||
|
// .filter(|(i, _)| raxes.contains(i))
|
||||||
|
// .map(|(i, _)| {
|
||||||
|
// println!("RHS is adding axis: {}", i);
|
||||||
|
// let vals = Axis::new(rhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||||
|
// println!("vals: {:?}", vals);
|
||||||
|
// })
|
||||||
|
// .collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// for r in rcartesian_contract.iter() {
|
||||||
|
// let r = r.clone();
|
||||||
|
// println!("{:?}", r);
|
||||||
|
// }
|
||||||
|
// // .multi_cartesian_product();
|
||||||
|
|
||||||
|
// // println!("lcartesian_contract: {:?}", lcartesian_contract);
|
||||||
|
// // println!("rcartesian_contract: {:?}", rcartesian_contract);
|
||||||
|
|
||||||
|
// // Initialize buffer for the resulting tensor
|
||||||
|
// // let mut result_buffer = Vec::new();
|
||||||
|
// // let zip_contract = lcartesian_contract.zip(rcartesian_contract);
|
||||||
|
|
||||||
|
// // for (i, (j, k)) in zip_contract.enumerate() {
|
||||||
|
// // let mut sum = T::zero();
|
||||||
|
// // print!("{}. sum = ", i);
|
||||||
|
// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) {
|
||||||
|
// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||||
|
// // // sum = sum + lhsnc * rhsnc;
|
||||||
|
// // // }
|
||||||
|
// // for lhsnc in j.iter() {
|
||||||
|
// // for rhsnc in k.iter() {
|
||||||
|
// // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||||
|
// // sum = sum + *lhsnc * *rhsnc;
|
||||||
|
// // }
|
||||||
|
// // }
|
||||||
|
// // print!("= {}\n", sum);
|
||||||
|
// // result_buffer.push(sum);
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// // Iterate over non-contracted axes combinations
|
||||||
|
// // for (lhsnci, rhsnci) in
|
||||||
|
// // lcartesian_contract.zip(rcartesian_contract)
|
||||||
|
// // {
|
||||||
|
// // let mut sum = T::zero();
|
||||||
|
// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) {
|
||||||
|
// // sum = sum + *lhsnc * *rhsnc;
|
||||||
|
// // }
|
||||||
|
// // // Append the result to the buffer
|
||||||
|
// // result_buffer.push(sum);
|
||||||
|
// // }
|
||||||
|
|
||||||
|
// // TODO! Implement the rest of the function
|
||||||
|
|
||||||
|
// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||||
|
// laxes
|
||||||
|
// .iter()
|
||||||
|
// .map(|axis| *axis)
|
||||||
|
// .collect::<Vec<_>>()
|
||||||
|
// .try_into()
|
||||||
|
// .unwrap(),
|
||||||
|
// );
|
||||||
|
// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
||||||
|
// raxes
|
||||||
|
// .iter()
|
||||||
|
// .map(|axis| *axis)
|
||||||
|
// .collect::<Vec<_>>()
|
||||||
|
// .try_into()
|
||||||
|
// .unwrap(),
|
||||||
|
// );
|
||||||
|
|
||||||
|
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||||
|
// let mut shape = Vec::new();
|
||||||
|
// shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||||
|
// shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||||
|
|
||||||
|
// let shape: [usize; R + S - 2 * N] =
|
||||||
|
// shape.try_into().expect("Failed to create shape array");
|
||||||
|
|
||||||
|
// let shape = Shape::new(shape);
|
||||||
|
|
||||||
|
// Tensor::new_with_buffer(shape, vec![])
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_tensor_contraction_simple() {
|
||||||
|
// Define two 2D tensors (matrices)
|
||||||
|
// Tensor A is 2x3
|
||||||
|
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||||
|
|
||||||
|
// Tensor B is 1x3x2
|
||||||
|
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||||
|
|
||||||
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
|
let contracted_tensor: Tensor<i32, 2> =
|
||||||
|
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
||||||
|
|
||||||
|
// Expect: [22, 28, 49, 64]
|
||||||
|
// Check if the contracted tensor is as expected
|
||||||
|
println!("A: {}", a);
|
||||||
|
println!("B: {}", b);
|
||||||
|
println!("A * B = {}", contracted_tensor);
|
||||||
|
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_tensor_contraction() {
|
||||||
|
// // Define two 2D tensors (matrices)
|
||||||
|
// // Tensor A is 2x3
|
||||||
|
// let tensor_a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||||
|
// println!("A: {}", tensor_a);
|
||||||
|
|
||||||
|
// // Tensor B is 1x3x2
|
||||||
|
// let tensor_b: Tensor<i32, 3> = Tensor::from([[[1, 2], [3, 4], [5, 6]]]);
|
||||||
|
// println!("B: {}", tensor_b);
|
||||||
|
|
||||||
|
// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||||
|
// let contracted_tensor: Tensor<i32, 3> =
|
||||||
|
// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]);
|
||||||
|
|
||||||
|
// // Expect: [22, 28, 49, 64]
|
||||||
|
// // Check if the contracted tensor is as expected
|
||||||
|
// println!("A: {}", tensor_a);
|
||||||
|
// println!("B: {}", tensor_b);
|
||||||
|
// println!("A * B = {}", contracted_tensor);
|
||||||
|
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||||
|
// }
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_axis_iterator() {
|
||||||
|
// Creating a 2x2 Tensor for testing
|
||||||
|
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||||
|
|
||||||
|
// Testing iteration over the first axis (axis = 0)
|
||||||
|
let axis = Axis::new(&tensor, 0);
|
||||||
|
|
||||||
|
let mut axis_iter = axis.into_iter();
|
||||||
|
|
||||||
|
axis_iter.for_each(|x| println!("t1: {}", x));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||||
|
|
||||||
|
// Resetting the iterator for the second axis (axis = 1)
|
||||||
|
let axis = Axis::new(&tensor, 1);
|
||||||
|
|
||||||
|
let mut axis_iter = axis.into_iter();
|
||||||
|
axis_iter.for_each(|x| println!("t2: {}", x));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||||
|
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||||
|
|
||||||
|
let shape = tensor.shape();
|
||||||
|
|
||||||
|
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||||
|
let b: Idx<2> = (shape, [1, 1]).into();
|
||||||
|
|
||||||
|
while a <= b {
|
||||||
|
println!("a: {}", a);
|
||||||
|
a.inc();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// #[test]
|
||||||
|
// fn test_3d_tensor_axis_iteration() {
|
||||||
|
// // Create a 3D Tensor with specific values
|
||||||
|
// // Tensor shape is 2x2x2 for simplicity
|
||||||
|
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||||
|
|
||||||
|
// // Axis 0 (Layer-wise):
|
||||||
|
// //
|
||||||
|
// // t[0][0][0] = 1
|
||||||
|
// // t[0][0][1] = 2
|
||||||
|
// // t[0][1][0] = 3
|
||||||
|
// // t[0][1][1] = 4
|
||||||
|
// // t[1][0][0] = 5
|
||||||
|
// // t[1][0][1] = 6
|
||||||
|
// // t[1][1][0] = 7
|
||||||
|
// // t[1][1][1] = 8
|
||||||
|
// // [1, 2, 3, 4, 5, 6, 7, 8]
|
||||||
|
// //
|
||||||
|
// // This order suggests that for each "layer" (first level of arrays),
|
||||||
|
// // the iterator goes through all rows and columns. It first completes
|
||||||
|
// // the entire first layer, then moves to the second.
|
||||||
|
|
||||||
|
// let a0 = Axis::new(t.clone(), 0);
|
||||||
|
// let a0_order = a0.into_iter().collect::<Vec<_>>();
|
||||||
|
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||||
|
|
||||||
|
// // Axis 1 (Row-wise within each layer):
|
||||||
|
// //
|
||||||
|
// // t[0][0][0] = 1
|
||||||
|
// // t[0][0][1] = 2
|
||||||
|
// // t[1][0][0] = 5
|
||||||
|
// // t[1][0][1] = 6
|
||||||
|
// // t[0][1][0] = 3
|
||||||
|
// // t[0][1][1] = 4
|
||||||
|
// // t[1][1][0] = 7
|
||||||
|
// // t[1][1][1] = 8
|
||||||
|
// // [1, 2, 5, 6, 3, 4, 7, 8]
|
||||||
|
// //
|
||||||
|
// // This indicates that within each "layer", the iterator first
|
||||||
|
// // completes the first row across all layers, then the second row
|
||||||
|
// // across all layers.
|
||||||
|
|
||||||
|
// let a1 = Axis::new(t.clone(), 1);
|
||||||
|
// let a1_order = a1.into_iter().collect::<Vec<_>>();
|
||||||
|
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||||
|
|
||||||
|
// // Axis 2 (Column-wise within each layer):
|
||||||
|
// //
|
||||||
|
// // t[0][0][0] = 1
|
||||||
|
// // t[0][1][0] = 3
|
||||||
|
// // t[1][0][0] = 5
|
||||||
|
// // t[1][1][0] = 7
|
||||||
|
// // t[0][0][1] = 2
|
||||||
|
// // t[0][1][1] = 4
|
||||||
|
// // t[1][0][1] = 6
|
||||||
|
// // t[1][1][1] = 8
|
||||||
|
// // [1, 3, 5, 7, 2, 4, 6, 8]
|
||||||
|
// //
|
||||||
|
// // This indicates that within each "layer", the iterator first
|
||||||
|
// // completes the first column across all layers, then the second
|
||||||
|
// // column across all layers.
|
||||||
|
|
||||||
|
// let a2 = Axis::new(t, 2);
|
||||||
|
// let a2_order = a2.into_iter().collect::<Vec<_>>();
|
||||||
|
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||||
|
// }
|
||||||
|
// }
|
164
src/index.rs
164
src/index.rs
@ -51,6 +51,9 @@ impl<const R: usize> Idx<R> {
|
|||||||
/// `true` if the increment does not overflow and is still within bounds;
|
/// `true` if the increment does not overflow and is still within bounds;
|
||||||
/// `false` if it overflows, indicating the end of the tensor.
|
/// `false` if it overflows, indicating the end of the tensor.
|
||||||
pub fn inc(&mut self) -> bool {
|
pub fn inc(&mut self) -> bool {
|
||||||
|
if self.indices[0] >= self.shape.get(0) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
let mut carry = 1;
|
let mut carry = 1;
|
||||||
for (i, &dim_size) in
|
for (i, &dim_size) in
|
||||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||||
@ -67,13 +70,40 @@ impl<const R: usize> Idx<R> {
|
|||||||
|
|
||||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||||
if carry == 1 {
|
if carry == 1 {
|
||||||
// Set the index to an invalid state (e.g., all indices to their max values)
|
// Set the index to an invalid state to indicate the end of the iteration indicated
|
||||||
|
// by setting the first index to the size of the first dimension
|
||||||
self.indices[0] = self.shape.as_array()[0];
|
self.indices[0] = self.shape.as_array()[0];
|
||||||
return true; // Indicate that the iteration is complete
|
return true; // Indicate that the iteration is complete
|
||||||
}
|
}
|
||||||
false
|
false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// fn inc_axis
|
||||||
|
|
||||||
|
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||||
|
assert!(fixed_axis < R, "Axis out of bounds");
|
||||||
|
assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds");
|
||||||
|
|
||||||
|
// Try to increment non-fixed axes
|
||||||
|
for i in (0..R).rev() {
|
||||||
|
if i != fixed_axis {
|
||||||
|
if self.indices[i] < self.shape.get(i) {
|
||||||
|
self.indices[i] += 1;
|
||||||
|
} else {
|
||||||
|
self.indices[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Increment the fixed axis and reset other axes to 0
|
||||||
|
self.indices[fixed_axis] += 1;
|
||||||
|
for i in 0..R {
|
||||||
|
if i != fixed_axis {
|
||||||
|
self.indices[i] = 0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
pub fn dec(&mut self) {
|
pub fn dec(&mut self) {
|
||||||
// Check if already at the start
|
// Check if already at the start
|
||||||
if self.indices.iter().all(|&i| i == 0) {
|
if self.indices.iter().all(|&i| i == 0) {
|
||||||
@ -95,6 +125,40 @@ impl<const R: usize> Idx<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
|
||||||
|
// Check if the fixed axis index is already in an invalid state
|
||||||
|
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to decrement non-fixed axes
|
||||||
|
for i in (0..R).rev() {
|
||||||
|
if i != fixed_axis {
|
||||||
|
if self.indices[i] > 0 {
|
||||||
|
self.indices[i] -= 1;
|
||||||
|
return true;
|
||||||
|
} else {
|
||||||
|
self.indices[i] = self.shape.get(i) - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decrement the fixed axis if possible and reset other axes to their max
|
||||||
|
if self.indices[fixed_axis] > 0 {
|
||||||
|
self.indices[fixed_axis] -= 1;
|
||||||
|
for i in 0..R {
|
||||||
|
if i != fixed_axis {
|
||||||
|
self.indices[i] = self.shape.get(i) - 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Fixed axis already at minimum, set to invalid state
|
||||||
|
self.indices[fixed_axis] = self.shape.get(fixed_axis);
|
||||||
|
}
|
||||||
|
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
/// Converts the multi-dimensional index to a flat index.
|
/// Converts the multi-dimensional index to a flat index.
|
||||||
///
|
///
|
||||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||||
@ -136,6 +200,56 @@ impl<const R: usize> Idx<R> {
|
|||||||
})
|
})
|
||||||
.0
|
.0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// pub fn inc_axis(&mut self, axis: usize) -> bool {
|
||||||
|
// if axis >= R {
|
||||||
|
// // Axis is out of bounds
|
||||||
|
// return false;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let dim_size = self.shape.get(axis);
|
||||||
|
// if self.indices[axis] + 1 < dim_size {
|
||||||
|
// // Increment the index if it's not at its maximum
|
||||||
|
// self.indices[axis] += 1;
|
||||||
|
// true
|
||||||
|
// } else {
|
||||||
|
// // Index is at its maximum for this axis
|
||||||
|
// false
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
||||||
|
assert!(axis < R, "Axis out of bounds");
|
||||||
|
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
||||||
|
self.indices[axis] = value;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
||||||
|
assert!(axis < R, "Axis out of bounds");
|
||||||
|
if value < self.shape.get(axis) {
|
||||||
|
self.indices[axis] = value;
|
||||||
|
true
|
||||||
|
} else {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_axis(&self, axis: usize) -> usize {
|
||||||
|
assert!(axis < R, "Axis out of bounds");
|
||||||
|
self.indices[axis]
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn indices(&self) -> &[usize; R] {
|
||||||
|
&self.indices
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape(&self) -> Shape<R> {
|
||||||
|
self.shape
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn shape_mut(&mut self) -> &mut Shape<R> {
|
||||||
|
&mut self.shape
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// --- blanket impls ---
|
// --- blanket impls ---
|
||||||
@ -194,6 +308,12 @@ impl<const R: usize> From<Shape<R>> for Idx<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const R: usize> From<Tensor<T, R>> for Idx<R> {
|
||||||
|
fn from(tensor: Tensor<T, R>) -> Self {
|
||||||
|
Self::zero(tensor.shape())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
impl<const R: usize> std::fmt::Display for Idx<R> {
|
impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
write!(f, "[")?;
|
write!(f, "[")?;
|
||||||
@ -249,3 +369,45 @@ impl<const R: usize> Sub for Idx<R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ---- Iterator ----
|
||||||
|
|
||||||
|
pub struct IdxIterator<const R: usize> {
|
||||||
|
current: Idx<R>,
|
||||||
|
end: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> IdxIterator<R> {
|
||||||
|
pub fn new(shape: Shape<R>) -> Self {
|
||||||
|
Self {
|
||||||
|
current: Idx::zero(shape),
|
||||||
|
end: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> Iterator for IdxIterator<R> {
|
||||||
|
type Item = Idx<R>;
|
||||||
|
|
||||||
|
fn next(&mut self) -> Option<Self::Item> {
|
||||||
|
if self.end {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = self.current;
|
||||||
|
self.end = self.current.inc();
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<const R: usize> IntoIterator for Idx<R> {
|
||||||
|
type Item = Idx<R>;
|
||||||
|
type IntoIter = IdxIterator<R>;
|
||||||
|
|
||||||
|
fn into_iter(self) -> Self::IntoIter {
|
||||||
|
IdxIterator {
|
||||||
|
current: self,
|
||||||
|
end: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
#![feature(generic_const_exprs)]
|
#![feature(generic_const_exprs)]
|
||||||
pub mod index;
|
pub mod index;
|
||||||
pub mod shape;
|
pub mod shape;
|
||||||
|
pub mod axis;
|
||||||
pub mod tensor;
|
pub mod tensor;
|
||||||
|
|
||||||
pub use index::Idx;
|
pub use index::Idx;
|
||||||
@ -11,6 +12,10 @@ pub use shape::Shape;
|
|||||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||||
use std::ops::{Index, IndexMut};
|
use std::ops::{Index, IndexMut};
|
||||||
pub use tensor::{Tensor, TensorIterator};
|
pub use tensor::{Tensor, TensorIterator};
|
||||||
|
pub use static_assertions::const_assert;
|
||||||
|
pub use itertools::Itertools;
|
||||||
|
pub use std::sync::Arc;
|
||||||
|
pub use axis::*;
|
||||||
|
|
||||||
pub trait Value:
|
pub trait Value:
|
||||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||||
|
164
src/tensor.rs
164
src/tensor.rs
@ -22,12 +22,28 @@ impl<T: Value> Tensor<T, 1> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||||
pub fn new(shape: Shape<R>) -> Self {
|
pub fn new(shape: Shape<R>) -> Self {
|
||||||
let total_size: usize = shape.iter().product();
|
// Handle rank 0 tensor (scalar) as a special case
|
||||||
|
let total_size = if R == 0 {
|
||||||
|
// A rank 0 tensor should still have a buffer with one element
|
||||||
|
1
|
||||||
|
} else {
|
||||||
|
// For tensors of rank 1 or higher, calculate the total size normally
|
||||||
|
shape.iter().product()
|
||||||
|
};
|
||||||
|
|
||||||
let buffer = vec![T::zero(); total_size];
|
let buffer = vec![T::zero(); total_size];
|
||||||
Self { buffer, shape }
|
Self { buffer, shape }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
||||||
|
Self { buffer, shape }
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn idx(&self) -> Idx<R> {
|
||||||
|
Idx::from(self.clone())
|
||||||
|
}
|
||||||
|
|
||||||
pub fn shape(&self) -> Shape<R> {
|
pub fn shape(&self) -> Shape<R> {
|
||||||
self.shape
|
self.shape
|
||||||
}
|
}
|
||||||
@ -111,93 +127,71 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// `self_dims` and `other_dims` specify the dimensions to contract over
|
// // Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||||
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
|
// for &axis in &axis_lhs {
|
||||||
&self,
|
// if axis >= R {
|
||||||
rhs: &Tensor<T, S>,
|
// panic!(
|
||||||
axis_lhs: [usize; NAXL],
|
// "Axis {} is out of bounds for the left-hand tensor",
|
||||||
axis_rhs: [usize; NAXR],
|
// axis
|
||||||
) -> Tensor<T, { R + S - NAXL - NAXR }>
|
// );
|
||||||
where
|
// }
|
||||||
[(); R - NAXL]:,
|
// }
|
||||||
[(); S - NAXR]:,
|
|
||||||
{
|
|
||||||
// Step 1: Validate the axes for both tensors to ensure they are within bounds
|
|
||||||
for &axis in &axis_lhs {
|
|
||||||
if axis >= R {
|
|
||||||
panic!(
|
|
||||||
"Axis {} is out of bounds for the left-hand tensor",
|
|
||||||
axis
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for &axis in &axis_rhs {
|
// for &axis in &axis_rhs {
|
||||||
if axis >= S {
|
// if axis >= S {
|
||||||
panic!(
|
// panic!(
|
||||||
"Axis {} is out of bounds for the right-hand tensor",
|
// "Axis {} is out of bounds for the right-hand tensor",
|
||||||
axis
|
// axis
|
||||||
);
|
// );
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
||||||
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
// // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||||
let mut result_buffer = Vec::new();
|
// let mut result_buffer = Vec::new();
|
||||||
|
|
||||||
for i in 0..self.shape.size() {
|
// for i in self.iter_over_non_contracted_axes(&axis_lhs) {
|
||||||
for j in 0..rhs.shape.size() {
|
// for j in rhs.iter_over_non_contracted_axes(&axis_rhs) {
|
||||||
// Debug: Print indices being processed
|
// let mut product_sum = T::zero();
|
||||||
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
|
|
||||||
|
|
||||||
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
|
// // Step 3: Contraction over specified axes
|
||||||
let mut product_sum = T::zero();
|
// for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||||
|
// let idx_l = i[axis_l];
|
||||||
|
// let idx_r = j[axis_r];
|
||||||
|
|
||||||
// Debug: Print axes of contraction
|
// let value_lhs = self.buffer[self.flat(&idx_l)];
|
||||||
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
|
// let value_rhs = rhs.buffer[rhs.flat(&idx_r)];
|
||||||
|
// product_sum = product_sum + value_lhs * value_rhs;
|
||||||
|
// }
|
||||||
|
|
||||||
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
// result_buffer.push(product_sum);
|
||||||
// Debug: Print values being multiplied
|
// }
|
||||||
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
|
// }
|
||||||
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
|
|
||||||
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
|
|
||||||
|
|
||||||
product_sum = product_sum + value_lhs * value_rhs;
|
// // Step 4: Remove contracted dimensions to create new shapes for both tensors
|
||||||
}
|
// let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs);
|
||||||
|
// let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||||
|
|
||||||
// Debug: Print the product sum for the current indices
|
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||||
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
|
// let mut new_shape = Vec::new();
|
||||||
|
// new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||||
|
// new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||||
|
|
||||||
result_buffer.push(product_sum);
|
// let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||||
}
|
// new_shape.try_into().expect("Failed to create shape array");
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Step 3: Remove contracted dimensions to create new shapes for both tensors
|
// Tensor {
|
||||||
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
|
// buffer: result_buffer,
|
||||||
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
// shape: Shape::new(new_shape_array),
|
||||||
|
// }
|
||||||
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
|
// }
|
||||||
let mut new_shape = Vec::new();
|
|
||||||
|
|
||||||
new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
|
||||||
new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
|
||||||
|
|
||||||
let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
|
||||||
new_shape.try_into().expect("Failed to create shape array");
|
|
||||||
|
|
||||||
Tensor {
|
|
||||||
buffer: result_buffer,
|
|
||||||
shape: Shape::new(new_shape_array),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Retrieve an element based on a specific axis and index
|
// Retrieve an element based on a specific axis and index
|
||||||
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||||
// Convert axis and index to a flat index
|
// Convert axis and index to a flat index
|
||||||
let flat_index = self.axis_to_flat_index(axis, index);
|
let flat_index = self.axis_to_flat_index(axis, index);
|
||||||
if flat_index >= self.buffer.len() {
|
if flat_index >= self.buffer.len() {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
Some(self.buffer[flat_index])
|
Some(self.buffer[flat_index])
|
||||||
}
|
}
|
||||||
@ -214,7 +208,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
|||||||
|
|
||||||
// Calculate the stride for each dimension and accumulate the flat index
|
// Calculate the stride for each dimension and accumulate the flat index
|
||||||
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
|
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
|
||||||
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
||||||
if i > axis {
|
if i > axis {
|
||||||
stride *= dim_size;
|
stride *= dim_size;
|
||||||
} else if i == axis {
|
} else if i == axis {
|
||||||
@ -366,3 +360,23 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
|||||||
tensor
|
tensor
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
||||||
|
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
|
||||||
|
{
|
||||||
|
fn from(array: [[[T; X]; Y]; Z]) -> Self {
|
||||||
|
let shape = Shape::new([Z, Y, X]);
|
||||||
|
let mut tensor = Tensor::new(shape);
|
||||||
|
let buffer = tensor.buffer_mut();
|
||||||
|
|
||||||
|
for (i, plane) in array.iter().enumerate() {
|
||||||
|
for (j, row) in plane.iter().enumerate() {
|
||||||
|
for (k, &elem) in row.iter().enumerate() {
|
||||||
|
buffer[i * X * Y + j * X + k] = elem;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tensor
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user