From 35da61c619bd966f0a94d01bce767df1661e5553 Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Tue, 2 Jan 2024 19:20:10 +0000 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A7=20General=20Plumming=20(#1)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit General plumming and types as well as weakly tested implemnentations. Reviewed-on: https://nordic-dev.net/julius/manifold/pulls/1 Co-authored-by: Julius Koskela Co-committed-by: Julius Koskela --- .vscode/launch.json | 64 ++++ Cargo.lock | 688 ++++++++++++++++++++++++++++++++++++- Cargo.toml | 26 +- README.md | 2 +- docs/tensor-contraction.md | 34 ++ docs/tensor-operations.md | 239 +++++++++++++ examples/operations.rs | 113 ++++-- src/axis.rs | 389 +++++++++++++++++++++ src/error.rs | 9 + src/index.rs | 281 +++++++++++++-- src/lib.rs | 26 +- src/shape.rs | 45 ++- src/tensor.rs | 395 +++++++++++++-------- 13 files changed, 2099 insertions(+), 212 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 docs/tensor-contraction.md create mode 100644 docs/tensor-operations.md create mode 100644 src/axis.rs create mode 100644 src/error.rs diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 0000000..354a0b7 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,64 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'manifold'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=manifold" + ], + "filter": { + "name": "manifold", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug example 'operations'", + "cargo": { + "args": [ + "build", + "--example=operations", + "--package=manifold" + ], + "filter": { + "name": "operations", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in example 'operations'", + "cargo": { + "args": [ + "test", + "--no-run", + "--example=operations", + "--package=manifold" + ], + "filter": { + "name": "operations", + "kind": "example" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock index 6871f48..c86a244 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,18 +2,262 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "aho-corasick" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "bitflags" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07" + +[[package]] +name = "bumpalo" +version = "3.14.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" + [[package]] name = "bytemuck" version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6" +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" + +[[package]] +name = "ciborium-ll" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools 0.10.5", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools 0.10.5", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751" +dependencies = [ + "cfg-if", + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d" +dependencies = [ + "autocfg", + "cfg-if", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "either" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" + +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "getrandom" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "getset" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9" +dependencies = [ + "proc-macro-error", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] +name = "half" +version = "1.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" + +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + +[[package]] +name = "is-terminal" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455" +dependencies = [ + "hermit-abi", + "rustix", + "windows-sys", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itertools" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.10" @@ -21,15 +265,54 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c" [[package]] -name = "mltensor" +name = "js-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.151" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4" + +[[package]] +name = "linux-raw-sys" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456" + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "manifold" version = "0.1.0" dependencies = [ "bytemuck", + "criterion", + "getset", + "itertools 0.12.0", "num", + "rand", "serde", "serde_json", + "static_assertions", + "thiserror", ] +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + [[package]] name = "num" version = "0.4.1" @@ -106,6 +389,76 @@ dependencies = [ "autocfg", ] +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "oorandom" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "plotters" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" + +[[package]] +name = "plotters-svg" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro-error" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c" +dependencies = [ + "proc-macro-error-attr", + "proc-macro2", + "quote", + "syn 1.0.109", + "version_check", +] + +[[package]] +name = "proc-macro-error-attr" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869" +dependencies = [ + "proc-macro2", + "quote", + "version_check", +] + [[package]] name = "proc-macro2" version = "1.0.71" @@ -124,12 +477,113 @@ dependencies = [ "proc-macro2", ] +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rayon" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" + +[[package]] +name = "rustix" +version = "0.38.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" +dependencies = [ + "bitflags", + "errno", + "libc", + "linux-raw-sys", + "windows-sys", +] + [[package]] name = "ryu" version = "1.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c" +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + [[package]] name = "serde" version = "1.0.193" @@ -147,7 +601,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3" dependencies = [ "proc-macro2", "quote", - "syn", + "syn 2.0.43", ] [[package]] @@ -161,6 +615,23 @@ dependencies = [ "serde", ] +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + +[[package]] +name = "syn" +version = "1.0.109" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + [[package]] name = "syn" version = "2.0.43" @@ -172,8 +643,221 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "thiserror" +version = "1.0.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.52" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "version_check" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" + +[[package]] +name = "walkdir" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e" +dependencies = [ + "cfg-if", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn 2.0.43", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.43", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.89" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f" + +[[package]] +name = "web-sys" +version = "0.3.66" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + +[[package]] +name = "winapi-util" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +dependencies = [ + "winapi", +] + +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04" diff --git a/Cargo.toml b/Cargo.toml index 17c0c54..040901a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,12 +1,32 @@ [package] -name = "mltensor" +name = "manifold" version = "0.1.0" edition = "2021" +license = "MIT/Apache-2.0" +authors = ["Julius Koskela "] -# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +description = """ +GDSL is a graph data-structure library including graph containers, +connected node strutures and efficient algorithms on those structures. +Nodes are independent of a graph container and can be used as connected +smart pointers. +""" + +repository = "https://nordic-dev.net/julius/manifold" + +keywords = ["data-structures", "algorithms", "containers"] +categories = ["data-structures", "algorithms", "mathematics", "science"] [dependencies] bytemuck = "1.14.0" +getset = "0.1.2" +itertools = "0.12.0" num = "0.4.1" serde = { version = "1.0.193", features = ["derive"] } -serde_json = "1.0.108" \ No newline at end of file +serde_json = "1.0.108" +static_assertions = "1.1.0" +thiserror = "1.0.52" + +[dev-dependencies] +rand = "0.8.5" +criterion = "0.5.1" diff --git a/README.md b/README.md index 22e0174..afbb613 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Tensorc +# Mainfold ```rust // Create two tensors with different ranks and shapes diff --git a/docs/tensor-contraction.md b/docs/tensor-contraction.md new file mode 100644 index 0000000..0361388 --- /dev/null +++ b/docs/tensor-contraction.md @@ -0,0 +1,34 @@ +To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps: + +1. **Tensor Shapes**: + - Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\] + - Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\] + +2. **Tensor Contraction Operation**: + - The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results. + - The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3. + +3. **Contraction Steps**: + + - Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix. + - \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \) + - Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix. + - \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \) + - Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix. + - \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \) + + - Continue this process for the remaining rows of `a` and columns of `b`: + - For the second row of `a`: + - \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \) + - \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \) + - \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \) + - For the third row of `a`: + - \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \) + - \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \) + - \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \) + +4. **Resulting Tensor**: + - The resulting 3x3 tensor from the contraction of `a` and `b` will be: + \[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\] + +These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`. \ No newline at end of file diff --git a/docs/tensor-operations.md b/docs/tensor-operations.md new file mode 100644 index 0000000..277cefd --- /dev/null +++ b/docs/tensor-operations.md @@ -0,0 +1,239 @@ +# Operations Index + +## 1. Addition + +Element-wize addition of two tensors. + +\( C = A + B \) where \( C_{ijk...} = A_{ijk...} + B_{ijk...} \) for all indices \( i, j, k, ... \). + +```rust +let t1 = tensor!([[1, 2], [3, 4]]); +let t2 = tensor!([[5, 6], [7, 8]]); +let sum = t1 + t2; +``` + +```sh +[[7, 8], [10, 12]] +``` + +## 2. Subtraction + +Element-wize substraction of two tensors. + +\( C = A - B \) where \( C_{ijk...} = A_{ijk...} - B_{ijk...} \). + +```rust +let t1 = tensor!([[1, 2], [3, 4]]); +let t2 = tensor!([[5, 6], [7, 8]]); +let diff = i1 - t2; +``` + +```sh +[[-4, -4], [-4, -4]] +``` + +## 3. Multiplication + +Element-wize multiplication of two tensors. + +\( C = A \odot B \) where \( C_{ijk...} = A_{ijk...} \times B_{ijk...} \). + +```rust +let t1 = tensor!([[1, 2], [3, 4]]); +let t2 = tensor!([[5, 6], [7, 8]]); +let prod = t1 * t2; +``` + +```sh +[[5, 12], [21, 32]] +``` + +## 4. Division + +Element-wize division of two tensors. + +\( C = A \div B \) where \( C_{ijk...} = A_{ijk...} \div B_{ijk...} \). + +```rust +let t1 = tensor!([[1, 2], [3, 4]]); +let t2 = tensor!([[1, 2], [3, 4]]); +let quot = t1 / t2; +``` + +```sh +[[1, 1], [1, 1]] +``` + +## 5. Contraction + +Contract two tensors over given axes. + +For matrices \( A \) and \( B \), \( C = AB \) where \( C_{ij} = \sum_k A_{ik} B_{kj} \). + +```rust +let t1 = tensor!([[1, 2], [3, 4], [5, 6]]); +let t2 = tensor!([[1, 2, 3], [4, 5, 6]]); + +let cont = contract((t1, [1]), (t2, [0])); +``` + +```sh +TODO! +``` + +## 6. Reduction (e.g., Sum) + +\( \text{sum}(A) \) where sum over all elements of A. + +```rust +let t1 = tensor!([[1, 2], [3, 4]]); +let total = t1.sum(); +``` + +```sh +10 +``` + +## 7. Broadcasting + +Adjusts tensors with different shapes to make them compatible for element-wise operations automatically +when using supported functions. + +## 8. Reshape + +Changing the shape of a tensor without altering its data. + +```rust +let t1 = tensor!([1, 2, 3, 4, 5, 6]); +let tr = t1.reshape([2, 3]); +``` + +```sh +[[1, 2, 3], [4, 5, 6]] +``` + +## 9. Transpose + +Transpose a tensor over given axes. + +\( B = A^T \) where \( B_{ij} = A_{ji} \). + +```rust +let t1 = tensor!([1, 2, 3, 4]); +let transposed = t1.transpose(); +``` + +```sh +TODO! +``` + +## 10. Concatenation + +Joining tensors along a specified dimension. + +```rust +let t1 = tensor!([1, 2, 3]); +let t2 = tensor!([4, 5, 6]); +let cat = t1.concat(&t2, 0); +``` + +```sh +TODO! +``` + +## 11. Slicing and Indexing + +Extracting parts of tensors based on indices. + +```rust +let t1 = tensor!([1, 2, 3, 4, 5, 6]); +let slice = t1.slice(s![1, ..]); +``` + +```sh +TODO! +``` + +## 12. Element-wise Functions (e.g., Sigmoid) + +**Mathematical Definition**: + +Applying a function to each element of a tensor, like \( \sigma(x) = \frac{1}{1 + e^{-x}} \) for sigmoid. + +**Rust Code Example**: + +```rust +let tensor = Tensor::::from([-1.0, 0.0, 1.0, 2.0]); // 2x2 tensor +let sigmoid_tensor = tensor.map(|x| 1.0 / (1.0 + (-x).exp())); // Apply sigmoid element-wise +``` + +## 13. Gradient Computation/Automatic Differentiation + +**Description**: + +Calculating the derivatives of tensors, crucial for training machine learning models. + +**Rust Code Example**: Depends on if your tensor library supports automatic differentiation. This is typically more complex and may involve constructing computational graphs. + +## 14. Normalization Operations (e.g., Batch Normalization) + +**Description**: Standardizing the inputs of a model across the batch dimension. + +**Rust Code Example**: This is specific to deep learning libraries and may not be directly supported in a general-purpose tensor library. + +## 15. Convolution Operations + +**Description**: Essential for image processing and CNNs. + +**Rust Code Example**: If your library supports it, convolutions typically involve using a specialized function that takes the input tensor and a kernel tensor. + +## 16. Pooling Operations (e.g., Max Pooling) + +**Description**: Reducing the spatial dimensions of + a tensor, commonly used in CNNs. + +**Rust Code Example**: Again, this depends on your library's support for such operations. + +## 17. Tensor Slicing and Joining + +**Description**: Operations to slice a tensor into sub-tensors or join multiple tensors into a larger tensor. + +**Rust Code Example**: Similar to the slicing and concatenation examples provided above. + +## 18. Dimension Permutation + +**Description**: Rearranging the dimensions of a tensor. + +**Rust Code Example**: + +```rust +let tensor = Tensor::::from([...]); // 3D tensor +let permuted_tensor = tensor.permute_dims([2, 0, 1]); // Permute dimensions +``` + +## 19. Expand and Squeeze Operations + +**Description**: Increasing or decreasing the dimensions of a tensor (adding/removing singleton dimensions). + +**Rust Code Example**: Depends on the specific functions provided by your library. + +## 20. Data Type Conversions + +**Description**: Converting tensors from one data type to another. + +**Rust Code Example**: + +```rust +let tensor = Tensor::::from([1, 2, 3, 4]); // 2x2 tensor +let converted_tensor = tensor.to_type::(); // Convert to f32 tensor +``` + +These examples provide a general guide. The actual implementation details may vary depending on the specific features and capabilities of the Rust tensor library you're using. + +## 21. Tensor Decompositions + +**CANDECOMP/PARAFAC (CP) Decomposition**: This decomposes a tensor into a sum of component rank-one tensors. For a third-order tensor, it's like expressing it as a sum of outer products of vectors. This is useful in applications like signal processing, psychometrics, and chemometrics. + +**Tucker Decomposition**: Similar to PCA for matrices, Tucker Decomposition decomposes a tensor into a core tensor multiplied by a matrix along each mode (dimension). It's more general than CP Decomposition and is useful in areas like data compression and tensor completion. + +**Higher-Order Singular Value Decomposition (HOSVD)**: A generalization of SVD for higher-order tensors, HOSVD decomposes a tensor into a core tensor and a set of orthogonal matrices for each mode. It's used in image processing, computer vision, and multilinear subspace learning. diff --git a/examples/operations.rs b/examples/operations.rs index 6f77676..b0babae 100644 --- a/examples/operations.rs +++ b/examples/operations.rs @@ -1,53 +1,94 @@ -use mltensor::*; +#![allow(mixed_script_confusables)] +#![allow(non_snake_case)] +use bytemuck::cast_slice; +use manifold::contract; +use manifold::*; fn tensor_product() { - println!("Tensor Product\n"); - let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor - let mut tensor2 = Tensor::::from([2]); // 2-element vector + println!("Tensor Product\n"); + let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor + let mut tensor2 = Tensor::::from([2]); // 2-element vector - // Fill tensors with some values - tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); - tensor2.buffer_mut().copy_from_slice(&[5, 6]); + // Fill tensors with some values + tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]); + tensor2.buffer_mut().copy_from_slice(&[5, 6]); - println!("T1: {}", tensor1); - println!("T2: {}", tensor2); + println!("T1: {}", tensor1); + println!("T2: {}", tensor2); - let product = tensor1.tensor_product(&tensor2); + let product = tensor1.tensor_product(&tensor2); - println!("T1 * T2 = {}", product); + println!("T1 * T2 = {}", product); - // Check shape of the resulting tensor - assert_eq!(product.shape(), Shape::new([2, 2, 2])); + // Check shape of the resulting tensor + assert_eq!(product.shape(), &Shape::new([2, 2, 2])); - // Check buffer of the resulting tensor - let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; - assert_eq!(product.buffer(), &expected_buffer); + // Check buffer of the resulting tensor + let expect: &[i32] = + cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]); + assert_eq!(product.buffer(), expect); } -fn tensor_contraction() { - println!("Tensor Contraction\n"); - // Create two tensors - let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor - let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor +fn test_tensor_contraction_23x32() { + // Define two 2D tensors (matrices) - // Specify axes for contraction - let axis_lhs = [1]; // Contract over the second dimension of tensor1 - let axis_rhs = [0]; // Contract over the first dimension of tensor2 + // Tensor A is 2x3 + let a: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + println!("a: {:?}\n{}\n", a.shape(), a); - // Perform contraction - let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs); + // Tensor B is 3x2 + let b: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); + println!("b: {:?}\n{}\n", b.shape(), b); - println!("T1: {}", tensor1); - println!("T2: {}", tensor2); - println!("T1 * T2 = {}", result); + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let ctr10 = contract((&a, [1]), (&b, [0])); - // Expected result, for example, could be a single number or a new tensor, - // depending on how you defined the contraction operation. - // Assert the result is as expected - // assert_eq!(result, expected_result); + println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10); + + let ctr01 = contract((&a, [0]), (&b, [1])); + + println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01); + // assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); + // assert_eq!( + // contracted_tensor.buffer(), + // &[9, 12, 15, 19, 26, 33, 29, 40, 51], + // "Contracted tensor buffer does not match expected" + // ); +} + +fn test_tensor_contraction_rank3() { + let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]); + let contracted_tensor = contract((&a, [2]), (&b, [0])); + + println!("a: {}", a); + println!("b: {}", b); + println!("contracted_tensor: {}", contracted_tensor); + // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); + // Verify specific elements of contracted_tensor + // assert_eq!(contracted_tensor[0][0][0][0], 50); + // assert_eq!(contracted_tensor[0][0][0][1], 60); + // ... further checks for other elements ... +} + +fn transpose() { + let a = Tensor::from([[1, 2, 3], [4, 5, 6]]); + let b = tensor!([[1, 2, 3], [4, 5, 6]]); + + // let iter = a.idx().iter_transposed([1, 0]); + + // for idx in iter { + // println!("{idx}"); + // } + let b = a.clone().transpose([1, 0]).unwrap(); + println!("a: {}", a); + println!("ta: {}", b); } fn main() { - tensor_product(); - tensor_contraction(); -} \ No newline at end of file + // tensor_product(); + // test_tensor_contraction_23x32(); + // test_tensor_contraction_rank3(); + + transpose(); +} diff --git a/src/axis.rs b/src/axis.rs new file mode 100644 index 0000000..97472e7 --- /dev/null +++ b/src/axis.rs @@ -0,0 +1,389 @@ +use super::*; +use getset::{Getters, MutGetters}; + +#[derive(Clone, Debug, Getters)] +pub struct Axis<'a, T: Value, const R: usize> { + #[getset(get = "pub")] + tensor: &'a Tensor, + #[getset(get = "pub")] + dim: usize, +} + +impl<'a, T: Value, const R: usize> Axis<'a, T, R> { + pub fn new(tensor: &'a Tensor, dim: usize) -> Self { + assert!(dim < R, "Axis out of bounds"); + Self { tensor, dim } + } + + pub fn len(&self) -> usize { + self.tensor.shape().get(self.dim) + } + + pub fn shape(&self) -> &Shape { + self.tensor.shape() + } + + pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> { + assert!(level < self.len(), "Level out of bounds"); + let mut index = Idx::new(self.shape(), [0; R]); + index.set_axis(self.dim, level); + AxisIterator::new(self).set_start(level).set_end(level + 1) + } +} + +#[derive(Clone, Debug, Getters, MutGetters)] +pub struct AxisIterator<'a, T: Value, const R: usize> { + #[getset(get = "pub")] + axis: &'a Axis<'a, T, R>, + #[getset(get = "pub", get_mut = "pub")] + index: Idx<'a, R>, + #[getset(get = "pub")] + end: Option, +} + +impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> { + pub fn new(axis: &'a Axis<'a, T, R>) -> Self { + Self { + axis, + index: Idx::new(axis.shape(), [0; R]), + end: None, + } + } + + pub fn set_start(self, start: usize) -> Self { + assert!(start < self.axis().len(), "Start out of bounds"); + let mut index = Idx::new(self.axis().shape(), [0; R]); + index.set_axis(self.axis.dim, start); + Self { + axis: self.axis(), + index, + end: None, + } + } + + pub fn set_end(self, end: usize) -> Self { + assert!(end <= self.axis().len(), "End out of bounds"); + Self { + axis: self.axis(), + index: self.index().clone(), + end: Some(end), + } + } + + pub fn set_level(self, level: usize) -> Self { + assert!(level < self.axis().len(), "Level out of bounds"); + self.set_start(level).set_end(level + 1) + } + + pub fn level(&'a self, level: usize) -> impl Iterator + 'a { + Self::new(self.axis()).set_level(level) + } + + pub fn axis_max_idx(&self) -> usize { + self.end().unwrap_or(self.axis().len()) + } + + pub fn axis_idx(&self) -> usize { + self.index().get_axis(*self.axis().dim()) + } + + pub fn axis_dim(&self) -> usize { + self.axis().dim().clone() + } +} + +impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> { + type Item = &'a T; + + fn next(&mut self) -> Option { + if self.axis_idx() == self.axis_max_idx() { + return None; + } + let result = unsafe { self.axis().tensor().get_unchecked(self.index) }; + let axis_dim = self.axis_dim(); + self.index_mut().inc_axis(axis_dim); + Some(result) + } +} + +impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> { + type Item = &'a T; + type IntoIter = AxisIterator<'a, T, R>; + + fn into_iter(self) -> Self::IntoIter { + AxisIterator::new(&self) + } +} + +pub fn contract< + 'a, + T: Value + std::fmt::Debug, + const R: usize, + const S: usize, + const N: usize, +>( + lhs: (&'a Tensor, [usize; N]), + rhs: (&'a Tensor, [usize; N]), +) -> Tensor +where + [(); R - N]:, + [(); S - N]:, + [(); R + S - 2 * N]:, +{ + let (lhs, la) = lhs; + let (rhs, ra) = rhs; + let lnc = (0..R).filter(|i| !la.contains(i)).collect::>(); + let rnc = (0..S).filter(|i| !ra.contains(i)).collect::>(); + + let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::>(); + let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::>(); + + let mut shape = Vec::new(); + shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array()); + shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array()); + let shape: [usize; R + S - 2 * N] = + shape.try_into().expect("Failed to create shape array"); + + let shape = Shape::new(shape); + + let result = contract_axes(&lnc, &rnc); + + Tensor::new_with_buffer(shape, result) +} + +pub fn contract_axes< + 'a, + T: Value + std::fmt::Debug, + const R: usize, + const S: usize, + const N: usize, +>( + laxes: &'a [Axis<'a, T, R>], + raxes: &'a [Axis<'a, T, S>], +) -> Vec +where + [(); R - N]:, + [(); S - N]:, +{ + let mut result = vec![]; + + let axes = laxes.into_iter().zip(raxes); + + for (laxis, raxis) in axes { + let mut axes_result: Vec = vec![]; + for i in 0..raxis.len() { + for j in 0..laxis.len() { + let mut sum = T::zero(); + let llevel = laxis.into_iter(); + let rlevel = raxis.into_iter(); + let zip = llevel.level(j).zip(rlevel.level(i)); + for (lv, rv) in zip { + sum = sum + *lv * *rv; + } + axes_result.push(sum); + } + } + result.extend_from_slice(&axes_result); + } + + result +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tensor_contraction_simple() { + // Define two 2D tensors (matrices) + // Tensor A is 2x3 + let a: Tensor = Tensor::from([[1, 2], [3, 4]]); + + // Tensor B is 1x3x2 + let b: Tensor = Tensor::from([[1, 2], [3, 4]]); + + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); + assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2])); + assert_eq!( + contracted_tensor.buffer(), + &[7, 10, 15, 22], + "Contracted tensor buffer does not match expected" + ); + } + + #[test] + fn test_tensor_contraction_23x32() { + // Define two 2D tensors (matrices) + + // Tensor A is 2x3 + let b: Tensor = Tensor::from([[1, 2, 3], [4, 5, 6]]); + println!("b: {}", b); + + // Tensor B is 3x2 + let a: Tensor = Tensor::from([[1, 2], [3, 4], [5, 6]]); + println!("a: {}", a); + + // Contract over the last axis of A (axis 1) and the first axis of B (axis 0) + let contracted_tensor: Tensor = contract((&a, [1]), (&b, [0])); + + println!("contracted_tensor: {}", contracted_tensor); + assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3])); + assert_eq!( + contracted_tensor.buffer(), + &[9, 12, 15, 19, 26, 33, 29, 40, 51], + "Contracted tensor buffer does not match expected" + ); + } + + #[test] + fn test_tensor_contraction_rank3() { + let a: Tensor = + Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24 + let b: Tensor = + Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24 + let contracted_tensor: Tensor = contract((&a, [2]), (&b, [0])); + + println!("a: {}", a); + println!("b: {}", b); + println!("contracted_tensor: {}", contracted_tensor); + // assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]); + // Verify specific elements of contracted_tensor + // assert_eq!(contracted_tensor[0][0][0][0], 50); + // assert_eq!(contracted_tensor[0][0][0][1], 60); + // ... further checks for other elements ... + } + + // #[test] + // fn test_axis_iterator_disassemble() { + // // Creating a 2x2 Tensor for testing + // let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); + + // // Testing iteration over the first axis (axis = 0) + // let axis = Axis::new(&tensor, 0); + + // let mut axis_iter = axis.into_iter().disassemble(); + + // assert_eq!(axis_iter[0].next(), Some(&1.0)); + // assert_eq!(axis_iter[0].next(), Some(&2.0)); + // assert_eq!(axis_iter[0].next(), None); + // assert_eq!(axis_iter[1].next(), Some(&3.0)); + // assert_eq!(axis_iter[1].next(), Some(&4.0)); + // assert_eq!(axis_iter[1].next(), None); + + // // Resetting the iterator for the second axis (axis = 1) + // let axis = Axis::new(&tensor, 1); + + // let mut axis_iter = axis.into_iter().disassemble(); + + // assert_eq!(axis_iter[0].next(), Some(&1.0)); + // assert_eq!(axis_iter[0].next(), Some(&3.0)); + // assert_eq!(axis_iter[0].next(), None); + // assert_eq!(axis_iter[1].next(), Some(&2.0)); + // assert_eq!(axis_iter[1].next(), Some(&4.0)); + // assert_eq!(axis_iter[1].next(), None); + // } + + #[test] + fn test_axis_iterator() { + // Creating a 2x2 Tensor for testing + let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]); + + // Testing iteration over the first axis (axis = 0) + let axis = Axis::new(&tensor, 0); + + let mut axis_iter = axis.into_iter(); + + assert_eq!(axis_iter.next(), Some(&1.0)); + assert_eq!(axis_iter.next(), Some(&2.0)); + assert_eq!(axis_iter.next(), Some(&3.0)); + assert_eq!(axis_iter.next(), Some(&4.0)); + + // Resetting the iterator for the second axis (axis = 1) + let axis = Axis::new(&tensor, 1); + + let mut axis_iter = axis.into_iter(); + + assert_eq!(axis_iter.next(), Some(&1.0)); + assert_eq!(axis_iter.next(), Some(&3.0)); + assert_eq!(axis_iter.next(), Some(&2.0)); + assert_eq!(axis_iter.next(), Some(&4.0)); + + let shape = tensor.shape(); + + let mut a: Idx<2> = (shape, [0, 0]).into(); + let b: Idx<2> = (shape, [1, 1]).into(); + + while a <= b { + println!("a: {}", a); + a.inc(); + } + } + + #[test] + fn test_3d_tensor_axis_iteration() { + // Create a 3D Tensor with specific values + // Tensor shape is 2x2x2 for simplicity + let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]); + + // Axis 0 (Layer-wise): + // + // t[0][0][0] = 1 + // t[0][0][1] = 2 + // t[0][1][0] = 3 + // t[0][1][1] = 4 + // t[1][0][0] = 5 + // t[1][0][1] = 6 + // t[1][1][0] = 7 + // t[1][1][1] = 8 + // [1, 2, 3, 4, 5, 6, 7, 8] + // + // This order suggests that for each "layer" (first level of arrays), + // the iterator goes through all rows and columns. It first completes + // the entire first layer, then moves to the second. + + let a0 = Axis::new(&t, 0); + let a0_order = a0.into_iter().cloned().collect::>(); + assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]); + + // Axis 1 (Row-wise within each layer): + // + // t[0][0][0] = 1 + // t[0][0][1] = 2 + // t[1][0][0] = 5 + // t[1][0][1] = 6 + // t[0][1][0] = 3 + // t[0][1][1] = 4 + // t[1][1][0] = 7 + // t[1][1][1] = 8 + // [1, 2, 5, 6, 3, 4, 7, 8] + // + // This indicates that within each "layer", the iterator first + // completes the first row across all layers, then the second row + // across all layers. + + let a1 = Axis::new(&t, 1); + let a1_order = a1.into_iter().cloned().collect::>(); + assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]); + + // Axis 2 (Column-wise within each layer): + // + // t[0][0][0] = 1 + // t[0][1][0] = 3 + // t[1][0][0] = 5 + // t[1][1][0] = 7 + // t[0][0][1] = 2 + // t[0][1][1] = 4 + // t[1][0][1] = 6 + // t[1][1][1] = 8 + // [1, 3, 5, 7, 2, 4, 6, 8] + // + // This indicates that within each "layer", the iterator first + // completes the first column across all layers, then the second + // column across all layers. + + let a2 = Axis::new(&t, 2); + let a2_order = a2.into_iter().cloned().collect::>(); + assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]); + } +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..76f24a8 --- /dev/null +++ b/src/error.rs @@ -0,0 +1,9 @@ +use thiserror::Error; + +pub type Result = std::result::Result; + +#[derive(Error, Debug)] +pub enum Error { + #[error("Invalid argument: {0}")] + InvalidArgument(String), +} diff --git a/src/index.rs b/src/index.rs index 7bc71e0..cde85cd 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,35 +1,41 @@ use super::*; +use getset::{Getters, MutGetters}; use std::cmp::Ordering; use std::ops::{Add, Sub}; -#[derive(Clone, Copy, Debug)] -pub struct Idx { +#[derive(Clone, Copy, Debug, Getters, MutGetters)] +pub struct Idx<'a, const R: usize> { + #[getset(get = "pub", get_mut = "pub")] indices: [usize; R], - shape: Shape, + #[getset(get = "pub")] + shape: &'a Shape, } -impl Idx { - pub const fn zero(shape: Shape) -> Self { +impl<'a, const R: usize> Idx<'a, R> { + pub const fn zero(shape: &'a Shape) -> Self { Self { indices: [0; R], shape, } } - pub fn max(shape: Shape) -> Self { + pub fn last(shape: &'a Shape) -> Self { let max_indices = shape.as_array().map(|dim_size| dim_size.saturating_sub(1)); Self { indices: max_indices, - shape, + shape: shape, } } - pub fn new(shape: Shape, indices: [usize; R]) -> Self { + pub fn new(shape: &'a Shape, indices: [usize; R]) -> Self { if !shape.check_indices(indices) { panic!("indices out of bounds"); } - Self { indices, shape } + Self { + indices, + shape: shape, + } } pub fn is_zero(&self) -> bool { @@ -51,6 +57,9 @@ impl Idx { /// `true` if the increment does not overflow and is still within bounds; /// `false` if it overflows, indicating the end of the tensor. pub fn inc(&mut self) -> bool { + if self.indices()[0] >= self.shape().get(0) { + return false; + } let mut carry = 1; for (i, &dim_size) in self.indices.iter_mut().zip(&self.shape.as_array()).rev() @@ -67,13 +76,74 @@ impl Idx { // If carry is still 1 after the loop, it means we've incremented past the last dimension if carry == 1 { - // Set the index to an invalid state (e.g., all indices to their max values) + // Set the index to an invalid state to indicate the end of the iteration indicated + // by setting the first index to the size of the first dimension self.indices[0] = self.shape.as_array()[0]; return true; // Indicate that the iteration is complete } false } + // fn inc_axis + + pub fn inc_axis(&mut self, fixed_axis: usize) { + assert!(fixed_axis < R, "Axis out of bounds"); + assert!( + self.indices()[fixed_axis] < self.shape().get(fixed_axis), + "Index out of bounds" + ); + + // Try to increment non-fixed axes + for i in (0..R).rev() { + if i != fixed_axis { + if self.indices[i] + 1 < self.shape.get(i) { + self.indices[i] += 1; + return; + } else { + self.indices[i] = 0; + } + } + } + + if self.indices[fixed_axis] < self.shape.get(fixed_axis) { + self.indices[fixed_axis] += 1; + for i in 0..R { + if i != fixed_axis { + self.indices[i] = 0; + } + } + return; + } + } + + pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool { + if self.indices()[order[0]] >= self.shape().get(order[0]) { + return false; + } + + let mut carry = 1; + + for i in order.iter().rev() { + let dim_size = self.shape().get(*i); + let i = self.index_mut(*i); + if carry == 1 { + *i += 1; + if *i >= dim_size { + *i = 0; + } else { + carry = 0; + } + } + } + + if carry == 1 { + self.indices_mut()[order[0]] = self.shape().get(order[0]); + return true; + } + + false + } + pub fn dec(&mut self) { // Check if already at the start if self.indices.iter().all(|&i| i == 0) { @@ -95,6 +165,61 @@ impl Idx { } } + pub fn dec_axis(&mut self, fixed_axis: usize) -> bool { + // Check if the fixed axis index is already in an invalid state + if self.indices[fixed_axis] == self.shape.get(fixed_axis) { + return false; + } + + // Try to decrement non-fixed axes + for i in (0..R).rev() { + if i != fixed_axis { + if self.indices[i] > 0 { + self.indices[i] -= 1; + return true; + } else { + self.indices[i] = self.shape.get(i) - 1; + } + } + } + + // Decrement the fixed axis if possible and reset other axes to their max + if self.indices[fixed_axis] > 0 { + self.indices[fixed_axis] -= 1; + for i in 0..R { + if i != fixed_axis { + self.indices[i] = self.shape.get(i) - 1; + } + } + } else { + // Fixed axis already at minimum, set to invalid state + self.indices[fixed_axis] = self.shape.get(fixed_axis); + } + + true + } + + pub fn dec_transposed(&mut self, order: [usize; R]) { + // Iterate over the axes in the specified order + for &axis in &order { + // Try to decrement the current axis + if self.indices[axis] > 0 { + self.indices[axis] -= 1; + // Reset all preceding axes in the order to their maximum + for &prev_axis in &order { + if prev_axis == axis { + break; + } + self.indices[prev_axis] = self.shape.get(prev_axis) - 1; + } + return; + } + } + + // If no axis can be decremented, set the first axis in the order to indicate overflow + self.indices[order[0]] = self.shape.get(order[0]); + } + /// Converts the multi-dimensional index to a flat index. /// /// This method calculates the flat index corresponding to the multi-dimensional index @@ -136,31 +261,59 @@ impl Idx { }) .0 } + + pub fn set_axis(&mut self, axis: usize, value: usize) { + assert!(axis < R, "Axis out of bounds"); + // assert!(value < self.shape.get(axis), "Value out of bounds"); + self.indices[axis] = value; + } + + pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool { + assert!(axis < R, "Axis out of bounds"); + if value < self.shape.get(axis) { + self.indices[axis] = value; + true + } else { + false + } + } + + pub fn get_axis(&self, axis: usize) -> usize { + assert!(axis < R, "Axis out of bounds"); + self.indices[axis] + } + + pub fn iter_transposed( + &self, + order: [usize; R], + ) -> IdxTransposedIterator<'a, R> { + IdxTransposedIterator::new(self.shape(), order) + } } // --- blanket impls --- -impl PartialEq for Idx { +impl<'a, const R: usize> PartialEq for Idx<'a, R> { fn eq(&self, other: &Self) -> bool { self.flat() == other.flat() } } -impl Eq for Idx {} +impl<'a, const R: usize> Eq for Idx<'a, R> {} -impl PartialOrd for Idx { +impl<'a, const R: usize> PartialOrd for Idx<'a, R> { fn partial_cmp(&self, other: &Self) -> Option { self.flat().partial_cmp(&other.flat()) } } -impl Ord for Idx { +impl<'a, const R: usize> Ord for Idx<'a, R> { fn cmp(&self, other: &Self) -> Ordering { self.flat().cmp(&other.flat()) } } -impl Index for Idx { +impl<'a, const R: usize> Index for Idx<'a, R> { type Output = usize; fn index(&self, index: usize) -> &Self::Output { @@ -168,33 +321,39 @@ impl Index for Idx { } } -impl IndexMut for Idx { +impl<'a, const R: usize> IndexMut for Idx<'a, R> { fn index_mut(&mut self, index: usize) -> &mut Self::Output { &mut self.indices[index] } } -impl From<(Shape, [usize; R])> for Idx { - fn from((shape, indices): (Shape, [usize; R])) -> Self { +impl<'a, const R: usize> From<(&'a Shape, [usize; R])> for Idx<'a, R> { + fn from((shape, indices): (&'a Shape, [usize; R])) -> Self { assert!(shape.check_indices(indices)); Self::new(shape, indices) } } -impl From<(Shape, usize)> for Idx { - fn from((shape, flat_index): (Shape, usize)) -> Self { +impl<'a, const R: usize> From<(&'a Shape, usize)> for Idx<'a, R> { + fn from((shape, flat_index): (&'a Shape, usize)) -> Self { let indices = shape.index_from_flat(flat_index).indices; Self::new(shape, indices) } } -impl From> for Idx { - fn from(shape: Shape) -> Self { +impl<'a, const R: usize> From<&'a Shape> for Idx<'a, R> { + fn from(shape: &'a Shape) -> Self { Self::zero(shape) } } -impl std::fmt::Display for Idx { +impl<'a, T: Value, const R: usize> From<&'a Tensor> for Idx<'a, R> { + fn from(tensor: &'a Tensor) -> Self { + Self::zero(tensor.shape()) + } +} + +impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[")?; for (i, (&idx, &dim_size)) in self @@ -214,7 +373,7 @@ impl std::fmt::Display for Idx { // ---- Arithmetic Operations ---- -impl Add for Idx { +impl<'a, const R: usize> Add for Idx<'a, R> { type Output = Self; fn add(self, rhs: Self) -> Self::Output { @@ -232,7 +391,7 @@ impl Add for Idx { } } -impl Sub for Idx { +impl<'a, const R: usize> Sub for Idx<'a, R> { type Output = Self; fn sub(self, rhs: Self) -> Self::Output { @@ -249,3 +408,75 @@ impl Sub for Idx { } } } + +// ---- Iterator ---- + +pub struct IdxIterator<'a, const R: usize> { + current: Idx<'a, R>, + end: bool, +} + +impl<'a, const R: usize> IdxIterator<'a, R> { + pub fn new(shape: &'a Shape) -> Self { + Self { + current: Idx::zero(shape), + end: false, + } + } +} + +impl<'a, const R: usize> Iterator for IdxIterator<'a, R> { + type Item = Idx<'a, R>; + + fn next(&mut self) -> Option { + if self.end { + return None; + } + + let result = self.current; + self.end = self.current.inc(); + Some(result) + } +} + +impl<'a, const R: usize> IntoIterator for Idx<'a, R> { + type Item = Idx<'a, R>; + type IntoIter = IdxIterator<'a, R>; + + fn into_iter(self) -> Self::IntoIter { + IdxIterator { + current: self, + end: false, + } + } +} + +pub struct IdxTransposedIterator<'a, const R: usize> { + current: Idx<'a, R>, + order: [usize; R], + end: bool, +} + +impl<'a, const R: usize> IdxTransposedIterator<'a, R> { + pub fn new(shape: &'a Shape, order: [usize; R]) -> Self { + Self { + current: Idx::zero(shape), + end: false, + order, + } + } +} + +impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> { + type Item = Idx<'a, R>; + + fn next(&mut self) -> Option { + if self.end { + return None; + } + + let result = self.current; + self.end = self.current.inc_transposed(&self.order); + Some(result) + } +} diff --git a/src/lib.rs b/src/lib.rs index be92164..52e6cf6 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,15 +1,21 @@ #![allow(incomplete_features)] #![feature(generic_const_exprs)] +pub mod axis; +pub mod error; pub mod index; pub mod shape; pub mod tensor; +pub use axis::*; pub use index::Idx; +pub use itertools::Itertools; use num::{Num, One, Zero}; pub use serde::{Deserialize, Serialize}; pub use shape::Shape; +pub use static_assertions::const_assert; pub use std::fmt::{Display, Formatter, Result as FmtResult}; use std::ops::{Index, IndexMut}; +pub use std::sync::Arc; pub use tensor::{Tensor, TensorIterator}; pub trait Value: @@ -26,9 +32,17 @@ impl Value for T where + Display + Serialize + Deserialize<'static> + + std::iter::Sum { } +#[macro_export] +macro_rules! tensor { + ($array:expr) => { + Tensor::from($array) + }; +} + // ---- Tests ---- #[cfg(test)] @@ -38,7 +52,7 @@ mod tests { #[test] fn test_tensor_product() { - let mut tensor1 = Tensor::::from([2, 2]); // 2x2 tensor + let mut tensor1 = Tensor::::from([[2], [2]]); // 2x2 tensor let mut tensor2 = Tensor::::from([2]); // 2-element vector // Fill tensors with some values @@ -48,7 +62,7 @@ mod tests { let product = tensor1.tensor_product(&tensor2); // Check shape of the resulting tensor - assert_eq!(product.shape(), Shape::new([2, 2, 2])); + assert_eq!(*product.shape(), Shape::new([2, 2, 2])); // Check buffer of the resulting tensor let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24]; @@ -153,7 +167,7 @@ mod tests { #[test] fn test_dec_method() { let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor - let mut index = Idx::zero(shape); + let mut index = Idx::zero(&shape); // Increment the index to the maximum for _ in 0..26 { @@ -162,7 +176,7 @@ mod tests { } // Check if the index is at the maximum - assert_eq!(index, Idx::new(shape, [2, 2, 2])); + assert_eq!(index, Idx::new(&shape, [2, 2, 2])); // Decrement step by step and check the index let expected_indices = [ @@ -198,7 +212,7 @@ mod tests { for (i, &expected) in expected_indices.iter().enumerate() { assert_eq!( index, - Idx::new(shape, expected), + Idx::new(&shape, expected), "Failed at index {}", i ); @@ -207,6 +221,6 @@ mod tests { // Finally, the index should reach [0, 0, 0] index.dec(); - assert_eq!(index, Idx::zero(shape)); + assert_eq!(index, Idx::zero(&shape)); } } diff --git a/src/shape.rs b/src/shape.rs index e6409ea..7d2ff7d 100644 --- a/src/shape.rs +++ b/src/shape.rs @@ -11,6 +11,18 @@ impl Shape { Self(shape) } + pub fn axis(&self, index: usize) -> Option<&usize> { + self.0.get(index) + } + + pub fn reorder(&self, indices: [usize; R]) -> Self { + let mut new_shape = Shape::new([0; R]); + for (new_index, &index) in indices.iter().enumerate() { + new_shape.0[new_index] = self.0[index]; + } + new_shape + } + pub const fn as_array(&self) -> [usize; R] { self.0 } @@ -65,18 +77,18 @@ impl Shape { } indices.reverse(); // Reverse the indices to match the original dimension order - Idx::new(*self, indices) + Idx::new(self, indices) } pub const fn index_zero(&self) -> Idx { - Idx::zero(*self) + Idx::zero(self) } pub fn index_max(&self) -> Idx { let max_indices = self.0 .map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 }); - Idx::new(*self, max_indices) + Idx::new(self, max_indices) } pub fn remove_dims( @@ -101,6 +113,31 @@ impl Shape { Shape(new_shape) } + + pub fn remove_axes<'a, T: Value, const NAX: usize>( + &self, + axes_to_remove: &'a [Axis<'a, T, R>; NAX], + ) -> Shape<{ R - NAX }> { + // Create a new array to store the remaining dimensions + let mut new_shape = [0; R - NAX]; + let mut new_index = 0; + + // Iterate over the original dimensions + for (index, &dim) in self.0.iter().enumerate() { + // Skip dimensions that are in the axes_to_remove array + for axis in axes_to_remove { + if *axis.dim() == index { + continue; + } + } + + // Add the dimension to the new shape array + new_shape[new_index] = dim; + new_index += 1; + } + + Shape(new_shape) + } } // ---- Serialize and Deserialize ---- @@ -173,6 +210,6 @@ where T: Value, { fn from(tensor: Tensor) -> Self { - tensor.shape() + *tensor.shape() } } diff --git a/src/tensor.rs b/src/tensor.rs index d8c78f6..0413b12 100644 --- a/src/tensor.rs +++ b/src/tensor.rs @@ -1,47 +1,72 @@ use super::*; +use crate::error::*; +use getset::{Getters, MutGetters}; +use std::fmt; -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)] pub struct Tensor { + #[getset(get = "pub", get_mut = "pub")] buffer: Vec, + #[getset(get = "pub")] shape: Shape, } -impl Tensor { - pub fn dot(&self, other: &Tensor) -> T { - if self.shape != other.shape { - panic!("Shapes of tensors do not match"); - } - - let mut result = T::zero(); - for (a, b) in self.buffer.iter().zip(other.buffer.iter()) { - result = result + (*a * *b); - } - - result - } -} - impl Tensor { pub fn new(shape: Shape) -> Self { - let total_size: usize = shape.iter().product(); + // Handle rank 0 tensor (scalar) as a special case + let total_size = if R == 0 { + // A rank 0 tensor should still have a buffer with one element + 1 + } else { + // For tensors of rank 1 or higher, calculate the total size normally + shape.iter().product() + }; + let buffer = vec![T::zero(); total_size]; Self { buffer, shape } } - pub fn shape(&self) -> Shape { - self.shape + pub fn new_with_buffer(shape: Shape, buffer: Vec) -> Self { + Self { buffer, shape } } - pub fn buffer(&self) -> &[T] { - &self.buffer + pub fn reshape(self, shape: Shape) -> Result { + if self.shape().size() != shape.size() { + let (ls, rs) = (self.shape().as_array(), shape.as_array()); + let (lsize, rsize) = (self.shape().size(), shape.size()); + Err(Error::InvalidArgument(format!( + "Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )", + ))) + } else { + Ok(Self { + buffer: self.buffer, + shape, + }) + } } - pub fn buffer_mut(&mut self) -> &mut [T] { - &mut self.buffer + pub fn transpose(self, order: [usize; R]) -> Result { + let buffer = Idx::from(self.shape()) + .iter_transposed(order) + .map(|index| self.get(index).unwrap().clone()) + .collect(); + + Ok(Tensor { + buffer, + shape: self.shape().reorder(order), + }) } - pub fn get(&self, index: Idx) -> &T { - &self.buffer[index.flat()] + pub fn idx(&self) -> Idx { + Idx::from(self) + } + + pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> { + Axis::new(self, axis) + } + + pub fn get(&self, index: Idx) -> Option<&T> { + self.buffer.get(index.flat()) } pub unsafe fn get_unchecked(&self, index: Idx) -> &T { @@ -56,6 +81,22 @@ impl Tensor { self.buffer.get_unchecked_mut(index.flat()) } + pub fn get_flat(&self, index: usize) -> Option<&T> { + self.buffer.get(index) + } + + pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T { + self.buffer.get_unchecked(index) + } + + pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> { + self.buffer.get_mut(index) + } + + pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T { + self.buffer.get_unchecked_mut(index) + } + pub fn rank(&self) -> usize { R } @@ -111,93 +152,13 @@ impl Tensor { } } - // `self_dims` and `other_dims` specify the dimensions to contract over - pub fn contract( - &self, - rhs: &Tensor, - axis_lhs: [usize; NAXL], - axis_rhs: [usize; NAXR], - ) -> Tensor - where - [(); R - NAXL]:, - [(); S - NAXR]:, - { - // Step 1: Validate the axes for both tensors to ensure they are within bounds - for &axis in &axis_lhs { - if axis >= R { - panic!( - "Axis {} is out of bounds for the left-hand tensor", - axis - ); - } - } - - for &axis in &axis_rhs { - if axis >= S { - panic!( - "Axis {} is out of bounds for the right-hand tensor", - axis - ); - } - } - - // Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions - let mut result_buffer = Vec::new(); - - for i in 0..self.shape.size() { - for j in 0..rhs.shape.size() { - // Debug: Print indices being processed - println!("Processing Indices: lhs = {}, rhs = {}", i, j); - - if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) { - let mut product_sum = T::zero(); - - // Debug: Print axes of contraction - println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs); - - for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) { - // Debug: Print values being multiplied - let value_lhs = self.get_by_axis(axis_l, i).unwrap(); - let value_rhs = rhs.get_by_axis(axis_r, j).unwrap(); - println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs); - - product_sum = product_sum + value_lhs * value_rhs; - } - - // Debug: Print the product sum for the current indices - println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum); - - result_buffer.push(product_sum); - } - } - } - - // Step 3: Remove contracted dimensions to create new shapes for both tensors - let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs); - let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs); - - // Step 4: Concatenate the shapes to form the shape of the resultant tensor - let mut new_shape = Vec::new(); - - new_shape.extend_from_slice(&new_shape_lhs.as_array()); - new_shape.extend_from_slice(&new_shape_rhs.as_array()); - - let new_shape_array: [usize; R + S - NAXL - NAXR] = - new_shape.try_into().expect("Failed to create shape array"); - - Tensor { - buffer: result_buffer, - shape: Shape::new(new_shape_array), - } - } - // Retrieve an element based on a specific axis and index pub fn get_by_axis(&self, axis: usize, index: usize) -> Option { // Convert axis and index to a flat index let flat_index = self.axis_to_flat_index(axis, index); - if flat_index >= self.buffer.len() { - return None; - } + if flat_index >= self.buffer.len() { + return None; + } Some(self.buffer[flat_index]) } @@ -214,7 +175,7 @@ impl Tensor { // Calculate the stride for each dimension and accumulate the flat index for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() { - println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); + println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride); if i > axis { stride *= dim_size; } else if i == axis { @@ -229,7 +190,7 @@ impl Tensor { // ---- Indexing ---- -impl Index> for Tensor { +impl<'a, T: Value, const R: usize> Index> for Tensor { type Output = T; fn index(&self, index: Idx) -> &Self::Output { @@ -237,33 +198,61 @@ impl Index> for Tensor { } } -impl IndexMut> for Tensor { +impl<'a, T: Value, const R: usize> IndexMut> for Tensor { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { &mut self.buffer[index.flat()] } } -impl std::fmt::Display for Tensor { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - // Print the shape of the tensor - write!(f, "Shape: [")?; - for (i, dim_size) in self.shape.as_array().iter().enumerate() { - write!(f, "{}", dim_size)?; - if i < R - 1 { - write!(f, ", ")?; - } - } - write!(f, "], Elements: [")?; +impl Index for Tensor { + type Output = T; - // Print the elements in a flattened form - for (i, elem) in self.buffer.iter().enumerate() { - write!(f, "{}", elem)?; - if i < self.buffer.len() - 1 { - write!(f, ", ")?; - } - } + fn index(&self, index: usize) -> &Self::Output { + &self.buffer[index] + } +} - write!(f, "]") +impl IndexMut for Tensor { + fn index_mut(&mut self, index: usize) -> &mut Self::Output { + &mut self.buffer[index] + } +} + +// ---- Display ---- + +impl Tensor +where + T: fmt::Display + Clone, +{ + fn fmt_helper( + buffer: &[T], + shape: &[usize], + f: &mut fmt::Formatter<'_>, + level: usize, + ) -> fmt::Result { + if shape.is_empty() { + // Base case: print individual elements + write!(f, "{}", buffer[0]) + } else { + let sub_len = shape[1..].iter().product::(); + write!(f, "[")?; + for (i, chunk) in buffer.chunks(sub_len).enumerate() { + if i > 0 { + write!(f, ",")?; + } + Tensor::::fmt_helper(chunk, &shape[1..], f, level + 1)?; + } + write!(f, "]") + } + } +} + +impl fmt::Display for Tensor +where + T: fmt::Display + Clone, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + Tensor::::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1) } } @@ -271,7 +260,7 @@ impl std::fmt::Display for Tensor { pub struct TensorIterator<'a, T: Value, const R: usize> { tensor: &'a Tensor, - index: Idx, + index: Idx<'a, R>, } impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> { @@ -342,10 +331,26 @@ impl From> for Tensor { } } -impl From<[usize; R]> for Tensor { - fn from(shape: [usize; R]) -> Self { - let shape = Shape::new(shape); - Self::new(shape.into()) +impl From for Tensor { + fn from(value: T) -> Self { + let shape = Shape::new([]); + let mut tensor = Tensor::new(shape); + tensor.buffer_mut()[0] = value; + tensor + } +} + +impl From<[T; X]> for Tensor { + fn from(array: [T; X]) -> Self { + let shape = Shape::new([X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, &elem) in array.iter().enumerate() { + buffer[i] = elem; + } + + tensor } } @@ -366,3 +371,123 @@ impl From<[[T; X]; Y]> tensor } } + +impl + From<[[[T; X]; Y]; Z]> for Tensor +{ + fn from(array: [[[T; X]; Y]; Z]) -> Self { + let shape = Shape::new([Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, plane) in array.iter().enumerate() { + for (j, row) in plane.iter().enumerate() { + for (k, &elem) in row.iter().enumerate() { + buffer[i * X * Y + j * X + k] = elem; + } + } + } + + tensor + } +} + +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + > From<[[[[T; X]; Y]; Z]; W]> for Tensor +{ + fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self { + let shape = Shape::new([W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperplane) in array.iter().enumerate() { + for (j, plane) in hyperplane.iter().enumerate() { + for (k, row) in plane.iter().enumerate() { + for (l, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem; + } + } + } + } + + tensor + } +} + +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + const V: usize, + > From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor +{ + fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self { + let shape = Shape::new([V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperhyperplane) in array.iter().enumerate() { + for (j, hyperplane) in hyperhyperplane.iter().enumerate() { + for (k, plane) in hyperplane.iter().enumerate() { + for (l, row) in plane.iter().enumerate() { + for (m, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W + + j * X * Y * Z + + k * X * Y + + l * X + + m] = elem; + } + } + } + } + } + + tensor + } +} + +impl< + T: Value, + const X: usize, + const Y: usize, + const Z: usize, + const W: usize, + const V: usize, + const U: usize, + > From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor +{ + fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self { + let shape = Shape::new([U, V, W, Z, Y, X]); + let mut tensor = Tensor::new(shape); + let buffer = tensor.buffer_mut(); + + for (i, hyperhyperhyperplane) in array.iter().enumerate() { + for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate() + { + for (k, hyperplane) in hyperhyperplane.iter().enumerate() { + for (l, plane) in hyperplane.iter().enumerate() { + for (m, row) in plane.iter().enumerate() { + for (n, &elem) in row.iter().enumerate() { + buffer[i * X * Y * Z * W * V + + j * X * Y * Z * W + + k * X * Y * Z + + l * X * Y + + m * X + + n] = elem; + } + } + } + } + } + } + + tensor + } +}