Formatting

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-02 21:16:56 +02:00
parent 26bd866403
commit d099625b97
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
6 changed files with 158 additions and 132 deletions

View File

@ -72,20 +72,17 @@ fn test_tensor_contraction_rank3() {
}
fn transpose() {
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = tensor!(
[[1, 2, 3],
[4, 5, 6]]
);
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
let b = tensor!([[1, 2, 3], [4, 5, 6]]);
// let iter = a.idx().iter_transposed([1, 0]);
// let iter = a.idx().iter_transposed([1, 0]);
// for idx in iter {
// println!("{idx}");
// }
let b = a.clone().transpose([1, 0]).unwrap();
println!("a: {}", a);
println!("ta: {}", b);
// for idx in iter {
// println!("{idx}");
// }
let b = a.clone().transpose([1, 0]).unwrap();
println!("a: {}", a);
println!("ta: {}", b);
}
fn main() {
@ -93,5 +90,5 @@ fn main() {
// test_tensor_contraction_23x32();
// test_tensor_contraction_rank3();
transpose();
transpose();
}

View File

@ -132,23 +132,23 @@ where
{
let (lhs, la) = lhs;
let (rhs, ra) = rhs;
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let mut shape = Vec::new();
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
let shape: [usize; R + S - 2 * N] =
shape.try_into().expect("Failed to create shape array");
let shape = Shape::new(shape);
let shape = Shape::new(shape);
let result = contract_axes(&lnc, &rnc);
let result = contract_axes(&lnc, &rnc);
Tensor::new_with_buffer(shape, result)
Tensor::new_with_buffer(shape, result)
}
pub fn contract_axes<

View File

@ -4,6 +4,6 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Invalid argument: {0}")]
InvalidArgument(String),
}
#[error("Invalid argument: {0}")]
InvalidArgument(String),
}

View File

@ -1,22 +1,22 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
pub mod axis;
pub mod error;
pub mod index;
pub mod shape;
pub mod axis;
pub mod tensor;
pub mod error;
pub use axis::*;
pub use index::Idx;
pub use itertools::Itertools;
use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize};
pub use shape::Shape;
pub use static_assertions::const_assert;
pub use std::fmt::{Display, Formatter, Result as FmtResult};
use std::ops::{Index, IndexMut};
pub use tensor::{Tensor, TensorIterator};
pub use static_assertions::const_assert;
pub use itertools::Itertools;
pub use std::sync::Arc;
pub use axis::*;
pub use tensor::{Tensor, TensorIterator};
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
@ -32,15 +32,15 @@ impl<T> Value for T where
+ Display
+ Serialize
+ Deserialize<'static>
+ std::iter::Sum
+ std::iter::Sum
{
}
#[macro_export]
macro_rules! tensor {
($array:expr) => {
Tensor::from($array)
};
($array:expr) => {
Tensor::from($array)
};
}
// ---- Tests ----

View File

@ -11,17 +11,17 @@ impl<const R: usize> Shape<R> {
Self(shape)
}
pub fn axis(&self, index: usize) -> Option<&usize> {
self.0.get(index)
}
pub fn axis(&self, index: usize) -> Option<&usize> {
self.0.get(index)
}
pub fn reorder(&self, indices: [usize; R]) -> Self {
let mut new_shape = Shape::new([0; R]);
for (new_index, &index) in indices.iter().enumerate() {
new_shape.0[new_index] = self.0[index];
}
new_shape
}
pub fn reorder(&self, indices: [usize; R]) -> Self {
let mut new_shape = Shape::new([0; R]);
for (new_index, &index) in indices.iter().enumerate() {
new_shape.0[new_index] = self.0[index];
}
new_shape
}
pub const fn as_array(&self) -> [usize; R] {
self.0
@ -114,7 +114,7 @@ impl<const R: usize> Shape<R> {
Shape(new_shape)
}
pub fn remove_axes<'a, T: Value, const NAX: usize>(
pub fn remove_axes<'a, T: Value, const NAX: usize>(
&self,
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
) -> Shape<{ R - NAX }> {
@ -126,10 +126,10 @@ impl<const R: usize> Shape<R> {
for (index, &dim) in self.0.iter().enumerate() {
// Skip dimensions that are in the axes_to_remove array
for axis in axes_to_remove {
if *axis.dim() == index {
continue;
}
}
if *axis.dim() == index {
continue;
}
}
// Add the dimension to the new shape array
new_shape[new_index] = dim;

View File

@ -46,14 +46,15 @@ impl<T: Value, const R: usize> Tensor<T, R> {
}
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
let buffer = Idx::from(self.shape()).iter_transposed(order)
.map(|index| {
println!("index: {}", index);
self.get(index).unwrap().clone()
})
.collect();
let buffer = Idx::from(self.shape())
.iter_transposed(order)
.map(|index| self.get(index).unwrap().clone())
.collect();
Ok(Tensor { buffer, shape: self.shape().reorder(order) })
Ok(Tensor {
buffer,
shape: self.shape().reorder(order),
})
}
pub fn idx(&self) -> Idx<R> {
@ -331,26 +332,26 @@ impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
}
impl<T: Value> From<T> for Tensor<T, 0> {
fn from(value: T) -> Self {
let shape = Shape::new([]);
let mut tensor = Tensor::new(shape);
tensor.buffer_mut()[0] = value;
tensor
}
fn from(value: T) -> Self {
let shape = Shape::new([]);
let mut tensor = Tensor::new(shape);
tensor.buffer_mut()[0] = value;
tensor
}
}
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
fn from(array: [T; X]) -> Self {
let shape = Shape::new([X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
fn from(array: [T; X]) -> Self {
let shape = Shape::new([X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, &elem) in array.iter().enumerate() {
buffer[i] = elem;
}
for (i, &elem) in array.iter().enumerate() {
buffer[i] = elem;
}
tensor
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
@ -391,74 +392,102 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize>
From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
impl<
T: Value,
const X: usize,
const Y: usize,
const Z: usize,
const W: usize,
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
{
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
let shape = Shape::new([W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
let shape = Shape::new([W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperplane) in array.iter().enumerate() {
for (j, plane) in hyperplane.iter().enumerate() {
for (k, row) in plane.iter().enumerate() {
for (l, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
}
}
}
}
for (i, hyperplane) in array.iter().enumerate() {
for (j, plane) in hyperplane.iter().enumerate() {
for (k, row) in plane.iter().enumerate() {
for (l, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
}
}
}
}
tensor
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize>
From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
impl<
T: Value,
const X: usize,
const Y: usize,
const Z: usize,
const W: usize,
const V: usize,
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
{
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
let shape = Shape::new([V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
let shape = Shape::new([V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperhyperplane) in array.iter().enumerate() {
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
for (k, plane) in hyperplane.iter().enumerate() {
for (l, row) in plane.iter().enumerate() {
for (m, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem;
}
}
}
}
}
for (i, hyperhyperplane) in array.iter().enumerate() {
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
for (k, plane) in hyperplane.iter().enumerate() {
for (l, row) in plane.iter().enumerate() {
for (m, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W
+ j * X * Y * Z
+ k * X * Y
+ l * X
+ m] = elem;
}
}
}
}
}
tensor
}
tensor
}
}
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize, const U: usize>
From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
impl<
T: Value,
const X: usize,
const Y: usize,
const Z: usize,
const W: usize,
const V: usize,
const U: usize,
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
{
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
let shape = Shape::new([U, V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
let shape = Shape::new([U, V, W, Z, Y, X]);
let mut tensor = Tensor::new(shape);
let buffer = tensor.buffer_mut();
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() {
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
for (l, plane) in hyperplane.iter().enumerate() {
for (m, row) in plane.iter().enumerate() {
for (n, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem;
}
}
}
}
}
}
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate()
{
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
for (l, plane) in hyperplane.iter().enumerate() {
for (m, row) in plane.iter().enumerate() {
for (n, &elem) in row.iter().enumerate() {
buffer[i * X * Y * Z * W * V
+ j * X * Y * Z * W
+ k * X * Y * Z
+ l * X * Y
+ m * X
+ n] = elem;
}
}
}
}
}
}
tensor
}
tensor
}
}