Compare commits
6 Commits
Author | SHA1 | Date | |
---|---|---|---|
668293d956 | |||
8bf134d3d2 | |||
0d701a49b6 | |||
f7b29baf95 | |||
5b3ef81c84 | |||
bf16edd3aa |
3
.gitignore
vendored
3
.gitignore
vendored
@ -1 +1,2 @@
|
||||
/target
|
||||
target
|
||||
result
|
||||
|
217
Cargo.lock
generated
217
Cargo.lock
generated
@ -44,6 +44,12 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@ -74,7 +80,7 @@ version = "0.2.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||
dependencies = [
|
||||
"hermit-abi",
|
||||
"hermit-abi 0.1.19",
|
||||
"libc",
|
||||
"winapi",
|
||||
]
|
||||
@ -156,9 +162,15 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "bytes"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||
|
||||
[[package]]
|
||||
name = "cc"
|
||||
version = "1.0.83"
|
||||
@ -174,6 +186,20 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codespan-reporting"
|
||||
version = "0.11.1"
|
||||
@ -224,7 +250,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20"
|
||||
dependencies = [
|
||||
"bitflags 2.4.1",
|
||||
"libloading 0.7.4",
|
||||
"libloading 0.8.1",
|
||||
"winapi",
|
||||
]
|
||||
|
||||
@ -247,6 +273,12 @@ version = "1.0.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
|
||||
|
||||
[[package]]
|
||||
name = "fixedbitset"
|
||||
version = "0.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
|
||||
|
||||
[[package]]
|
||||
name = "flume"
|
||||
version = "0.11.0"
|
||||
@ -277,7 +309,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -322,6 +354,18 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "getset"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
|
||||
dependencies = [
|
||||
"proc-macro-error",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "gimli"
|
||||
version = "0.28.1"
|
||||
@ -447,6 +491,12 @@ dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "hermit-abi"
|
||||
version = "0.3.3"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
|
||||
|
||||
[[package]]
|
||||
name = "hexf-parse"
|
||||
version = "0.2.1"
|
||||
@ -459,6 +509,29 @@ version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.1.0"
|
||||
@ -585,6 +658,17 @@ dependencies = [
|
||||
"adler",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "mio"
|
||||
version = "0.8.10"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"wasi",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "naga"
|
||||
version = "0.14.2"
|
||||
@ -598,6 +682,7 @@ dependencies = [
|
||||
"indexmap",
|
||||
"log",
|
||||
"num-traits",
|
||||
"petgraph",
|
||||
"rustc-hash",
|
||||
"spirv",
|
||||
"termcolor",
|
||||
@ -623,6 +708,16 @@ dependencies = [
|
||||
"autocfg",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "num_cpus"
|
||||
version = "1.16.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
|
||||
dependencies = [
|
||||
"hermit-abi 0.3.3",
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "objc"
|
||||
version = "0.2.7"
|
||||
@ -711,6 +806,22 @@ version = "1.0.14"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
|
||||
|
||||
[[package]]
|
||||
name = "petgraph"
|
||||
version = "0.6.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
|
||||
dependencies = [
|
||||
"fixedbitset",
|
||||
"indexmap",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.13"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
|
||||
|
||||
[[package]]
|
||||
name = "pkg-config"
|
||||
version = "0.3.28"
|
||||
@ -729,6 +840,30 @@ version = "0.3.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa"
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
|
||||
dependencies = [
|
||||
"proc-macro-error-attr",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 1.0.109",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro-error-attr"
|
||||
version = "1.0.4"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"version_check",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "proc-macro2"
|
||||
version = "1.0.70"
|
||||
@ -836,6 +971,15 @@ version = "1.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||
|
||||
[[package]]
|
||||
name = "signal-hook-registry"
|
||||
version = "1.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||
dependencies = [
|
||||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "slotmap"
|
||||
version = "1.0.7"
|
||||
@ -851,6 +995,16 @@ version = "1.11.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
|
||||
|
||||
[[package]]
|
||||
name = "socket2"
|
||||
version = "0.5.5"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
|
||||
dependencies = [
|
||||
"libc",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spin"
|
||||
version = "0.9.8"
|
||||
@ -876,6 +1030,17 @@ version = "1.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "1.0.109"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"unicode-ident",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "syn"
|
||||
version = "2.0.42"
|
||||
@ -913,7 +1078,37 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio"
|
||||
version = "1.35.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
|
||||
dependencies = [
|
||||
"backtrace",
|
||||
"bytes",
|
||||
"libc",
|
||||
"mio",
|
||||
"num_cpus",
|
||||
"parking_lot 0.12.1",
|
||||
"pin-project-lite",
|
||||
"signal-hook-registry",
|
||||
"socket2",
|
||||
"tokio-macros",
|
||||
"windows-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-macros"
|
||||
version = "2.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.42",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
@ -967,7 +1162,7 @@ dependencies = [
|
||||
"once_cell",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
|
||||
@ -1001,7 +1196,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
"wasm-bindgen-backend",
|
||||
"wasm-bindgen-shared",
|
||||
]
|
||||
@ -1093,7 +1288,7 @@ dependencies = [
|
||||
"js-sys",
|
||||
"khronos-egl",
|
||||
"libc",
|
||||
"libloading 0.7.4",
|
||||
"libloading 0.8.1",
|
||||
"log",
|
||||
"metal",
|
||||
"naga",
|
||||
@ -1129,9 +1324,13 @@ name = "wgpu_compute_shader"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"chrono",
|
||||
"env_logger",
|
||||
"futures-intrusive",
|
||||
"getset",
|
||||
"pollster",
|
||||
"thiserror",
|
||||
"tokio",
|
||||
"wgpu",
|
||||
]
|
||||
|
||||
@ -1280,5 +1479,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn",
|
||||
"syn 2.0.42",
|
||||
]
|
||||
|
@ -3,11 +3,16 @@ name = "wgpu_compute_shader"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[lib]
|
||||
crate-type = ["lib"]
|
||||
|
||||
[dependencies]
|
||||
wgpu = { version = "0.18", features = ["vulkan-portability"] }
|
||||
wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
|
||||
env_logger = "0.9.1"
|
||||
pollster = "0.2.5"
|
||||
futures-intrusive = "0.4"
|
||||
bytemuck = { version = "1.12.1", features = ["derive"] }
|
||||
thiserror = "1.0.51"
|
||||
tokio = { version = "1.35.1", features = ["sync", "full"] }
|
||||
getset = "0.1.2"
|
||||
chrono = "0.4.31"
|
||||
|
21
LICENSE
Normal file
21
LICENSE
Normal file
@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2023 Julius Koskela
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
72
README.md
Normal file
72
README.md
Normal file
@ -0,0 +1,72 @@
|
||||
# Rust wGPU Matrix Multiplication Shader Test
|
||||
|
||||
## Overview
|
||||
|
||||
This project demonstrates a matrix multiplication operation using compute shaders in Rust with the wGPU library. The primary goal is to showcase how to perform a simple matrix multiplication on the GPU, leveraging the power of parallel computation provided by shaders.
|
||||
|
||||
## Features
|
||||
|
||||
- Matrix multiplication using a compute shader.
|
||||
- Implementation using Rust and the wGPU library.
|
||||
- Efficient GPU-based computation for matrix operations.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- Rust programming language.
|
||||
- wGPU library for Rust.
|
||||
- GPU that supports WebGPU.
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Installation
|
||||
|
||||
1. Clone the repository:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/your-repository/rust-wgpu-matrix-multiplication.git
|
||||
cd rust-wgpu-matrix-multiplication
|
||||
```
|
||||
|
||||
2. Ensure you have the latest version of Rust installed:
|
||||
|
||||
```bash
|
||||
rustup update
|
||||
```
|
||||
|
||||
### Build and Run
|
||||
|
||||
1. Build the project using Cargo, Rust's package manager and build system:
|
||||
|
||||
```bash
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
2. Run the compiled binary:
|
||||
|
||||
```bash
|
||||
cargo run --release
|
||||
```
|
||||
|
||||
## Project Structure
|
||||
|
||||
- `src/main.rs`: The main entry point for the application.
|
||||
- `src/ppu.rs`: Contains the `PPU` struct for handling GPU tasks.
|
||||
- `shaders/`: Folder containing WGSL shader files.
|
||||
|
||||
## Shader
|
||||
|
||||
The matrix multiplication logic is implemented in a WGSL shader. The shader takes two input matrices (`matrixA` and `matrixB`) and outputs the product (`matrixC`).
|
||||
|
||||
## How It Works
|
||||
|
||||
- The compute shader (`main`) is dispatched with a workgroup size of 4x4.
|
||||
- Each work item calculates one element of the product matrix.
|
||||
- The result is written back to a buffer that is read by the Rust application.
|
||||
|
||||
## Contributing
|
||||
|
||||
Contributions to this project are welcome. Please open an issue or pull request on the GitHub repository.
|
||||
|
||||
## License
|
||||
|
||||
This project is open-source and available under the [MIT License](LICENSE).
|
9
build.rs
9
build.rs
@ -1,9 +0,0 @@
|
||||
fn main() {
|
||||
if cfg!(target_os = "linux") {
|
||||
// println!("cargo:rustc-link-lib=X11");
|
||||
// println!("cargo:rustc-link-lib=Xcursor");
|
||||
// println!("cargo:rustc-link-lib=Xrandr");
|
||||
// println!("cargo:rustc-link-lib=Xi");
|
||||
println!("cargo:rustc-link-lib=vulkan");
|
||||
}
|
||||
}
|
76
default.nix
76
default.nix
@ -1,45 +1,41 @@
|
||||
{pkgs}:
|
||||
with pkgs;
|
||||
rustPlatform.buildRustPackage {
|
||||
pname = "matmul-vshader";
|
||||
version = "0.1.0";
|
||||
{pkgs}: let
|
||||
manifest = (pkgs.lib.importTOML ./Cargo.toml).package;
|
||||
buildInputs = with pkgs; [
|
||||
udev
|
||||
alsa-lib
|
||||
vulkan-loader
|
||||
vulkan-headers
|
||||
vulkan-tools
|
||||
vulkan-validation-layers
|
||||
xorg.libX11
|
||||
xorg.libXcursor
|
||||
xorg.libXi
|
||||
xorg.libXrandr # To use the x11 feature
|
||||
libxkbcommon
|
||||
wayland # To use the wayland feature
|
||||
];
|
||||
in
|
||||
pkgs.rustPlatform.buildRustPackage {
|
||||
inherit buildInputs;
|
||||
pname = manifest.name;
|
||||
version = manifest.version;
|
||||
|
||||
src = ./.;
|
||||
src = pkgs.lib.cleanSource ./.;
|
||||
|
||||
packages = [cmake shaderc];
|
||||
buildInputs = [cmake shaderc];
|
||||
nativeBuildInputs = [
|
||||
vulkan-headers
|
||||
vulkan-loader
|
||||
vulkan-validation-layers
|
||||
vulkan-tools
|
||||
pkg-config
|
||||
git
|
||||
gcc
|
||||
cmake
|
||||
glibc
|
||||
python3
|
||||
shaderc
|
||||
];
|
||||
nativeBuildInputs = with pkgs; [
|
||||
pkg-config
|
||||
];
|
||||
|
||||
RUST_BACKTRACE = "1";
|
||||
# LD_LIBRARY_PATH = "${vulkan-loader}/lib:${vulkan-validation-layers}/lib:${vulkan-tools}/lib:${vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}";
|
||||
# VK_ICD_FILENAMES = "${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json";
|
||||
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
|
||||
# VK_INSTANCE_LAYERS = "VK_LAYER_KHRONOS_validation";
|
||||
# VK_DEVICE_LAYERS = "VK_LAYER_KHRONOS_validation";
|
||||
# VK_LOADER_DEBUG = "all";
|
||||
# VK_LOADER_DEBUG_FILE = "/tmp/vulkan.log";
|
||||
# VK_INSTANCE_EXTENSIONS = "VK_EXT_debug_utils";
|
||||
# VK_DEVICE_EXTENSIONS = "VK_EXT_debug_utils";
|
||||
# VK_LAYER_ENABLES = "VK_LAYER_KHRONOS_validation";
|
||||
# VK_LAYER_DISABLES = "VK_LAYER_LUNARG_api_dump";
|
||||
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
|
||||
preConfigure = ''
|
||||
LD_LIBRARY_PATH=${pkgs.lib.makeLibraryPath buildInputs}
|
||||
'';
|
||||
|
||||
cargoBuildFlags = ["--release"];
|
||||
LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
|
||||
|
||||
cargoLock = {
|
||||
lockFile = ./Cargo.lock;
|
||||
allowBuiltinFetchGit = true;
|
||||
};
|
||||
}
|
||||
cargoBuildFlags = ["--release"];
|
||||
|
||||
cargoLock = {
|
||||
lockFile = ./Cargo.lock;
|
||||
allowBuiltinFetchGit = true;
|
||||
};
|
||||
}
|
||||
|
168
examples/mat_mul_4x4.rs
Normal file
168
examples/mat_mul_4x4.rs
Normal file
@ -0,0 +1,168 @@
|
||||
use wgpu_compute_shader::*;
|
||||
|
||||
const MM4X4_SHADER: &str = "
|
||||
struct Matrix {
|
||||
data: array<array<f32, 4>, 4>,
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read> matrixA: Matrix;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read> matrixB: Matrix;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> matrixC: Matrix;
|
||||
|
||||
// Consider setting workgroup size to power of 2 for better efficiency
|
||||
@compute @workgroup_size(4, 4, 1)
|
||||
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
// Ensure the computation is within the bounds of the 4x4 matrix
|
||||
if (global_id.x < 4u && global_id.y < 4u) {
|
||||
let row: u32 = global_id.y;
|
||||
let col: u32 = global_id.x;
|
||||
|
||||
var sum: f32 = 0.0;
|
||||
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
|
||||
}
|
||||
matrixC.data[row][col] = sum;
|
||||
}
|
||||
}
|
||||
";
|
||||
|
||||
const MATRIX_A: [[f32; 4]; 4] = [
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0],
|
||||
];
|
||||
|
||||
const MATRIX_B: [[f32; 4]; 4] = [
|
||||
[1.0, 2.0, 3.0, 4.0],
|
||||
[5.0, 6.0, 7.0, 8.0],
|
||||
[9.0, 10.0, 11.0, 12.0],
|
||||
[13.0, 14.0, 15.0, 16.0],
|
||||
];
|
||||
|
||||
// const EXPECT: [[f32; 4]; 4] = [
|
||||
// [90.0, 100.0, 110.0, 120.0],
|
||||
// [202.0, 228.0, 254.0, 280.0],
|
||||
// [314.0, 356.0, 398.0, 440.0],
|
||||
// [426.0, 484.0, 542.0, 600.0],
|
||||
// ];
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<(), Error> {
|
||||
// Create PPU
|
||||
let mut ppu = PPU::new().await?;
|
||||
ppu.load_shader("MM4X4_SHADER", MM4X4_SHADER)?;
|
||||
|
||||
let mut buffers = ComputeBuffers::new::<f32>(&ppu, 16);
|
||||
|
||||
buffers.add_buffer_init(
|
||||
&ppu,
|
||||
"MATRIX_A",
|
||||
&MATRIX_A,
|
||||
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||
);
|
||||
buffers.add_buffer_init(
|
||||
&ppu,
|
||||
"MATRIX_B",
|
||||
&MATRIX_B,
|
||||
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||
);
|
||||
buffers.add_buffer_init(
|
||||
&ppu,
|
||||
"MATRIX_C",
|
||||
&[0.0; 16],
|
||||
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
|
||||
);
|
||||
|
||||
// Create Bind Group Layout
|
||||
let bind_group_layout_entries = [
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 1,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
wgpu::BindGroupLayoutEntry {
|
||||
binding: 2,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
},
|
||||
];
|
||||
|
||||
ppu.load_bind_group_layout("MM4X4_BIND_GROUP_LAYOUT", &bind_group_layout_entries);
|
||||
|
||||
// Finally, create the bind group
|
||||
ppu.create_bind_group(
|
||||
"MM4X4_BIND_GROUP",
|
||||
"MM4X4_BIND_GROUP_LAYOUT",
|
||||
&[
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: buffers.get_buffer("MATRIX_A").unwrap().as_entire_binding(),
|
||||
},
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: buffers.get_buffer("MATRIX_B").unwrap().as_entire_binding(),
|
||||
},
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: buffers.get_buffer("MATRIX_C").unwrap().as_entire_binding(),
|
||||
},
|
||||
],
|
||||
)?;
|
||||
|
||||
// Load Pipeline
|
||||
ppu.load_pipeline(
|
||||
"MM4X4_PIPELINE",
|
||||
"MM4X4_SHADER",
|
||||
"main",
|
||||
&["MM4X4_BIND_GROUP_LAYOUT"],
|
||||
)?;
|
||||
|
||||
// Execute the compute task
|
||||
let workgroup_count = (1, 1, 1);
|
||||
let results: ComputeResult<f32> = ppu
|
||||
.execute_compute_task(
|
||||
"MM4X4_PIPELINE",
|
||||
"MM4X4_BIND_GROUP",
|
||||
&buffers,
|
||||
"MATRIX_C",
|
||||
workgroup_count,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// Process results
|
||||
println!("Matrix C:");
|
||||
|
||||
results.data().chunks(4).for_each(|row| {
|
||||
println!("{:?}", row);
|
||||
});
|
||||
|
||||
println!("Time elapsed: {} us", results.time_elapsed_us());
|
||||
|
||||
Ok(())
|
||||
}
|
@ -7,7 +7,6 @@
|
||||
nixpkgs,
|
||||
}: {
|
||||
packages.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./default.nix {};
|
||||
devShells.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./shell.nix {};
|
||||
formatter.x86_64-linux = nixpkgs.legacyPackages.x86_64-linux.alejandra;
|
||||
};
|
||||
}
|
||||
|
51
shell.nix
51
shell.nix
@ -1,51 +0,0 @@
|
||||
{pkgs}:
|
||||
with pkgs; let
|
||||
build = pkgs.callPackage ./default.nix {};
|
||||
in
|
||||
mkShell {
|
||||
inherit build;
|
||||
packages = [
|
||||
libX11
|
||||
libXcursor
|
||||
libXrandr
|
||||
libXi
|
||||
vulkan-headers
|
||||
vulkan-loader
|
||||
vulkan-validation-layers
|
||||
vulkan-tools
|
||||
pkg-config
|
||||
git
|
||||
gcc
|
||||
gnumake
|
||||
cmake
|
||||
glibc
|
||||
python3
|
||||
shaderc
|
||||
];
|
||||
|
||||
inputsFrom = [
|
||||
cmake
|
||||
shaderc
|
||||
];
|
||||
|
||||
RUST_BACKTRACE = "1";
|
||||
LD_LIBRARY_PATH="${pkgs.libX11}/lib:${pkgs.libXcursor}/lib:${pkgs.libXrandr}/lib:${pkgs.libXi}/lib:${pkgs.vulkan-loader}/lib:${pkgs.vulkan-validation-layers}/lib:${pkgs.vulkan-tools}/lib:${pkgs.vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}:$LD_LIBRARY_PATH";
|
||||
|
||||
cargoBuildFlags = ["--release --features build-from-source"];
|
||||
|
||||
shellHook = ''
|
||||
export VK_ICD_FILENAMES=${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json
|
||||
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
|
||||
export VK_INSTANCE_LAYERS=VK_LAYER_KHRONOS_validation
|
||||
export VK_DEVICE_LAYERS=VK_LAYER_KHRONOS_validation
|
||||
export VK_LOADER_DEBUG=all
|
||||
export VK_LOADER_DEBUG_FILE=/tmp/vulkan.log
|
||||
export VK_INSTANCE_EXTENSIONS=VK_EXT_debug_utils
|
||||
export VK_DEVICE_EXTENSIONS=VK_EXT_debug_utils
|
||||
export VK_LAYER_ENABLES=VK_LAYER_KHRONOS_validation
|
||||
export VK_LAYER_DISABLES=VK_LAYER_LUNARG_api_dump
|
||||
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
|
||||
|
||||
pkg-config --cflags fontconfig fontconfig >= 2.11.1 --libs vulkan
|
||||
'';
|
||||
}
|
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
||||
mod ppu;
|
||||
pub use ppu::*;
|
196
src/main.rs
196
src/main.rs
@ -1,196 +0,0 @@
|
||||
// ... existing imports ...
|
||||
use std::time::Instant;
|
||||
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
use bytemuck;
|
||||
|
||||
async fn run() {
|
||||
let instance = wgpu::Instance::default();
|
||||
println!("instance {:?}", instance);
|
||||
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: false,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await;
|
||||
let adapter = match adapter {
|
||||
Some(adapter) => adapter,
|
||||
None => {
|
||||
println!("No suitable GPU adapters found on the system!");
|
||||
return;
|
||||
}
|
||||
};
|
||||
let features = adapter.features();
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: None,
|
||||
features: features & wgpu::Features::TIMESTAMP_QUERY,
|
||||
limits: Default::default(),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
count: 2,
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
label: None,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let start_instant = Instant::now();
|
||||
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: None,
|
||||
//source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
|
||||
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
|
||||
});
|
||||
println!("shader compilation {:?}", start_instant.elapsed());
|
||||
|
||||
let matrix_size = std::mem::size_of::<[[f32; 4]; 4]>();
|
||||
let matrix_a_data = [[1.0; 4]; 4]; // Example data
|
||||
let matrix_b_data = [[2.0; 4]; 4]; // Example data
|
||||
|
||||
let matrix_a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("Matrix A Buffer"),
|
||||
contents: bytemuck::bytes_of(&matrix_a_data),
|
||||
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||
});
|
||||
let matrix_b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("Matrix B Buffer"),
|
||||
contents: bytemuck::bytes_of(&matrix_b_data),
|
||||
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||
});
|
||||
let matrix_c_buf = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Matrix C Buffer"),
|
||||
size: matrix_size as u64,
|
||||
usage: wgpu::BufferUsages::STORAGE
|
||||
| wgpu::BufferUsages::MAP_READ
|
||||
| wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let query_buf = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||
Some(device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Query Buffer"),
|
||||
size: 16, // Enough for two 64-bit timestamps
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
}))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: None,
|
||||
entries: &[wgpu::BindGroupLayoutEntry {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStages::COMPUTE,
|
||||
ty: wgpu::BindingType::Buffer {
|
||||
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||
has_dynamic_offset: false,
|
||||
min_binding_size: None,
|
||||
},
|
||||
count: None,
|
||||
}],
|
||||
});
|
||||
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: None,
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: None,
|
||||
layout: Some(&compute_pipeline_layout),
|
||||
module: &cs_module,
|
||||
entry_point: "main",
|
||||
});
|
||||
|
||||
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some("Compute Bind Group"),
|
||||
layout: &bind_group_layout,
|
||||
entries: &[
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: matrix_a_buf.as_entire_binding(),
|
||||
},
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 1,
|
||||
resource: matrix_b_buf.as_entire_binding(),
|
||||
},
|
||||
wgpu::BindGroupEntry {
|
||||
binding: 2,
|
||||
resource: matrix_c_buf.as_entire_binding(),
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
let mut encoder = device.create_command_encoder(&Default::default());
|
||||
if let Some(query_set) = &query_set {
|
||||
encoder.write_timestamp(query_set, 0);
|
||||
}
|
||||
{
|
||||
let mut cpass = encoder.begin_compute_pass(&Default::default());
|
||||
cpass.set_pipeline(&pipeline);
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.dispatch_workgroups(4, 4, 1); // Dispatch for a 4x4 matrix
|
||||
}
|
||||
|
||||
if let Some(query_set) = &query_set {
|
||||
encoder.write_timestamp(query_set, 1);
|
||||
}
|
||||
encoder.copy_buffer_to_buffer(&matrix_c_buf, 0, &matrix_c_buf, 0, matrix_size as u64);
|
||||
if let Some(query_set) = &query_set {
|
||||
if let Some(query_buf) = &query_buf {
|
||||
encoder.write_timestamp(query_set, 1);
|
||||
encoder.resolve_query_set(query_set, 0..2, query_buf, 0);
|
||||
}
|
||||
}
|
||||
queue.submit(Some(encoder.finish()));
|
||||
|
||||
// Assuming query_buf has been properly initialized earlier
|
||||
|
||||
let buf_slice = matrix_c_buf.slice(..);
|
||||
let query_slice = query_buf.as_ref().map(|buf| buf.slice(..)); // Adjust this line
|
||||
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
|
||||
// Map the buffer for reading (matrix_c_buf)
|
||||
buf_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
|
||||
|
||||
if let Some(q_slice) = query_slice {
|
||||
// Map the query buffer if it exists
|
||||
let _query_future = q_slice.map_async(wgpu::MapMode::Read, |_| ());
|
||||
}
|
||||
|
||||
println!("pre-poll {:?}", std::time::Instant::now());
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
println!("post-poll {:?}", std::time::Instant::now());
|
||||
|
||||
if let Some(Ok(())) = receiver.receive().await {
|
||||
let data_raw = &*buf_slice.get_mapped_range();
|
||||
println!("compute shader result: {:?}", data_raw);
|
||||
}
|
||||
|
||||
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||
if let Some(q_slice) = query_slice {
|
||||
let ts_period = queue.get_timestamp_period();
|
||||
let ts_data_raw = &*q_slice.get_mapped_range();
|
||||
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
|
||||
println!(
|
||||
"compute shader elapsed: {:?}ms",
|
||||
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn main() {
|
||||
pollster::block_on(run());
|
||||
}
|
470
src/ppu.rs
Normal file
470
src/ppu.rs
Normal file
@ -0,0 +1,470 @@
|
||||
// #![allow(unused)]
|
||||
use bytemuck;
|
||||
use chrono::Duration;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
#[derive(Error, Debug)]
|
||||
pub enum Error {
|
||||
#[error("Failed to find an appropriate adapter.")]
|
||||
AdapterNotFound,
|
||||
#[error("Failed to create device.")]
|
||||
DeviceCreationFailed(#[from] wgpu::RequestDeviceError),
|
||||
#[error("Failed to create query set. {0}")]
|
||||
QuerySetCreationFailed(String),
|
||||
#[error("Failed to async map a buffer.")]
|
||||
BufferAsyncError(#[from] wgpu::BufferAsyncError),
|
||||
#[error("Failed to create compute pipeline, shader module {0} was not found.")]
|
||||
ShaderModuleNotFound(String),
|
||||
#[error("No buffer found with name {0}.")]
|
||||
BufferNotFound(String),
|
||||
// #[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")]
|
||||
// BufferTooSmall(String, usize, usize),
|
||||
#[error("No bind group layout found with name {0}.")]
|
||||
BindGroupLayoutNotFound(String),
|
||||
#[error("No pipeline found with name {0}.")]
|
||||
PipelineNotFound(String),
|
||||
#[error("No bind group found with name {0}.")]
|
||||
BindGroupNotFound(String),
|
||||
}
|
||||
|
||||
pub struct PPU {
|
||||
instance: wgpu::Instance,
|
||||
adapter: wgpu::Adapter,
|
||||
device: wgpu::Device,
|
||||
queue: wgpu::Queue,
|
||||
query_set: wgpu::QuerySet,
|
||||
// buffers: HashMap<String, wgpu::Buffer>,
|
||||
shader_modules: HashMap<String, wgpu::ShaderModule>,
|
||||
pipelines: HashMap<String, wgpu::ComputePipeline>,
|
||||
bind_group_layouts: HashMap<String, wgpu::BindGroupLayout>,
|
||||
bind_groups: HashMap<String, wgpu::BindGroup>,
|
||||
}
|
||||
|
||||
#[derive(Getters, MutGetters)]
|
||||
pub struct ComputeBuffers {
|
||||
#[getset(get, get_mut)]
|
||||
buffers: HashMap<String, wgpu::Buffer>,
|
||||
#[getset(get)]
|
||||
staging: wgpu::Buffer,
|
||||
#[getset(get)]
|
||||
readback: wgpu::Buffer,
|
||||
#[getset(get)]
|
||||
timestamp_query: wgpu::Buffer,
|
||||
}
|
||||
|
||||
impl ComputeBuffers {
|
||||
pub fn new<T: bytemuck::Pod>(ppu: &PPU, output_size: usize) -> Self {
|
||||
let staging = ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("COMPUTE_STAGING_BUFFER"),
|
||||
size: output_size as u64 * std::mem::size_of::<T>() as u64,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let timestamp_query = ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("COMPUTE_TIMESTAMP_QUERY_BUFFER"),
|
||||
size: 16,
|
||||
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let readback = ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("COMPUTE_READBACK_BUFFER"),
|
||||
size: 16,
|
||||
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
Self {
|
||||
buffers: HashMap::new(),
|
||||
staging,
|
||||
readback,
|
||||
timestamp_query,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn add_buffer<T>(
|
||||
&mut self,
|
||||
ppu: &PPU,
|
||||
label: &str,
|
||||
size: usize,
|
||||
usage: wgpu::BufferUsages,
|
||||
) {
|
||||
self.buffers.entry(label.to_string()).or_insert_with(|| {
|
||||
ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some(label),
|
||||
size: size as u64 * std::mem::size_of::<T>() as u64,
|
||||
usage,
|
||||
mapped_at_creation: false,
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn add_buffer_init<T: bytemuck::Pod>(
|
||||
&mut self,
|
||||
ppu: &PPU,
|
||||
label: &str,
|
||||
data: &[T],
|
||||
usage: wgpu::BufferUsages,
|
||||
) {
|
||||
self.buffers.entry(label.to_string()).or_insert_with(|| {
|
||||
ppu.device
|
||||
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some(label),
|
||||
contents: bytemuck::cast_slice(data),
|
||||
usage,
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
|
||||
let buffer = match self.buffers.get(label) {
|
||||
Some(buffer) => buffer,
|
||||
None => return,
|
||||
};
|
||||
|
||||
let data_bytes = bytemuck::cast_slice(data);
|
||||
let data_len = data_bytes.len() as wgpu::BufferAddress;
|
||||
|
||||
if buffer.size() < data_len {
|
||||
return;
|
||||
}
|
||||
|
||||
ppu.queue.write_buffer(buffer, 0, data_bytes);
|
||||
}
|
||||
|
||||
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
|
||||
self.buffers.get(label)
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ComputeResult<T> {
|
||||
data: Vec<T>,
|
||||
time_elapsed: Duration,
|
||||
}
|
||||
|
||||
impl<T> ComputeResult<T> {
|
||||
pub fn new(data: Vec<T>, time_elapsed: Duration) -> Self {
|
||||
Self { data, time_elapsed }
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &Vec<T> {
|
||||
&self.data
|
||||
}
|
||||
|
||||
pub fn time_elapsed_sec(&self) -> f32 {
|
||||
self.time_elapsed.num_nanoseconds().unwrap() as f32 * 1e-9
|
||||
}
|
||||
|
||||
pub fn time_elapsed_ms(&self) -> f32 {
|
||||
self.time_elapsed.num_milliseconds() as f32
|
||||
}
|
||||
|
||||
pub fn time_elapsed_us(&self) -> f32 {
|
||||
self.time_elapsed.num_microseconds().unwrap() as f32
|
||||
}
|
||||
|
||||
pub fn time_elapsed_ns(&self) -> f32 {
|
||||
self.time_elapsed.num_nanoseconds().unwrap() as f32
|
||||
}
|
||||
}
|
||||
|
||||
impl PPU {
|
||||
/// Initialize a new PPU instance
|
||||
///
|
||||
/// # Examples
|
||||
///
|
||||
/// ```
|
||||
/// use ppu::PPU;
|
||||
///
|
||||
/// #[tokio::main]
|
||||
/// async fn main() {
|
||||
/// let ppu = PPU::new().await.unwrap();
|
||||
/// }
|
||||
pub async fn new() -> Result<Self, Error> {
|
||||
let instance = wgpu::Instance::default();
|
||||
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||
force_fallback_adapter: true,
|
||||
compatible_surface: None,
|
||||
})
|
||||
.await
|
||||
.ok_or(Error::AdapterNotFound)?;
|
||||
|
||||
let features = adapter.features();
|
||||
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: None,
|
||||
features: features & wgpu::Features::TIMESTAMP_QUERY,
|
||||
limits: Default::default(),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||
device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
count: 2,
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
label: None,
|
||||
})
|
||||
} else {
|
||||
return Err(Error::QuerySetCreationFailed(
|
||||
"Timestamp query is not supported".to_string(),
|
||||
));
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
instance,
|
||||
adapter,
|
||||
device,
|
||||
queue,
|
||||
query_set,
|
||||
shader_modules: HashMap::new(),
|
||||
pipelines: HashMap::new(),
|
||||
bind_group_layouts: HashMap::new(),
|
||||
bind_groups: HashMap::new(),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn load_bind_group_layout(
|
||||
&mut self,
|
||||
label: &str,
|
||||
layout_entries: &[wgpu::BindGroupLayoutEntry],
|
||||
) {
|
||||
let bind_group_layout =
|
||||
self.device
|
||||
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: Some(label),
|
||||
entries: layout_entries,
|
||||
});
|
||||
|
||||
self.bind_group_layouts
|
||||
.insert(label.to_string(), bind_group_layout);
|
||||
}
|
||||
|
||||
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
||||
self.bind_group_layouts.get(label)
|
||||
}
|
||||
|
||||
/// Create and store a bind group
|
||||
pub fn create_bind_group(
|
||||
&mut self,
|
||||
name: &str,
|
||||
layout_name: &str,
|
||||
entries: &[wgpu::BindGroupEntry],
|
||||
) -> Result<(), Error> {
|
||||
let layout = self
|
||||
.bind_group_layouts
|
||||
.get(layout_name)
|
||||
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))?;
|
||||
|
||||
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
label: Some(name),
|
||||
layout,
|
||||
entries,
|
||||
});
|
||||
self.bind_groups.insert(name.to_string(), bind_group);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Retrieve a bind group by name
|
||||
pub fn get_bind_group(&self, name: &str) -> Option<&wgpu::BindGroup> {
|
||||
self.bind_groups.get(name)
|
||||
}
|
||||
|
||||
/// Load a shader and store it in the hash map
|
||||
pub fn load_shader(&mut self, name: &str, source: &str) -> Result<(), Error> {
|
||||
let shader_module = self
|
||||
.device
|
||||
.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||
label: Some(name),
|
||||
source: wgpu::ShaderSource::Wgsl(source.into()),
|
||||
});
|
||||
self.shader_modules.insert(name.to_string(), shader_module);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Get a shader from the hash map
|
||||
pub fn get_shader(&self, name: &str) -> Option<&wgpu::ShaderModule> {
|
||||
self.shader_modules.get(name)
|
||||
}
|
||||
|
||||
pub fn load_pipeline(
|
||||
&mut self,
|
||||
name: &str,
|
||||
shader_module_name: &str,
|
||||
entry_point: &str,
|
||||
bind_group_layout_names: &[&str], // Use names of bind group layouts
|
||||
) -> Result<(), Error> {
|
||||
// Retrieve the shader module
|
||||
let shader_module = self
|
||||
.get_shader(shader_module_name)
|
||||
.ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?;
|
||||
|
||||
// Retrieve the bind group layouts
|
||||
let bind_group_layouts = bind_group_layout_names
|
||||
.iter()
|
||||
.map(|layout_name| {
|
||||
self.bind_group_layouts
|
||||
.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts
|
||||
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))
|
||||
})
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
// Create the pipeline layout
|
||||
let pipeline_layout = self
|
||||
.device
|
||||
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
label: Some(name),
|
||||
bind_group_layouts: &bind_group_layouts
|
||||
.iter()
|
||||
.map(|layout| *layout)
|
||||
.collect::<Vec<_>>(),
|
||||
push_constant_ranges: &[],
|
||||
});
|
||||
|
||||
// Create the compute pipeline
|
||||
let pipeline = self
|
||||
.device
|
||||
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
label: Some(name),
|
||||
layout: Some(&pipeline_layout),
|
||||
module: shader_module,
|
||||
entry_point,
|
||||
});
|
||||
|
||||
// Store the pipeline
|
||||
self.pipelines.insert(name.to_string(), pipeline);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn get_pipeline(&self, name: &str) -> Option<&wgpu::ComputePipeline> {
|
||||
self.pipelines.get(name)
|
||||
}
|
||||
|
||||
/// Execute a compute task and retrieve the result
|
||||
pub async fn execute_compute_task<T>(
|
||||
&self,
|
||||
pipeline_name: &str,
|
||||
bind_group_name: &str,
|
||||
buffers: &ComputeBuffers,
|
||||
output_buffer_name: &str,
|
||||
workgroup_count: (u32, u32, u32),
|
||||
) -> Result<ComputeResult<T>, Error>
|
||||
where
|
||||
T: bytemuck::Pod + bytemuck::Zeroable + Send + Sync + std::fmt::Debug, // Added Debug trait
|
||||
{
|
||||
// Retrieve the pipeline and bind group
|
||||
let pipeline = self
|
||||
.get_pipeline(pipeline_name)
|
||||
.ok_or_else(|| Error::PipelineNotFound(pipeline_name.to_string()))?;
|
||||
let bind_group = self
|
||||
.get_bind_group(bind_group_name)
|
||||
.ok_or_else(|| Error::BindGroupNotFound(bind_group_name.to_string()))?;
|
||||
|
||||
// Create a command encoder and dispatch the compute shader
|
||||
let mut encoder = self.device.create_command_encoder(&Default::default());
|
||||
{
|
||||
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
|
||||
compute_pass.set_pipeline(pipeline);
|
||||
compute_pass.set_bind_group(0, bind_group, &[]);
|
||||
compute_pass.dispatch_workgroups(
|
||||
workgroup_count.0,
|
||||
workgroup_count.1,
|
||||
workgroup_count.2,
|
||||
);
|
||||
}
|
||||
|
||||
// Create a new query set for this task
|
||||
let query_set = self.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
count: 2, // Two timestamps: start and end
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
label: Some("Timestamp Query Set"),
|
||||
});
|
||||
|
||||
// Record the start timestamp
|
||||
encoder.write_timestamp(&query_set, 0);
|
||||
|
||||
// Copy output to staging buffer
|
||||
let output_buffer = buffers
|
||||
.get_buffer(output_buffer_name)
|
||||
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
|
||||
|
||||
encoder.copy_buffer_to_buffer(
|
||||
output_buffer,
|
||||
0,
|
||||
buffers.staging(),
|
||||
0,
|
||||
buffers.staging().size(),
|
||||
);
|
||||
|
||||
// Record the end timestamp
|
||||
encoder.write_timestamp(&query_set, 1);
|
||||
|
||||
// Resolve timestamp query and write to timestamp query buffer
|
||||
encoder.resolve_query_set(&query_set, 0..2, buffers.timestamp_query(), 0);
|
||||
|
||||
// Copy timestamp query buffer to readback buffer
|
||||
encoder.copy_buffer_to_buffer(
|
||||
buffers.timestamp_query(),
|
||||
0,
|
||||
buffers.readback(),
|
||||
0,
|
||||
buffers.readback().size(),
|
||||
);
|
||||
|
||||
// Submit the command encoder
|
||||
self.queue.submit(Some(encoder.finish()));
|
||||
|
||||
// Wait for the GPU to finish executing before mapping buffers
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
// Read the staging buffer data
|
||||
let buffer_slice = buffers.staging().slice(..);
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
// buffer_slice.unmap();
|
||||
|
||||
let data = buffer_slice.get_mapped_range().to_vec();
|
||||
let result = data
|
||||
.chunks_exact(std::mem::size_of::<T>())
|
||||
.map(|chunk| *bytemuck::from_bytes::<T>(chunk))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
// Read the timestamp query results to the readback buffer
|
||||
|
||||
let timestamp_buffer_slice = buffers.readback().slice(..);
|
||||
timestamp_buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
// timestamp_buffer_slice.unmap();
|
||||
|
||||
let ts_data_raw = timestamp_buffer_slice.get_mapped_range();
|
||||
let ts_data: &[u64] = bytemuck::cast_slice(&ts_data_raw);
|
||||
|
||||
// Calculate the elapsed time in nanoseconds
|
||||
let elapsed_ns = (ts_data[1] - ts_data[0]) as f64 * self.queue.get_timestamp_period() as f64;
|
||||
let time_elapsed = Duration::nanoseconds(elapsed_ns as i64);
|
||||
|
||||
Ok(ComputeResult::new(result, time_elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PPU {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
write!(f, "PPU")?;
|
||||
write!(f, "instance {:?}", self.instance)?;
|
||||
write!(f, "adapter {:?}", self.adapter)?;
|
||||
write!(f, "device {:?}", self.device)?;
|
||||
write!(f, "queue {:?}", self.queue)?;
|
||||
write!(f, "query_set {:?}", self.query_set)?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
@ -1,5 +1,5 @@
|
||||
struct Matrix {
|
||||
data: array<array<f32, 4>, 4>; // Assuming 4x4 matrices
|
||||
data: array<array<f32, 4>, 4>, // Corrected: Use a comma instead of a semicolon
|
||||
};
|
||||
|
||||
@group(0) @binding(0)
|
||||
@ -17,7 +17,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||
let col: u32 = global_id.x;
|
||||
|
||||
var sum: f32 = 0.0;
|
||||
for (var k: u32 = 0; k < 4u; k = k + 1u) {
|
||||
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
|
||||
}
|
||||
matrixC.data[row][col] = sum;
|
||||
|
Loading…
Reference in New Issue
Block a user