Formatting
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
26bd866403
commit
d099625b97
@ -72,20 +72,17 @@ fn test_tensor_contraction_rank3() {
|
||||
}
|
||||
|
||||
fn transpose() {
|
||||
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
let b = tensor!(
|
||||
[[1, 2, 3],
|
||||
[4, 5, 6]]
|
||||
);
|
||||
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
let b = tensor!([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
// let iter = a.idx().iter_transposed([1, 0]);
|
||||
// let iter = a.idx().iter_transposed([1, 0]);
|
||||
|
||||
// for idx in iter {
|
||||
// println!("{idx}");
|
||||
// }
|
||||
let b = a.clone().transpose([1, 0]).unwrap();
|
||||
println!("a: {}", a);
|
||||
println!("ta: {}", b);
|
||||
// for idx in iter {
|
||||
// println!("{idx}");
|
||||
// }
|
||||
let b = a.clone().transpose([1, 0]).unwrap();
|
||||
println!("a: {}", a);
|
||||
println!("ta: {}", b);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
@ -93,5 +90,5 @@ fn main() {
|
||||
// test_tensor_contraction_23x32();
|
||||
// test_tensor_contraction_rank3();
|
||||
|
||||
transpose();
|
||||
transpose();
|
||||
}
|
||||
|
24
src/axis.rs
24
src/axis.rs
@ -132,23 +132,23 @@ where
|
||||
{
|
||||
let (lhs, la) = lhs;
|
||||
let (rhs, ra) = rhs;
|
||||
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
|
||||
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
|
||||
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
|
||||
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
|
||||
|
||||
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
let shape = Shape::new(shape);
|
||||
let shape = Shape::new(shape);
|
||||
|
||||
let result = contract_axes(&lnc, &rnc);
|
||||
let result = contract_axes(&lnc, &rnc);
|
||||
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
}
|
||||
|
||||
pub fn contract_axes<
|
||||
|
@ -4,6 +4,6 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
}
|
||||
#[error("Invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
}
|
||||
|
20
src/lib.rs
20
src/lib.rs
@ -1,22 +1,22 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
pub mod axis;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod shape;
|
||||
pub mod axis;
|
||||
pub mod tensor;
|
||||
pub mod error;
|
||||
|
||||
pub use axis::*;
|
||||
pub use index::Idx;
|
||||
pub use itertools::Itertools;
|
||||
use num::{Num, One, Zero};
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use shape::Shape;
|
||||
pub use static_assertions::const_assert;
|
||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use std::ops::{Index, IndexMut};
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
pub use static_assertions::const_assert;
|
||||
pub use itertools::Itertools;
|
||||
pub use std::sync::Arc;
|
||||
pub use axis::*;
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
@ -32,15 +32,15 @@ impl<T> Value for T where
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ std::iter::Sum
|
||||
+ std::iter::Sum
|
||||
{
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
($array:expr) => {
|
||||
Tensor::from($array)
|
||||
};
|
||||
($array:expr) => {
|
||||
Tensor::from($array)
|
||||
};
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
30
src/shape.rs
30
src/shape.rs
@ -11,17 +11,17 @@ impl<const R: usize> Shape<R> {
|
||||
Self(shape)
|
||||
}
|
||||
|
||||
pub fn axis(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
pub fn axis(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = Shape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
new_shape.0[new_index] = self.0[index];
|
||||
}
|
||||
new_shape
|
||||
}
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = Shape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
new_shape.0[new_index] = self.0[index];
|
||||
}
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> [usize; R] {
|
||||
self.0
|
||||
@ -114,7 +114,7 @@ impl<const R: usize> Shape<R> {
|
||||
Shape(new_shape)
|
||||
}
|
||||
|
||||
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||
&self,
|
||||
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
|
||||
) -> Shape<{ R - NAX }> {
|
||||
@ -126,10 +126,10 @@ impl<const R: usize> Shape<R> {
|
||||
for (index, &dim) in self.0.iter().enumerate() {
|
||||
// Skip dimensions that are in the axes_to_remove array
|
||||
for axis in axes_to_remove {
|
||||
if *axis.dim() == index {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
if *axis.dim() == index {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the dimension to the new shape array
|
||||
new_shape[new_index] = dim;
|
||||
|
187
src/tensor.rs
187
src/tensor.rs
@ -46,14 +46,15 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
let buffer = Idx::from(self.shape()).iter_transposed(order)
|
||||
.map(|index| {
|
||||
println!("index: {}", index);
|
||||
self.get(index).unwrap().clone()
|
||||
})
|
||||
.collect();
|
||||
let buffer = Idx::from(self.shape())
|
||||
.iter_transposed(order)
|
||||
.map(|index| self.get(index).unwrap().clone())
|
||||
.collect();
|
||||
|
||||
Ok(Tensor { buffer, shape: self.shape().reorder(order) })
|
||||
Ok(Tensor {
|
||||
buffer,
|
||||
shape: self.shape().reorder(order),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn idx(&self) -> Idx<R> {
|
||||
@ -331,26 +332,26 @@ impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
||||
}
|
||||
|
||||
impl<T: Value> From<T> for Tensor<T, 0> {
|
||||
fn from(value: T) -> Self {
|
||||
let shape = Shape::new([]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
tensor.buffer_mut()[0] = value;
|
||||
tensor
|
||||
}
|
||||
fn from(value: T) -> Self {
|
||||
let shape = Shape::new([]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
tensor.buffer_mut()[0] = value;
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
|
||||
fn from(array: [T; X]) -> Self {
|
||||
let shape = Shape::new([X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
fn from(array: [T; X]) -> Self {
|
||||
let shape = Shape::new([X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, &elem) in array.iter().enumerate() {
|
||||
buffer[i] = elem;
|
||||
}
|
||||
for (i, &elem) in array.iter().enumerate() {
|
||||
buffer[i] = elem;
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
||||
@ -391,74 +392,102 @@ impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize>
|
||||
From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
||||
{
|
||||
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
||||
let shape = Shape::new([W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
||||
let shape = Shape::new([W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperplane) in array.iter().enumerate() {
|
||||
for (j, plane) in hyperplane.iter().enumerate() {
|
||||
for (k, row) in plane.iter().enumerate() {
|
||||
for (l, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (i, hyperplane) in array.iter().enumerate() {
|
||||
for (j, plane) in hyperplane.iter().enumerate() {
|
||||
for (k, row) in plane.iter().enumerate() {
|
||||
for (l, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize>
|
||||
From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
const V: usize,
|
||||
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
||||
{
|
||||
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
||||
let shape = Shape::new([V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
||||
let shape = Shape::new([V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (k, plane) in hyperplane.iter().enumerate() {
|
||||
for (l, row) in plane.iter().enumerate() {
|
||||
for (m, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W + j * X * Y * Z + k * X * Y + l * X + m] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (i, hyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (k, plane) in hyperplane.iter().enumerate() {
|
||||
for (l, row) in plane.iter().enumerate() {
|
||||
for (m, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W
|
||||
+ j * X * Y * Z
|
||||
+ k * X * Y
|
||||
+ l * X
|
||||
+ m] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize, const W: usize, const V: usize, const U: usize>
|
||||
From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
const V: usize,
|
||||
const U: usize,
|
||||
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
||||
{
|
||||
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
||||
let shape = Shape::new([U, V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
||||
let shape = Shape::new([U, V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() {
|
||||
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (l, plane) in hyperplane.iter().enumerate() {
|
||||
for (m, row) in plane.iter().enumerate() {
|
||||
for (n, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W * V + j * X * Y * Z * W + k * X * Y * Z + l * X * Y + m * X + n] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate()
|
||||
{
|
||||
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (l, plane) in hyperplane.iter().enumerate() {
|
||||
for (m, row) in plane.iter().enumerate() {
|
||||
for (n, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W * V
|
||||
+ j * X * Y * Z * W
|
||||
+ k * X * Y * Z
|
||||
+ l * X * Y
|
||||
+ m * X
|
||||
+ n] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user