Work on tensor contraction and axis iteration

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-29 01:47:18 +02:00
parent b17264ba18
commit d54179cd5d
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
7 changed files with 979 additions and 124 deletions

99
Cargo.lock generated
View File

@ -14,22 +14,39 @@ version = "1.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
[[package]]
name = "either"
version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "getset"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "itertools"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
[[package]]
name = "mltensor"
version = "0.1.0"
dependencies = [
"bytemuck",
"num",
"serde",
"serde_json",
]
[[package]]
name = "num"
version = "0.4.1"
@ -106,6 +123,30 @@ dependencies = [
"autocfg",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.71"
@ -147,7 +188,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.43",
]
[[package]]
@ -161,6 +202,23 @@ dependencies = [
"serde",
]
[[package]]
name = "static_assertions"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.43"
@ -172,8 +230,27 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "tensorc"
version = "0.1.0"
dependencies = [
"bytemuck",
"getset",
"itertools",
"num",
"serde",
"serde_json",
"static_assertions",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "version_check"
version = "0.9.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"

View File

@ -1,5 +1,5 @@
[package]
name = "mltensor"
name = "tensorc"
version = "0.1.0"
edition = "2021"
@ -7,6 +7,9 @@ edition = "2021"
[dependencies]
bytemuck = "1.14.0"
getset = "0.1.2"
itertools = "0.12.0"
num = "0.4.1"
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
static_assertions = "1.1.0"

View File

@ -1,4 +1,7 @@
use mltensor::*;
#![allow(mixed_script_confusables)]
#![allow(non_snake_case)]
use bytemuck::cast_slice;
use tensorc::*;
fn tensor_product() {
println!("Tensor Product\n");
@ -20,8 +23,9 @@ fn tensor_product() {
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
// Check buffer of the resulting tensor
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
assert_eq!(product.buffer(), &expected_buffer);
let expect: &[i32] =
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
assert_eq!(product.buffer(), expect);
}
fn tensor_contraction() {
@ -35,16 +39,25 @@ fn tensor_contraction() {
let axis_rhs = [0]; // Contract over the first dimension of tensor2
// Perform contraction
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
// let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
println!("T1: {}", tensor1);
println!("T2: {}", tensor2);
println!("T1 * T2 = {}", result);
// println!("T1: {}", tensor1);
// println!("T2: {}", tensor2);
// println!("T1 * T2 = {}", result);
// Expected result, for example, could be a single number or a new tensor,
// depending on how you defined the contraction operation.
// Assert the result is as expected
// assert_eq!(result, expected_result);
// let Λ = Tensor::<f64, 2>::from([
// [1.0, 0.0, 0.0, 0.0],
// [0.0, 1.0, 0.0 ,0.0],
// [0.0, 0.0, 1.0, 0.0],
// [0.0, 0.0, 0.0, 1.0]
// ]);
// println!("Λ: {}", Λ);
}
fn main() {

581
src/axis.rs Normal file
View File

@ -0,0 +1,581 @@
use super::*;
use getset::{Getters, MutGetters};
#[derive(Clone, Debug, Getters)]
pub struct Axis<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
tensor: &'a Tensor<T, R>,
#[getset(get = "pub")]
shape: Shape<R>,
#[getset(get = "pub")]
dim: usize,
#[getset(get = "pub")]
len: usize,
}
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
assert!(dim < R, "Axis out of bounds");
Self {
tensor,
shape: tensor.shape(),
dim,
len: tensor.shape().get(dim),
}
}
pub fn iter_level(&self, level: usize) -> AxisIterator<'a, T, R> {
assert!(level < self.len, "Level out of bounds");
let mut index = Idx::new(self.shape.clone(), [0; R]);
index.set_axis(self.dim, level);
AxisIterator::new(self.clone()).set_start(level).set_end(level + 1)
}
}
#[derive(Clone, Debug, Getters, MutGetters)]
pub struct AxisIterator<'a, T: Value, const R: usize> {
#[getset(get = "pub")]
axis: Axis<'a, T, R>,
#[getset(get = "pub", get_mut = "pub")]
index: Idx<R>,
#[getset(get = "pub")]
end: Option<usize>,
}
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
pub fn new(axis: Axis<'a, T, R>) -> Self {
Self {
axis: axis.clone(),
index: Idx::new(axis.shape().clone(), [0; R]),
end: None,
}
}
pub fn set_start(self, start: usize) -> Self {
assert!(start < self.axis.len, "Start out of bounds");
let mut index = Idx::new(self.axis.shape().clone(), [0; R]);
index.set_axis(self.axis.dim, start);
Self {
axis: self.axis.clone(),
index,
end: None,
}
}
pub fn set_end(self, end: usize) -> Self {
assert!(end <= self.axis.len, "End out of bounds");
Self {
axis: self.axis.clone(),
index: self.index.clone(),
end: Some(end),
}
}
pub fn set_level(self, level: usize) -> Self {
assert!(level < self.axis.len, "Level out of bounds");
self.set_start(level).set_end(level + 1)
}
pub fn disassemble(self) -> Vec<Self> {
let mut result = Vec::new();
for i in 0..self.axis().len {
result.push(Self::new(self.axis.clone()).set_level(i));
}
result
}
pub fn axis_max_idx(&self) -> usize {
self.end().unwrap_or(*self.axis().len())
}
pub fn axis_idx(&self) -> usize {
self.index().get_axis(*self.axis().dim())
}
pub fn axis_dim(&self) -> usize {
self.axis().dim().clone()
}
}
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
if self.axis_idx() == self.axis_max_idx() {
return None;
}
// Increment the index along the fixed axis and check if it's within bounds
let result = self.axis().tensor().get(self.index);
let axis_dim = self.axis_dim();
self.index_mut().inc_axis(axis_dim);
Some(result)
}
}
impl<'a, T: Value, const R: usize> IntoIterator for Axis<'a, T, R> {
type Item = &'a T;
type IntoIter = AxisIterator<'a, T, R>;
fn into_iter(self) -> Self::IntoIter {
AxisIterator::new(self)
}
}
pub fn contract<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
lhs: &'a Tensor<T, R>,
rhs: &'a Tensor<T, S>,
laxes: [Axis<'a, T, R>; N],
raxes: [Axis<'a, T, S>; N],
) -> Tensor<T, { R + S - 2 * N }>
where
[(); R - N]:,
[(); S - N]:,
[(); R + S - 2 * N]:,
{
let li = laxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let ri = raxes
.clone()
.into_iter()
.map(|axis| axis.into_iter())
.collect::<Vec<_>>();
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
laxes
.iter()
.map(|axis| *axis.dim())
.collect::<Vec<_>>()
.try_into()
.unwrap(),
);
let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
raxes
.iter()
.map(|axis| *axis.dim())
.collect::<Vec<_>>()
.try_into()
.unwrap(),
);
// Step 5: Concatenate the shapes to form the shape of the resultant tensor
let mut shape = Vec::new();
shape.extend_from_slice(&lhs_shape_reduced.as_array());
shape.extend_from_slice(&rhs_shape_reduced.as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let shape = Shape::new(shape);
Tensor::new_with_buffer(shape, result)
}
pub fn contract_axes<
'a,
T: Value + std::fmt::Debug,
const R: usize,
const S: usize,
const N: usize,
>(
laxes: [AxisIterator<'a, T, R>; N],
raxes: [AxisIterator<'a, T, S>; N],
) -> Vec<T>
where
[(); R - N]:,
[(); S - N]:,
// [(); R + S - 2 * N]:,
{
let mut result = Vec::new();
// Disassemble each axis iterator into iterators over each level
let laxes: Vec<Vec<AxisIterator<'a, T, R>>> =
laxes.into_iter().map(|axis| axis.disassemble()).collect();
let raxes: Vec<Vec<AxisIterator<'a, T, S>>> =
raxes.into_iter().map(|axis| axis.disassemble()).collect();
// Iterate over all combinations of levels for contracted axes
let lcart = laxes.into_iter().multi_cartesian_product();
let rcart = raxes.into_iter().multi_cartesian_product();
let cartesian = lcart.zip(rcart);
for (i, (l, r)) in cartesian.enumerate() {
println!("{} cartesian:", i);
let mut sum = T::zero();
let level = l.into_iter().zip(r.into_iter());
for (i, (l, r)) in level.enumerate() {
// let mut sum = T::zero();
println!("{} level:", i);
let value = l.into_iter().zip(r.into_iter());
for (i, (l, r)) in value.enumerate() {
println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r);
sum = sum + *l * *r;
}
// result.push(sum);
}
result.push(sum);
}
result
}
// pub fn contract_axes<
// 'a,
// T: Value + std::fmt::Debug,
// const R: usize,
// const S: usize,
// const N: usize,
// >(
// // Axes are (contracted, non-contracted)
// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]),
// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]),
// ) -> Vec<T> //Tensor<T, { R + S - 2 * N }>
// where
// [(); R - N]:,
// [(); S - N]:,
// [(); R + S - 2 * N]:,
// {
// let (lc, lnc) = laxes;
// let (rc, rnc) = raxes;
// // Step 1: Prepare cartesian products of contracted axes
// let lc_prod = lc.into_iter().multi_cartesian_product();
// let rc_prod = rc.into_iter().multi_cartesian_product();
// // Step 2: Prepare iterators for non-contracted axes
// let lnc_prod = lnc.into_iter().multi_cartesian_product();
// let rnc_prod = rnc.into_iter().multi_cartesian_product();
// // Initialize buffer for the resulting tensor
// let mut result_buffer = Vec::new();
// // Step 3: Iterate over combinations and compute tensor product
// for lnc_vals in lnc_prod {
// for rnc_vals in rnc_prod.clone() {
// let mut sum = T::zero();
// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod
// .clone()
// .zip(rc_prod.clone())
// .zip(lnc_vals.iter().zip(rnc_vals.iter()))
// {
// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) {
// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val);
// }
// }
// result_buffer.push(sum);
// }
// }
// result_buffer
// }
// // Check that there are no duplicate axes
// for (i, laxis) in laxes.iter().enumerate() {
// for (j, laxisx) in laxes.iter().enumerate() {
// if i != j && laxis == laxisx {
// panic!("Duplicate axes");
// }
// }
// }
// for (i, raxis) in raxes.iter().enumerate() {
// for (j, raxisx) in raxes.iter().enumerate() {
// if i != j && raxis == raxisx {
// panic!("Duplicate axes");
// }
// }
// }
// // Check that the sizes of the axes to contract are the same
// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) {
// // // Check that axes are within bounds
// // if lhs.shape()[*laxis] != *raxis {
// // println!("laxis: {}, raxis: {}", laxis, raxis);
// // panic!("Axis out of bounds");
// // }
// // }
// // 1. Axis iterators
// //
// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors
// // and take the Cartesian product of the iterators.
// let lcartesian_contract = lhs
// .shape()
// .iter()
// .enumerate()
// .filter(|(i, _)| laxes.contains(i))
// .map(|(i, _)| {
// println!("LHS is adding axis: {}", i);
// let vals = Axis::new(lhs.clone(), i).into_iter().collect::<Vec<_>>();
// println!("vals: {:?}", vals);
// })
// .collect::<Vec<_>>();
// for l in lcartesian_contract.iter() {
// let l = l.clone();
// println!("{:?}", l);
// }
// // .multi_cartesian_product();
// // let cart_product = lcartesian_contract.multi_cartesian_product();
// // for l in cart_product {
// // println!("cartesian_contract l:");
// // for r in l {
// // println!("l: {:?}", r);
// // }
// // }
// let rcartesian_contract = rhs
// .shape()
// .iter()
// .enumerate()
// .filter(|(i, _)| raxes.contains(i))
// .map(|(i, _)| {
// println!("RHS is adding axis: {}", i);
// let vals = Axis::new(rhs.clone(), i).into_iter().collect::<Vec<_>>();
// println!("vals: {:?}", vals);
// })
// .collect::<Vec<_>>();
// for r in rcartesian_contract.iter() {
// let r = r.clone();
// println!("{:?}", r);
// }
// // .multi_cartesian_product();
// // println!("lcartesian_contract: {:?}", lcartesian_contract);
// // println!("rcartesian_contract: {:?}", rcartesian_contract);
// // Initialize buffer for the resulting tensor
// // let mut result_buffer = Vec::new();
// // let zip_contract = lcartesian_contract.zip(rcartesian_contract);
// // for (i, (j, k)) in zip_contract.enumerate() {
// // let mut sum = T::zero();
// // print!("{}. sum = ", i);
// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) {
// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
// // // sum = sum + lhsnc * rhsnc;
// // // }
// // for lhsnc in j.iter() {
// // for rhsnc in k.iter() {
// // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
// // sum = sum + *lhsnc * *rhsnc;
// // }
// // }
// // print!("= {}\n", sum);
// // result_buffer.push(sum);
// // }
// // Iterate over non-contracted axes combinations
// // for (lhsnci, rhsnci) in
// // lcartesian_contract.zip(rcartesian_contract)
// // {
// // let mut sum = T::zero();
// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) {
// // sum = sum + *lhsnc * *rhsnc;
// // }
// // // Append the result to the buffer
// // result_buffer.push(sum);
// // }
// // TODO! Implement the rest of the function
// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
// laxes
// .iter()
// .map(|axis| *axis)
// .collect::<Vec<_>>()
// .try_into()
// .unwrap(),
// );
// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
// raxes
// .iter()
// .map(|axis| *axis)
// .collect::<Vec<_>>()
// .try_into()
// .unwrap(),
// );
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
// let mut shape = Vec::new();
// shape.extend_from_slice(&lhs_shape_reduced.as_array());
// shape.extend_from_slice(&rhs_shape_reduced.as_array());
// let shape: [usize; R + S - 2 * N] =
// shape.try_into().expect("Failed to create shape array");
// let shape = Shape::new(shape);
// Tensor::new_with_buffer(shape, vec![])
// }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_contraction_simple() {
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Tensor B is 1x3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let contracted_tensor: Tensor<i32, 2> =
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
// Expect: [22, 28, 49, 64]
// Check if the contracted tensor is as expected
println!("A: {}", a);
println!("B: {}", b);
println!("A * B = {}", contracted_tensor);
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
}
}
// #[test]
// fn test_tensor_contraction() {
// // Define two 2D tensors (matrices)
// // Tensor A is 2x3
// let tensor_a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
// println!("A: {}", tensor_a);
// // Tensor B is 1x3x2
// let tensor_b: Tensor<i32, 3> = Tensor::from([[[1, 2], [3, 4], [5, 6]]]);
// println!("B: {}", tensor_b);
// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
// let contracted_tensor: Tensor<i32, 3> =
// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]);
// // Expect: [22, 28, 49, 64]
// // Check if the contracted tensor is as expected
// println!("A: {}", tensor_a);
// println!("B: {}", tensor_b);
// println!("A * B = {}", contracted_tensor);
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
// }
#[test]
fn test_axis_iterator() {
// Creating a 2x2 Tensor for testing
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
// Testing iteration over the first axis (axis = 0)
let axis = Axis::new(&tensor, 0);
let mut axis_iter = axis.into_iter();
axis_iter.for_each(|x| println!("t1: {}", x));
// assert_eq!(axis_iter.next(), Some(&1.0));
// assert_eq!(axis_iter.next(), Some(&2.0));
// assert_eq!(axis_iter.next(), Some(&3.0));
// assert_eq!(axis_iter.next(), Some(&4.0));
// Resetting the iterator for the second axis (axis = 1)
let axis = Axis::new(&tensor, 1);
let mut axis_iter = axis.into_iter();
axis_iter.for_each(|x| println!("t2: {}", x));
// assert_eq!(axis_iter.next(), Some(&1.0));
// assert_eq!(axis_iter.next(), Some(&3.0));
// assert_eq!(axis_iter.next(), Some(&2.0));
// assert_eq!(axis_iter.next(), Some(&4.0));
let shape = tensor.shape();
let mut a: Idx<2> = (shape, [0, 0]).into();
let b: Idx<2> = (shape, [1, 1]).into();
while a <= b {
println!("a: {}", a);
a.inc();
}
}
// #[test]
// fn test_3d_tensor_axis_iteration() {
// // Create a 3D Tensor with specific values
// // Tensor shape is 2x2x2 for simplicity
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
// // Axis 0 (Layer-wise):
// //
// // t[0][0][0] = 1
// // t[0][0][1] = 2
// // t[0][1][0] = 3
// // t[0][1][1] = 4
// // t[1][0][0] = 5
// // t[1][0][1] = 6
// // t[1][1][0] = 7
// // t[1][1][1] = 8
// // [1, 2, 3, 4, 5, 6, 7, 8]
// //
// // This order suggests that for each "layer" (first level of arrays),
// // the iterator goes through all rows and columns. It first completes
// // the entire first layer, then moves to the second.
// let a0 = Axis::new(t.clone(), 0);
// let a0_order = a0.into_iter().collect::<Vec<_>>();
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
// // Axis 1 (Row-wise within each layer):
// //
// // t[0][0][0] = 1
// // t[0][0][1] = 2
// // t[1][0][0] = 5
// // t[1][0][1] = 6
// // t[0][1][0] = 3
// // t[0][1][1] = 4
// // t[1][1][0] = 7
// // t[1][1][1] = 8
// // [1, 2, 5, 6, 3, 4, 7, 8]
// //
// // This indicates that within each "layer", the iterator first
// // completes the first row across all layers, then the second row
// // across all layers.
// let a1 = Axis::new(t.clone(), 1);
// let a1_order = a1.into_iter().collect::<Vec<_>>();
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
// // Axis 2 (Column-wise within each layer):
// //
// // t[0][0][0] = 1
// // t[0][1][0] = 3
// // t[1][0][0] = 5
// // t[1][1][0] = 7
// // t[0][0][1] = 2
// // t[0][1][1] = 4
// // t[1][0][1] = 6
// // t[1][1][1] = 8
// // [1, 3, 5, 7, 2, 4, 6, 8]
// //
// // This indicates that within each "layer", the iterator first
// // completes the first column across all layers, then the second
// // column across all layers.
// let a2 = Axis::new(t, 2);
// let a2_order = a2.into_iter().collect::<Vec<_>>();
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
// }
// }

View File

@ -51,6 +51,9 @@ impl<const R: usize> Idx<R> {
/// `true` if the increment does not overflow and is still within bounds;
/// `false` if it overflows, indicating the end of the tensor.
pub fn inc(&mut self) -> bool {
if self.indices[0] >= self.shape.get(0) {
return false;
}
let mut carry = 1;
for (i, &dim_size) in
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
@ -67,13 +70,40 @@ impl<const R: usize> Idx<R> {
// If carry is still 1 after the loop, it means we've incremented past the last dimension
if carry == 1 {
// Set the index to an invalid state (e.g., all indices to their max values)
// Set the index to an invalid state to indicate the end of the iteration indicated
// by setting the first index to the size of the first dimension
self.indices[0] = self.shape.as_array()[0];
return true; // Indicate that the iteration is complete
}
false
}
// fn inc_axis
pub fn inc_axis(&mut self, fixed_axis: usize) {
assert!(fixed_axis < R, "Axis out of bounds");
assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds");
// Try to increment non-fixed axes
for i in (0..R).rev() {
if i != fixed_axis {
if self.indices[i] < self.shape.get(i) {
self.indices[i] += 1;
} else {
self.indices[i] = 0;
}
}
}
// Increment the fixed axis and reset other axes to 0
self.indices[fixed_axis] += 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = 0;
}
}
}
pub fn dec(&mut self) {
// Check if already at the start
if self.indices.iter().all(|&i| i == 0) {
@ -95,6 +125,40 @@ impl<const R: usize> Idx<R> {
}
}
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
// Check if the fixed axis index is already in an invalid state
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
return false;
}
// Try to decrement non-fixed axes
for i in (0..R).rev() {
if i != fixed_axis {
if self.indices[i] > 0 {
self.indices[i] -= 1;
return true;
} else {
self.indices[i] = self.shape.get(i) - 1;
}
}
}
// Decrement the fixed axis if possible and reset other axes to their max
if self.indices[fixed_axis] > 0 {
self.indices[fixed_axis] -= 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = self.shape.get(i) - 1;
}
}
} else {
// Fixed axis already at minimum, set to invalid state
self.indices[fixed_axis] = self.shape.get(fixed_axis);
}
true
}
/// Converts the multi-dimensional index to a flat index.
///
/// This method calculates the flat index corresponding to the multi-dimensional index
@ -136,6 +200,56 @@ impl<const R: usize> Idx<R> {
})
.0
}
// pub fn inc_axis(&mut self, axis: usize) -> bool {
// if axis >= R {
// // Axis is out of bounds
// return false;
// }
// let dim_size = self.shape.get(axis);
// if self.indices[axis] + 1 < dim_size {
// // Increment the index if it's not at its maximum
// self.indices[axis] += 1;
// true
// } else {
// // Index is at its maximum for this axis
// false
// }
// }
pub fn set_axis(&mut self, axis: usize, value: usize) {
assert!(axis < R, "Axis out of bounds");
// assert!(value < self.shape.get(axis), "Value out of bounds");
self.indices[axis] = value;
}
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
assert!(axis < R, "Axis out of bounds");
if value < self.shape.get(axis) {
self.indices[axis] = value;
true
} else {
false
}
}
pub fn get_axis(&self, axis: usize) -> usize {
assert!(axis < R, "Axis out of bounds");
self.indices[axis]
}
pub fn indices(&self) -> &[usize; R] {
&self.indices
}
pub fn shape(&self) -> Shape<R> {
self.shape
}
pub fn shape_mut(&mut self) -> &mut Shape<R> {
&mut self.shape
}
}
// --- blanket impls ---
@ -194,6 +308,12 @@ impl<const R: usize> From<Shape<R>> for Idx<R> {
}
}
impl<T: Value, const R: usize> From<Tensor<T, R>> for Idx<R> {
fn from(tensor: Tensor<T, R>) -> Self {
Self::zero(tensor.shape())
}
}
impl<const R: usize> std::fmt::Display for Idx<R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "[")?;
@ -249,3 +369,45 @@ impl<const R: usize> Sub for Idx<R> {
}
}
}
// ---- Iterator ----
pub struct IdxIterator<const R: usize> {
current: Idx<R>,
end: bool,
}
impl<const R: usize> IdxIterator<R> {
pub fn new(shape: Shape<R>) -> Self {
Self {
current: Idx::zero(shape),
end: false,
}
}
}
impl<const R: usize> Iterator for IdxIterator<R> {
type Item = Idx<R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
return None;
}
let result = self.current;
self.end = self.current.inc();
Some(result)
}
}
impl<const R: usize> IntoIterator for Idx<R> {
type Item = Idx<R>;
type IntoIter = IdxIterator<R>;
fn into_iter(self) -> Self::IntoIter {
IdxIterator {
current: self,
end: false,
}
}
}

View File

@ -2,6 +2,7 @@
#![feature(generic_const_exprs)]
pub mod index;
pub mod shape;
pub mod axis;
pub mod tensor;
pub use index::Idx;
@ -11,6 +12,10 @@ pub use shape::Shape;
pub use std::fmt::{Display, Formatter, Result as FmtResult};
use std::ops::{Index, IndexMut};
pub use tensor::{Tensor, TensorIterator};
pub use static_assertions::const_assert;
pub use itertools::Itertools;
pub use std::sync::Arc;
pub use axis::*;
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>

View File

@ -23,11 +23,27 @@ impl<T: Value> Tensor<T, 1> {
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn new(shape: Shape<R>) -> Self {
let total_size: usize = shape.iter().product();
// Handle rank 0 tensor (scalar) as a special case
let total_size = if R == 0 {
// A rank 0 tensor should still have a buffer with one element
1
} else {
// For tensors of rank 1 or higher, calculate the total size normally
shape.iter().product()
};
let buffer = vec![T::zero(); total_size];
Self { buffer, shape }
}
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
Self { buffer, shape }
}
pub fn idx(&self) -> Idx<R> {
Idx::from(self.clone())
}
pub fn shape(&self) -> Shape<R> {
self.shape
}
@ -111,85 +127,63 @@ impl<T: Value, const R: usize> Tensor<T, R> {
}
}
// `self_dims` and `other_dims` specify the dimensions to contract over
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
&self,
rhs: &Tensor<T, S>,
axis_lhs: [usize; NAXL],
axis_rhs: [usize; NAXR],
) -> Tensor<T, { R + S - NAXL - NAXR }>
where
[(); R - NAXL]:,
[(); S - NAXR]:,
{
// Step 1: Validate the axes for both tensors to ensure they are within bounds
for &axis in &axis_lhs {
if axis >= R {
panic!(
"Axis {} is out of bounds for the left-hand tensor",
axis
);
}
}
// // Step 1: Validate the axes for both tensors to ensure they are within bounds
// for &axis in &axis_lhs {
// if axis >= R {
// panic!(
// "Axis {} is out of bounds for the left-hand tensor",
// axis
// );
// }
// }
for &axis in &axis_rhs {
if axis >= S {
panic!(
"Axis {} is out of bounds for the right-hand tensor",
axis
);
}
}
// for &axis in &axis_rhs {
// if axis >= S {
// panic!(
// "Axis {} is out of bounds for the right-hand tensor",
// axis
// );
// }
// }
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
let mut result_buffer = Vec::new();
// // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
// let mut result_buffer = Vec::new();
for i in 0..self.shape.size() {
for j in 0..rhs.shape.size() {
// Debug: Print indices being processed
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
// for i in self.iter_over_non_contracted_axes(&axis_lhs) {
// for j in rhs.iter_over_non_contracted_axes(&axis_rhs) {
// let mut product_sum = T::zero();
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
let mut product_sum = T::zero();
// // Step 3: Contraction over specified axes
// for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
// let idx_l = i[axis_l];
// let idx_r = j[axis_r];
// Debug: Print axes of contraction
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
// let value_lhs = self.buffer[self.flat(&idx_l)];
// let value_rhs = rhs.buffer[rhs.flat(&idx_r)];
// product_sum = product_sum + value_lhs * value_rhs;
// }
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
// Debug: Print values being multiplied
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
// result_buffer.push(product_sum);
// }
// }
product_sum = product_sum + value_lhs * value_rhs;
}
// // Step 4: Remove contracted dimensions to create new shapes for both tensors
// let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs);
// let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
// Debug: Print the product sum for the current indices
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
// let mut new_shape = Vec::new();
// new_shape.extend_from_slice(&new_shape_lhs.as_array());
// new_shape.extend_from_slice(&new_shape_rhs.as_array());
result_buffer.push(product_sum);
}
}
}
// let new_shape_array: [usize; R + S - NAXL - NAXR] =
// new_shape.try_into().expect("Failed to create shape array");
// Step 3: Remove contracted dimensions to create new shapes for both tensors
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
let mut new_shape = Vec::new();
new_shape.extend_from_slice(&new_shape_lhs.as_array());
new_shape.extend_from_slice(&new_shape_rhs.as_array());
let new_shape_array: [usize; R + S - NAXL - NAXR] =
new_shape.try_into().expect("Failed to create shape array");
Tensor {
buffer: result_buffer,
shape: Shape::new(new_shape_array),
}
}
// Tensor {
// buffer: result_buffer,
// shape: Shape::new(new_shape_array),
// }
// }
// Retrieve an element based on a specific axis and index
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
@ -366,3 +360,23 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize>
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
{
fn from(array: [[[T; X]; Y]; Z]) -> Self {
let shape = Shape::new([Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, plane) in array.iter().enumerate() {
for (j, row) in plane.iter().enumerate() {
for (k, &elem) in row.iter().enumerate() {
buffer[i * X * Y + j * X + k] = elem;
}
}
}
tensor
}
}