Basic implementation for transposed iteration of a Tensor

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2024-01-01 22:25:54 +02:00
parent f9c29aefd5
commit 08747a2932
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
11 changed files with 280 additions and 53 deletions

64
.vscode/launch.json vendored Normal file
View File

@ -0,0 +1,64 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in library 'manifold'",
"cargo": {
"args": [
"test",
"--no-run",
"--lib",
"--package=manifold"
],
"filter": {
"name": "manifold",
"kind": "lib"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug example 'operations'",
"cargo": {
"args": [
"build",
"--example=operations",
"--package=manifold"
],
"filter": {
"name": "operations",
"kind": "example"
}
},
"args": [],
"cwd": "${workspaceFolder}"
},
{
"type": "lldb",
"request": "launch",
"name": "Debug unit tests in example 'operations'",
"cargo": {
"args": [
"test",
"--no-run",
"--example=operations",
"--package=manifold"
],
"filter": {
"name": "operations",
"kind": "example"
}
},
"args": [],
"cwd": "${workspaceFolder}"
}
]
}

21
Cargo.lock generated
View File

@ -304,6 +304,7 @@ dependencies = [
"serde",
"serde_json",
"static_assertions",
"thiserror",
]
[[package]]
@ -642,6 +643,26 @@ dependencies = [
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "1.0.52"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.52"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.43",
]
[[package]]
name = "tinytemplate"
version = "1.2.1"

View File

@ -25,6 +25,7 @@ num = "0.4.1"
serde = { version = "1.0.193", features = ["derive"] }
serde_json = "1.0.108"
static_assertions = "1.1.0"
thiserror = "1.0.52"
[dev-dependencies]
rand = "0.8.5"

View File

@ -71,8 +71,23 @@ fn test_tensor_contraction_rank3() {
// ... further checks for other elements ...
}
fn transpose() {
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
// let iter = a.idx().iter_transposed([1, 0]);
// for idx in iter {
// println!("{idx}");
// }
let b = a.clone().transpose([1, 0]).unwrap();
println!("a: {}", a);
println!("ta: {}", b);
}
fn main() {
// tensor_product();
test_tensor_contraction_23x32();
test_tensor_contraction_rank3();
// test_tensor_contraction_23x32();
// test_tensor_contraction_rank3();
transpose();
}

View File

@ -111,7 +111,7 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
if self.axis_idx() == self.axis_max_idx() {
return None;
}
let result = self.axis().tensor().get(self.index);
let result = unsafe { self.axis().tensor().get_unchecked(self.index) };
let axis_dim = self.axis_dim();
self.index_mut().inc_axis(axis_dim);
Some(result)

9
src/error.rs Normal file
View File

@ -0,0 +1,9 @@
use thiserror::Error;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Error, Debug)]
pub enum Error {
#[error("Invalid argument: {0}")]
InvalidArgument(String),
}

View File

@ -15,7 +15,7 @@ impl<'a, const R: usize> Idx<'a, R> {
pub const fn zero(shape: &'a Shape<R>) -> Self {
Self {
indices: [0; R],
shape: shape,
shape,
}
}
@ -110,6 +110,46 @@ impl<'a, const R: usize> Idx<'a, R> {
}
}
// pub fn inc_transposed(&mut self, order: [usize; R]) -> bool {
// // Iterate over axes in the specified order
// for &axis in order.iter().rev() {
// if self.indices[axis] + 1 < self.shape.get(axis) {
// self.indices[axis] += 1;
// return true;
// } else {
// self.indices[axis] = 0;
// }
// }
// false
// }
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
if self.indices[order[0]] >= self.shape.get(order[0]) {
return false;
}
let mut carry = 1;
for i in
order.iter().rev()
{
let dim_size = self.shape().get(*i);
let i = self.index_mut(*i);
if carry == 1 {
*i += 1;
if *i >= dim_size {
*i = 0;
} else {
carry = 0;
}
}
}
if carry == 1 {
self.indices[order[0]] = self.shape.as_array()[order[0]];
return true;
}
false
}
pub fn dec(&mut self) {
// Check if already at the start
if self.indices.iter().all(|&i| i == 0) {
@ -165,6 +205,27 @@ impl<'a, const R: usize> Idx<'a, R> {
true
}
pub fn dec_transposed(&mut self, order: [usize; R]) {
// Iterate over the axes in the specified order
for &axis in &order {
// Try to decrement the current axis
if self.indices[axis] > 0 {
self.indices[axis] -= 1;
// Reset all preceding axes in the order to their maximum
for &prev_axis in &order {
if prev_axis == axis {
break;
}
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
}
return;
}
}
// If no axis can be decremented, set the first axis in the order to indicate overflow
self.indices[order[0]] = self.shape.get(order[0]);
}
/// Converts the multi-dimensional index to a flat index.
///
/// This method calculates the flat index corresponding to the multi-dimensional index
@ -227,6 +288,10 @@ impl<'a, const R: usize> Idx<'a, R> {
assert!(axis < R, "Axis out of bounds");
self.indices[axis]
}
pub fn iter_transposed(&self, order: [usize; R]) -> IdxTransposedIterator<'a, R> {
IdxTransposedIterator::new(self.shape(), order)
}
}
// --- blanket impls ---
@ -388,3 +453,33 @@ impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
}
}
}
pub struct IdxTransposedIterator<'a, const R: usize> {
current: Idx<'a, R>,
order: [usize; R],
end: bool,
}
impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
pub fn new(shape: &'a Shape<R>, order: [usize; R]) -> Self {
Self {
current: Idx::zero(shape),
end: false,
order,
}
}
}
impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> {
type Item = Idx<'a, R>;
fn next(&mut self) -> Option<Self::Item> {
if self.end {
return None;
}
let result = self.current;
self.end = self.current.inc_transposed(&self.order);
Some(result)
}
}

View File

@ -4,7 +4,7 @@ pub mod index;
pub mod shape;
pub mod axis;
pub mod tensor;
pub mod tensor_view;
pub mod error;
pub use index::Idx;
use num::{Num, One, Zero};

View File

@ -11,6 +11,18 @@ impl<const R: usize> Shape<R> {
Self(shape)
}
pub fn axis(&self, index: usize) -> Option<&usize> {
self.0.get(index)
}
pub fn reorder(&self, indices: [usize; R]) -> Self {
let mut new_shape = Shape::new([0; R]);
for (new_index, &index) in indices.iter().enumerate() {
new_shape.0[new_index] = self.0[index];
}
new_shape
}
pub const fn as_array(&self) -> [usize; R] {
self.0
}

View File

@ -1,31 +1,17 @@
use super::*;
use crate::error::*;
use getset::{Getters, MutGetters};
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct Tensor<T, const R: usize> {
#[getset(get = "pub", get_mut = "pub")]
#[getset(get = "pub", get_mut = "pub")]
buffer: Vec<T>,
#[getset(get = "pub")]
#[getset(get = "pub")]
shape: Shape<R>,
}
impl<T: Value> Tensor<T, 1> {
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
if self.shape != other.shape {
panic!("Shapes of tensors do not match");
}
let mut result = T::zero();
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
result = result + (*a * *b);
}
result
}
}
impl<T: Value, const R: usize> Tensor<T, R> {
pub fn new(shape: Shape<R>) -> Self {
pub fn new(shape: Shape<R>) -> Self {
// Handle rank 0 tensor (scalar) as a special case
let total_size = if R == 0 {
// A rank 0 tensor should still have a buffer with one element
@ -39,21 +25,49 @@ impl<T: Value, const R: usize> Tensor<T, R> {
Self { buffer, shape }
}
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
Self { buffer, shape }
}
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
Self { buffer, shape }
}
pub fn reshape(self, shape: Shape<R>) -> Result<Self> {
if self.shape().size() != shape.size() {
let (ls, rs) = (self.shape().as_array(), shape.as_array());
let (lsize, rsize) = (self.shape().size(), shape.size());
Err(Error::InvalidArgument(format!(
"Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
)))
} else {
Ok(Self {
buffer: self.buffer,
shape,
})
}
}
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
// let shape = self.shape().reorder(order);
let buffer = Idx::from(self.shape()).iter_transposed(order)
.map(|index| {
println!("index: {}", index);
self.get(index).unwrap().clone()
})
.collect();
Ok(Tensor { buffer, shape: self.shape().reorder(order) })
}
pub fn idx(&self) -> Idx<R> {
Idx::from(self)
}
pub fn get(&self, index: Idx<R>) -> &T {
&self.buffer[index.flat()]
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
Axis::new(self, axis)
}
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
Axis::new(self, axis)
}
pub fn get(&self, index: Idx<R>) -> Option<&T> {
self.buffer.get(index.flat())
}
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
self.buffer.get_unchecked(index.flat())
@ -67,21 +81,21 @@ impl<T: Value, const R: usize> Tensor<T, R> {
self.buffer.get_unchecked_mut(index.flat())
}
pub fn get_flat(&self, index: usize) -> &T {
&self.buffer[index]
}
pub fn get_flat(&self, index: usize) -> Option<&T> {
self.buffer.get(index)
}
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer.get_unchecked(index)
}
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
self.buffer.get_unchecked(index)
}
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer.get_mut(index)
}
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
self.buffer.get_mut(index)
}
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer.get_unchecked_mut(index)
}
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
self.buffer.get_unchecked_mut(index)
}
pub fn rank(&self) -> usize {
R
@ -204,7 +218,6 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
}
}
// ---- Display ----
use std::fmt;
@ -213,7 +226,12 @@ impl<T, const R: usize> Tensor<T, R>
where
T: fmt::Display + Clone,
{
fn fmt_helper(buffer: &[T], shape: &[usize], f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
fn fmt_helper(
buffer: &[T],
shape: &[usize],
f: &mut fmt::Formatter<'_>,
level: usize,
) -> fmt::Result {
if shape.is_empty() {
// Base case: print individual elements
write!(f, "{}", buffer[0])

View File

@ -1,8 +0,0 @@
use super::*;
use getset::{Getters, MutGetters};
// #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
pub struct TensorView<'a, T, const R: usize> {
tensor: &'a Tensor<T, R>,
shape: Shape<R>,
}