Compare commits

..

6 Commits

Author SHA1 Message Date
668293d956
🔧 Implement as a library
- Make the crate a lib
- Move main to examples as mat_mul_4x4.rs
- Correctly track elapsed time of compute task

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-23 04:00:45 +02:00
8bf134d3d2
📝 Add LICENSE and README.md
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-23 02:48:40 +02:00
0d701a49b6
🔧 Make nix develop work
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-23 02:41:35 +02:00
f7b29baf95
🚀 4x4 matrix multiplication working
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-23 02:10:51 +02:00
5b3ef81c84
🚀 Builds on Linux, work on abstraction
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-22 16:01:40 +02:00
bf16edd3aa
🚀 Builds and runs on Linux and MacOS, result still wrong
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-22 10:51:26 +02:00
14 changed files with 989 additions and 312 deletions

3
.gitignore vendored
View File

@ -1 +1,2 @@
/target target
result

217
Cargo.lock generated
View File

@ -44,6 +44,12 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]] [[package]]
name = "android_system_properties" name = "android_system_properties"
version = "0.1.5" version = "0.1.5"
@ -74,7 +80,7 @@ version = "0.2.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
dependencies = [ dependencies = [
"hermit-abi", "hermit-abi 0.1.19",
"libc", "libc",
"winapi", "winapi",
] ]
@ -156,9 +162,15 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
] ]
[[package]]
name = "bytes"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]] [[package]]
name = "cc" name = "cc"
version = "1.0.83" version = "1.0.83"
@ -174,6 +186,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets",
]
[[package]] [[package]]
name = "codespan-reporting" name = "codespan-reporting"
version = "0.11.1" version = "0.11.1"
@ -224,7 +250,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20" checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20"
dependencies = [ dependencies = [
"bitflags 2.4.1", "bitflags 2.4.1",
"libloading 0.7.4", "libloading 0.8.1",
"winapi", "winapi",
] ]
@ -247,6 +273,12 @@ version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "fixedbitset"
version = "0.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0ce7134b9999ecaf8bcd65542e436736ef32ddca1b3e06094cb6ec5755203b80"
[[package]] [[package]]
name = "flume" name = "flume"
version = "0.11.0" version = "0.11.0"
@ -277,7 +309,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
] ]
[[package]] [[package]]
@ -322,6 +354,18 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "getset"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]] [[package]]
name = "gimli" name = "gimli"
version = "0.28.1" version = "0.28.1"
@ -447,6 +491,12 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "hermit-abi"
version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
[[package]] [[package]]
name = "hexf-parse" name = "hexf-parse"
version = "0.2.1" version = "0.2.1"
@ -459,6 +509,29 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "iana-time-zone"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.1.0" version = "2.1.0"
@ -585,6 +658,17 @@ dependencies = [
"adler", "adler",
] ]
[[package]]
name = "mio"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09"
dependencies = [
"libc",
"wasi",
"windows-sys",
]
[[package]] [[package]]
name = "naga" name = "naga"
version = "0.14.2" version = "0.14.2"
@ -598,6 +682,7 @@ dependencies = [
"indexmap", "indexmap",
"log", "log",
"num-traits", "num-traits",
"petgraph",
"rustc-hash", "rustc-hash",
"spirv", "spirv",
"termcolor", "termcolor",
@ -623,6 +708,16 @@ dependencies = [
"autocfg", "autocfg",
] ]
[[package]]
name = "num_cpus"
version = "1.16.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
dependencies = [
"hermit-abi 0.3.3",
"libc",
]
[[package]] [[package]]
name = "objc" name = "objc"
version = "0.2.7" version = "0.2.7"
@ -711,6 +806,22 @@ version = "1.0.14"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c"
[[package]]
name = "petgraph"
version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
dependencies = [
"fixedbitset",
"indexmap",
]
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]] [[package]]
name = "pkg-config" name = "pkg-config"
version = "0.3.28" version = "0.3.28"
@ -729,6 +840,30 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa" checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]] [[package]]
name = "proc-macro2" name = "proc-macro2"
version = "1.0.70" version = "1.0.70"
@ -836,6 +971,15 @@ version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "signal-hook-registry"
version = "1.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
dependencies = [
"libc",
]
[[package]] [[package]]
name = "slotmap" name = "slotmap"
version = "1.0.7" version = "1.0.7"
@ -851,6 +995,16 @@ version = "1.11.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
[[package]]
name = "socket2"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
dependencies = [
"libc",
"windows-sys",
]
[[package]] [[package]]
name = "spin" name = "spin"
version = "0.9.8" version = "0.9.8"
@ -876,6 +1030,17 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.42" version = "2.0.42"
@ -913,7 +1078,37 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
]
[[package]]
name = "tokio"
version = "1.35.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
dependencies = [
"backtrace",
"bytes",
"libc",
"mio",
"num_cpus",
"parking_lot 0.12.1",
"pin-project-lite",
"signal-hook-registry",
"socket2",
"tokio-macros",
"windows-sys",
]
[[package]]
name = "tokio-macros"
version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.42",
] ]
[[package]] [[package]]
@ -967,7 +1162,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -1001,7 +1196,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -1093,7 +1288,7 @@ dependencies = [
"js-sys", "js-sys",
"khronos-egl", "khronos-egl",
"libc", "libc",
"libloading 0.7.4", "libloading 0.8.1",
"log", "log",
"metal", "metal",
"naga", "naga",
@ -1129,9 +1324,13 @@ name = "wgpu_compute_shader"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"chrono",
"env_logger", "env_logger",
"futures-intrusive", "futures-intrusive",
"getset",
"pollster", "pollster",
"thiserror",
"tokio",
"wgpu", "wgpu",
] ]
@ -1280,5 +1479,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn", "syn 2.0.42",
] ]

View File

@ -3,11 +3,16 @@ name = "wgpu_compute_shader"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib]
crate-type = ["lib"]
[dependencies] [dependencies]
wgpu = { version = "0.18", features = ["vulkan-portability"] } wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
env_logger = "0.9.1" env_logger = "0.9.1"
pollster = "0.2.5" pollster = "0.2.5"
futures-intrusive = "0.4" futures-intrusive = "0.4"
bytemuck = { version = "1.12.1", features = ["derive"] } bytemuck = { version = "1.12.1", features = ["derive"] }
thiserror = "1.0.51"
tokio = { version = "1.35.1", features = ["sync", "full"] }
getset = "0.1.2"
chrono = "0.4.31"

21
LICENSE Normal file
View File

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Julius Koskela
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

72
README.md Normal file
View File

@ -0,0 +1,72 @@
# Rust wGPU Matrix Multiplication Shader Test
## Overview
This project demonstrates a matrix multiplication operation using compute shaders in Rust with the wGPU library. The primary goal is to showcase how to perform a simple matrix multiplication on the GPU, leveraging the power of parallel computation provided by shaders.
## Features
- Matrix multiplication using a compute shader.
- Implementation using Rust and the wGPU library.
- Efficient GPU-based computation for matrix operations.
## Prerequisites
- Rust programming language.
- wGPU library for Rust.
- GPU that supports WebGPU.
## Getting Started
### Installation
1. Clone the repository:
```bash
git clone https://github.com/your-repository/rust-wgpu-matrix-multiplication.git
cd rust-wgpu-matrix-multiplication
```
2. Ensure you have the latest version of Rust installed:
```bash
rustup update
```
### Build and Run
1. Build the project using Cargo, Rust's package manager and build system:
```bash
cargo build --release
```
2. Run the compiled binary:
```bash
cargo run --release
```
## Project Structure
- `src/main.rs`: The main entry point for the application.
- `src/ppu.rs`: Contains the `PPU` struct for handling GPU tasks.
- `shaders/`: Folder containing WGSL shader files.
## Shader
The matrix multiplication logic is implemented in a WGSL shader. The shader takes two input matrices (`matrixA` and `matrixB`) and outputs the product (`matrixC`).
## How It Works
- The compute shader (`main`) is dispatched with a workgroup size of 4x4.
- Each work item calculates one element of the product matrix.
- The result is written back to a buffer that is read by the Rust application.
## Contributing
Contributions to this project are welcome. Please open an issue or pull request on the GitHub repository.
## License
This project is open-source and available under the [MIT License](LICENSE).

View File

@ -1,9 +0,0 @@
fn main() {
if cfg!(target_os = "linux") {
// println!("cargo:rustc-link-lib=X11");
// println!("cargo:rustc-link-lib=Xcursor");
// println!("cargo:rustc-link-lib=Xrandr");
// println!("cargo:rustc-link-lib=Xi");
println!("cargo:rustc-link-lib=vulkan");
}
}

View File

@ -1,40 +1,36 @@
{pkgs}: {pkgs}: let
with pkgs; manifest = (pkgs.lib.importTOML ./Cargo.toml).package;
rustPlatform.buildRustPackage { buildInputs = with pkgs; [
pname = "matmul-vshader"; udev
version = "0.1.0"; alsa-lib
src = ./.;
packages = [cmake shaderc];
buildInputs = [cmake shaderc];
nativeBuildInputs = [
vulkan-headers
vulkan-loader vulkan-loader
vulkan-validation-layers vulkan-headers
vulkan-tools vulkan-tools
vulkan-validation-layers
xorg.libX11
xorg.libXcursor
xorg.libXi
xorg.libXrandr # To use the x11 feature
libxkbcommon
wayland # To use the wayland feature
];
in
pkgs.rustPlatform.buildRustPackage {
inherit buildInputs;
pname = manifest.name;
version = manifest.version;
src = pkgs.lib.cleanSource ./.;
nativeBuildInputs = with pkgs; [
pkg-config pkg-config
git
gcc
cmake
glibc
python3
shaderc
]; ];
RUST_BACKTRACE = "1"; preConfigure = ''
# LD_LIBRARY_PATH = "${vulkan-loader}/lib:${vulkan-validation-layers}/lib:${vulkan-tools}/lib:${vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}"; LD_LIBRARY_PATH=${pkgs.lib.makeLibraryPath buildInputs}
# VK_ICD_FILENAMES = "${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json"; '';
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
# VK_INSTANCE_LAYERS = "VK_LAYER_KHRONOS_validation"; LD_LIBRARY_PATH = pkgs.lib.makeLibraryPath buildInputs;
# VK_DEVICE_LAYERS = "VK_LAYER_KHRONOS_validation";
# VK_LOADER_DEBUG = "all";
# VK_LOADER_DEBUG_FILE = "/tmp/vulkan.log";
# VK_INSTANCE_EXTENSIONS = "VK_EXT_debug_utils";
# VK_DEVICE_EXTENSIONS = "VK_EXT_debug_utils";
# VK_LAYER_ENABLES = "VK_LAYER_KHRONOS_validation";
# VK_LAYER_DISABLES = "VK_LAYER_LUNARG_api_dump";
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
cargoBuildFlags = ["--release"]; cargoBuildFlags = ["--release"];

168
examples/mat_mul_4x4.rs Normal file
View File

@ -0,0 +1,168 @@
use wgpu_compute_shader::*;
const MM4X4_SHADER: &str = "
struct Matrix {
data: array<array<f32, 4>, 4>,
};
@group(0) @binding(0)
var<storage, read> matrixA: Matrix;
@group(0) @binding(1)
var<storage, read> matrixB: Matrix;
@group(0) @binding(2)
var<storage, read_write> matrixC: Matrix;
// Consider setting workgroup size to power of 2 for better efficiency
@compute @workgroup_size(4, 4, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
// Ensure the computation is within the bounds of the 4x4 matrix
if (global_id.x < 4u && global_id.y < 4u) {
let row: u32 = global_id.y;
let col: u32 = global_id.x;
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
}
matrixC.data[row][col] = sum;
}
}
";
const MATRIX_A: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
const MATRIX_B: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
// const EXPECT: [[f32; 4]; 4] = [
// [90.0, 100.0, 110.0, 120.0],
// [202.0, 228.0, 254.0, 280.0],
// [314.0, 356.0, 398.0, 440.0],
// [426.0, 484.0, 542.0, 600.0],
// ];
#[tokio::main]
async fn main() -> Result<(), Error> {
// Create PPU
let mut ppu = PPU::new().await?;
ppu.load_shader("MM4X4_SHADER", MM4X4_SHADER)?;
let mut buffers = ComputeBuffers::new::<f32>(&ppu, 16);
buffers.add_buffer_init(
&ppu,
"MATRIX_A",
&MATRIX_A,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_B",
&MATRIX_B,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_C",
&[0.0; 16],
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
);
// Create Bind Group Layout
let bind_group_layout_entries = [
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
];
ppu.load_bind_group_layout("MM4X4_BIND_GROUP_LAYOUT", &bind_group_layout_entries);
// Finally, create the bind group
ppu.create_bind_group(
"MM4X4_BIND_GROUP",
"MM4X4_BIND_GROUP_LAYOUT",
&[
wgpu::BindGroupEntry {
binding: 0,
resource: buffers.get_buffer("MATRIX_A").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buffers.get_buffer("MATRIX_B").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: buffers.get_buffer("MATRIX_C").unwrap().as_entire_binding(),
},
],
)?;
// Load Pipeline
ppu.load_pipeline(
"MM4X4_PIPELINE",
"MM4X4_SHADER",
"main",
&["MM4X4_BIND_GROUP_LAYOUT"],
)?;
// Execute the compute task
let workgroup_count = (1, 1, 1);
let results: ComputeResult<f32> = ppu
.execute_compute_task(
"MM4X4_PIPELINE",
"MM4X4_BIND_GROUP",
&buffers,
"MATRIX_C",
workgroup_count,
)
.await?;
// Process results
println!("Matrix C:");
results.data().chunks(4).for_each(|row| {
println!("{:?}", row);
});
println!("Time elapsed: {} us", results.time_elapsed_us());
Ok(())
}

View File

@ -7,7 +7,6 @@
nixpkgs, nixpkgs,
}: { }: {
packages.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./default.nix {}; packages.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./default.nix {};
devShells.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./shell.nix {};
formatter.x86_64-linux = nixpkgs.legacyPackages.x86_64-linux.alejandra; formatter.x86_64-linux = nixpkgs.legacyPackages.x86_64-linux.alejandra;
}; };
} }

View File

@ -1,51 +0,0 @@
{pkgs}:
with pkgs; let
build = pkgs.callPackage ./default.nix {};
in
mkShell {
inherit build;
packages = [
libX11
libXcursor
libXrandr
libXi
vulkan-headers
vulkan-loader
vulkan-validation-layers
vulkan-tools
pkg-config
git
gcc
gnumake
cmake
glibc
python3
shaderc
];
inputsFrom = [
cmake
shaderc
];
RUST_BACKTRACE = "1";
LD_LIBRARY_PATH="${pkgs.libX11}/lib:${pkgs.libXcursor}/lib:${pkgs.libXrandr}/lib:${pkgs.libXi}/lib:${pkgs.vulkan-loader}/lib:${pkgs.vulkan-validation-layers}/lib:${pkgs.vulkan-tools}/lib:${pkgs.vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}:$LD_LIBRARY_PATH";
cargoBuildFlags = ["--release --features build-from-source"];
shellHook = ''
export VK_ICD_FILENAMES=${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
export VK_INSTANCE_LAYERS=VK_LAYER_KHRONOS_validation
export VK_DEVICE_LAYERS=VK_LAYER_KHRONOS_validation
export VK_LOADER_DEBUG=all
export VK_LOADER_DEBUG_FILE=/tmp/vulkan.log
export VK_INSTANCE_EXTENSIONS=VK_EXT_debug_utils
export VK_DEVICE_EXTENSIONS=VK_EXT_debug_utils
export VK_LAYER_ENABLES=VK_LAYER_KHRONOS_validation
export VK_LAYER_DISABLES=VK_LAYER_LUNARG_api_dump
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
pkg-config --cflags fontconfig fontconfig >= 2.11.1 --libs vulkan
'';
}

2
src/lib.rs Normal file
View File

@ -0,0 +1,2 @@
mod ppu;
pub use ppu::*;

View File

@ -1,196 +0,0 @@
// ... existing imports ...
use std::time::Instant;
use wgpu::util::DeviceExt;
use bytemuck;
async fn run() {
let instance = wgpu::Instance::default();
println!("instance {:?}", instance);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: false,
compatible_surface: None,
})
.await;
let adapter = match adapter {
Some(adapter) => adapter,
None => {
println!("No suitable GPU adapters found on the system!");
return;
}
};
let features = adapter.features();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: features & wgpu::Features::TIMESTAMP_QUERY,
limits: Default::default(),
},
None,
)
.await
.unwrap();
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2,
ty: wgpu::QueryType::Timestamp,
label: None,
}))
} else {
None
};
let start_instant = Instant::now();
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
//source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
});
println!("shader compilation {:?}", start_instant.elapsed());
let matrix_size = std::mem::size_of::<[[f32; 4]; 4]>();
let matrix_a_data = [[1.0; 4]; 4]; // Example data
let matrix_b_data = [[2.0; 4]; 4]; // Example data
let matrix_a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Matrix A Buffer"),
contents: bytemuck::bytes_of(&matrix_a_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let matrix_b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Matrix B Buffer"),
contents: bytemuck::bytes_of(&matrix_b_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let matrix_c_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Matrix C Buffer"),
size: matrix_size as u64,
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::MAP_READ
| wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let query_buf = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
Some(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Query Buffer"),
size: 16, // Enough for two 64-bit timestamps
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}))
} else {
None
};
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
}],
});
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&compute_pipeline_layout),
module: &cs_module,
entry_point: "main",
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Compute Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: matrix_a_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: matrix_b_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: matrix_c_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&Default::default());
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 0);
}
{
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups(4, 4, 1); // Dispatch for a 4x4 matrix
}
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 1);
}
encoder.copy_buffer_to_buffer(&matrix_c_buf, 0, &matrix_c_buf, 0, matrix_size as u64);
if let Some(query_set) = &query_set {
if let Some(query_buf) = &query_buf {
encoder.write_timestamp(query_set, 1);
encoder.resolve_query_set(query_set, 0..2, query_buf, 0);
}
}
queue.submit(Some(encoder.finish()));
// Assuming query_buf has been properly initialized earlier
let buf_slice = matrix_c_buf.slice(..);
let query_slice = query_buf.as_ref().map(|buf| buf.slice(..)); // Adjust this line
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
// Map the buffer for reading (matrix_c_buf)
buf_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
if let Some(q_slice) = query_slice {
// Map the query buffer if it exists
let _query_future = q_slice.map_async(wgpu::MapMode::Read, |_| ());
}
println!("pre-poll {:?}", std::time::Instant::now());
device.poll(wgpu::Maintain::Wait);
println!("post-poll {:?}", std::time::Instant::now());
if let Some(Ok(())) = receiver.receive().await {
let data_raw = &*buf_slice.get_mapped_range();
println!("compute shader result: {:?}", data_raw);
}
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
if let Some(q_slice) = query_slice {
let ts_period = queue.get_timestamp_period();
let ts_data_raw = &*q_slice.get_mapped_range();
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
println!(
"compute shader elapsed: {:?}ms",
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
);
}
}
}
fn main() {
pollster::block_on(run());
}

470
src/ppu.rs Normal file
View File

@ -0,0 +1,470 @@
// #![allow(unused)]
use bytemuck;
use chrono::Duration;
use getset::{Getters, MutGetters};
use std::collections::HashMap;
use thiserror::Error;
use wgpu::util::DeviceExt;
#[derive(Error, Debug)]
pub enum Error {
#[error("Failed to find an appropriate adapter.")]
AdapterNotFound,
#[error("Failed to create device.")]
DeviceCreationFailed(#[from] wgpu::RequestDeviceError),
#[error("Failed to create query set. {0}")]
QuerySetCreationFailed(String),
#[error("Failed to async map a buffer.")]
BufferAsyncError(#[from] wgpu::BufferAsyncError),
#[error("Failed to create compute pipeline, shader module {0} was not found.")]
ShaderModuleNotFound(String),
#[error("No buffer found with name {0}.")]
BufferNotFound(String),
// #[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")]
// BufferTooSmall(String, usize, usize),
#[error("No bind group layout found with name {0}.")]
BindGroupLayoutNotFound(String),
#[error("No pipeline found with name {0}.")]
PipelineNotFound(String),
#[error("No bind group found with name {0}.")]
BindGroupNotFound(String),
}
pub struct PPU {
instance: wgpu::Instance,
adapter: wgpu::Adapter,
device: wgpu::Device,
queue: wgpu::Queue,
query_set: wgpu::QuerySet,
// buffers: HashMap<String, wgpu::Buffer>,
shader_modules: HashMap<String, wgpu::ShaderModule>,
pipelines: HashMap<String, wgpu::ComputePipeline>,
bind_group_layouts: HashMap<String, wgpu::BindGroupLayout>,
bind_groups: HashMap<String, wgpu::BindGroup>,
}
#[derive(Getters, MutGetters)]
pub struct ComputeBuffers {
#[getset(get, get_mut)]
buffers: HashMap<String, wgpu::Buffer>,
#[getset(get)]
staging: wgpu::Buffer,
#[getset(get)]
readback: wgpu::Buffer,
#[getset(get)]
timestamp_query: wgpu::Buffer,
}
impl ComputeBuffers {
pub fn new<T: bytemuck::Pod>(ppu: &PPU, output_size: usize) -> Self {
let staging = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_STAGING_BUFFER"),
size: output_size as u64 * std::mem::size_of::<T>() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let timestamp_query = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_TIMESTAMP_QUERY_BUFFER"),
size: 16,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let readback = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_READBACK_BUFFER"),
size: 16,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Self {
buffers: HashMap::new(),
staging,
readback,
timestamp_query,
}
}
pub fn add_buffer<T>(
&mut self,
ppu: &PPU,
label: &str,
size: usize,
usage: wgpu::BufferUsages,
) {
self.buffers.entry(label.to_string()).or_insert_with(|| {
ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some(label),
size: size as u64 * std::mem::size_of::<T>() as u64,
usage,
mapped_at_creation: false,
})
});
}
pub fn add_buffer_init<T: bytemuck::Pod>(
&mut self,
ppu: &PPU,
label: &str,
data: &[T],
usage: wgpu::BufferUsages,
) {
self.buffers.entry(label.to_string()).or_insert_with(|| {
ppu.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytemuck::cast_slice(data),
usage,
})
});
}
pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
let buffer = match self.buffers.get(label) {
Some(buffer) => buffer,
None => return,
};
let data_bytes = bytemuck::cast_slice(data);
let data_len = data_bytes.len() as wgpu::BufferAddress;
if buffer.size() < data_len {
return;
}
ppu.queue.write_buffer(buffer, 0, data_bytes);
}
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
self.buffers.get(label)
}
}
pub struct ComputeResult<T> {
data: Vec<T>,
time_elapsed: Duration,
}
impl<T> ComputeResult<T> {
pub fn new(data: Vec<T>, time_elapsed: Duration) -> Self {
Self { data, time_elapsed }
}
pub fn data(&self) -> &Vec<T> {
&self.data
}
pub fn time_elapsed_sec(&self) -> f32 {
self.time_elapsed.num_nanoseconds().unwrap() as f32 * 1e-9
}
pub fn time_elapsed_ms(&self) -> f32 {
self.time_elapsed.num_milliseconds() as f32
}
pub fn time_elapsed_us(&self) -> f32 {
self.time_elapsed.num_microseconds().unwrap() as f32
}
pub fn time_elapsed_ns(&self) -> f32 {
self.time_elapsed.num_nanoseconds().unwrap() as f32
}
}
impl PPU {
/// Initialize a new PPU instance
///
/// # Examples
///
/// ```
/// use ppu::PPU;
///
/// #[tokio::main]
/// async fn main() {
/// let ppu = PPU::new().await.unwrap();
/// }
pub async fn new() -> Result<Self, Error> {
let instance = wgpu::Instance::default();
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: true,
compatible_surface: None,
})
.await
.ok_or(Error::AdapterNotFound)?;
let features = adapter.features();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: features & wgpu::Features::TIMESTAMP_QUERY,
limits: Default::default(),
},
None,
)
.await?;
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2,
ty: wgpu::QueryType::Timestamp,
label: None,
})
} else {
return Err(Error::QuerySetCreationFailed(
"Timestamp query is not supported".to_string(),
));
};
Ok(Self {
instance,
adapter,
device,
queue,
query_set,
shader_modules: HashMap::new(),
pipelines: HashMap::new(),
bind_group_layouts: HashMap::new(),
bind_groups: HashMap::new(),
})
}
pub fn load_bind_group_layout(
&mut self,
label: &str,
layout_entries: &[wgpu::BindGroupLayoutEntry],
) {
let bind_group_layout =
self.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: layout_entries,
});
self.bind_group_layouts
.insert(label.to_string(), bind_group_layout);
}
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
self.bind_group_layouts.get(label)
}
/// Create and store a bind group
pub fn create_bind_group(
&mut self,
name: &str,
layout_name: &str,
entries: &[wgpu::BindGroupEntry],
) -> Result<(), Error> {
let layout = self
.bind_group_layouts
.get(layout_name)
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))?;
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some(name),
layout,
entries,
});
self.bind_groups.insert(name.to_string(), bind_group);
Ok(())
}
/// Retrieve a bind group by name
pub fn get_bind_group(&self, name: &str) -> Option<&wgpu::BindGroup> {
self.bind_groups.get(name)
}
/// Load a shader and store it in the hash map
pub fn load_shader(&mut self, name: &str, source: &str) -> Result<(), Error> {
let shader_module = self
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some(name),
source: wgpu::ShaderSource::Wgsl(source.into()),
});
self.shader_modules.insert(name.to_string(), shader_module);
Ok(())
}
/// Get a shader from the hash map
pub fn get_shader(&self, name: &str) -> Option<&wgpu::ShaderModule> {
self.shader_modules.get(name)
}
pub fn load_pipeline(
&mut self,
name: &str,
shader_module_name: &str,
entry_point: &str,
bind_group_layout_names: &[&str], // Use names of bind group layouts
) -> Result<(), Error> {
// Retrieve the shader module
let shader_module = self
.get_shader(shader_module_name)
.ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?;
// Retrieve the bind group layouts
let bind_group_layouts = bind_group_layout_names
.iter()
.map(|layout_name| {
self.bind_group_layouts
.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
// Create the pipeline layout
let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(name),
bind_group_layouts: &bind_group_layouts
.iter()
.map(|layout| *layout)
.collect::<Vec<_>>(),
push_constant_ranges: &[],
});
// Create the compute pipeline
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(name),
layout: Some(&pipeline_layout),
module: shader_module,
entry_point,
});
// Store the pipeline
self.pipelines.insert(name.to_string(), pipeline);
Ok(())
}
pub fn get_pipeline(&self, name: &str) -> Option<&wgpu::ComputePipeline> {
self.pipelines.get(name)
}
/// Execute a compute task and retrieve the result
pub async fn execute_compute_task<T>(
&self,
pipeline_name: &str,
bind_group_name: &str,
buffers: &ComputeBuffers,
output_buffer_name: &str,
workgroup_count: (u32, u32, u32),
) -> Result<ComputeResult<T>, Error>
where
T: bytemuck::Pod + bytemuck::Zeroable + Send + Sync + std::fmt::Debug, // Added Debug trait
{
// Retrieve the pipeline and bind group
let pipeline = self
.get_pipeline(pipeline_name)
.ok_or_else(|| Error::PipelineNotFound(pipeline_name.to_string()))?;
let bind_group = self
.get_bind_group(bind_group_name)
.ok_or_else(|| Error::BindGroupNotFound(bind_group_name.to_string()))?;
// Create a command encoder and dispatch the compute shader
let mut encoder = self.device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(
workgroup_count.0,
workgroup_count.1,
workgroup_count.2,
);
}
// Create a new query set for this task
let query_set = self.device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2, // Two timestamps: start and end
ty: wgpu::QueryType::Timestamp,
label: Some("Timestamp Query Set"),
});
// Record the start timestamp
encoder.write_timestamp(&query_set, 0);
// Copy output to staging buffer
let output_buffer = buffers
.get_buffer(output_buffer_name)
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
encoder.copy_buffer_to_buffer(
output_buffer,
0,
buffers.staging(),
0,
buffers.staging().size(),
);
// Record the end timestamp
encoder.write_timestamp(&query_set, 1);
// Resolve timestamp query and write to timestamp query buffer
encoder.resolve_query_set(&query_set, 0..2, buffers.timestamp_query(), 0);
// Copy timestamp query buffer to readback buffer
encoder.copy_buffer_to_buffer(
buffers.timestamp_query(),
0,
buffers.readback(),
0,
buffers.readback().size(),
);
// Submit the command encoder
self.queue.submit(Some(encoder.finish()));
// Wait for the GPU to finish executing before mapping buffers
self.device.poll(wgpu::Maintain::Wait);
// Read the staging buffer data
let buffer_slice = buffers.staging().slice(..);
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
self.device.poll(wgpu::Maintain::Wait);
// buffer_slice.unmap();
let data = buffer_slice.get_mapped_range().to_vec();
let result = data
.chunks_exact(std::mem::size_of::<T>())
.map(|chunk| *bytemuck::from_bytes::<T>(chunk))
.collect::<Vec<_>>();
// Read the timestamp query results to the readback buffer
let timestamp_buffer_slice = buffers.readback().slice(..);
timestamp_buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
self.device.poll(wgpu::Maintain::Wait);
// timestamp_buffer_slice.unmap();
let ts_data_raw = timestamp_buffer_slice.get_mapped_range();
let ts_data: &[u64] = bytemuck::cast_slice(&ts_data_raw);
// Calculate the elapsed time in nanoseconds
let elapsed_ns = (ts_data[1] - ts_data[0]) as f64 * self.queue.get_timestamp_period() as f64;
let time_elapsed = Duration::nanoseconds(elapsed_ns as i64);
Ok(ComputeResult::new(result, time_elapsed))
}
}
impl std::fmt::Display for PPU {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(f, "PPU")?;
write!(f, "instance {:?}", self.instance)?;
write!(f, "adapter {:?}", self.adapter)?;
write!(f, "device {:?}", self.device)?;
write!(f, "queue {:?}", self.queue)?;
write!(f, "query_set {:?}", self.query_set)?;
Ok(())
}
}

View File

@ -1,5 +1,5 @@
struct Matrix { struct Matrix {
data: array<array<f32, 4>, 4>; // Assuming 4x4 matrices data: array<array<f32, 4>, 4>, // Corrected: Use a comma instead of a semicolon
}; };
@group(0) @binding(0) @group(0) @binding(0)
@ -17,7 +17,7 @@ fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let col: u32 = global_id.x; let col: u32 = global_id.x;
var sum: f32 = 0.0; var sum: f32 = 0.0;
for (var k: u32 = 0; k < 4u; k = k + 1u) { for (var k: u32 = 0u; k < 4u; k = k + 1u) {
sum = sum + matrixA.data[row][k] * matrixB.data[k][col]; sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
} }
matrixC.data[row][col] = sum; matrixC.data[row][col] = sum;