Work on tensor contraction and axis iteration
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
b17264ba18
commit
d54179cd5d
99
Cargo.lock
generated
99
Cargo.lock
generated
@ -14,22 +14,39 @@ version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
|
||||
|
||||
[[package]]
|
||||
name = "getset"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||
|
||||
[[package]]
|
||||
name = "mltensor"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"num",
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
@ -106,6 +123,30 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.71"
|
||||
@ -147,7 +188,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.43",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -161,6 +202,23 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.43"
|
||||
@ -172,8 +230,27 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tensorc"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"getset",
|
||||
"itertools",
|
||||
"num",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"static_assertions",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
@ -1,5 +1,5 @@
|
||||
[package]
|
||||
name = "mltensor"
|
||||
name = "tensorc"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
@ -7,6 +7,9 @@ edition = "2021"
|
||||
|
||||
[dependencies]
|
||||
bytemuck = "1.14.0"
|
||||
getset = "0.1.2"
|
||||
itertools = "0.12.0"
|
||||
num = "0.4.1"
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
||||
static_assertions = "1.1.0"
|
||||
|
@ -1,4 +1,7 @@
|
||||
use mltensor::*;
|
||||
#![allow(mixed_script_confusables)]
|
||||
#![allow(non_snake_case)]
|
||||
use bytemuck::cast_slice;
|
||||
use tensorc::*;
|
||||
|
||||
fn tensor_product() {
|
||||
println!("Tensor Product\n");
|
||||
@ -20,8 +23,9 @@ fn tensor_product() {
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||
assert_eq!(product.buffer(), &expected_buffer);
|
||||
let expect: &[i32] =
|
||||
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
|
||||
assert_eq!(product.buffer(), expect);
|
||||
}
|
||||
|
||||
fn tensor_contraction() {
|
||||
@ -35,16 +39,25 @@ fn tensor_contraction() {
|
||||
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
||||
|
||||
// Perform contraction
|
||||
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||
// let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
println!("T1 * T2 = {}", result);
|
||||
// println!("T1: {}", tensor1);
|
||||
// println!("T2: {}", tensor2);
|
||||
// println!("T1 * T2 = {}", result);
|
||||
|
||||
// Expected result, for example, could be a single number or a new tensor,
|
||||
// depending on how you defined the contraction operation.
|
||||
// Assert the result is as expected
|
||||
// assert_eq!(result, expected_result);
|
||||
|
||||
// let Λ = Tensor::<f64, 2>::from([
|
||||
// [1.0, 0.0, 0.0, 0.0],
|
||||
// [0.0, 1.0, 0.0 ,0.0],
|
||||
// [0.0, 0.0, 1.0, 0.0],
|
||||
// [0.0, 0.0, 0.0, 1.0]
|
||||
// ]);
|
||||
|
||||
// println!("Λ: {}", Λ);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
|
581
src/axis.rs
Normal file
581
src/axis.rs
Normal file
@ -0,0 +1,581 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
#[derive(Clone, Debug, Getters)]
|
||||
pub struct Axis<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
tensor: &'a Tensor<T, R>,
|
||||
#[getset(get = "pub")]
|
||||
shape: Shape<R>,
|
||||
#[getset(get = "pub")]
|
||||
dim: usize,
|
||||
#[getset(get = "pub")]
|
||||
len: usize,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
||||
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
||||
assert!(dim < R, "Axis out of bounds");
|
||||
Self {
|
||||
tensor,
|
||||
shape: tensor.shape(),
|
||||
dim,
|
||||
len: tensor.shape().get(dim),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn iter_level(&self, level: usize) -> AxisIterator<'a, T, R> {
|
||||
assert!(level < self.len, "Level out of bounds");
|
||||
let mut index = Idx::new(self.shape.clone(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
AxisIterator::new(self.clone()).set_start(level).set_end(level + 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Getters, MutGetters)]
|
||||
pub struct AxisIterator<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
axis: Axis<'a, T, R>,
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
index: Idx<R>,
|
||||
#[getset(get = "pub")]
|
||||
end: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
||||
pub fn new(axis: Axis<'a, T, R>) -> Self {
|
||||
Self {
|
||||
axis: axis.clone(),
|
||||
index: Idx::new(axis.shape().clone(), [0; R]),
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_start(self, start: usize) -> Self {
|
||||
assert!(start < self.axis.len, "Start out of bounds");
|
||||
let mut index = Idx::new(self.axis.shape().clone(), [0; R]);
|
||||
index.set_axis(self.axis.dim, start);
|
||||
Self {
|
||||
axis: self.axis.clone(),
|
||||
index,
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_end(self, end: usize) -> Self {
|
||||
assert!(end <= self.axis.len, "End out of bounds");
|
||||
Self {
|
||||
axis: self.axis.clone(),
|
||||
index: self.index.clone(),
|
||||
end: Some(end),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_level(self, level: usize) -> Self {
|
||||
assert!(level < self.axis.len, "Level out of bounds");
|
||||
self.set_start(level).set_end(level + 1)
|
||||
}
|
||||
|
||||
pub fn disassemble(self) -> Vec<Self> {
|
||||
let mut result = Vec::new();
|
||||
for i in 0..self.axis().len {
|
||||
result.push(Self::new(self.axis.clone()).set_level(i));
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
pub fn axis_max_idx(&self) -> usize {
|
||||
self.end().unwrap_or(*self.axis().len())
|
||||
}
|
||||
|
||||
pub fn axis_idx(&self) -> usize {
|
||||
self.index().get_axis(*self.axis().dim())
|
||||
}
|
||||
|
||||
pub fn axis_dim(&self) -> usize {
|
||||
self.axis().dim().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
// Increment the index along the fixed axis and check if it's within bounds
|
||||
let result = self.axis().tensor().get(self.index);
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for Axis<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = AxisIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
AxisIterator::new(self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contract<
|
||||
'a,
|
||||
T: Value + std::fmt::Debug,
|
||||
const R: usize,
|
||||
const S: usize,
|
||||
const N: usize,
|
||||
>(
|
||||
lhs: &'a Tensor<T, R>,
|
||||
rhs: &'a Tensor<T, S>,
|
||||
laxes: [Axis<'a, T, R>; N],
|
||||
raxes: [Axis<'a, T, S>; N],
|
||||
) -> Tensor<T, { R + S - 2 * N }>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
[(); R + S - 2 * N]:,
|
||||
{
|
||||
let li = laxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
let ri = raxes
|
||||
.clone()
|
||||
.into_iter()
|
||||
.map(|axis| axis.into_iter())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let result = contract_axes(li.try_into().unwrap(), ri.try_into().unwrap());
|
||||
|
||||
let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||
laxes
|
||||
.iter()
|
||||
.map(|axis| *axis.dim())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
||||
raxes
|
||||
.iter()
|
||||
.map(|axis| *axis.dim())
|
||||
.collect::<Vec<_>>()
|
||||
.try_into()
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||
shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
let shape = Shape::new(shape);
|
||||
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
}
|
||||
|
||||
pub fn contract_axes<
|
||||
'a,
|
||||
T: Value + std::fmt::Debug,
|
||||
const R: usize,
|
||||
const S: usize,
|
||||
const N: usize,
|
||||
>(
|
||||
laxes: [AxisIterator<'a, T, R>; N],
|
||||
raxes: [AxisIterator<'a, T, S>; N],
|
||||
) -> Vec<T>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
// [(); R + S - 2 * N]:,
|
||||
{
|
||||
let mut result = Vec::new();
|
||||
|
||||
// Disassemble each axis iterator into iterators over each level
|
||||
let laxes: Vec<Vec<AxisIterator<'a, T, R>>> =
|
||||
laxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||
let raxes: Vec<Vec<AxisIterator<'a, T, S>>> =
|
||||
raxes.into_iter().map(|axis| axis.disassemble()).collect();
|
||||
|
||||
// Iterate over all combinations of levels for contracted axes
|
||||
|
||||
let lcart = laxes.into_iter().multi_cartesian_product();
|
||||
let rcart = raxes.into_iter().multi_cartesian_product();
|
||||
let cartesian = lcart.zip(rcart);
|
||||
|
||||
for (i, (l, r)) in cartesian.enumerate() {
|
||||
println!("{} cartesian:", i);
|
||||
let mut sum = T::zero();
|
||||
|
||||
let level = l.into_iter().zip(r.into_iter());
|
||||
|
||||
for (i, (l, r)) in level.enumerate() {
|
||||
// let mut sum = T::zero();
|
||||
println!("{} level:", i);
|
||||
let value = l.into_iter().zip(r.into_iter());
|
||||
|
||||
for (i, (l, r)) in value.enumerate() {
|
||||
println!("{} product: {} + {} * {} = {}", i, sum, l, r, sum + *l * *r);
|
||||
sum = sum + *l * *r;
|
||||
}
|
||||
// result.push(sum);
|
||||
}
|
||||
|
||||
result.push(sum);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
// pub fn contract_axes<
|
||||
// 'a,
|
||||
// T: Value + std::fmt::Debug,
|
||||
// const R: usize,
|
||||
// const S: usize,
|
||||
// const N: usize,
|
||||
// >(
|
||||
// // Axes are (contracted, non-contracted)
|
||||
// laxes: ([AxisIterator<'a, T, R>; N], [AxisIterator<'a, T, R>; R - N]),
|
||||
// raxes: ([AxisIterator<'a, T, S>; N], [AxisIterator<'a, T, S>; S - N]),
|
||||
// ) -> Vec<T> //Tensor<T, { R + S - 2 * N }>
|
||||
// where
|
||||
// [(); R - N]:,
|
||||
// [(); S - N]:,
|
||||
// [(); R + S - 2 * N]:,
|
||||
// {
|
||||
// let (lc, lnc) = laxes;
|
||||
// let (rc, rnc) = raxes;
|
||||
|
||||
// // Step 1: Prepare cartesian products of contracted axes
|
||||
// let lc_prod = lc.into_iter().multi_cartesian_product();
|
||||
// let rc_prod = rc.into_iter().multi_cartesian_product();
|
||||
|
||||
// // Step 2: Prepare iterators for non-contracted axes
|
||||
// let lnc_prod = lnc.into_iter().multi_cartesian_product();
|
||||
// let rnc_prod = rnc.into_iter().multi_cartesian_product();
|
||||
|
||||
// // Initialize buffer for the resulting tensor
|
||||
// let mut result_buffer = Vec::new();
|
||||
|
||||
// // Step 3: Iterate over combinations and compute tensor product
|
||||
// for lnc_vals in lnc_prod {
|
||||
// for rnc_vals in rnc_prod.clone() {
|
||||
// let mut sum = T::zero();
|
||||
// for ((lc, rc), (lnc_val, rnc_val)) in lc_prod
|
||||
// .clone()
|
||||
// .zip(rc_prod.clone())
|
||||
// .zip(lnc_vals.iter().zip(rnc_vals.iter()))
|
||||
// {
|
||||
// for (lc_val, rc_val) in lc.into_iter().zip(rc.into_iter()) {
|
||||
// sum = sum + (*lc_val * *rc_val) * (**lnc_val * **rnc_val);
|
||||
// }
|
||||
// }
|
||||
// result_buffer.push(sum);
|
||||
// }
|
||||
// }
|
||||
|
||||
// result_buffer
|
||||
// }
|
||||
|
||||
// // Check that there are no duplicate axes
|
||||
// for (i, laxis) in laxes.iter().enumerate() {
|
||||
// for (j, laxisx) in laxes.iter().enumerate() {
|
||||
// if i != j && laxis == laxisx {
|
||||
// panic!("Duplicate axes");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// for (i, raxis) in raxes.iter().enumerate() {
|
||||
// for (j, raxisx) in raxes.iter().enumerate() {
|
||||
// if i != j && raxis == raxisx {
|
||||
// panic!("Duplicate axes");
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// // Check that the sizes of the axes to contract are the same
|
||||
// // for (laxis, raxis) in laxes.iter().zip(raxes.iter()) {
|
||||
// // // Check that axes are within bounds
|
||||
// // if lhs.shape()[*laxis] != *raxis {
|
||||
// // println!("laxis: {}, raxis: {}", laxis, raxis);
|
||||
// // panic!("Axis out of bounds");
|
||||
// // }
|
||||
// // }
|
||||
// // 1. Axis iterators
|
||||
// //
|
||||
// // Create iterators for each non-contracted axis of the left-hand and right-hand tensors
|
||||
// // and take the Cartesian product of the iterators.
|
||||
|
||||
// let lcartesian_contract = lhs
|
||||
// .shape()
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(i, _)| laxes.contains(i))
|
||||
// .map(|(i, _)| {
|
||||
// println!("LHS is adding axis: {}", i);
|
||||
// let vals = Axis::new(lhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||
// println!("vals: {:?}", vals);
|
||||
// })
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
// for l in lcartesian_contract.iter() {
|
||||
// let l = l.clone();
|
||||
// println!("{:?}", l);
|
||||
// }
|
||||
// // .multi_cartesian_product();
|
||||
|
||||
// // let cart_product = lcartesian_contract.multi_cartesian_product();
|
||||
|
||||
// // for l in cart_product {
|
||||
// // println!("cartesian_contract l:");
|
||||
// // for r in l {
|
||||
// // println!("l: {:?}", r);
|
||||
// // }
|
||||
// // }
|
||||
|
||||
// let rcartesian_contract = rhs
|
||||
// .shape()
|
||||
// .iter()
|
||||
// .enumerate()
|
||||
// .filter(|(i, _)| raxes.contains(i))
|
||||
// .map(|(i, _)| {
|
||||
// println!("RHS is adding axis: {}", i);
|
||||
// let vals = Axis::new(rhs.clone(), i).into_iter().collect::<Vec<_>>();
|
||||
// println!("vals: {:?}", vals);
|
||||
// })
|
||||
// .collect::<Vec<_>>();
|
||||
|
||||
// for r in rcartesian_contract.iter() {
|
||||
// let r = r.clone();
|
||||
// println!("{:?}", r);
|
||||
// }
|
||||
// // .multi_cartesian_product();
|
||||
|
||||
// // println!("lcartesian_contract: {:?}", lcartesian_contract);
|
||||
// // println!("rcartesian_contract: {:?}", rcartesian_contract);
|
||||
|
||||
// // Initialize buffer for the resulting tensor
|
||||
// // let mut result_buffer = Vec::new();
|
||||
// // let zip_contract = lcartesian_contract.zip(rcartesian_contract);
|
||||
|
||||
// // for (i, (j, k)) in zip_contract.enumerate() {
|
||||
// // let mut sum = T::zero();
|
||||
// // print!("{}. sum = ", i);
|
||||
// // // for (lhsnc, rhsnc) in j.into_iter().zip(k.into_iter()) {
|
||||
// // // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||
// // // sum = sum + lhsnc * rhsnc;
|
||||
// // // }
|
||||
// // for lhsnc in j.iter() {
|
||||
// // for rhsnc in k.iter() {
|
||||
// // print!("{} * {} + {} ", lhsnc, rhsnc, sum);
|
||||
// // sum = sum + *lhsnc * *rhsnc;
|
||||
// // }
|
||||
// // }
|
||||
// // print!("= {}\n", sum);
|
||||
// // result_buffer.push(sum);
|
||||
// // }
|
||||
|
||||
// // Iterate over non-contracted axes combinations
|
||||
// // for (lhsnci, rhsnci) in
|
||||
// // lcartesian_contract.zip(rcartesian_contract)
|
||||
// // {
|
||||
// // let mut sum = T::zero();
|
||||
// // for (lhsnc, rhsnc) in lhsnci.iter().zip(rhsnci.iter()) {
|
||||
// // sum = sum + *lhsnc * *rhsnc;
|
||||
// // }
|
||||
// // // Append the result to the buffer
|
||||
// // result_buffer.push(sum);
|
||||
// // }
|
||||
|
||||
// // TODO! Implement the rest of the function
|
||||
|
||||
// let lhs_shape_reduced = lhs.shape().remove_dims::<{ N }>(
|
||||
// laxes
|
||||
// .iter()
|
||||
// .map(|axis| *axis)
|
||||
// .collect::<Vec<_>>()
|
||||
// .try_into()
|
||||
// .unwrap(),
|
||||
// );
|
||||
// let rhs_shape_reduced = rhs.shape().remove_dims::<{ N }>(
|
||||
// raxes
|
||||
// .iter()
|
||||
// .map(|axis| *axis)
|
||||
// .collect::<Vec<_>>()
|
||||
// .try_into()
|
||||
// .unwrap(),
|
||||
// );
|
||||
|
||||
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
// let mut shape = Vec::new();
|
||||
// shape.extend_from_slice(&lhs_shape_reduced.as_array());
|
||||
// shape.extend_from_slice(&rhs_shape_reduced.as_array());
|
||||
|
||||
// let shape: [usize; R + S - 2 * N] =
|
||||
// shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
// let shape = Shape::new(shape);
|
||||
|
||||
// Tensor::new_with_buffer(shape, vec![])
|
||||
// }
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tensor_contraction_simple() {
|
||||
// Define two 2D tensors (matrices)
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||
|
||||
// Tensor B is 1x3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let contracted_tensor: Tensor<i32, 2> =
|
||||
contract(&a, &b, [Axis::new(&a, 1)], [Axis::new(&b, 0)]);
|
||||
|
||||
// Expect: [22, 28, 49, 64]
|
||||
// Check if the contracted tensor is as expected
|
||||
println!("A: {}", a);
|
||||
println!("B: {}", b);
|
||||
println!("A * B = {}", contracted_tensor);
|
||||
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_tensor_contraction() {
|
||||
// // Define two 2D tensors (matrices)
|
||||
// // Tensor A is 2x3
|
||||
// let tensor_a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
// println!("A: {}", tensor_a);
|
||||
|
||||
// // Tensor B is 1x3x2
|
||||
// let tensor_b: Tensor<i32, 3> = Tensor::from([[[1, 2], [3, 4], [5, 6]]]);
|
||||
// println!("B: {}", tensor_b);
|
||||
|
||||
// // Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
// let contracted_tensor: Tensor<i32, 3> =
|
||||
// contract(tensor_a.clone(), tensor_b.clone(), [1], [1]);
|
||||
|
||||
// // Expect: [22, 28, 49, 64]
|
||||
// // Check if the contracted tensor is as expected
|
||||
// println!("A: {}", tensor_a);
|
||||
// println!("B: {}", tensor_b);
|
||||
// println!("A * B = {}", contracted_tensor);
|
||||
// println!("Expected: {}", Tensor::from([[22, 28], [49, 64]]));
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
axis_iter.for_each(|x| println!("t1: {}", x));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
axis_iter.for_each(|x| println!("t2: {}", x));
|
||||
// assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_3d_tensor_axis_iteration() {
|
||||
// // Create a 3D Tensor with specific values
|
||||
// // Tensor shape is 2x2x2 for simplicity
|
||||
// let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// // Axis 0 (Layer-wise):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
// //
|
||||
// // This order suggests that for each "layer" (first level of arrays),
|
||||
// // the iterator goes through all rows and columns. It first completes
|
||||
// // the entire first layer, then moves to the second.
|
||||
|
||||
// let a0 = Axis::new(t.clone(), 0);
|
||||
// let a0_order = a0.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// // Axis 1 (Row-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][0][1] = 2
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][0][1] = 6
|
||||
// // t[0][1][0] = 3
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][1][0] = 7
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first row across all layers, then the second row
|
||||
// // across all layers.
|
||||
|
||||
// let a1 = Axis::new(t.clone(), 1);
|
||||
// let a1_order = a1.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// // Axis 2 (Column-wise within each layer):
|
||||
// //
|
||||
// // t[0][0][0] = 1
|
||||
// // t[0][1][0] = 3
|
||||
// // t[1][0][0] = 5
|
||||
// // t[1][1][0] = 7
|
||||
// // t[0][0][1] = 2
|
||||
// // t[0][1][1] = 4
|
||||
// // t[1][0][1] = 6
|
||||
// // t[1][1][1] = 8
|
||||
// // [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
// //
|
||||
// // This indicates that within each "layer", the iterator first
|
||||
// // completes the first column across all layers, then the second
|
||||
// // column across all layers.
|
||||
|
||||
// let a2 = Axis::new(t, 2);
|
||||
// let a2_order = a2.into_iter().collect::<Vec<_>>();
|
||||
// assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
// }
|
||||
// }
|
164
src/index.rs
164
src/index.rs
@ -51,6 +51,9 @@ impl<const R: usize> Idx<R> {
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
/// `false` if it overflows, indicating the end of the tensor.
|
||||
pub fn inc(&mut self) -> bool {
|
||||
if self.indices[0] >= self.shape.get(0) {
|
||||
return false;
|
||||
}
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
@ -67,13 +70,40 @@ impl<const R: usize> Idx<R> {
|
||||
|
||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||
if carry == 1 {
|
||||
// Set the index to an invalid state (e.g., all indices to their max values)
|
||||
// Set the index to an invalid state to indicate the end of the iteration indicated
|
||||
// by setting the first index to the size of the first dimension
|
||||
self.indices[0] = self.shape.as_array()[0];
|
||||
return true; // Indicate that the iteration is complete
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// fn inc_axis
|
||||
|
||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||
assert!(fixed_axis < R, "Axis out of bounds");
|
||||
assert!(self.indices[fixed_axis] < self.shape.get(fixed_axis), "Index out of bounds");
|
||||
|
||||
// Try to increment non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] < self.shape.get(i) {
|
||||
self.indices[i] += 1;
|
||||
} else {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Increment the fixed axis and reset other axes to 0
|
||||
self.indices[fixed_axis] += 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -95,6 +125,40 @@ impl<const R: usize> Idx<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
|
||||
// Check if the fixed axis index is already in an invalid state
|
||||
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to decrement non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] > 0 {
|
||||
self.indices[i] -= 1;
|
||||
return true;
|
||||
} else {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the fixed axis if possible and reset other axes to their max
|
||||
if self.indices[fixed_axis] > 0 {
|
||||
self.indices[fixed_axis] -= 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fixed axis already at minimum, set to invalid state
|
||||
self.indices[fixed_axis] = self.shape.get(fixed_axis);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||
@ -136,6 +200,56 @@ impl<const R: usize> Idx<R> {
|
||||
})
|
||||
.0
|
||||
}
|
||||
|
||||
// pub fn inc_axis(&mut self, axis: usize) -> bool {
|
||||
// if axis >= R {
|
||||
// // Axis is out of bounds
|
||||
// return false;
|
||||
// }
|
||||
|
||||
// let dim_size = self.shape.get(axis);
|
||||
// if self.indices[axis] + 1 < dim_size {
|
||||
// // Increment the index if it's not at its maximum
|
||||
// self.indices[axis] += 1;
|
||||
// true
|
||||
// } else {
|
||||
// // Index is at its maximum for this axis
|
||||
// false
|
||||
// }
|
||||
// }
|
||||
|
||||
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
||||
self.indices[axis] = value;
|
||||
}
|
||||
|
||||
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
if value < self.shape.get(axis) {
|
||||
self.indices[axis] = value;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_axis(&self, axis: usize) -> usize {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
self.indices[axis]
|
||||
}
|
||||
|
||||
pub fn indices(&self) -> &[usize; R] {
|
||||
&self.indices
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Shape<R> {
|
||||
self.shape
|
||||
}
|
||||
|
||||
pub fn shape_mut(&mut self) -> &mut Shape<R> {
|
||||
&mut self.shape
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
@ -194,6 +308,12 @@ impl<const R: usize> From<Shape<R>> for Idx<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> From<Tensor<T, R>> for Idx<R> {
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
@ -249,3 +369,45 @@ impl<const R: usize> Sub for Idx<R> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Iterator ----
|
||||
|
||||
pub struct IdxIterator<const R: usize> {
|
||||
current: Idx<R>,
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<const R: usize> IdxIterator<R> {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
Self {
|
||||
current: Idx::zero(shape),
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Iterator for IdxIterator<R> {
|
||||
type Item = Idx<R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> IntoIterator for Idx<R> {
|
||||
type Item = Idx<R>;
|
||||
type IntoIter = IdxIterator<R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
IdxIterator {
|
||||
current: self,
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2,6 +2,7 @@
|
||||
#![feature(generic_const_exprs)]
|
||||
pub mod index;
|
||||
pub mod shape;
|
||||
pub mod axis;
|
||||
pub mod tensor;
|
||||
|
||||
pub use index::Idx;
|
||||
@ -11,6 +12,10 @@ pub use shape::Shape;
|
||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use std::ops::{Index, IndexMut};
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
pub use static_assertions::const_assert;
|
||||
pub use itertools::Itertools;
|
||||
pub use std::sync::Arc;
|
||||
pub use axis::*;
|
||||
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
|
154
src/tensor.rs
154
src/tensor.rs
@ -23,11 +23,27 @@ impl<T: Value> Tensor<T, 1> {
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
let total_size: usize = shape.iter().product();
|
||||
// Handle rank 0 tensor (scalar) as a special case
|
||||
let total_size = if R == 0 {
|
||||
// A rank 0 tensor should still have a buffer with one element
|
||||
1
|
||||
} else {
|
||||
// For tensors of rank 1 or higher, calculate the total size normally
|
||||
shape.iter().product()
|
||||
};
|
||||
|
||||
let buffer = vec![T::zero(); total_size];
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn idx(&self) -> Idx<R> {
|
||||
Idx::from(self.clone())
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Shape<R> {
|
||||
self.shape
|
||||
}
|
||||
@ -111,85 +127,63 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// `self_dims` and `other_dims` specify the dimensions to contract over
|
||||
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
|
||||
&self,
|
||||
rhs: &Tensor<T, S>,
|
||||
axis_lhs: [usize; NAXL],
|
||||
axis_rhs: [usize; NAXR],
|
||||
) -> Tensor<T, { R + S - NAXL - NAXR }>
|
||||
where
|
||||
[(); R - NAXL]:,
|
||||
[(); S - NAXR]:,
|
||||
{
|
||||
// Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||
for &axis in &axis_lhs {
|
||||
if axis >= R {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the left-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
// // Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||
// for &axis in &axis_lhs {
|
||||
// if axis >= R {
|
||||
// panic!(
|
||||
// "Axis {} is out of bounds for the left-hand tensor",
|
||||
// axis
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
for &axis in &axis_rhs {
|
||||
if axis >= S {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the right-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
// for &axis in &axis_rhs {
|
||||
// if axis >= S {
|
||||
// panic!(
|
||||
// "Axis {} is out of bounds for the right-hand tensor",
|
||||
// axis
|
||||
// );
|
||||
// }
|
||||
// }
|
||||
|
||||
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||
let mut result_buffer = Vec::new();
|
||||
// // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||
// let mut result_buffer = Vec::new();
|
||||
|
||||
for i in 0..self.shape.size() {
|
||||
for j in 0..rhs.shape.size() {
|
||||
// Debug: Print indices being processed
|
||||
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
|
||||
// for i in self.iter_over_non_contracted_axes(&axis_lhs) {
|
||||
// for j in rhs.iter_over_non_contracted_axes(&axis_rhs) {
|
||||
// let mut product_sum = T::zero();
|
||||
|
||||
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
|
||||
let mut product_sum = T::zero();
|
||||
// // Step 3: Contraction over specified axes
|
||||
// for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||
// let idx_l = i[axis_l];
|
||||
// let idx_r = j[axis_r];
|
||||
|
||||
// Debug: Print axes of contraction
|
||||
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
|
||||
// let value_lhs = self.buffer[self.flat(&idx_l)];
|
||||
// let value_rhs = rhs.buffer[rhs.flat(&idx_r)];
|
||||
// product_sum = product_sum + value_lhs * value_rhs;
|
||||
// }
|
||||
|
||||
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||
// Debug: Print values being multiplied
|
||||
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
|
||||
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
|
||||
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
|
||||
// result_buffer.push(product_sum);
|
||||
// }
|
||||
// }
|
||||
|
||||
product_sum = product_sum + value_lhs * value_rhs;
|
||||
}
|
||||
// // Step 4: Remove contracted dimensions to create new shapes for both tensors
|
||||
// let new_shape_lhs = self.shape.remove_dims::<{ NL }>(axis_lhs);
|
||||
// let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||
|
||||
// Debug: Print the product sum for the current indices
|
||||
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
|
||||
// // Step 5: Concatenate the shapes to form the shape of the resultant tensor
|
||||
// let mut new_shape = Vec::new();
|
||||
// new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||
// new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||
|
||||
result_buffer.push(product_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
// let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||
// new_shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
// Step 3: Remove contracted dimensions to create new shapes for both tensors
|
||||
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
|
||||
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||
|
||||
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
|
||||
let mut new_shape = Vec::new();
|
||||
|
||||
new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||
new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||
|
||||
let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||
new_shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
Tensor {
|
||||
buffer: result_buffer,
|
||||
shape: Shape::new(new_shape_array),
|
||||
}
|
||||
}
|
||||
// Tensor {
|
||||
// buffer: result_buffer,
|
||||
// shape: Shape::new(new_shape_array),
|
||||
// }
|
||||
// }
|
||||
|
||||
// Retrieve an element based on a specific axis and index
|
||||
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||
@ -366,3 +360,23 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
||||
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
|
||||
{
|
||||
fn from(array: [[[T; X]; Y]; Z]) -> Self {
|
||||
let shape = Shape::new([Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, plane) in array.iter().enumerate() {
|
||||
for (j, row) in plane.iter().enumerate() {
|
||||
for (k, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y + j * X + k] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user