Rank 3 contraction not working yet

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-30 01:39:57 +02:00
parent 33c056f2b9
commit f9c29aefd5
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 86 additions and 64 deletions

View File

@ -1,8 +1,8 @@
#![allow(mixed_script_confusables)]
#![allow(non_snake_case)]
use bytemuck::cast_slice;
use manifold::*;
use manifold::contract;
use manifold::*;
fn tensor_product() {
println!("Tensor Product\n");
@ -30,33 +30,49 @@ fn tensor_product() {
}
fn test_tensor_contraction_23x32() {
// Define two 2D tensors (matrices)
// Define two 2D tensors (matrices)
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
println!("a: {}", a);
// Tensor A is 2x3
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
println!("a: {:?}\n{}\n", a.shape(), a);
// Tensor B is 3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
println!("b: {}", b);
// Tensor B is 3x2
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
println!("b: {:?}\n{}\n", b.shape(), b);
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let ctr10 = contract((&a, [1]), (&b, [0]));
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
let ctr10 = contract((&a, [1]), (&b, [0]));
println!("[1, 0]: {}", ctr10);
println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10);
let ctr01 = contract((&a, [0]), (&b, [1]));
let ctr01 = contract((&a, [0]), (&b, [1]));
println!("[0, 1]: {}", ctr01);
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
// assert_eq!(
// contracted_tensor.buffer(),
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
// "Contracted tensor buffer does not match expected"
// );
println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01);
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
// assert_eq!(
// contracted_tensor.buffer(),
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
// "Contracted tensor buffer does not match expected"
// );
}
fn test_tensor_contraction_rank3() {
let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
let contracted_tensor = contract((&a, [2]), (&b, [0]));
println!("a: {}", a);
println!("b: {}", b);
println!("contracted_tensor: {}", contracted_tensor);
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
// Verify specific elements of contracted_tensor
// assert_eq!(contracted_tensor[0][0][0][0], 50);
// assert_eq!(contracted_tensor[0][0][0][1], 60);
// ... further checks for other elements ...
}
fn main() {
// tensor_product();
test_tensor_contraction_23x32();
test_tensor_contraction_rank3();
}

View File

@ -144,35 +144,23 @@ where
{
let (lhs, la) = lhs;
let (rhs, ra) = rhs;
let raxes = ra.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let raxes: [Axis<'a, T, S>; N] =
raxes.try_into().expect("Failed to create raxes array");
let laxes = la.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let laxes: [Axis<'a, T, R>; N] =
laxes.try_into().expect("Failed to create laxes array");
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
let shape = Shape::new(shape);
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let result = contract_axes(laxes, raxes);
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
let shape = Shape::new(shape);
// Iterate over the remaining axes results (skipping the first one already used)
for axis_result in result.iter().skip(1) {
// Iterate through each element of the axis_result and add it to the corresponding element in t's buffer
for (i, val) in axis_result.iter().enumerate() {
unsafe {
*t.get_flat_unchecked_mut(i) = *t.get_flat(i) + *val;
}
}
}
let result = contract_axes(&lnc, &rnc);
t
Tensor::new_with_buffer(shape, result)
}
pub fn contract_axes<
@ -182,9 +170,9 @@ pub fn contract_axes<
const S: usize,
const N: usize,
>(
laxes: [Axis<'a, T, R>; N],
raxes: [Axis<'a, T, S>; N],
) -> Vec<Vec<T>>
laxes: &'a [Axis<'a, T, R>],
raxes: &'a [Axis<'a, T, S>],
) -> Vec<T>
where
[(); R - N]:,
[(); S - N]:,
@ -207,7 +195,7 @@ where
axes_result.push(sum);
}
}
result.push(axes_result);
result.extend_from_slice(&axes_result);
}
result

View File

@ -4,6 +4,7 @@ pub mod index;
pub mod shape;
pub mod axis;
pub mod tensor;
pub mod tensor_view;
pub use index::Idx;
use num::{Num, One, Zero};

View File

@ -207,27 +207,36 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
// ---- Display ----
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
// Print the shape of the tensor
write!(f, "Shape: [")?;
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
write!(f, "{}", dim_size)?;
if i < R - 1 {
write!(f, ", ")?;
}
}
write!(f, "], Elements: [")?;
use std::fmt;
// Print the elements in a flattened form
for (i, elem) in self.buffer.iter().enumerate() {
write!(f, "{}", elem)?;
if i < self.buffer.len() - 1 {
write!(f, ", ")?;
impl<T, const R: usize> Tensor<T, R>
where
T: fmt::Display + Clone,
{
fn fmt_helper(buffer: &[T], shape: &[usize], f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
if shape.is_empty() {
// Base case: print individual elements
write!(f, "{}", buffer[0])
} else {
let sub_len = shape[1..].iter().product::<usize>();
write!(f, "[")?;
for (i, chunk) in buffer.chunks(sub_len).enumerate() {
if i > 0 {
write!(f, ",")?;
}
Tensor::<T, R>::fmt_helper(chunk, &shape[1..], f, level + 1)?;
}
write!(f, "]")
}
}
}
write!(f, "]")
impl<T, const R: usize> fmt::Display for Tensor<T, R>
where
T: fmt::Display + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
}
}

8
src/tensor_view.rs Normal file
View File

@ -0,0 +1,8 @@
use super::*;
use getset::{Getters, MutGetters};
// #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct TensorView<'a, T, const R: usize> {
tensor: &'a Tensor<T, R>,
shape: Shape<R>,
}