Rank 3 contraction not working yet
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
33c056f2b9
commit
f9c29aefd5
@ -1,8 +1,8 @@
|
||||
#![allow(mixed_script_confusables)]
|
||||
#![allow(non_snake_case)]
|
||||
use bytemuck::cast_slice;
|
||||
use manifold::*;
|
||||
use manifold::contract;
|
||||
use manifold::*;
|
||||
|
||||
fn tensor_product() {
|
||||
println!("Tensor Product\n");
|
||||
@ -30,33 +30,49 @@ fn tensor_product() {
|
||||
}
|
||||
|
||||
fn test_tensor_contraction_23x32() {
|
||||
// Define two 2D tensors (matrices)
|
||||
// Define two 2D tensors (matrices)
|
||||
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
println!("a: {}", a);
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
println!("a: {:?}\n{}\n", a.shape(), a);
|
||||
|
||||
// Tensor B is 3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||
println!("b: {}", b);
|
||||
// Tensor B is 3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||
println!("b: {:?}\n{}\n", b.shape(), b);
|
||||
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let ctr10 = contract((&a, [1]), (&b, [0]));
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let ctr10 = contract((&a, [1]), (&b, [0]));
|
||||
|
||||
println!("[1, 0]: {}", ctr10);
|
||||
println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10);
|
||||
|
||||
let ctr01 = contract((&a, [0]), (&b, [1]));
|
||||
let ctr01 = contract((&a, [0]), (&b, [1]));
|
||||
|
||||
println!("[0, 1]: {}", ctr01);
|
||||
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||
// assert_eq!(
|
||||
// contracted_tensor.buffer(),
|
||||
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||
// "Contracted tensor buffer does not match expected"
|
||||
// );
|
||||
println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01);
|
||||
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||
// assert_eq!(
|
||||
// contracted_tensor.buffer(),
|
||||
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||
// "Contracted tensor buffer does not match expected"
|
||||
// );
|
||||
}
|
||||
|
||||
fn test_tensor_contraction_rank3() {
|
||||
let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
|
||||
let contracted_tensor = contract((&a, [2]), (&b, [0]));
|
||||
|
||||
println!("a: {}", a);
|
||||
println!("b: {}", b);
|
||||
println!("contracted_tensor: {}", contracted_tensor);
|
||||
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
|
||||
// Verify specific elements of contracted_tensor
|
||||
// assert_eq!(contracted_tensor[0][0][0][0], 50);
|
||||
// assert_eq!(contracted_tensor[0][0][0][1], 60);
|
||||
// ... further checks for other elements ...
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// tensor_product();
|
||||
test_tensor_contraction_23x32();
|
||||
test_tensor_contraction_rank3();
|
||||
}
|
||||
|
44
src/axis.rs
44
src/axis.rs
@ -144,35 +144,23 @@ where
|
||||
{
|
||||
let (lhs, la) = lhs;
|
||||
let (rhs, ra) = rhs;
|
||||
let raxes = ra.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||
let raxes: [Axis<'a, T, S>; N] =
|
||||
raxes.try_into().expect("Failed to create raxes array");
|
||||
let laxes = la.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||
let laxes: [Axis<'a, T, R>; N] =
|
||||
laxes.try_into().expect("Failed to create laxes array");
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
|
||||
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
|
||||
|
||||
let shape = Shape::new(shape);
|
||||
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||
|
||||
let result = contract_axes(laxes, raxes);
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
let mut t = Tensor::new_with_buffer(shape, result[0].clone());
|
||||
let shape = Shape::new(shape);
|
||||
|
||||
// Iterate over the remaining axes results (skipping the first one already used)
|
||||
for axis_result in result.iter().skip(1) {
|
||||
// Iterate through each element of the axis_result and add it to the corresponding element in t's buffer
|
||||
for (i, val) in axis_result.iter().enumerate() {
|
||||
unsafe {
|
||||
*t.get_flat_unchecked_mut(i) = *t.get_flat(i) + *val;
|
||||
}
|
||||
}
|
||||
}
|
||||
let result = contract_axes(&lnc, &rnc);
|
||||
|
||||
t
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
}
|
||||
|
||||
pub fn contract_axes<
|
||||
@ -182,9 +170,9 @@ pub fn contract_axes<
|
||||
const S: usize,
|
||||
const N: usize,
|
||||
>(
|
||||
laxes: [Axis<'a, T, R>; N],
|
||||
raxes: [Axis<'a, T, S>; N],
|
||||
) -> Vec<Vec<T>>
|
||||
laxes: &'a [Axis<'a, T, R>],
|
||||
raxes: &'a [Axis<'a, T, S>],
|
||||
) -> Vec<T>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
@ -207,7 +195,7 @@ where
|
||||
axes_result.push(sum);
|
||||
}
|
||||
}
|
||||
result.push(axes_result);
|
||||
result.extend_from_slice(&axes_result);
|
||||
}
|
||||
|
||||
result
|
||||
|
@ -4,6 +4,7 @@ pub mod index;
|
||||
pub mod shape;
|
||||
pub mod axis;
|
||||
pub mod tensor;
|
||||
pub mod tensor_view;
|
||||
|
||||
pub use index::Idx;
|
||||
use num::{Num, One, Zero};
|
||||
|
@ -207,27 +207,36 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
|
||||
// ---- Display ----
|
||||
|
||||
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Print the shape of the tensor
|
||||
write!(f, "Shape: [")?;
|
||||
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
|
||||
write!(f, "{}", dim_size)?;
|
||||
if i < R - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "], Elements: [")?;
|
||||
use std::fmt;
|
||||
|
||||
// Print the elements in a flattened form
|
||||
for (i, elem) in self.buffer.iter().enumerate() {
|
||||
write!(f, "{}", elem)?;
|
||||
if i < self.buffer.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
impl<T, const R: usize> Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
{
|
||||
fn fmt_helper(buffer: &[T], shape: &[usize], f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
|
||||
if shape.is_empty() {
|
||||
// Base case: print individual elements
|
||||
write!(f, "{}", buffer[0])
|
||||
} else {
|
||||
let sub_len = shape[1..].iter().product::<usize>();
|
||||
write!(f, "[")?;
|
||||
for (i, chunk) in buffer.chunks(sub_len).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ",")?;
|
||||
}
|
||||
Tensor::<T, R>::fmt_helper(chunk, &shape[1..], f, level + 1)?;
|
||||
}
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
impl<T, const R: usize> fmt::Display for Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
|
||||
}
|
||||
}
|
||||
|
||||
|
8
src/tensor_view.rs
Normal file
8
src/tensor_view.rs
Normal file
@ -0,0 +1,8 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
// #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
|
||||
pub struct TensorView<'a, T, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
shape: Shape<R>,
|
||||
}
|
Loading…
Reference in New Issue
Block a user