79 lines
2.5 KiB
Rust
79 lines
2.5 KiB
Rust
#![allow(mixed_script_confusables)]
|
|
#![allow(non_snake_case)]
|
|
use bytemuck::cast_slice;
|
|
use manifold::contract;
|
|
use manifold::*;
|
|
|
|
fn tensor_product() {
|
|
println!("Tensor Product\n");
|
|
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
|
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
|
|
|
// Fill tensors with some values
|
|
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
|
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
|
|
|
println!("T1: {}", tensor1);
|
|
println!("T2: {}", tensor2);
|
|
|
|
let product = tensor1.tensor_product(&tensor2);
|
|
|
|
println!("T1 * T2 = {}", product);
|
|
|
|
// Check shape of the resulting tensor
|
|
assert_eq!(product.shape(), &Shape::new([2, 2, 2]));
|
|
|
|
// Check buffer of the resulting tensor
|
|
let expect: &[i32] =
|
|
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
|
|
assert_eq!(product.buffer(), expect);
|
|
}
|
|
|
|
fn test_tensor_contraction_23x32() {
|
|
// Define two 2D tensors (matrices)
|
|
|
|
// Tensor A is 2x3
|
|
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
|
println!("a: {:?}\n{}\n", a.shape(), a);
|
|
|
|
// Tensor B is 3x2
|
|
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
|
println!("b: {:?}\n{}\n", b.shape(), b);
|
|
|
|
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
|
let ctr10 = contract((&a, [1]), (&b, [0]));
|
|
|
|
println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10);
|
|
|
|
let ctr01 = contract((&a, [0]), (&b, [1]));
|
|
|
|
println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01);
|
|
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
|
// assert_eq!(
|
|
// contracted_tensor.buffer(),
|
|
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
|
// "Contracted tensor buffer does not match expected"
|
|
// );
|
|
}
|
|
|
|
fn test_tensor_contraction_rank3() {
|
|
let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
|
let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
|
|
let contracted_tensor = contract((&a, [2]), (&b, [0]));
|
|
|
|
println!("a: {}", a);
|
|
println!("b: {}", b);
|
|
println!("contracted_tensor: {}", contracted_tensor);
|
|
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
|
|
// Verify specific elements of contracted_tensor
|
|
// assert_eq!(contracted_tensor[0][0][0][0], 50);
|
|
// assert_eq!(contracted_tensor[0][0][0][1], 60);
|
|
// ... further checks for other elements ...
|
|
}
|
|
|
|
fn main() {
|
|
// tensor_product();
|
|
test_tensor_contraction_23x32();
|
|
test_tensor_contraction_rank3();
|
|
}
|