Basic implementation for transposed iteration of a Tensor
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
f9c29aefd5
commit
08747a2932
64
.vscode/launch.json
vendored
Normal file
64
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in library 'manifold'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--lib",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "manifold",
|
||||
"kind": "lib"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug example 'operations'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"build",
|
||||
"--example=operations",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "operations",
|
||||
"kind": "example"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in example 'operations'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--example=operations",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "operations",
|
||||
"kind": "example"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
21
Cargo.lock
generated
21
Cargo.lock
generated
@ -304,6 +304,7 @@ dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
"static_assertions",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -642,6 +643,26 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.43",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
|
@ -25,6 +25,7 @@ num = "0.4.1"
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
||||
static_assertions = "1.1.0"
|
||||
thiserror = "1.0.52"
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.8.5"
|
||||
|
@ -71,8 +71,23 @@ fn test_tensor_contraction_rank3() {
|
||||
// ... further checks for other elements ...
|
||||
}
|
||||
|
||||
fn transpose() {
|
||||
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
// let iter = a.idx().iter_transposed([1, 0]);
|
||||
|
||||
// for idx in iter {
|
||||
// println!("{idx}");
|
||||
// }
|
||||
let b = a.clone().transpose([1, 0]).unwrap();
|
||||
println!("a: {}", a);
|
||||
println!("ta: {}", b);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
// tensor_product();
|
||||
test_tensor_contraction_23x32();
|
||||
test_tensor_contraction_rank3();
|
||||
// test_tensor_contraction_23x32();
|
||||
// test_tensor_contraction_rank3();
|
||||
|
||||
transpose();
|
||||
}
|
||||
|
@ -111,7 +111,7 @@ impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
let result = self.axis().tensor().get(self.index);
|
||||
let result = unsafe { self.axis().tensor().get_unchecked(self.index) };
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
|
9
src/error.rs
Normal file
9
src/error.rs
Normal file
@ -0,0 +1,9 @@
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
}
|
97
src/index.rs
97
src/index.rs
@ -15,7 +15,7 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
pub const fn zero(shape: &'a Shape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
shape: shape,
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
@ -110,6 +110,46 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn inc_transposed(&mut self, order: [usize; R]) -> bool {
|
||||
// // Iterate over axes in the specified order
|
||||
// for &axis in order.iter().rev() {
|
||||
// if self.indices[axis] + 1 < self.shape.get(axis) {
|
||||
// self.indices[axis] += 1;
|
||||
// return true;
|
||||
// } else {
|
||||
// self.indices[axis] = 0;
|
||||
// }
|
||||
// }
|
||||
// false
|
||||
// }
|
||||
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
|
||||
if self.indices[order[0]] >= self.shape.get(order[0]) {
|
||||
return false;
|
||||
}
|
||||
let mut carry = 1;
|
||||
for i in
|
||||
order.iter().rev()
|
||||
{
|
||||
let dim_size = self.shape().get(*i);
|
||||
let i = self.index_mut(*i);
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
if *i >= dim_size {
|
||||
*i = 0;
|
||||
} else {
|
||||
carry = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if carry == 1 {
|
||||
self.indices[order[0]] = self.shape.as_array()[order[0]];
|
||||
return true;
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -165,6 +205,27 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
true
|
||||
}
|
||||
|
||||
pub fn dec_transposed(&mut self, order: [usize; R]) {
|
||||
// Iterate over the axes in the specified order
|
||||
for &axis in &order {
|
||||
// Try to decrement the current axis
|
||||
if self.indices[axis] > 0 {
|
||||
self.indices[axis] -= 1;
|
||||
// Reset all preceding axes in the order to their maximum
|
||||
for &prev_axis in &order {
|
||||
if prev_axis == axis {
|
||||
break;
|
||||
}
|
||||
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no axis can be decremented, set the first axis in the order to indicate overflow
|
||||
self.indices[order[0]] = self.shape.get(order[0]);
|
||||
}
|
||||
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||
@ -227,6 +288,10 @@ impl<'a, const R: usize> Idx<'a, R> {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
self.indices[axis]
|
||||
}
|
||||
|
||||
pub fn iter_transposed(&self, order: [usize; R]) -> IdxTransposedIterator<'a, R> {
|
||||
IdxTransposedIterator::new(self.shape(), order)
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
@ -388,3 +453,33 @@ impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdxTransposedIterator<'a, const R: usize> {
|
||||
current: Idx<'a, R>,
|
||||
order: [usize; R],
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
|
||||
pub fn new(shape: &'a Shape<R>, order: [usize; R]) -> Self {
|
||||
Self {
|
||||
current: Idx::zero(shape),
|
||||
end: false,
|
||||
order,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> {
|
||||
type Item = Idx<'a, R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc_transposed(&self.order);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
@ -4,7 +4,7 @@ pub mod index;
|
||||
pub mod shape;
|
||||
pub mod axis;
|
||||
pub mod tensor;
|
||||
pub mod tensor_view;
|
||||
pub mod error;
|
||||
|
||||
pub use index::Idx;
|
||||
use num::{Num, One, Zero};
|
||||
|
12
src/shape.rs
12
src/shape.rs
@ -11,6 +11,18 @@ impl<const R: usize> Shape<R> {
|
||||
Self(shape)
|
||||
}
|
||||
|
||||
pub fn axis(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = Shape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
new_shape.0[new_index] = self.0[index];
|
||||
}
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> [usize; R] {
|
||||
self.0
|
||||
}
|
||||
|
@ -1,31 +1,17 @@
|
||||
use super::*;
|
||||
use crate::error::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
|
||||
pub struct Tensor<T, const R: usize> {
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
buffer: Vec<T>,
|
||||
#[getset(get = "pub")]
|
||||
#[getset(get = "pub")]
|
||||
shape: Shape<R>,
|
||||
}
|
||||
|
||||
impl<T: Value> Tensor<T, 1> {
|
||||
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
|
||||
if self.shape != other.shape {
|
||||
panic!("Shapes of tensors do not match");
|
||||
}
|
||||
|
||||
let mut result = T::zero();
|
||||
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||
result = result + (*a * *b);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
// Handle rank 0 tensor (scalar) as a special case
|
||||
let total_size = if R == 0 {
|
||||
// A rank 0 tensor should still have a buffer with one element
|
||||
@ -39,21 +25,49 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn reshape(self, shape: Shape<R>) -> Result<Self> {
|
||||
if self.shape().size() != shape.size() {
|
||||
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
||||
let (lsize, rsize) = (self.shape().size(), shape.size());
|
||||
Err(Error::InvalidArgument(format!(
|
||||
"Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
||||
)))
|
||||
} else {
|
||||
Ok(Self {
|
||||
buffer: self.buffer,
|
||||
shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
// let shape = self.shape().reorder(order);
|
||||
|
||||
let buffer = Idx::from(self.shape()).iter_transposed(order)
|
||||
.map(|index| {
|
||||
println!("index: {}", index);
|
||||
self.get(index).unwrap().clone()
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Tensor { buffer, shape: self.shape().reorder(order) })
|
||||
}
|
||||
|
||||
pub fn idx(&self) -> Idx<R> {
|
||||
Idx::from(self)
|
||||
}
|
||||
|
||||
pub fn get(&self, index: Idx<R>) -> &T {
|
||||
&self.buffer[index.flat()]
|
||||
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
|
||||
Axis::new(self, axis)
|
||||
}
|
||||
|
||||
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
|
||||
Axis::new(self, axis)
|
||||
}
|
||||
pub fn get(&self, index: Idx<R>) -> Option<&T> {
|
||||
self.buffer.get(index.flat())
|
||||
}
|
||||
|
||||
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
||||
self.buffer.get_unchecked(index.flat())
|
||||
@ -67,21 +81,21 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer.get_unchecked_mut(index.flat())
|
||||
}
|
||||
|
||||
pub fn get_flat(&self, index: usize) -> &T {
|
||||
&self.buffer[index]
|
||||
}
|
||||
pub fn get_flat(&self, index: usize) -> Option<&T> {
|
||||
self.buffer.get(index)
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
|
||||
self.buffer.get_unchecked(index)
|
||||
}
|
||||
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
|
||||
self.buffer.get_unchecked(index)
|
||||
}
|
||||
|
||||
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
|
||||
self.buffer.get_mut(index)
|
||||
}
|
||||
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
|
||||
self.buffer.get_mut(index)
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
|
||||
self.buffer.get_unchecked_mut(index)
|
||||
}
|
||||
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
|
||||
self.buffer.get_unchecked_mut(index)
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
R
|
||||
@ -204,7 +218,6 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// ---- Display ----
|
||||
|
||||
use std::fmt;
|
||||
@ -213,7 +226,12 @@ impl<T, const R: usize> Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
{
|
||||
fn fmt_helper(buffer: &[T], shape: &[usize], f: &mut fmt::Formatter<'_>, level: usize) -> fmt::Result {
|
||||
fn fmt_helper(
|
||||
buffer: &[T],
|
||||
shape: &[usize],
|
||||
f: &mut fmt::Formatter<'_>,
|
||||
level: usize,
|
||||
) -> fmt::Result {
|
||||
if shape.is_empty() {
|
||||
// Base case: print individual elements
|
||||
write!(f, "{}", buffer[0])
|
||||
|
@ -1,8 +0,0 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
// #[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
|
||||
pub struct TensorView<'a, T, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
shape: Shape<R>,
|
||||
}
|
Loading…
Reference in New Issue
Block a user