🔧 General Plumming #1
64
.vscode/launch.json
vendored
Normal file
64
.vscode/launch.json
vendored
Normal file
@ -0,0 +1,64 @@
|
||||
{
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in library 'manifold'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--lib",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "manifold",
|
||||
"kind": "lib"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug example 'operations'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"build",
|
||||
"--example=operations",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "operations",
|
||||
"kind": "example"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
},
|
||||
{
|
||||
"type": "lldb",
|
||||
"request": "launch",
|
||||
"name": "Debug unit tests in example 'operations'",
|
||||
"cargo": {
|
||||
"args": [
|
||||
"test",
|
||||
"--no-run",
|
||||
"--example=operations",
|
||||
"--package=manifold"
|
||||
],
|
||||
"filter": {
|
||||
"name": "operations",
|
||||
"kind": "example"
|
||||
}
|
||||
},
|
||||
"args": [],
|
||||
"cwd": "${workspaceFolder}"
|
||||
}
|
||||
]
|
||||
}
|
688
Cargo.lock
generated
688
Cargo.lock
generated
@ -2,18 +2,262 @@
|
||||
# It is not intended for manual editing.
|
||||
version = 3
|
||||
|
||||
[[package]]
|
||||
name = "aho-corasick"
|
||||
version = "1.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0"
|
||||
dependencies = [
|
||||
"memchr",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "anes"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299"
|
||||
|
||||
[[package]]
|
||||
name = "anstyle"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87"
|
||||
|
||||
[[package]]
|
||||
name = "autocfg"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
|
||||
|
||||
[[package]]
|
||||
name = "bitflags"
|
||||
version = "2.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "327762f6e5a765692301e5bb513e0d9fef63be86bbc14528052b1cd3e6f03e07"
|
||||
|
||||
[[package]]
|
||||
name = "bumpalo"
|
||||
version = "3.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec"
|
||||
|
||||
[[package]]
|
||||
name = "bytemuck"
|
||||
version = "1.14.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "374d28ec25809ee0e23827c2ab573d729e293f281dfe393500e7ad618baa61c6"
|
||||
|
||||
[[package]]
|
||||
name = "cast"
|
||||
version = "0.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5"
|
||||
|
||||
[[package]]
|
||||
name = "cfg-if"
|
||||
version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"ciborium-ll",
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-io"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656"
|
||||
|
||||
[[package]]
|
||||
name = "ciborium-ll"
|
||||
version = "0.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b"
|
||||
dependencies = [
|
||||
"ciborium-io",
|
||||
"half",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap"
|
||||
version = "4.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dcfab8ba68f3668e89f6ff60f5b205cea56aa7b769451a59f34b8682f51c056d"
|
||||
dependencies = [
|
||||
"clap_builder",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_builder"
|
||||
version = "4.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fb7fb5e4e979aec3be7791562fcba452f94ad85e954da024396433e0e25a79e9"
|
||||
dependencies = [
|
||||
"anstyle",
|
||||
"clap_lex",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "clap_lex"
|
||||
version = "0.6.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "702fc72eb24e5a1e48ce58027a675bc24edd52096d5397d4aea7c6dd9eca0bd1"
|
||||
|
||||
[[package]]
|
||||
name = "criterion"
|
||||
version = "0.5.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f"
|
||||
dependencies = [
|
||||
"anes",
|
||||
"cast",
|
||||
"ciborium",
|
||||
"clap",
|
||||
"criterion-plot",
|
||||
"is-terminal",
|
||||
"itertools 0.10.5",
|
||||
"num-traits",
|
||||
"once_cell",
|
||||
"oorandom",
|
||||
"plotters",
|
||||
"rayon",
|
||||
"regex",
|
||||
"serde",
|
||||
"serde_derive",
|
||||
"serde_json",
|
||||
"tinytemplate",
|
||||
"walkdir",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "criterion-plot"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1"
|
||||
dependencies = [
|
||||
"cast",
|
||||
"itertools 0.10.5",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-deque"
|
||||
version = "0.8.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fca89a0e215bab21874660c67903c5f143333cab1da83d041c7ded6053774751"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"crossbeam-epoch",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-epoch"
|
||||
version = "0.9.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0e3681d554572a651dda4186cd47240627c3d0114d45a95f6ad27f2f22e7548d"
|
||||
dependencies = [
|
||||
"autocfg",
|
||||
"cfg-if",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "crossbeam-utils"
|
||||
version = "0.8.18"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c3a430a770ebd84726f584a90ee7f020d28db52c6d02138900f22341f866d39c"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "either"
|
||||
version = "1.9.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
|
||||
|
||||
[[package]]
|
||||
name = "errno"
|
||||
version = "0.3.8"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
version = "0.2.11"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"libc",
|
||||
"wasi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getset"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "half"
|
||||
version = "1.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7"
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
|
||||
|
||||
[[package]]
|
||||
name = "is-terminal"
|
||||
version = "0.4.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0bad00257d07be169d870ab665980b06cdb366d792ad690bf2e76876dc503455"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"rustix",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.10.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "25db6b064527c5d482d0423354fcd07a89a2dfe07b67892e62411946db7f07b0"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.10"
|
||||
@ -21,15 +265,54 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b1a46d1a171d865aa5f83f92695765caa047a9b4cbae2cbf37dbd613a793fd4c"
|
||||
|
||||
[[package]]
|
||||
name = "mltensor"
|
||||
name = "js-sys"
|
||||
version = "0.3.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cee9c64da59eae3b50095c18d3e74f8b73c0b86d2792824ff01bbce68ba229ca"
|
||||
dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "libc"
|
||||
version = "0.2.151"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "302d7ab3130588088d277783b1e2d2e10c9e9e4a16dd9050e6ec93fb3e7048f4"
|
||||
|
||||
[[package]]
|
||||
name = "linux-raw-sys"
|
||||
version = "0.4.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c4cd1a83af159aa67994778be9070f0ae1bd732942279cabb14f86f986a21456"
|
||||
|
||||
[[package]]
|
||||
name = "log"
|
||||
version = "0.4.20"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
|
||||
|
||||
[[package]]
|
||||
name = "manifold"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"criterion",
|
||||
"getset",
|
||||
"itertools 0.12.0",
|
||||
"num",
|
||||
"rand",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"static_assertions",
|
||||
"thiserror",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "memchr"
|
||||
version = "2.7.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
|
||||
|
||||
[[package]]
|
||||
name = "num"
|
||||
version = "0.4.1"
|
||||
@ -106,6 +389,76 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "once_cell"
|
||||
version = "1.19.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92"
|
||||
|
||||
[[package]]
|
||||
name = "oorandom"
|
||||
version = "11.1.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575"
|
||||
|
||||
[[package]]
|
||||
name = "plotters"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45"
|
||||
dependencies = [
|
||||
"num-traits",
|
||||
"plotters-backend",
|
||||
"plotters-svg",
|
||||
"wasm-bindgen",
|
||||
"web-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "plotters-backend"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609"
|
||||
|
||||
[[package]]
|
||||
name = "plotters-svg"
|
||||
version = "0.3.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab"
|
||||
dependencies = [
|
||||
"plotters-backend",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ppv-lite86"
|
||||
version = "0.2.17"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.71"
|
||||
@ -124,12 +477,113 @@ dependencies = [
|
||||
"proc-macro2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand"
|
||||
version = "0.8.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"rand_chacha",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_chacha"
|
||||
version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88"
|
||||
dependencies = [
|
||||
"ppv-lite86",
|
||||
"rand_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rand_core"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c"
|
||||
dependencies = [
|
||||
"getrandom",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon"
|
||||
version = "1.8.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1"
|
||||
dependencies = [
|
||||
"either",
|
||||
"rayon-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "rayon-core"
|
||||
version = "1.12.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed"
|
||||
dependencies = [
|
||||
"crossbeam-deque",
|
||||
"crossbeam-utils",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex"
|
||||
version = "1.10.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-automata",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-automata"
|
||||
version = "0.4.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f"
|
||||
dependencies = [
|
||||
"aho-corasick",
|
||||
"memchr",
|
||||
"regex-syntax",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "regex-syntax"
|
||||
version = "0.8.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
version = "0.38.28"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"errno",
|
||||
"libc",
|
||||
"linux-raw-sys",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ryu"
|
||||
version = "1.0.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f98d2aa92eebf49b69786be48e4477826b256916e84a57ff2a4f21923b48eb4c"
|
||||
|
||||
[[package]]
|
||||
name = "same-file"
|
||||
version = "1.0.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502"
|
||||
dependencies = [
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "serde"
|
||||
version = "1.0.193"
|
||||
@ -147,7 +601,7 @@ checksum = "43576ca501357b9b071ac53cdc7da8ef0cbd9493d8df094cd821777ea6e894d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.43",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -161,6 +615,23 @@ dependencies = [
|
||||
"serde",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "static_assertions"
|
||||
version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.43"
|
||||
@ -172,8 +643,221 @@ dependencies = [
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror"
|
||||
version = "1.0.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "83a48fd946b02c0a526b2e9481c8e2a17755e47039164a86c4070446e3a4614d"
|
||||
dependencies = [
|
||||
"thiserror-impl",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "thiserror-impl"
|
||||
version = "1.0.52"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e7fbe9b594d6568a6a1443250a7e67d80b74e1e96f6d1715e1e21cc1888291d3"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.43",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tinytemplate"
|
||||
version = "1.2.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc"
|
||||
dependencies = [
|
||||
"serde",
|
||||
"serde_json",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "unicode-ident"
|
||||
version = "1.0.12"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
|
||||
|
||||
[[package]]
|
||||
name = "version_check"
|
||||
version = "0.9.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f"
|
||||
|
||||
[[package]]
|
||||
name = "walkdir"
|
||||
version = "2.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee"
|
||||
dependencies = [
|
||||
"same-file",
|
||||
"winapi-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasi"
|
||||
version = "0.11.0+wasi-snapshot-preview1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen"
|
||||
version = "0.2.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ed0d4f68a3015cc185aff4db9506a015f4b96f95303897bfa23f846db54064e"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"wasm-bindgen-macro",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-backend"
|
||||
version = "0.2.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1b56f625e64f3a1084ded111c4d5f477df9f8c92df113852fa5a374dbda78826"
|
||||
dependencies = [
|
||||
"bumpalo",
|
||||
"log",
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.43",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro"
|
||||
version = "0.2.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0162dbf37223cd2afce98f3d0785506dcb8d266223983e4b5b525859e6e182b2"
|
||||
dependencies = [
|
||||
"quote",
|
||||
"wasm-bindgen-macro-support",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-macro-support"
|
||||
version = "0.2.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.43",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "wasm-bindgen-shared"
|
||||
version = "0.2.89"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7ab9b36309365056cd639da3134bf87fa8f3d86008abf99e612384a6eecd459f"
|
||||
|
||||
[[package]]
|
||||
name = "web-sys"
|
||||
version = "0.3.66"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "50c24a44ec86bb68fbecd1b3efed7e85ea5621b39b35ef2766b66cd984f8010f"
|
||||
dependencies = [
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi"
|
||||
version = "0.3.9"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419"
|
||||
dependencies = [
|
||||
"winapi-i686-pc-windows-gnu",
|
||||
"winapi-x86_64-pc-windows-gnu",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-i686-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
|
||||
|
||||
[[package]]
|
||||
name = "winapi-util"
|
||||
version = "0.1.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
|
||||
dependencies = [
|
||||
"winapi",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "winapi-x86_64-pc-windows-gnu"
|
||||
version = "0.4.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f"
|
||||
|
||||
[[package]]
|
||||
name = "windows-sys"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d"
|
||||
dependencies = [
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows-targets"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8a18201040b24831fbb9e4eb208f8892e1f50a37feb53cc7ff887feb8f50e7cd"
|
||||
dependencies = [
|
||||
"windows_aarch64_gnullvm",
|
||||
"windows_aarch64_msvc",
|
||||
"windows_i686_gnu",
|
||||
"windows_i686_msvc",
|
||||
"windows_x86_64_gnu",
|
||||
"windows_x86_64_gnullvm",
|
||||
"windows_x86_64_msvc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "cb7764e35d4db8a7921e09562a0304bf2f93e0a51bfccee0bd0bb0b666b015ea"
|
||||
|
||||
[[package]]
|
||||
name = "windows_aarch64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbaa0368d4f1d2aaefc55b6fcfee13f41544ddf36801e793edbbfd7d7df075ef"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_gnu"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a28637cb1fa3560a16915793afb20081aba2c92ee8af57b4d5f28e4b3e7df313"
|
||||
|
||||
[[package]]
|
||||
name = "windows_i686_msvc"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ffe5e8e31046ce6230cc7215707b816e339ff4d4d67c65dffa206fd0f7aa7b9a"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnu"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d6fa32db2bc4a2f5abeacf2b69f7992cd09dca97498da74a151a3132c26befd"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_gnullvm"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1a657e1e9d3f514745a572a6846d3c7aa7dbe1658c056ed9c3344c4109a6949e"
|
||||
|
||||
[[package]]
|
||||
name = "windows_x86_64_msvc"
|
||||
version = "0.52.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dff9641d1cd4be8d1a070daf9e3773c5f67e78b4d9d42263020c057706765c04"
|
||||
|
26
Cargo.toml
26
Cargo.toml
@ -1,12 +1,32 @@
|
||||
[package]
|
||||
name = "mltensor"
|
||||
name = "manifold"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
license = "MIT/Apache-2.0"
|
||||
authors = ["Julius Koskela <julius.koskela@nordic-dev.net>"]
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
description = """
|
||||
GDSL is a graph data-structure library including graph containers,
|
||||
connected node strutures and efficient algorithms on those structures.
|
||||
Nodes are independent of a graph container and can be used as connected
|
||||
smart pointers.
|
||||
"""
|
||||
|
||||
repository = "https://nordic-dev.net/julius/manifold"
|
||||
|
||||
keywords = ["data-structures", "algorithms", "containers"]
|
||||
categories = ["data-structures", "algorithms", "mathematics", "science"]
|
||||
|
||||
[dependencies]
|
||||
bytemuck = "1.14.0"
|
||||
getset = "0.1.2"
|
||||
itertools = "0.12.0"
|
||||
num = "0.4.1"
|
||||
serde = { version = "1.0.193", features = ["derive"] }
|
||||
serde_json = "1.0.108"
|
||||
serde_json = "1.0.108"
|
||||
static_assertions = "1.1.0"
|
||||
thiserror = "1.0.52"
|
||||
|
||||
[dev-dependencies]
|
||||
rand = "0.8.5"
|
||||
criterion = "0.5.1"
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Tensorc
|
||||
# Mainfold
|
||||
|
||||
```rust
|
||||
// Create two tensors with different ranks and shapes
|
||||
|
34
docs/tensor-contraction.md
Normal file
34
docs/tensor-contraction.md
Normal file
@ -0,0 +1,34 @@
|
||||
To understand how the tensor contraction should work for the given tensors `a` and `b`, let's first clarify their shapes and then walk through the contraction steps:
|
||||
|
||||
1. **Tensor Shapes**:
|
||||
- Tensor `a` is a 2x3 matrix (3 rows and 2 columns): \[\begin{matrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{matrix}\]
|
||||
- Tensor `b` is a 3x2 matrix (2 rows and 3 columns): \[\begin{matrix} 1 & 2 & 3 \\ 4 & 5 & 6 \end{matrix}\]
|
||||
|
||||
2. **Tensor Contraction Operation**:
|
||||
- The contraction operation in this case involves multiplying corresponding elements along the shared dimension (the second dimension of `a` and the first dimension of `b`) and summing the results.
|
||||
- The resulting tensor will have the shape determined by the other dimensions of the original tensors, which in this case is 3x3.
|
||||
|
||||
3. **Contraction Steps**:
|
||||
|
||||
- Step 1: Multiply each element of the first row of `a` with each element of the first column of `b`, then sum these products. This forms the first element of the resulting matrix.
|
||||
- \( (1 \times 1) + (2 \times 4) = 1 + 8 = 9 \)
|
||||
- Step 2: Multiply each element of the first row of `a` with each element of the second column of `b`, then sum these products. This forms the second element of the first row of the resulting matrix.
|
||||
- \( (1 \times 2) + (2 \times 5) = 2 + 10 = 12 \)
|
||||
- Step 3: Multiply each element of the first row of `a` with each element of the third column of `b`, then sum these products. This forms the third element of the first row of the resulting matrix.
|
||||
- \( (1 \times 3) + (2 \times 6) = 3 + 12 = 15 \)
|
||||
|
||||
- Continue this process for the remaining rows of `a` and columns of `b`:
|
||||
- For the second row of `a`:
|
||||
- \( (3 \times 1) + (4 \times 4) = 3 + 16 = 19 \)
|
||||
- \( (3 \times 2) + (4 \times 5) = 6 + 20 = 26 \)
|
||||
- \( (3 \times 3) + (4 \times 6) = 9 + 24 = 33 \)
|
||||
- For the third row of `a`:
|
||||
- \( (5 \times 1) + (6 \times 4) = 5 + 24 = 29 \)
|
||||
- \( (5 \times 2) + (6 \times 5) = 10 + 30 = 40 \)
|
||||
- \( (5 \times 3) + (6 \times 6) = 15 + 36 = 51 \)
|
||||
|
||||
4. **Resulting Tensor**:
|
||||
- The resulting 3x3 tensor from the contraction of `a` and `b` will be:
|
||||
\[\begin{matrix} 9 & 12 & 15 \\ 19 & 26 & 33 \\ 29 & 40 & 51 \end{matrix}\]
|
||||
|
||||
These steps provide the detailed calculations for each element of the resulting tensor after contracting tensors `a` and `b`.
|
239
docs/tensor-operations.md
Normal file
239
docs/tensor-operations.md
Normal file
@ -0,0 +1,239 @@
|
||||
# Operations Index
|
||||
|
||||
## 1. Addition
|
||||
|
||||
Element-wize addition of two tensors.
|
||||
|
||||
\( C = A + B \) where \( C_{ijk...} = A_{ijk...} + B_{ijk...} \) for all indices \( i, j, k, ... \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4]]);
|
||||
let t2 = tensor!([[5, 6], [7, 8]]);
|
||||
let sum = t1 + t2;
|
||||
```
|
||||
|
||||
```sh
|
||||
[[7, 8], [10, 12]]
|
||||
```
|
||||
|
||||
## 2. Subtraction
|
||||
|
||||
Element-wize substraction of two tensors.
|
||||
|
||||
\( C = A - B \) where \( C_{ijk...} = A_{ijk...} - B_{ijk...} \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4]]);
|
||||
let t2 = tensor!([[5, 6], [7, 8]]);
|
||||
let diff = i1 - t2;
|
||||
```
|
||||
|
||||
```sh
|
||||
[[-4, -4], [-4, -4]]
|
||||
```
|
||||
|
||||
## 3. Multiplication
|
||||
|
||||
Element-wize multiplication of two tensors.
|
||||
|
||||
\( C = A \odot B \) where \( C_{ijk...} = A_{ijk...} \times B_{ijk...} \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4]]);
|
||||
let t2 = tensor!([[5, 6], [7, 8]]);
|
||||
let prod = t1 * t2;
|
||||
```
|
||||
|
||||
```sh
|
||||
[[5, 12], [21, 32]]
|
||||
```
|
||||
|
||||
## 4. Division
|
||||
|
||||
Element-wize division of two tensors.
|
||||
|
||||
\( C = A \div B \) where \( C_{ijk...} = A_{ijk...} \div B_{ijk...} \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4]]);
|
||||
let t2 = tensor!([[1, 2], [3, 4]]);
|
||||
let quot = t1 / t2;
|
||||
```
|
||||
|
||||
```sh
|
||||
[[1, 1], [1, 1]]
|
||||
```
|
||||
|
||||
## 5. Contraction
|
||||
|
||||
Contract two tensors over given axes.
|
||||
|
||||
For matrices \( A \) and \( B \), \( C = AB \) where \( C_{ij} = \sum_k A_{ik} B_{kj} \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4], [5, 6]]);
|
||||
let t2 = tensor!([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
let cont = contract((t1, [1]), (t2, [0]));
|
||||
```
|
||||
|
||||
```sh
|
||||
TODO!
|
||||
```
|
||||
|
||||
## 6. Reduction (e.g., Sum)
|
||||
|
||||
\( \text{sum}(A) \) where sum over all elements of A.
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([[1, 2], [3, 4]]);
|
||||
let total = t1.sum();
|
||||
```
|
||||
|
||||
```sh
|
||||
10
|
||||
```
|
||||
|
||||
## 7. Broadcasting
|
||||
|
||||
Adjusts tensors with different shapes to make them compatible for element-wise operations automatically
|
||||
when using supported functions.
|
||||
|
||||
## 8. Reshape
|
||||
|
||||
Changing the shape of a tensor without altering its data.
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([1, 2, 3, 4, 5, 6]);
|
||||
let tr = t1.reshape([2, 3]);
|
||||
```
|
||||
|
||||
```sh
|
||||
[[1, 2, 3], [4, 5, 6]]
|
||||
```
|
||||
|
||||
## 9. Transpose
|
||||
|
||||
Transpose a tensor over given axes.
|
||||
|
||||
\( B = A^T \) where \( B_{ij} = A_{ji} \).
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([1, 2, 3, 4]);
|
||||
let transposed = t1.transpose();
|
||||
```
|
||||
|
||||
```sh
|
||||
TODO!
|
||||
```
|
||||
|
||||
## 10. Concatenation
|
||||
|
||||
Joining tensors along a specified dimension.
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([1, 2, 3]);
|
||||
let t2 = tensor!([4, 5, 6]);
|
||||
let cat = t1.concat(&t2, 0);
|
||||
```
|
||||
|
||||
```sh
|
||||
TODO!
|
||||
```
|
||||
|
||||
## 11. Slicing and Indexing
|
||||
|
||||
Extracting parts of tensors based on indices.
|
||||
|
||||
```rust
|
||||
let t1 = tensor!([1, 2, 3, 4, 5, 6]);
|
||||
let slice = t1.slice(s![1, ..]);
|
||||
```
|
||||
|
||||
```sh
|
||||
TODO!
|
||||
```
|
||||
|
||||
## 12. Element-wise Functions (e.g., Sigmoid)
|
||||
|
||||
**Mathematical Definition**:
|
||||
|
||||
Applying a function to each element of a tensor, like \( \sigma(x) = \frac{1}{1 + e^{-x}} \) for sigmoid.
|
||||
|
||||
**Rust Code Example**:
|
||||
|
||||
```rust
|
||||
let tensor = Tensor::<f32, 2>::from([-1.0, 0.0, 1.0, 2.0]); // 2x2 tensor
|
||||
let sigmoid_tensor = tensor.map(|x| 1.0 / (1.0 + (-x).exp())); // Apply sigmoid element-wise
|
||||
```
|
||||
|
||||
## 13. Gradient Computation/Automatic Differentiation
|
||||
|
||||
**Description**:
|
||||
|
||||
Calculating the derivatives of tensors, crucial for training machine learning models.
|
||||
|
||||
**Rust Code Example**: Depends on if your tensor library supports automatic differentiation. This is typically more complex and may involve constructing computational graphs.
|
||||
|
||||
## 14. Normalization Operations (e.g., Batch Normalization)
|
||||
|
||||
**Description**: Standardizing the inputs of a model across the batch dimension.
|
||||
|
||||
**Rust Code Example**: This is specific to deep learning libraries and may not be directly supported in a general-purpose tensor library.
|
||||
|
||||
## 15. Convolution Operations
|
||||
|
||||
**Description**: Essential for image processing and CNNs.
|
||||
|
||||
**Rust Code Example**: If your library supports it, convolutions typically involve using a specialized function that takes the input tensor and a kernel tensor.
|
||||
|
||||
## 16. Pooling Operations (e.g., Max Pooling)
|
||||
|
||||
**Description**: Reducing the spatial dimensions of
|
||||
a tensor, commonly used in CNNs.
|
||||
|
||||
**Rust Code Example**: Again, this depends on your library's support for such operations.
|
||||
|
||||
## 17. Tensor Slicing and Joining
|
||||
|
||||
**Description**: Operations to slice a tensor into sub-tensors or join multiple tensors into a larger tensor.
|
||||
|
||||
**Rust Code Example**: Similar to the slicing and concatenation examples provided above.
|
||||
|
||||
## 18. Dimension Permutation
|
||||
|
||||
**Description**: Rearranging the dimensions of a tensor.
|
||||
|
||||
**Rust Code Example**:
|
||||
|
||||
```rust
|
||||
let tensor = Tensor::<i32, 3>::from([...]); // 3D tensor
|
||||
let permuted_tensor = tensor.permute_dims([2, 0, 1]); // Permute dimensions
|
||||
```
|
||||
|
||||
## 19. Expand and Squeeze Operations
|
||||
|
||||
**Description**: Increasing or decreasing the dimensions of a tensor (adding/removing singleton dimensions).
|
||||
|
||||
**Rust Code Example**: Depends on the specific functions provided by your library.
|
||||
|
||||
## 20. Data Type Conversions
|
||||
|
||||
**Description**: Converting tensors from one data type to another.
|
||||
|
||||
**Rust Code Example**:
|
||||
|
||||
```rust
|
||||
let tensor = Tensor::<i32, 2>::from([1, 2, 3, 4]); // 2x2 tensor
|
||||
let converted_tensor = tensor.to_type::<f32>(); // Convert to f32 tensor
|
||||
```
|
||||
|
||||
These examples provide a general guide. The actual implementation details may vary depending on the specific features and capabilities of the Rust tensor library you're using.
|
||||
|
||||
## 21. Tensor Decompositions
|
||||
|
||||
**CANDECOMP/PARAFAC (CP) Decomposition**: This decomposes a tensor into a sum of component rank-one tensors. For a third-order tensor, it's like expressing it as a sum of outer products of vectors. This is useful in applications like signal processing, psychometrics, and chemometrics.
|
||||
|
||||
**Tucker Decomposition**: Similar to PCA for matrices, Tucker Decomposition decomposes a tensor into a core tensor multiplied by a matrix along each mode (dimension). It's more general than CP Decomposition and is useful in areas like data compression and tensor completion.
|
||||
|
||||
**Higher-Order Singular Value Decomposition (HOSVD)**: A generalization of SVD for higher-order tensors, HOSVD decomposes a tensor into a core tensor and a set of orthogonal matrices for each mode. It's used in image processing, computer vision, and multilinear subspace learning.
|
@ -1,53 +1,94 @@
|
||||
use mltensor::*;
|
||||
#![allow(mixed_script_confusables)]
|
||||
#![allow(non_snake_case)]
|
||||
use bytemuck::cast_slice;
|
||||
use manifold::contract;
|
||||
use manifold::*;
|
||||
|
||||
fn tensor_product() {
|
||||
println!("Tensor Product\n");
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
println!("Tensor Product\n");
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||
// Fill tensors with some values
|
||||
tensor1.buffer_mut().copy_from_slice(&[1, 2, 3, 4]);
|
||||
tensor2.buffer_mut().copy_from_slice(&[5, 6]);
|
||||
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
|
||||
println!("T1 * T2 = {}", product);
|
||||
println!("T1 * T2 = {}", product);
|
||||
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), &Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||
assert_eq!(product.buffer(), &expected_buffer);
|
||||
// Check buffer of the resulting tensor
|
||||
let expect: &[i32] =
|
||||
cast_slice(&[[[5, 6], [10, 12]], [[15, 18], [20, 24]]]);
|
||||
assert_eq!(product.buffer(), expect);
|
||||
}
|
||||
|
||||
fn tensor_contraction() {
|
||||
println!("Tensor Contraction\n");
|
||||
// Create two tensors
|
||||
let tensor1 = Tensor::from([[1, 2], [3, 4]]); // 2x2 tensor
|
||||
let tensor2 = Tensor::from([[5, 6], [7, 8]]); // 2x2 tensor
|
||||
fn test_tensor_contraction_23x32() {
|
||||
// Define two 2D tensors (matrices)
|
||||
|
||||
// Specify axes for contraction
|
||||
let axis_lhs = [1]; // Contract over the second dimension of tensor1
|
||||
let axis_rhs = [0]; // Contract over the first dimension of tensor2
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
println!("a: {:?}\n{}\n", a.shape(), a);
|
||||
|
||||
// Perform contraction
|
||||
let result = tensor1.contract(&tensor2, axis_lhs, axis_rhs);
|
||||
// Tensor B is 3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||
println!("b: {:?}\n{}\n", b.shape(), b);
|
||||
|
||||
println!("T1: {}", tensor1);
|
||||
println!("T2: {}", tensor2);
|
||||
println!("T1 * T2 = {}", result);
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let ctr10 = contract((&a, [1]), (&b, [0]));
|
||||
|
||||
// Expected result, for example, could be a single number or a new tensor,
|
||||
// depending on how you defined the contraction operation.
|
||||
// Assert the result is as expected
|
||||
// assert_eq!(result, expected_result);
|
||||
println!("[1, 0]: {:?}\n{}\n", ctr10.shape(), ctr10);
|
||||
|
||||
let ctr01 = contract((&a, [0]), (&b, [1]));
|
||||
|
||||
println!("[0, 1]: {:?}\n{}\n", ctr01.shape(), ctr01);
|
||||
// assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||
// assert_eq!(
|
||||
// contracted_tensor.buffer(),
|
||||
// &[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||
// "Contracted tensor buffer does not match expected"
|
||||
// );
|
||||
}
|
||||
|
||||
fn test_tensor_contraction_rank3() {
|
||||
let a = tensor!([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
let b = tensor!([[[9, 10], [11, 12]], [[13, 14], [15, 16]]]);
|
||||
let contracted_tensor = contract((&a, [2]), (&b, [0]));
|
||||
|
||||
println!("a: {}", a);
|
||||
println!("b: {}", b);
|
||||
println!("contracted_tensor: {}", contracted_tensor);
|
||||
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
|
||||
// Verify specific elements of contracted_tensor
|
||||
// assert_eq!(contracted_tensor[0][0][0][0], 50);
|
||||
// assert_eq!(contracted_tensor[0][0][0][1], 60);
|
||||
// ... further checks for other elements ...
|
||||
}
|
||||
|
||||
fn transpose() {
|
||||
let a = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
let b = tensor!([[1, 2, 3], [4, 5, 6]]);
|
||||
|
||||
// let iter = a.idx().iter_transposed([1, 0]);
|
||||
|
||||
// for idx in iter {
|
||||
// println!("{idx}");
|
||||
// }
|
||||
let b = a.clone().transpose([1, 0]).unwrap();
|
||||
println!("a: {}", a);
|
||||
println!("ta: {}", b);
|
||||
}
|
||||
|
||||
fn main() {
|
||||
tensor_product();
|
||||
tensor_contraction();
|
||||
}
|
||||
// tensor_product();
|
||||
// test_tensor_contraction_23x32();
|
||||
// test_tensor_contraction_rank3();
|
||||
|
||||
transpose();
|
||||
}
|
||||
|
389
src/axis.rs
Normal file
389
src/axis.rs
Normal file
@ -0,0 +1,389 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
|
||||
#[derive(Clone, Debug, Getters)]
|
||||
pub struct Axis<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
tensor: &'a Tensor<T, R>,
|
||||
#[getset(get = "pub")]
|
||||
dim: usize,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Axis<'a, T, R> {
|
||||
pub fn new(tensor: &'a Tensor<T, R>, dim: usize) -> Self {
|
||||
assert!(dim < R, "Axis out of bounds");
|
||||
Self { tensor, dim }
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.tensor.shape().get(self.dim)
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> &Shape<R> {
|
||||
self.tensor.shape()
|
||||
}
|
||||
|
||||
pub fn iter_level(&'a self, level: usize) -> AxisIterator<'a, T, R> {
|
||||
assert!(level < self.len(), "Level out of bounds");
|
||||
let mut index = Idx::new(self.shape(), [0; R]);
|
||||
index.set_axis(self.dim, level);
|
||||
AxisIterator::new(self).set_start(level).set_end(level + 1)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Getters, MutGetters)]
|
||||
pub struct AxisIterator<'a, T: Value, const R: usize> {
|
||||
#[getset(get = "pub")]
|
||||
axis: &'a Axis<'a, T, R>,
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
index: Idx<'a, R>,
|
||||
#[getset(get = "pub")]
|
||||
end: Option<usize>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> AxisIterator<'a, T, R> {
|
||||
pub fn new(axis: &'a Axis<'a, T, R>) -> Self {
|
||||
Self {
|
||||
axis,
|
||||
index: Idx::new(axis.shape(), [0; R]),
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_start(self, start: usize) -> Self {
|
||||
assert!(start < self.axis().len(), "Start out of bounds");
|
||||
let mut index = Idx::new(self.axis().shape(), [0; R]);
|
||||
index.set_axis(self.axis.dim, start);
|
||||
Self {
|
||||
axis: self.axis(),
|
||||
index,
|
||||
end: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_end(self, end: usize) -> Self {
|
||||
assert!(end <= self.axis().len(), "End out of bounds");
|
||||
Self {
|
||||
axis: self.axis(),
|
||||
index: self.index().clone(),
|
||||
end: Some(end),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn set_level(self, level: usize) -> Self {
|
||||
assert!(level < self.axis().len(), "Level out of bounds");
|
||||
self.set_start(level).set_end(level + 1)
|
||||
}
|
||||
|
||||
pub fn level(&'a self, level: usize) -> impl Iterator<Item = &'a T> + 'a {
|
||||
Self::new(self.axis()).set_level(level)
|
||||
}
|
||||
|
||||
pub fn axis_max_idx(&self) -> usize {
|
||||
self.end().unwrap_or(self.axis().len())
|
||||
}
|
||||
|
||||
pub fn axis_idx(&self) -> usize {
|
||||
self.index().get_axis(*self.axis().dim())
|
||||
}
|
||||
|
||||
pub fn axis_dim(&self) -> usize {
|
||||
self.axis().dim().clone()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> Iterator for AxisIterator<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.axis_idx() == self.axis_max_idx() {
|
||||
return None;
|
||||
}
|
||||
let result = unsafe { self.axis().tensor().get_unchecked(self.index) };
|
||||
let axis_dim = self.axis_dim();
|
||||
self.index_mut().inc_axis(axis_dim);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> IntoIterator for &'a Axis<'a, T, R> {
|
||||
type Item = &'a T;
|
||||
type IntoIter = AxisIterator<'a, T, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
AxisIterator::new(&self)
|
||||
}
|
||||
}
|
||||
|
||||
pub fn contract<
|
||||
'a,
|
||||
T: Value + std::fmt::Debug,
|
||||
const R: usize,
|
||||
const S: usize,
|
||||
const N: usize,
|
||||
>(
|
||||
lhs: (&'a Tensor<T, R>, [usize; N]),
|
||||
rhs: (&'a Tensor<T, S>, [usize; N]),
|
||||
) -> Tensor<T, { R + S - 2 * N }>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
[(); R + S - 2 * N]:,
|
||||
{
|
||||
let (lhs, la) = lhs;
|
||||
let (rhs, ra) = rhs;
|
||||
let lnc = (0..R).filter(|i| !la.contains(i)).collect::<Vec<_>>();
|
||||
let rnc = (0..S).filter(|i| !ra.contains(i)).collect::<Vec<_>>();
|
||||
|
||||
let lnc = lnc.into_iter().map(|i| lhs.axis(i)).collect::<Vec<_>>();
|
||||
let rnc = rnc.into_iter().map(|i| rhs.axis(i)).collect::<Vec<_>>();
|
||||
|
||||
let mut shape = Vec::new();
|
||||
shape.extend_from_slice(&rhs.shape().remove_dims::<{ N }>(ra).as_array());
|
||||
shape.extend_from_slice(&lhs.shape().remove_dims::<{ N }>(la).as_array());
|
||||
let shape: [usize; R + S - 2 * N] =
|
||||
shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
let shape = Shape::new(shape);
|
||||
|
||||
let result = contract_axes(&lnc, &rnc);
|
||||
|
||||
Tensor::new_with_buffer(shape, result)
|
||||
}
|
||||
|
||||
pub fn contract_axes<
|
||||
'a,
|
||||
T: Value + std::fmt::Debug,
|
||||
const R: usize,
|
||||
const S: usize,
|
||||
const N: usize,
|
||||
>(
|
||||
laxes: &'a [Axis<'a, T, R>],
|
||||
raxes: &'a [Axis<'a, T, S>],
|
||||
) -> Vec<T>
|
||||
where
|
||||
[(); R - N]:,
|
||||
[(); S - N]:,
|
||||
{
|
||||
let mut result = vec![];
|
||||
|
||||
let axes = laxes.into_iter().zip(raxes);
|
||||
|
||||
for (laxis, raxis) in axes {
|
||||
let mut axes_result: Vec<T> = vec![];
|
||||
for i in 0..raxis.len() {
|
||||
for j in 0..laxis.len() {
|
||||
let mut sum = T::zero();
|
||||
let llevel = laxis.into_iter();
|
||||
let rlevel = raxis.into_iter();
|
||||
let zip = llevel.level(j).zip(rlevel.level(i));
|
||||
for (lv, rv) in zip {
|
||||
sum = sum + *lv * *rv;
|
||||
}
|
||||
axes_result.push(sum);
|
||||
}
|
||||
}
|
||||
result.extend_from_slice(&axes_result);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_tensor_contraction_simple() {
|
||||
// Define two 2D tensors (matrices)
|
||||
// Tensor A is 2x3
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||
|
||||
// Tensor B is 1x3x2
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4]]);
|
||||
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||
assert_eq!(contracted_tensor.shape(), &Shape::new([2, 2]));
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[7, 10, 15, 22],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_contraction_23x32() {
|
||||
// Define two 2D tensors (matrices)
|
||||
|
||||
// Tensor A is 2x3
|
||||
let b: Tensor<i32, 2> = Tensor::from([[1, 2, 3], [4, 5, 6]]);
|
||||
println!("b: {}", b);
|
||||
|
||||
// Tensor B is 3x2
|
||||
let a: Tensor<i32, 2> = Tensor::from([[1, 2], [3, 4], [5, 6]]);
|
||||
println!("a: {}", a);
|
||||
|
||||
// Contract over the last axis of A (axis 1) and the first axis of B (axis 0)
|
||||
let contracted_tensor: Tensor<i32, 2> = contract((&a, [1]), (&b, [0]));
|
||||
|
||||
println!("contracted_tensor: {}", contracted_tensor);
|
||||
assert_eq!(contracted_tensor.shape(), &Shape::new([3, 3]));
|
||||
assert_eq!(
|
||||
contracted_tensor.buffer(),
|
||||
&[9, 12, 15, 19, 26, 33, 29, 40, 51],
|
||||
"Contracted tensor buffer does not match expected"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tensor_contraction_rank3() {
|
||||
let a: Tensor<i32, 3> =
|
||||
Tensor::new_with_buffer(Shape::from([2, 3, 4]), (1..25).collect()); // Fill with elements 1 to 24
|
||||
let b: Tensor<i32, 3> =
|
||||
Tensor::new_with_buffer(Shape::from([4, 3, 2]), (1..25).collect()); // Fill with elements 1 to 24
|
||||
let contracted_tensor: Tensor<i32, 4> = contract((&a, [2]), (&b, [0]));
|
||||
|
||||
println!("a: {}", a);
|
||||
println!("b: {}", b);
|
||||
println!("contracted_tensor: {}", contracted_tensor);
|
||||
// assert_eq!(contracted_tensor.shape(), &[2, 4, 3, 2]);
|
||||
// Verify specific elements of contracted_tensor
|
||||
// assert_eq!(contracted_tensor[0][0][0][0], 50);
|
||||
// assert_eq!(contracted_tensor[0][0][0][1], 60);
|
||||
// ... further checks for other elements ...
|
||||
}
|
||||
|
||||
// #[test]
|
||||
// fn test_axis_iterator_disassemble() {
|
||||
// // Creating a 2x2 Tensor for testing
|
||||
// let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// // Testing iteration over the first axis (axis = 0)
|
||||
// let axis = Axis::new(&tensor, 0);
|
||||
|
||||
// let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
// assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter[0].next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter[0].next(), None);
|
||||
// assert_eq!(axis_iter[1].next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
// assert_eq!(axis_iter[1].next(), None);
|
||||
|
||||
// // Resetting the iterator for the second axis (axis = 1)
|
||||
// let axis = Axis::new(&tensor, 1);
|
||||
|
||||
// let mut axis_iter = axis.into_iter().disassemble();
|
||||
|
||||
// assert_eq!(axis_iter[0].next(), Some(&1.0));
|
||||
// assert_eq!(axis_iter[0].next(), Some(&3.0));
|
||||
// assert_eq!(axis_iter[0].next(), None);
|
||||
// assert_eq!(axis_iter[1].next(), Some(&2.0));
|
||||
// assert_eq!(axis_iter[1].next(), Some(&4.0));
|
||||
// assert_eq!(axis_iter[1].next(), None);
|
||||
// }
|
||||
|
||||
#[test]
|
||||
fn test_axis_iterator() {
|
||||
// Creating a 2x2 Tensor for testing
|
||||
let tensor = Tensor::from([[1.0, 2.0], [3.0, 4.0]]);
|
||||
|
||||
// Testing iteration over the first axis (axis = 0)
|
||||
let axis = Axis::new(&tensor, 0);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
// Resetting the iterator for the second axis (axis = 1)
|
||||
let axis = Axis::new(&tensor, 1);
|
||||
|
||||
let mut axis_iter = axis.into_iter();
|
||||
|
||||
assert_eq!(axis_iter.next(), Some(&1.0));
|
||||
assert_eq!(axis_iter.next(), Some(&3.0));
|
||||
assert_eq!(axis_iter.next(), Some(&2.0));
|
||||
assert_eq!(axis_iter.next(), Some(&4.0));
|
||||
|
||||
let shape = tensor.shape();
|
||||
|
||||
let mut a: Idx<2> = (shape, [0, 0]).into();
|
||||
let b: Idx<2> = (shape, [1, 1]).into();
|
||||
|
||||
while a <= b {
|
||||
println!("a: {}", a);
|
||||
a.inc();
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_3d_tensor_axis_iteration() {
|
||||
// Create a 3D Tensor with specific values
|
||||
// Tensor shape is 2x2x2 for simplicity
|
||||
let t = Tensor::from([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]);
|
||||
|
||||
// Axis 0 (Layer-wise):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 3, 4, 5, 6, 7, 8]
|
||||
//
|
||||
// This order suggests that for each "layer" (first level of arrays),
|
||||
// the iterator goes through all rows and columns. It first completes
|
||||
// the entire first layer, then moves to the second.
|
||||
|
||||
let a0 = Axis::new(&t, 0);
|
||||
let a0_order = a0.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a0_order, [1, 2, 3, 4, 5, 6, 7, 8]);
|
||||
|
||||
// Axis 1 (Row-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][0][1] = 2
|
||||
// t[1][0][0] = 5
|
||||
// t[1][0][1] = 6
|
||||
// t[0][1][0] = 3
|
||||
// t[0][1][1] = 4
|
||||
// t[1][1][0] = 7
|
||||
// t[1][1][1] = 8
|
||||
// [1, 2, 5, 6, 3, 4, 7, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first row across all layers, then the second row
|
||||
// across all layers.
|
||||
|
||||
let a1 = Axis::new(&t, 1);
|
||||
let a1_order = a1.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a1_order, [1, 2, 5, 6, 3, 4, 7, 8]);
|
||||
|
||||
// Axis 2 (Column-wise within each layer):
|
||||
//
|
||||
// t[0][0][0] = 1
|
||||
// t[0][1][0] = 3
|
||||
// t[1][0][0] = 5
|
||||
// t[1][1][0] = 7
|
||||
// t[0][0][1] = 2
|
||||
// t[0][1][1] = 4
|
||||
// t[1][0][1] = 6
|
||||
// t[1][1][1] = 8
|
||||
// [1, 3, 5, 7, 2, 4, 6, 8]
|
||||
//
|
||||
// This indicates that within each "layer", the iterator first
|
||||
// completes the first column across all layers, then the second
|
||||
// column across all layers.
|
||||
|
||||
let a2 = Axis::new(&t, 2);
|
||||
let a2_order = a2.into_iter().cloned().collect::<Vec<_>>();
|
||||
assert_eq!(a2_order, [1, 3, 5, 7, 2, 4, 6, 8]);
|
||||
}
|
||||
}
|
9
src/error.rs
Normal file
9
src/error.rs
Normal file
@ -0,0 +1,9 @@
|
||||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Invalid argument: {0}")]
|
||||
InvalidArgument(String),
|
||||
}
|
281
src/index.rs
281
src/index.rs
@ -1,35 +1,41 @@
|
||||
use super::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::cmp::Ordering;
|
||||
use std::ops::{Add, Sub};
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Idx<const R: usize> {
|
||||
#[derive(Clone, Copy, Debug, Getters, MutGetters)]
|
||||
pub struct Idx<'a, const R: usize> {
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
indices: [usize; R],
|
||||
shape: Shape<R>,
|
||||
#[getset(get = "pub")]
|
||||
shape: &'a Shape<R>,
|
||||
}
|
||||
|
||||
impl<const R: usize> Idx<R> {
|
||||
pub const fn zero(shape: Shape<R>) -> Self {
|
||||
impl<'a, const R: usize> Idx<'a, R> {
|
||||
pub const fn zero(shape: &'a Shape<R>) -> Self {
|
||||
Self {
|
||||
indices: [0; R],
|
||||
shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max(shape: Shape<R>) -> Self {
|
||||
pub fn last(shape: &'a Shape<R>) -> Self {
|
||||
let max_indices =
|
||||
shape.as_array().map(|dim_size| dim_size.saturating_sub(1));
|
||||
Self {
|
||||
indices: max_indices,
|
||||
shape,
|
||||
shape: shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn new(shape: Shape<R>, indices: [usize; R]) -> Self {
|
||||
pub fn new(shape: &'a Shape<R>, indices: [usize; R]) -> Self {
|
||||
if !shape.check_indices(indices) {
|
||||
panic!("indices out of bounds");
|
||||
}
|
||||
Self { indices, shape }
|
||||
Self {
|
||||
indices,
|
||||
shape: shape,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_zero(&self) -> bool {
|
||||
@ -51,6 +57,9 @@ impl<const R: usize> Idx<R> {
|
||||
/// `true` if the increment does not overflow and is still within bounds;
|
||||
/// `false` if it overflows, indicating the end of the tensor.
|
||||
pub fn inc(&mut self) -> bool {
|
||||
if self.indices()[0] >= self.shape().get(0) {
|
||||
return false;
|
||||
}
|
||||
let mut carry = 1;
|
||||
for (i, &dim_size) in
|
||||
self.indices.iter_mut().zip(&self.shape.as_array()).rev()
|
||||
@ -67,13 +76,74 @@ impl<const R: usize> Idx<R> {
|
||||
|
||||
// If carry is still 1 after the loop, it means we've incremented past the last dimension
|
||||
if carry == 1 {
|
||||
// Set the index to an invalid state (e.g., all indices to their max values)
|
||||
// Set the index to an invalid state to indicate the end of the iteration indicated
|
||||
// by setting the first index to the size of the first dimension
|
||||
self.indices[0] = self.shape.as_array()[0];
|
||||
return true; // Indicate that the iteration is complete
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// fn inc_axis
|
||||
|
||||
pub fn inc_axis(&mut self, fixed_axis: usize) {
|
||||
assert!(fixed_axis < R, "Axis out of bounds");
|
||||
assert!(
|
||||
self.indices()[fixed_axis] < self.shape().get(fixed_axis),
|
||||
"Index out of bounds"
|
||||
);
|
||||
|
||||
// Try to increment non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] + 1 < self.shape.get(i) {
|
||||
self.indices[i] += 1;
|
||||
return;
|
||||
} else {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if self.indices[fixed_axis] < self.shape.get(fixed_axis) {
|
||||
self.indices[fixed_axis] += 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = 0;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
pub fn inc_transposed(&mut self, order: &[usize; R]) -> bool {
|
||||
if self.indices()[order[0]] >= self.shape().get(order[0]) {
|
||||
return false;
|
||||
}
|
||||
|
||||
let mut carry = 1;
|
||||
|
||||
for i in order.iter().rev() {
|
||||
let dim_size = self.shape().get(*i);
|
||||
let i = self.index_mut(*i);
|
||||
if carry == 1 {
|
||||
*i += 1;
|
||||
if *i >= dim_size {
|
||||
*i = 0;
|
||||
} else {
|
||||
carry = 0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if carry == 1 {
|
||||
self.indices_mut()[order[0]] = self.shape().get(order[0]);
|
||||
return true;
|
||||
}
|
||||
|
||||
false
|
||||
}
|
||||
|
||||
pub fn dec(&mut self) {
|
||||
// Check if already at the start
|
||||
if self.indices.iter().all(|&i| i == 0) {
|
||||
@ -95,6 +165,61 @@ impl<const R: usize> Idx<R> {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn dec_axis(&mut self, fixed_axis: usize) -> bool {
|
||||
// Check if the fixed axis index is already in an invalid state
|
||||
if self.indices[fixed_axis] == self.shape.get(fixed_axis) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to decrement non-fixed axes
|
||||
for i in (0..R).rev() {
|
||||
if i != fixed_axis {
|
||||
if self.indices[i] > 0 {
|
||||
self.indices[i] -= 1;
|
||||
return true;
|
||||
} else {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Decrement the fixed axis if possible and reset other axes to their max
|
||||
if self.indices[fixed_axis] > 0 {
|
||||
self.indices[fixed_axis] -= 1;
|
||||
for i in 0..R {
|
||||
if i != fixed_axis {
|
||||
self.indices[i] = self.shape.get(i) - 1;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fixed axis already at minimum, set to invalid state
|
||||
self.indices[fixed_axis] = self.shape.get(fixed_axis);
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
pub fn dec_transposed(&mut self, order: [usize; R]) {
|
||||
// Iterate over the axes in the specified order
|
||||
for &axis in &order {
|
||||
// Try to decrement the current axis
|
||||
if self.indices[axis] > 0 {
|
||||
self.indices[axis] -= 1;
|
||||
// Reset all preceding axes in the order to their maximum
|
||||
for &prev_axis in &order {
|
||||
if prev_axis == axis {
|
||||
break;
|
||||
}
|
||||
self.indices[prev_axis] = self.shape.get(prev_axis) - 1;
|
||||
}
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
// If no axis can be decremented, set the first axis in the order to indicate overflow
|
||||
self.indices[order[0]] = self.shape.get(order[0]);
|
||||
}
|
||||
|
||||
/// Converts the multi-dimensional index to a flat index.
|
||||
///
|
||||
/// This method calculates the flat index corresponding to the multi-dimensional index
|
||||
@ -136,31 +261,59 @@ impl<const R: usize> Idx<R> {
|
||||
})
|
||||
.0
|
||||
}
|
||||
|
||||
pub fn set_axis(&mut self, axis: usize, value: usize) {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
// assert!(value < self.shape.get(axis), "Value out of bounds");
|
||||
self.indices[axis] = value;
|
||||
}
|
||||
|
||||
pub fn try_set_axis(&mut self, axis: usize, value: usize) -> bool {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
if value < self.shape.get(axis) {
|
||||
self.indices[axis] = value;
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
pub fn get_axis(&self, axis: usize) -> usize {
|
||||
assert!(axis < R, "Axis out of bounds");
|
||||
self.indices[axis]
|
||||
}
|
||||
|
||||
pub fn iter_transposed(
|
||||
&self,
|
||||
order: [usize; R],
|
||||
) -> IdxTransposedIterator<'a, R> {
|
||||
IdxTransposedIterator::new(self.shape(), order)
|
||||
}
|
||||
}
|
||||
|
||||
// --- blanket impls ---
|
||||
|
||||
impl<const R: usize> PartialEq for Idx<R> {
|
||||
impl<'a, const R: usize> PartialEq for Idx<'a, R> {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.flat() == other.flat()
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Eq for Idx<R> {}
|
||||
impl<'a, const R: usize> Eq for Idx<'a, R> {}
|
||||
|
||||
impl<const R: usize> PartialOrd for Idx<R> {
|
||||
impl<'a, const R: usize> PartialOrd for Idx<'a, R> {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
self.flat().partial_cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Ord for Idx<R> {
|
||||
impl<'a, const R: usize> Ord for Idx<'a, R> {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.flat().cmp(&other.flat())
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Index<usize> for Idx<R> {
|
||||
impl<'a, const R: usize> Index<usize> for Idx<'a, R> {
|
||||
type Output = usize;
|
||||
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
@ -168,33 +321,39 @@ impl<const R: usize> Index<usize> for Idx<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> IndexMut<usize> for Idx<R> {
|
||||
impl<'a, const R: usize> IndexMut<usize> for Idx<'a, R> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.indices[index]
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(Shape<R>, [usize; R])> for Idx<R> {
|
||||
fn from((shape, indices): (Shape<R>, [usize; R])) -> Self {
|
||||
impl<'a, const R: usize> From<(&'a Shape<R>, [usize; R])> for Idx<'a, R> {
|
||||
fn from((shape, indices): (&'a Shape<R>, [usize; R])) -> Self {
|
||||
assert!(shape.check_indices(indices));
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<(Shape<R>, usize)> for Idx<R> {
|
||||
fn from((shape, flat_index): (Shape<R>, usize)) -> Self {
|
||||
impl<'a, const R: usize> From<(&'a Shape<R>, usize)> for Idx<'a, R> {
|
||||
fn from((shape, flat_index): (&'a Shape<R>, usize)) -> Self {
|
||||
let indices = shape.index_from_flat(flat_index).indices;
|
||||
Self::new(shape, indices)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> From<Shape<R>> for Idx<R> {
|
||||
fn from(shape: Shape<R>) -> Self {
|
||||
impl<'a, const R: usize> From<&'a Shape<R>> for Idx<'a, R> {
|
||||
fn from(shape: &'a Shape<R>) -> Self {
|
||||
Self::zero(shape)
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||
impl<'a, T: Value, const R: usize> From<&'a Tensor<T, R>> for Idx<'a, R> {
|
||||
fn from(tensor: &'a Tensor<T, R>) -> Self {
|
||||
Self::zero(tensor.shape())
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> std::fmt::Display for Idx<'a, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "[")?;
|
||||
for (i, (&idx, &dim_size)) in self
|
||||
@ -214,7 +373,7 @@ impl<const R: usize> std::fmt::Display for Idx<R> {
|
||||
|
||||
// ---- Arithmetic Operations ----
|
||||
|
||||
impl<const R: usize> Add for Idx<R> {
|
||||
impl<'a, const R: usize> Add for Idx<'a, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn add(self, rhs: Self) -> Self::Output {
|
||||
@ -232,7 +391,7 @@ impl<const R: usize> Add for Idx<R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<const R: usize> Sub for Idx<R> {
|
||||
impl<'a, const R: usize> Sub for Idx<'a, R> {
|
||||
type Output = Self;
|
||||
|
||||
fn sub(self, rhs: Self) -> Self::Output {
|
||||
@ -249,3 +408,75 @@ impl<const R: usize> Sub for Idx<R> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Iterator ----
|
||||
|
||||
pub struct IdxIterator<'a, const R: usize> {
|
||||
current: Idx<'a, R>,
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IdxIterator<'a, R> {
|
||||
pub fn new(shape: &'a Shape<R>) -> Self {
|
||||
Self {
|
||||
current: Idx::zero(shape),
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Iterator for IdxIterator<'a, R> {
|
||||
type Item = Idx<'a, R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc();
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IntoIterator for Idx<'a, R> {
|
||||
type Item = Idx<'a, R>;
|
||||
type IntoIter = IdxIterator<'a, R>;
|
||||
|
||||
fn into_iter(self) -> Self::IntoIter {
|
||||
IdxIterator {
|
||||
current: self,
|
||||
end: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct IdxTransposedIterator<'a, const R: usize> {
|
||||
current: Idx<'a, R>,
|
||||
order: [usize; R],
|
||||
end: bool,
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> IdxTransposedIterator<'a, R> {
|
||||
pub fn new(shape: &'a Shape<R>, order: [usize; R]) -> Self {
|
||||
Self {
|
||||
current: Idx::zero(shape),
|
||||
end: false,
|
||||
order,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, const R: usize> Iterator for IdxTransposedIterator<'a, R> {
|
||||
type Item = Idx<'a, R>;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
if self.end {
|
||||
return None;
|
||||
}
|
||||
|
||||
let result = self.current;
|
||||
self.end = self.current.inc_transposed(&self.order);
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
26
src/lib.rs
26
src/lib.rs
@ -1,15 +1,21 @@
|
||||
#![allow(incomplete_features)]
|
||||
#![feature(generic_const_exprs)]
|
||||
pub mod axis;
|
||||
pub mod error;
|
||||
pub mod index;
|
||||
pub mod shape;
|
||||
pub mod tensor;
|
||||
|
||||
pub use axis::*;
|
||||
pub use index::Idx;
|
||||
pub use itertools::Itertools;
|
||||
use num::{Num, One, Zero};
|
||||
pub use serde::{Deserialize, Serialize};
|
||||
pub use shape::Shape;
|
||||
pub use static_assertions::const_assert;
|
||||
pub use std::fmt::{Display, Formatter, Result as FmtResult};
|
||||
use std::ops::{Index, IndexMut};
|
||||
pub use std::sync::Arc;
|
||||
pub use tensor::{Tensor, TensorIterator};
|
||||
|
||||
pub trait Value:
|
||||
@ -26,9 +32,17 @@ impl<T> Value for T where
|
||||
+ Display
|
||||
+ Serialize
|
||||
+ Deserialize<'static>
|
||||
+ std::iter::Sum
|
||||
{
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! tensor {
|
||||
($array:expr) => {
|
||||
Tensor::from($array)
|
||||
};
|
||||
}
|
||||
|
||||
// ---- Tests ----
|
||||
|
||||
#[cfg(test)]
|
||||
@ -38,7 +52,7 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_tensor_product() {
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([2, 2]); // 2x2 tensor
|
||||
let mut tensor1 = Tensor::<i32, 2>::from([[2], [2]]); // 2x2 tensor
|
||||
let mut tensor2 = Tensor::<i32, 1>::from([2]); // 2-element vector
|
||||
|
||||
// Fill tensors with some values
|
||||
@ -48,7 +62,7 @@ mod tests {
|
||||
let product = tensor1.tensor_product(&tensor2);
|
||||
|
||||
// Check shape of the resulting tensor
|
||||
assert_eq!(product.shape(), Shape::new([2, 2, 2]));
|
||||
assert_eq!(*product.shape(), Shape::new([2, 2, 2]));
|
||||
|
||||
// Check buffer of the resulting tensor
|
||||
let expected_buffer = [5, 6, 10, 12, 15, 18, 20, 24];
|
||||
@ -153,7 +167,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_dec_method() {
|
||||
let shape = Shape::new([3, 3, 3]); // Example shape for a 3x3x3 tensor
|
||||
let mut index = Idx::zero(shape);
|
||||
let mut index = Idx::zero(&shape);
|
||||
|
||||
// Increment the index to the maximum
|
||||
for _ in 0..26 {
|
||||
@ -162,7 +176,7 @@ mod tests {
|
||||
}
|
||||
|
||||
// Check if the index is at the maximum
|
||||
assert_eq!(index, Idx::new(shape, [2, 2, 2]));
|
||||
assert_eq!(index, Idx::new(&shape, [2, 2, 2]));
|
||||
|
||||
// Decrement step by step and check the index
|
||||
let expected_indices = [
|
||||
@ -198,7 +212,7 @@ mod tests {
|
||||
for (i, &expected) in expected_indices.iter().enumerate() {
|
||||
assert_eq!(
|
||||
index,
|
||||
Idx::new(shape, expected),
|
||||
Idx::new(&shape, expected),
|
||||
"Failed at index {}",
|
||||
i
|
||||
);
|
||||
@ -207,6 +221,6 @@ mod tests {
|
||||
|
||||
// Finally, the index should reach [0, 0, 0]
|
||||
index.dec();
|
||||
assert_eq!(index, Idx::zero(shape));
|
||||
assert_eq!(index, Idx::zero(&shape));
|
||||
}
|
||||
}
|
||||
|
45
src/shape.rs
45
src/shape.rs
@ -11,6 +11,18 @@ impl<const R: usize> Shape<R> {
|
||||
Self(shape)
|
||||
}
|
||||
|
||||
pub fn axis(&self, index: usize) -> Option<&usize> {
|
||||
self.0.get(index)
|
||||
}
|
||||
|
||||
pub fn reorder(&self, indices: [usize; R]) -> Self {
|
||||
let mut new_shape = Shape::new([0; R]);
|
||||
for (new_index, &index) in indices.iter().enumerate() {
|
||||
new_shape.0[new_index] = self.0[index];
|
||||
}
|
||||
new_shape
|
||||
}
|
||||
|
||||
pub const fn as_array(&self) -> [usize; R] {
|
||||
self.0
|
||||
}
|
||||
@ -65,18 +77,18 @@ impl<const R: usize> Shape<R> {
|
||||
}
|
||||
|
||||
indices.reverse(); // Reverse the indices to match the original dimension order
|
||||
Idx::new(*self, indices)
|
||||
Idx::new(self, indices)
|
||||
}
|
||||
|
||||
pub const fn index_zero(&self) -> Idx<R> {
|
||||
Idx::zero(*self)
|
||||
Idx::zero(self)
|
||||
}
|
||||
|
||||
pub fn index_max(&self) -> Idx<R> {
|
||||
let max_indices =
|
||||
self.0
|
||||
.map(|dim_size| if dim_size > 0 { dim_size - 1 } else { 0 });
|
||||
Idx::new(*self, max_indices)
|
||||
Idx::new(self, max_indices)
|
||||
}
|
||||
|
||||
pub fn remove_dims<const NAX: usize>(
|
||||
@ -101,6 +113,31 @@ impl<const R: usize> Shape<R> {
|
||||
|
||||
Shape(new_shape)
|
||||
}
|
||||
|
||||
pub fn remove_axes<'a, T: Value, const NAX: usize>(
|
||||
&self,
|
||||
axes_to_remove: &'a [Axis<'a, T, R>; NAX],
|
||||
) -> Shape<{ R - NAX }> {
|
||||
// Create a new array to store the remaining dimensions
|
||||
let mut new_shape = [0; R - NAX];
|
||||
let mut new_index = 0;
|
||||
|
||||
// Iterate over the original dimensions
|
||||
for (index, &dim) in self.0.iter().enumerate() {
|
||||
// Skip dimensions that are in the axes_to_remove array
|
||||
for axis in axes_to_remove {
|
||||
if *axis.dim() == index {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Add the dimension to the new shape array
|
||||
new_shape[new_index] = dim;
|
||||
new_index += 1;
|
||||
}
|
||||
|
||||
Shape(new_shape)
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Serialize and Deserialize ----
|
||||
@ -173,6 +210,6 @@ where
|
||||
T: Value,
|
||||
{
|
||||
fn from(tensor: Tensor<T, R>) -> Self {
|
||||
tensor.shape()
|
||||
*tensor.shape()
|
||||
}
|
||||
}
|
||||
|
395
src/tensor.rs
395
src/tensor.rs
@ -1,47 +1,72 @@
|
||||
use super::*;
|
||||
use crate::error::*;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::fmt;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, Getters, MutGetters)]
|
||||
pub struct Tensor<T, const R: usize> {
|
||||
#[getset(get = "pub", get_mut = "pub")]
|
||||
buffer: Vec<T>,
|
||||
#[getset(get = "pub")]
|
||||
shape: Shape<R>,
|
||||
}
|
||||
|
||||
impl<T: Value> Tensor<T, 1> {
|
||||
pub fn dot(&self, other: &Tensor<T, 1>) -> T {
|
||||
if self.shape != other.shape {
|
||||
panic!("Shapes of tensors do not match");
|
||||
}
|
||||
|
||||
let mut result = T::zero();
|
||||
for (a, b) in self.buffer.iter().zip(other.buffer.iter()) {
|
||||
result = result + (*a * *b);
|
||||
}
|
||||
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
pub fn new(shape: Shape<R>) -> Self {
|
||||
let total_size: usize = shape.iter().product();
|
||||
// Handle rank 0 tensor (scalar) as a special case
|
||||
let total_size = if R == 0 {
|
||||
// A rank 0 tensor should still have a buffer with one element
|
||||
1
|
||||
} else {
|
||||
// For tensors of rank 1 or higher, calculate the total size normally
|
||||
shape.iter().product()
|
||||
};
|
||||
|
||||
let buffer = vec![T::zero(); total_size];
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn shape(&self) -> Shape<R> {
|
||||
self.shape
|
||||
pub fn new_with_buffer(shape: Shape<R>, buffer: Vec<T>) -> Self {
|
||||
Self { buffer, shape }
|
||||
}
|
||||
|
||||
pub fn buffer(&self) -> &[T] {
|
||||
&self.buffer
|
||||
pub fn reshape(self, shape: Shape<R>) -> Result<Self> {
|
||||
if self.shape().size() != shape.size() {
|
||||
let (ls, rs) = (self.shape().as_array(), shape.as_array());
|
||||
let (lsize, rsize) = (self.shape().size(), shape.size());
|
||||
Err(Error::InvalidArgument(format!(
|
||||
"Shape size mismatch: ( {ls:?} = {lsize} ) != ( {rs:?} = {rsize} )",
|
||||
)))
|
||||
} else {
|
||||
Ok(Self {
|
||||
buffer: self.buffer,
|
||||
shape,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn buffer_mut(&mut self) -> &mut [T] {
|
||||
&mut self.buffer
|
||||
pub fn transpose(self, order: [usize; R]) -> Result<Self> {
|
||||
let buffer = Idx::from(self.shape())
|
||||
.iter_transposed(order)
|
||||
.map(|index| self.get(index).unwrap().clone())
|
||||
.collect();
|
||||
|
||||
Ok(Tensor {
|
||||
buffer,
|
||||
shape: self.shape().reorder(order),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn get(&self, index: Idx<R>) -> &T {
|
||||
&self.buffer[index.flat()]
|
||||
pub fn idx(&self) -> Idx<R> {
|
||||
Idx::from(self)
|
||||
}
|
||||
|
||||
pub fn axis<'a>(&'a self, axis: usize) -> Axis<'a, T, R> {
|
||||
Axis::new(self, axis)
|
||||
}
|
||||
|
||||
pub fn get(&self, index: Idx<R>) -> Option<&T> {
|
||||
self.buffer.get(index.flat())
|
||||
}
|
||||
|
||||
pub unsafe fn get_unchecked(&self, index: Idx<R>) -> &T {
|
||||
@ -56,6 +81,22 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
self.buffer.get_unchecked_mut(index.flat())
|
||||
}
|
||||
|
||||
pub fn get_flat(&self, index: usize) -> Option<&T> {
|
||||
self.buffer.get(index)
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked(&self, index: usize) -> &T {
|
||||
self.buffer.get_unchecked(index)
|
||||
}
|
||||
|
||||
pub fn get_flat_mut(&mut self, index: usize) -> Option<&mut T> {
|
||||
self.buffer.get_mut(index)
|
||||
}
|
||||
|
||||
pub unsafe fn get_flat_unchecked_mut(&mut self, index: usize) -> &mut T {
|
||||
self.buffer.get_unchecked_mut(index)
|
||||
}
|
||||
|
||||
pub fn rank(&self) -> usize {
|
||||
R
|
||||
}
|
||||
@ -111,93 +152,13 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
// `self_dims` and `other_dims` specify the dimensions to contract over
|
||||
pub fn contract<const S: usize, const NAXL: usize, const NAXR: usize>(
|
||||
&self,
|
||||
rhs: &Tensor<T, S>,
|
||||
axis_lhs: [usize; NAXL],
|
||||
axis_rhs: [usize; NAXR],
|
||||
) -> Tensor<T, { R + S - NAXL - NAXR }>
|
||||
where
|
||||
[(); R - NAXL]:,
|
||||
[(); S - NAXR]:,
|
||||
{
|
||||
// Step 1: Validate the axes for both tensors to ensure they are within bounds
|
||||
for &axis in &axis_lhs {
|
||||
if axis >= R {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the left-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
for &axis in &axis_rhs {
|
||||
if axis >= S {
|
||||
panic!(
|
||||
"Axis {} is out of bounds for the right-hand tensor",
|
||||
axis
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Step 2: Iterate over the tensors, multiplying and summing elements across contracted dimensions
|
||||
let mut result_buffer = Vec::new();
|
||||
|
||||
for i in 0..self.shape.size() {
|
||||
for j in 0..rhs.shape.size() {
|
||||
// Debug: Print indices being processed
|
||||
println!("Processing Indices: lhs = {}, rhs = {}", i, j);
|
||||
|
||||
if !axis_lhs.contains(&i) && !axis_rhs.contains(&j) {
|
||||
let mut product_sum = T::zero();
|
||||
|
||||
// Debug: Print axes of contraction
|
||||
println!("Contracting Axes: lhs = {:?}, rhs = {:?}", axis_lhs, axis_rhs);
|
||||
|
||||
for (&axis_l, &axis_r) in axis_lhs.iter().zip(axis_rhs.iter()) {
|
||||
// Debug: Print values being multiplied
|
||||
let value_lhs = self.get_by_axis(axis_l, i).unwrap();
|
||||
let value_rhs = rhs.get_by_axis(axis_r, j).unwrap();
|
||||
println!("Multiplying: lhs_value = {}, rhs_value = {}", value_lhs, value_rhs);
|
||||
|
||||
product_sum = product_sum + value_lhs * value_rhs;
|
||||
}
|
||||
|
||||
// Debug: Print the product sum for the current indices
|
||||
println!("Product Sum for indices (lhs = {}, rhs = {}) = {}", i, j, product_sum);
|
||||
|
||||
result_buffer.push(product_sum);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Step 3: Remove contracted dimensions to create new shapes for both tensors
|
||||
let new_shape_lhs = self.shape.remove_dims::<{ NAXL }>(axis_lhs);
|
||||
let new_shape_rhs = rhs.shape.remove_dims::<{ NAXR }>(axis_rhs);
|
||||
|
||||
// Step 4: Concatenate the shapes to form the shape of the resultant tensor
|
||||
let mut new_shape = Vec::new();
|
||||
|
||||
new_shape.extend_from_slice(&new_shape_lhs.as_array());
|
||||
new_shape.extend_from_slice(&new_shape_rhs.as_array());
|
||||
|
||||
let new_shape_array: [usize; R + S - NAXL - NAXR] =
|
||||
new_shape.try_into().expect("Failed to create shape array");
|
||||
|
||||
Tensor {
|
||||
buffer: result_buffer,
|
||||
shape: Shape::new(new_shape_array),
|
||||
}
|
||||
}
|
||||
|
||||
// Retrieve an element based on a specific axis and index
|
||||
pub fn get_by_axis(&self, axis: usize, index: usize) -> Option<T> {
|
||||
// Convert axis and index to a flat index
|
||||
let flat_index = self.axis_to_flat_index(axis, index);
|
||||
if flat_index >= self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
if flat_index >= self.buffer.len() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(self.buffer[flat_index])
|
||||
}
|
||||
@ -214,7 +175,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
// Calculate the stride for each dimension and accumulate the flat index
|
||||
for (i, &dim_size) in self.shape.as_array().iter().enumerate().rev() {
|
||||
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
||||
println!("i: {}, dim_size: {}, stride: {}", i, dim_size, stride);
|
||||
if i > axis {
|
||||
stride *= dim_size;
|
||||
} else if i == axis {
|
||||
@ -229,7 +190,7 @@ impl<T: Value, const R: usize> Tensor<T, R> {
|
||||
|
||||
// ---- Indexing ----
|
||||
|
||||
impl<T: Value, const R: usize> Index<Idx<R>> for Tensor<T, R> {
|
||||
impl<'a, T: Value, const R: usize> Index<Idx<'a, R>> for Tensor<T, R> {
|
||||
type Output = T;
|
||||
|
||||
fn index(&self, index: Idx<R>) -> &Self::Output {
|
||||
@ -237,33 +198,61 @@ impl<T: Value, const R: usize> Index<Idx<R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> IndexMut<Idx<R>> for Tensor<T, R> {
|
||||
impl<'a, T: Value, const R: usize> IndexMut<Idx<'a, R>> for Tensor<T, R> {
|
||||
fn index_mut(&mut self, index: Idx<R>) -> &mut Self::Output {
|
||||
&mut self.buffer[index.flat()]
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
// Print the shape of the tensor
|
||||
write!(f, "Shape: [")?;
|
||||
for (i, dim_size) in self.shape.as_array().iter().enumerate() {
|
||||
write!(f, "{}", dim_size)?;
|
||||
if i < R - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
write!(f, "], Elements: [")?;
|
||||
impl<T: Value, const R: usize> Index<usize> for Tensor<T, R> {
|
||||
type Output = T;
|
||||
|
||||
// Print the elements in a flattened form
|
||||
for (i, elem) in self.buffer.iter().enumerate() {
|
||||
write!(f, "{}", elem)?;
|
||||
if i < self.buffer.len() - 1 {
|
||||
write!(f, ", ")?;
|
||||
}
|
||||
}
|
||||
fn index(&self, index: usize) -> &Self::Output {
|
||||
&self.buffer[index]
|
||||
}
|
||||
}
|
||||
|
||||
write!(f, "]")
|
||||
impl<T: Value, const R: usize> IndexMut<usize> for Tensor<T, R> {
|
||||
fn index_mut(&mut self, index: usize) -> &mut Self::Output {
|
||||
&mut self.buffer[index]
|
||||
}
|
||||
}
|
||||
|
||||
// ---- Display ----
|
||||
|
||||
impl<T, const R: usize> Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
{
|
||||
fn fmt_helper(
|
||||
buffer: &[T],
|
||||
shape: &[usize],
|
||||
f: &mut fmt::Formatter<'_>,
|
||||
level: usize,
|
||||
) -> fmt::Result {
|
||||
if shape.is_empty() {
|
||||
// Base case: print individual elements
|
||||
write!(f, "{}", buffer[0])
|
||||
} else {
|
||||
let sub_len = shape[1..].iter().product::<usize>();
|
||||
write!(f, "[")?;
|
||||
for (i, chunk) in buffer.chunks(sub_len).enumerate() {
|
||||
if i > 0 {
|
||||
write!(f, ",")?;
|
||||
}
|
||||
Tensor::<T, R>::fmt_helper(chunk, &shape[1..], f, level + 1)?;
|
||||
}
|
||||
write!(f, "]")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const R: usize> fmt::Display for Tensor<T, R>
|
||||
where
|
||||
T: fmt::Display + Clone,
|
||||
{
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
Tensor::<T, R>::fmt_helper(&self.buffer, &self.shape.as_array(), f, 1)
|
||||
}
|
||||
}
|
||||
|
||||
@ -271,7 +260,7 @@ impl<T: Value, const R: usize> std::fmt::Display for Tensor<T, R> {
|
||||
|
||||
pub struct TensorIterator<'a, T: Value, const R: usize> {
|
||||
tensor: &'a Tensor<T, R>,
|
||||
index: Idx<R>,
|
||||
index: Idx<'a, R>,
|
||||
}
|
||||
|
||||
impl<'a, T: Value, const R: usize> TensorIterator<'a, T, R> {
|
||||
@ -342,10 +331,26 @@ impl<T: Value, const R: usize> From<Shape<R>> for Tensor<T, R> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const R: usize> From<[usize; R]> for Tensor<T, R> {
|
||||
fn from(shape: [usize; R]) -> Self {
|
||||
let shape = Shape::new(shape);
|
||||
Self::new(shape.into())
|
||||
impl<T: Value> From<T> for Tensor<T, 0> {
|
||||
fn from(value: T) -> Self {
|
||||
let shape = Shape::new([]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
tensor.buffer_mut()[0] = value;
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize> From<[T; X]> for Tensor<T, 1> {
|
||||
fn from(array: [T; X]) -> Self {
|
||||
let shape = Shape::new([X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, &elem) in array.iter().enumerate() {
|
||||
buffer[i] = elem;
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
@ -366,3 +371,123 @@ impl<T: Value, const X: usize, const Y: usize> From<[[T; X]; Y]>
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Value, const X: usize, const Y: usize, const Z: usize>
|
||||
From<[[[T; X]; Y]; Z]> for Tensor<T, 3>
|
||||
{
|
||||
fn from(array: [[[T; X]; Y]; Z]) -> Self {
|
||||
let shape = Shape::new([Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, plane) in array.iter().enumerate() {
|
||||
for (j, row) in plane.iter().enumerate() {
|
||||
for (k, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y + j * X + k] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
> From<[[[[T; X]; Y]; Z]; W]> for Tensor<T, 4>
|
||||
{
|
||||
fn from(array: [[[[T; X]; Y]; Z]; W]) -> Self {
|
||||
let shape = Shape::new([W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperplane) in array.iter().enumerate() {
|
||||
for (j, plane) in hyperplane.iter().enumerate() {
|
||||
for (k, row) in plane.iter().enumerate() {
|
||||
for (l, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z + j * X * Y + k * X + l] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
const V: usize,
|
||||
> From<[[[[[T; X]; Y]; Z]; W]; V]> for Tensor<T, 5>
|
||||
{
|
||||
fn from(array: [[[[[T; X]; Y]; Z]; W]; V]) -> Self {
|
||||
let shape = Shape::new([V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (k, plane) in hyperplane.iter().enumerate() {
|
||||
for (l, row) in plane.iter().enumerate() {
|
||||
for (m, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W
|
||||
+ j * X * Y * Z
|
||||
+ k * X * Y
|
||||
+ l * X
|
||||
+ m] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
||||
impl<
|
||||
T: Value,
|
||||
const X: usize,
|
||||
const Y: usize,
|
||||
const Z: usize,
|
||||
const W: usize,
|
||||
const V: usize,
|
||||
const U: usize,
|
||||
> From<[[[[[[T; X]; Y]; Z]; W]; V]; U]> for Tensor<T, 6>
|
||||
{
|
||||
fn from(array: [[[[[[T; X]; Y]; Z]; W]; V]; U]) -> Self {
|
||||
let shape = Shape::new([U, V, W, Z, Y, X]);
|
||||
let mut tensor = Tensor::new(shape);
|
||||
let buffer = tensor.buffer_mut();
|
||||
|
||||
for (i, hyperhyperhyperplane) in array.iter().enumerate() {
|
||||
for (j, hyperhyperplane) in hyperhyperhyperplane.iter().enumerate()
|
||||
{
|
||||
for (k, hyperplane) in hyperhyperplane.iter().enumerate() {
|
||||
for (l, plane) in hyperplane.iter().enumerate() {
|
||||
for (m, row) in plane.iter().enumerate() {
|
||||
for (n, &elem) in row.iter().enumerate() {
|
||||
buffer[i * X * Y * Z * W * V
|
||||
+ j * X * Y * Z * W
|
||||
+ k * X * Y * Z
|
||||
+ l * X * Y
|
||||
+ m * X
|
||||
+ n] = elem;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tensor
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user