🚀 Implement Tensor-type and basic methods #15

Merged
julius merged 8 commits from core-types into master 2024-01-03 21:52:54 +00:00
5 changed files with 49 additions and 43 deletions
Showing only changes of commit e0c4e0407f - Show all commits

View File

@ -1,7 +1,9 @@
use super::*;
use getset::{Getters, MutGetters};
use std::cmp::Ordering;
use std::ops::{Add, Sub};
use std::{
ops::{Index, IndexMut, Add, Sub},
cmp::Ordering,
};
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
pub struct TensorIndex<'a, const R: usize> {

View File

@ -1,41 +1,15 @@
#![allow(incomplete_features)]
#![feature(generic_const_exprs)]
#![warn(clippy::all)]
pub mod axis;
pub mod error;
pub mod index;
pub mod shape;
pub mod tensor;
pub mod value;
pub use axis::*;
pub use index::TensorIndex;
// pub use itertools::Itertools;
use num::{Num, One, Zero};
pub use serde::{Deserialize, Serialize};
pub use shape::TensorShape;
pub use static_assertions::const_assert;
pub use std::fmt::{Display, Formatter, Result as FmtResult};
use std::ops::{Index, IndexMut};
pub use std::sync::Arc;
pub use tensor::{Tensor, TensorIterator};
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ std::iter::Sum
{
}
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
#[macro_export]
macro_rules! tensor {

View File

@ -1,7 +1,8 @@
use super::*;
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
use serde::ser::{Serialize, SerializeTuple, Serializer};
use std::fmt;
use std::fmt::{Result as FmtResult, Formatter};
use core::result::Result as SerdeResult;
#[derive(Clone, Copy, Debug)]
pub struct TensorShape<const R: usize>([usize; R]);
@ -151,11 +152,11 @@ struct TensorShapeVisitor<const R: usize>;
impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
type Value = TensorShape<R>;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
formatter.write_str(concat!("an array of length ", "{R}"))
}
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
fn visit_seq<A>(self, mut seq: A) -> SerdeResult<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
@ -170,7 +171,7 @@ impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
}
impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
fn deserialize<D>(deserializer: D) -> Result<TensorShape<R>, D::Error>
fn deserialize<D>(deserializer: D) -> SerdeResult<TensorShape<R>, D::Error>
where
D: Deserializer<'de>,
{
@ -179,7 +180,7 @@ impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
}
impl<const R: usize> Serialize for TensorShape<R> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
fn serialize<S>(&self, serializer: S) -> SerdeResult<S::Ok, S::Error>
where
S: Serializer,
{

View File

@ -1,7 +1,11 @@
use super::*;
use crate::error::*;
use getset::{Getters, MutGetters};
use std::fmt;
use serde::{Deserialize, Serialize};
use std::{
fmt::{Display, Formatter, Result as FmtResult},
ops::{Index, IndexMut},
};
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
@ -447,14 +451,14 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
impl<T, const R: usize> Tensor<T, R>
where
T: fmt::Display + Clone,
T: Display + Clone,
{
fn fmt_helper(
buffer: &[T],
shape: &[usize],
f: &mut fmt::Formatter<'_>,
f: &mut Formatter<'_>,
level: usize,
) -> fmt::Result {
) -> FmtResult {
if shape.is_empty() {
// Base case: print individual elements
write!(f, "{}", buffer[0])
@ -472,11 +476,11 @@ where
}
}
impl<T, const R: usize> fmt::Display for Tensor<T, R>
impl<T, const R: usize> Display for Tensor<T, R>
where
T: fmt::Display + Clone,
T: Display + Clone,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
}
}

25
src/value.rs Normal file
View File

@ -0,0 +1,25 @@
use num::{Num, One, Zero};
use serde::{Deserialize, Serialize};
use std::{
fmt::Display,
iter::Sum,
};
/// A trait for types that can be used as values in a tensor.
pub trait Value:
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
{
}
impl<T> Value for T where
T: Num
+ Zero
+ One
+ Copy
+ Clone
+ Display
+ Serialize
+ Deserialize<'static>
+ Sum
{
}