🚀 Implement Tensor-type and basic methods #15
@ -1,7 +1,9 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::{Add, Sub};
|
||||
use std::{
|
||||
ops::{Index, IndexMut, Add, Sub},
|
||||
cmp::Ordering,
|
||||
};
|
||||
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
pub struct TensorIndex<'a, const R: usize> {
|
||||
|
32
src/lib.rs
32
src/lib.rs
@ -1,41 +1,15 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
#![warn(clippy::all)]
|
||||
|
||||
pub mod axis;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod shape;
|
||||
pub mod tensor;
|
||||
pub mod value;
|
||||
|
||||
pub use axis::*;
|
||||
pub use index::TensorIndex;
|
||||
// pub use itertools::Itertools;
|
||||
use num::{Num, One, Zero};
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use shape::TensorShape;
|
||||
pub use static_assertions::const_assert;
|
||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use std::ops::{Index, IndexMut};
|
||||
pub use std::sync::Arc;
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> Value for T where
|
||||
T: Num
|
||||
+ Zero
|
||||
+ One
|
||||
+ Copy
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ std::iter::Sum
|
||||
{
|
||||
}
|
||||
pub use {value::*, axis::*, error::*, index::*, shape::*, tensor::*};
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
|
11
src/shape.rs
11
src/shape.rs
@ -1,7 +1,8 @@
|
||||
use super::*;
|
||||
use serde::de::{self, Deserialize, Deserializer, SeqAccess, Visitor};
|
||||
use serde::ser::{Serialize, SerializeTuple, Serializer};
|
||||
use std::fmt;
|
||||
use std::fmt::{Result as FmtResult, Formatter};
|
||||
use core::result::Result as SerdeResult;
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct TensorShape<const R: usize>([usize; R]);
|
||||
@ -151,11 +152,11 @@ struct TensorShapeVisitor<const R: usize>;
|
||||
impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
|
||||
type Value = TensorShape<R>;
|
||||
|
||||
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
|
||||
fn expecting(&self, formatter: &mut Formatter) -> FmtResult {
|
||||
formatter.write_str(concat!("an array of length ", "{R}"))
|
||||
}
|
||||
|
||||
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
|
||||
fn visit_seq<A>(self, mut seq: A) -> SerdeResult<Self::Value, A::Error>
|
||||
where
|
||||
A: SeqAccess<'de>,
|
||||
{
|
||||
@ -170,7 +171,7 @@ impl<'de, const R: usize> Visitor<'de> for TensorShapeVisitor<R> {
|
||||
}
|
||||
|
||||
impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
|
||||
fn deserialize<D>(deserializer: D) -> Result<TensorShape<R>, D::Error>
|
||||
fn deserialize<D>(deserializer: D) -> SerdeResult<TensorShape<R>, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
@ -179,7 +180,7 @@ impl<'de, const R: usize> Deserialize<'de> for TensorShape<R> {
|
||||
}
|
||||
|
||||
impl<const R: usize> Serialize for TensorShape<R> {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
fn serialize<S>(&self, serializer: S) -> SerdeResult<S::Ok, S::Error>
|
||||
where
|
||||
S: Serializer,
|
||||
{
|
||||
|
@ -1,7 +1,11 @@
|
||||
use super::*;
|
||||
use crate::error::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::fmt;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::{Display, Formatter, Result as FmtResult},
|
||||
ops::{Index, IndexMut},
|
||||
};
|
||||
|
||||
/// A tensor is a multi-dimensional array of values. The rank of a tensor is the
|
||||
/// number of dimensions it has. A rank 0 tensor is a scalar, a rank 1 tensor is
|
||||
@ -447,14 +451,14 @@ impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
|
||||
impl<T, const R: usize> Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
T: Display + Clone,
|
||||
{
|
||||
fn fmt_helper(
|
||||
buffer: &[T],
|
||||
shape: &[usize],
|
||||
f: &mut fmt::Formatter<'_>,
|
||||
f: &mut Formatter<'_>,
|
||||
level: usize,
|
||||
) -> fmt::Result {
|
||||
) -> FmtResult {
|
||||
if shape.is_empty() {
|
||||
// Base case: print individual elements
|
||||
write!(f, "{}", buffer[0])
|
||||
@ -472,11 +476,11 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const R: usize> fmt::Display for Tensor<T, R>
|
||||
impl<T, const R: usize> Display for Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
T: Display + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
|
||||
}
|
||||
}
|
||||
|
25
src/value.rs
Normal file
25
src/value.rs
Normal file
@ -0,0 +1,25 @@
|
||||
use num::{Num, One, Zero};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::{
|
||||
fmt::Display,
|
||||
iter::Sum,
|
||||
};
|
||||
|
||||
/// A trait for types that can be used as values in a tensor.
|
||||
pub trait Value:
|
||||
Num + Zero + One + Copy + Clone + Display + Serialize + Deserialize<'static>
|
||||
{
|
||||
}
|
||||
|
||||
impl<T> Value for T where
|
||||
T: Num
|
||||
+ Zero
|
||||
+ One
|
||||
+ Copy
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ Sum
|
||||
{
|
||||
}
|
Loading…
Reference in New Issue
Block a user