Add tensor! macro
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
08747a2932
commit
776c58d92a
@ -6,7 +6,7 @@ use manifold::*;
|
||||
|
||||
fn tensor_product() {
|
||||
println!("Tensor Product\n");
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
@ -57,8 +57,8 @@ fn test_tensor_contraction_23x32() {
|
||||
}
|
||||
|
||||
fn test_tensor_contraction_rank3() {
|
||||
let a = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
let b = Tensor::from([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
|
||||
let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
|
||||
let contracted_tensor = contract((&a, [2]), (&b, [0]));
|
||||
|
||||
println!("a: {}", a);
|
||||
@ -73,6 +73,10 @@ fn test_tensor_contraction_rank3() {
|
||||
|
||||
fn transpose() {
|
||||
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
let b = tensor!(
|
||||
[[1, 2, 3],
|
||||
[4, 5, 6]]
|
||||
);
|
||||
|
||||
// let iter = a.idx().iter_transposed([1, 0]);
|
||||
|
||||
|
12
src/axis.rs
12
src/axis.rs
@ -75,18 +75,6 @@ impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
||||
self.set_start(level).set_end(level + 1)
|
||||
}
|
||||
|
||||
// pub fn disassemble(self) -> Vec<Self> {
|
||||
// let mut result = Vec::new();
|
||||
// for i in 0..self.axis().len() {
|
||||
// result.push(Self::new(self.axis()).set_level(i));
|
||||
// }
|
||||
// result
|
||||
// }
|
||||
|
||||
// pub fn disassemble(&'a self) -> impl Iterator<Item = Self> + 'a {
|
||||
// (0..self.axis().len()).map(move |i| Self::new(self.axis()).set_level(i))
|
||||
// }
|
||||
|
||||
pub fn level(&'a self, level: usize) -> impl Iterator<Item = &'a T> + 'a {
|
||||
Self::new(self.axis()).set_level(level)
|
||||
}
|
||||
|
47
src/index.rs
47
src/index.rs
@ -1,7 +1,7 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::{Add, Sub};
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
pub struct Idx<'a, const R: usize> {
|
||||
@ -32,7 +32,10 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape: shape }
|
||||
Self {
|
||||
indices,
|
||||
shape: shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
@ -54,7 +57,7 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
/// `false` if it overflows, indicating the end of the tensor.
|
||||
pub fn inc(&mut self) -> bool {
|
||||
if self.indices[0] >= self.shape.get(0) {
|
||||
if self.indices()[0] >= self.shape().get(0) {
|
||||
return false;
|
||||
}
|
||||
let mut carry = 1;
|
||||
@ -85,14 +88,17 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
|
||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||
assert!(fixed_axis < R, "Axis out of bounds");
|
||||
assert!(self.indices()[fixed_axis] < self.shape().get(fixed_axis), "Index out of bounds");
|
||||
assert!(
|
||||
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
|
||||
"Index out of bounds"
|
||||
);
|
||||
|
||||
// Try to increment non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] + 1 < self.shape.get(i) {
|
||||
self.indices[i] += 1;
|
||||
return ;
|
||||
return;
|
||||
} else {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
@ -106,31 +112,18 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
return ;
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn inc_transposed(&mut self, order: [usize; R]) -> bool {
|
||||
// // Iterate over axes in the specified order
|
||||
// for &axis in order.iter().rev() {
|
||||
// if self.indices[axis] + 1 < self.shape.get(axis) {
|
||||
// self.indices[axis] += 1;
|
||||
// return true;
|
||||
// } else {
|
||||
// self.indices[axis] = 0;
|
||||
// }
|
||||
// }
|
||||
// false
|
||||
// }
|
||||
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
|
||||
if self.indices[order[0]] >= self.shape.get(order[0]) {
|
||||
if self.indices()[order[0]] >= self.shape().get(order[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut carry = 1;
|
||||
for i in
|
||||
order.iter().rev()
|
||||
{
|
||||
|
||||
for i in order.iter().rev() {
|
||||
let dim_size = self.shape().get(*i);
|
||||
let i = self.index_mut(*i);
|
||||
if carry == 1 {
|
||||
@ -144,9 +137,10 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
}
|
||||
|
||||
if carry == 1 {
|
||||
self.indices[order[0]] = self.shape.as_array()[order[0]];
|
||||
self.indices_mut()[order[0]] = self.shape().get(order[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
@ -289,7 +283,10 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
self.indices[axis]
|
||||
}
|
||||
|
||||
pub fn iter_transposed(&self, order: [usize; R]) -> IdxTransposedIterator<'a, R> {
|
||||
pub fn iter_transposed(
|
||||
&self,
|
||||
order: [usize; R],
|
||||
) -> IdxTransposedIterator<'a, R> {
|
||||
IdxTransposedIterator::new(self.shape(), order)
|
||||
}
|
||||
}
|
||||
|
@ -36,6 +36,13 @@ impl<T> Value for T where
|
||||
{
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
($array:expr) => {
|
||||
Tensor::from($array)
|
||||
};
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
#[cfg(test)]
|
||||
@ -45,7 +52,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_tensor_product() {
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
|
101
src/tensor.rs
101
src/tensor.rs
@ -1,6 +1,7 @@
|
||||
use super::*;
|
||||
use crate::error::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
|
||||
pub struct Tensor<T, const R: usize> {
|
||||
@ -45,8 +46,6 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
// let shape = self.shape().reorder(order);
|
||||
|
||||
let buffer = Idx::from(self.shape()).iter_transposed(order)
|
||||
.map(|index| {
|
||||
println!("index: {}", index);
|
||||
@ -220,8 +219,6 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
|
||||
// ---- Display ----
|
||||
|
||||
use std::fmt;
|
||||
|
||||
impl<T, const R: usize> Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
@ -333,10 +330,26 @@ impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
let shape = Shape::new(shape);
|
||||
Self::new(shape.into())
|
||||
impl<T: Value> From<T> for Tensor<T, 0> {
|
||||
fn from(value: T) -> Self {
|
||||
let shape = Shape::new([]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
tensor.buffer_mut()[0] = value;
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
|
||||
fn from(array: [T; X]) -> Self {
|
||||
let shape = Shape::new([X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, &elem) in array.iter().enumerate() {
|
||||
buffer[i] = elem;
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
@ -377,3 +390,75 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize>
|
||||
From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
||||
{
|
||||
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
||||
let shape = Shape::new([W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperplane) in array.iter().enumerate() {
|
||||
for (j, plane) in hyperplane.iter().enumerate() {
|
||||
for (k, row) in plane.iter().enumerate() {
|
||||
for (l, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize>
|
||||
From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
||||
{
|
||||
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
||||
let shape = Shape::new([V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (k, plane) in hyperplane.iter().enumerate() {
|
||||
for (l, row) in plane.iter().enumerate() {
|
||||
for (m, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize, const U: usize>
|
||||
From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
||||
{
|
||||
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
||||
let shape = Shape::new([U, V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() {
|
||||
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (l, plane) in hyperplane.iter().enumerate() {
|
||||
for (m, row) in plane.iter().enumerate() {
|
||||
for (n, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user