Add tensor! macro

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-02 00:03:12 +02:00
parent 08747a2932
commit 776c58d92a
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 170 additions and 89 deletions

View File

@ -6,7 +6,7 @@ use manifold::*;
fn tensor_product() {
println!("Tensor Product\n");
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values
@ -57,8 +57,8 @@ fn test_tensor_contraction_23x32() {
}
fn test_tensor_contraction_rank3() {
let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
let contracted_tensor = contract((&a, [2]), (&b, [0]));
println!("a: {}", a);
@ -73,6 +73,10 @@ fn test_tensor_contraction_rank3() {
fn transpose() {
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = tensor!(
[[1, 2, 3],
[4, 5, 6]]
);
// let iter = a.idx().iter_transposed([1, 0]);

View File

@ -75,18 +75,6 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
self.set_start(level).set_end(level + 1)
}
// pub fn disassemble(self) -> Vec<Self> {
// let mut result = Vec::new();
// for i in 0..self.axis().len() {
// result.push(Self::new(self.axis()).set_level(i));
// }
// result
// }
// pub fn disassemble(&'a self) -> impl Iterator<Item = Self> + 'a {
// (0..self.axis().len()).map(move |i| Self::new(self.axis()).set_level(i))
// }
pub fn level(&'a self, level: usize) -> impl Iterator<Item = &'a T> + 'a {
Self::new(self.axis()).set_level(level)
}

View File

@ -1,7 +1,7 @@
use super::*;
use getset::{Getters, MutGetters};
use std::cmp::Ordering;
use std::ops::{Add, Sub};
use getset::{Getters, MutGetters};
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct Idx<'a, const R: usize> {
@ -32,7 +32,10 @@ impl<'a, const R: usize> Idx<'a, R> {
if !shape.check_indices(indices) {
panic!("indices out of bounds");
}
Self { indices, shape: shape }
Self {
indices,
shape: shape,
}
}
pub fn is_zero(&self) -> bool {
@ -54,7 +57,7 @@ impl<'a, const R: usize> Idx<'a, R> {
/// `true` if the increment does not overflow and is still within bounds;
/// `false` if it overflows, indicating the end of the tensor.
pub fn inc(&mut self) -> bool {
if self.indices[0] >= self.shape.get(0) {
if self.indices()[0] >= self.shape().get(0) {
return false;
}
let mut carry = 1;
@ -84,55 +87,45 @@ impl<'a, const R: usize> Idx<'a, R> {
// fn inc_axis
pub fn inc_axis(&mut self, fixed_axis: usize) {
assert!(fixed_axis < R, "Axis out of bounds");
assert!(self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds");
assert!(fixed_axis < R, "Axis out of bounds");
assert!(
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
"Index out of bounds"
);
// Try to increment non-fixed axes
// Try to increment non-fixed axes
for i in (0..R).rev() {
if i != fixed_axis {
if self.indices[i] + 1 < self.shape.get(i) {
if self.indices[i] + 1 < self.shape.get(i) {
self.indices[i] += 1;
return ;
return;
} else {
self.indices[i] = 0;
}
}
}
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
self.indices[fixed_axis] += 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = 0;
}
}
return ;
}
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
self.indices[fixed_axis] += 1;
for i in 0..R {
if i != fixed_axis {
self.indices[i] = 0;
}
}
return;
}
}
// pub fn inc_transposed(&mut self, order: [usize; R]) -> bool {
// // Iterate over axes in the specified order
// for &axis in order.iter().rev() {
// if self.indices[axis] + 1 < self.shape.get(axis) {
// self.indices[axis] += 1;
// return true;
// } else {
// self.indices[axis] = 0;
// }
// }
// false
// }
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
if self.indices[order[0]] >= self.shape.get(order[0]) {
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
if self.indices()[order[0]] >= self.shape().get(order[0]) {
return false;
}
let mut carry = 1;
for i in
order.iter().rev()
{
let dim_size = self.shape().get(*i);
let i = self.index_mut(*i);
for i in order.iter().rev() {
let dim_size = self.shape().get(*i);
let i = self.index_mut(*i);
if carry == 1 {
*i += 1;
if *i >= dim_size {
@ -144,9 +137,10 @@ impl<'a, const R: usize> Idx<'a, R> {
}
if carry == 1 {
self.indices[order[0]] = self.shape.as_array()[order[0]];
self.indices_mut()[order[0]] = self.shape().get(order[0]);
return true;
}
false
}
@ -171,7 +165,7 @@ impl<'a, const R: usize> Idx<'a, R> {
}
}
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
// Check if the fixed axis index is already in an invalid state
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
return false;
@ -202,29 +196,29 @@ impl<'a, const R: usize> Idx<'a, R> {
self.indices[fixed_axis] = self.shape.get(fixed_axis);
}
true
true
}
pub fn dec_transposed(&mut self, order: [usize; R]) {
// Iterate over the axes in the specified order
for &axis in &order {
// Try to decrement the current axis
if self.indices[axis] > 0 {
self.indices[axis] -= 1;
// Reset all preceding axes in the order to their maximum
for &prev_axis in &order {
if prev_axis == axis {
break;
}
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
}
return;
}
}
pub fn dec_transposed(&mut self, order: [usize; R]) {
// Iterate over the axes in the specified order
for &axis in &order {
// Try to decrement the current axis
if self.indices[axis] > 0 {
self.indices[axis] -= 1;
// Reset all preceding axes in the order to their maximum
for &prev_axis in &order {
if prev_axis == axis {
break;
}
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
}
return;
}
}
// If no axis can be decremented, set the first axis in the order to indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
// If no axis can be decremented, set the first axis in the order to indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
/// Converts the multi-dimensional index to a flat index.
///
@ -289,9 +283,12 @@ impl<'a, const R: usize> Idx<'a, R> {
self.indices[axis]
}
pub fn iter_transposed(&self, order: [usize; R]) -> IdxTransposedIterator<'a, R> {
IdxTransposedIterator::new(self.shape(), order)
}
pub fn iter_transposed(
&self,
order: [usize; R],
) -> IdxTransposedIterator<'a, R> {
IdxTransposedIterator::new(self.shape(), order)
}
}
// --- blanket impls ---
@ -456,7 +453,7 @@ impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
pub struct IdxTransposedIterator<'a, const R: usize> {
current: Idx<'a, R>,
order: [usize; R],
order: [usize; R],
end: bool,
}
@ -465,7 +462,7 @@ impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
Self {
current: Idx::zero(shape),
end: false,
order,
order,
}
}
}

View File

@ -36,6 +36,13 @@ impl<T> Value for T where
{
}
#[macro_export]
macro_rules! tensor {
($array:expr) => {
Tensor::from($array)
};
}
// ---- Tests ----
#[cfg(test)]
@ -45,7 +52,7 @@ mod tests {
#[test]
fn test_tensor_product() {
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
// Fill tensors with some values

View File

@ -1,6 +1,7 @@
use super::*;
use crate::error::*;
use getset::{Getters, MutGetters};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct Tensor<T, const R: usize> {
@ -45,8 +46,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
}
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
// let shape = self.shape().reorder(order);
let buffer = Idx::from(self.shape()).iter_transposed(order)
.map(|index| {
println!("index: {}", index);
@ -220,8 +219,6 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
// ---- Display ----
use std::fmt;
impl<T, const R: usize> Tensor<T, R>
where
T: fmt::Display + Clone,
@ -333,11 +330,27 @@ impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
}
}
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
fn from(shape: [usize; R]) -> Self {
let shape = Shape::new(shape);
Self::new(shape.into())
}
impl<T: Value> From<T> for Tensor<T, 0> {
fn from(value: T) -> Self {
let shape = Shape::new([]);
let mut tensor = Tensor::new(shape);
tensor.buffer_mut()[0] = value;
tensor
}
}
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
fn from(array: [T; X]) -> Self {
let shape = Shape::new([X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, &elem) in array.iter().enumerate() {
buffer[i] = elem;
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
@ -377,3 +390,75 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize>
From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
{
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
let shape = Shape::new([W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperplane) in array.iter().enumerate() {
for (j, plane) in hyperplane.iter().enumerate() {
for (k, row) in plane.iter().enumerate() {
for (l, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
}
}
}
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize>
From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
{
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
let shape = Shape::new([V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperhyperplane) in array.iter().enumerate() {
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
for (k, plane) in hyperplane.iter().enumerate() {
for (l, row) in plane.iter().enumerate() {
for (m, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem;
}
}
}
}
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize, const U: usize>
From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
{
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
let shape = Shape::new([U, V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() {
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
for (l, plane) in hyperplane.iter().enumerate() {
for (m, row) in plane.iter().enumerate() {
for (n, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem;
}
}
}
}
}
}
tensor
}
}