🔧 Implement as a library

- Make the crate a lib
- Move main to examples as mat_mul_4x4.rs
- Correctly track elapsed time of compute task

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-23 04:00:45 +02:00
parent 8bf134d3d2
commit 668293d956
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
5 changed files with 148 additions and 55 deletions

44
Cargo.lock generated
View File

@ -44,6 +44,12 @@ version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "android-tzdata"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
[[package]] [[package]]
name = "android_system_properties" name = "android_system_properties"
version = "0.1.5" version = "0.1.5"
@ -180,6 +186,20 @@ version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "chrono"
version = "0.4.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
dependencies = [
"android-tzdata",
"iana-time-zone",
"js-sys",
"num-traits",
"wasm-bindgen",
"windows-targets",
]
[[package]] [[package]]
name = "codespan-reporting" name = "codespan-reporting"
version = "0.11.1" version = "0.11.1"
@ -489,6 +509,29 @@ version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "iana-time-zone"
version = "0.1.58"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
dependencies = [
"android_system_properties",
"core-foundation-sys",
"iana-time-zone-haiku",
"js-sys",
"wasm-bindgen",
"windows-core",
]
[[package]]
name = "iana-time-zone-haiku"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
dependencies = [
"cc",
]
[[package]] [[package]]
name = "indexmap" name = "indexmap"
version = "2.1.0" version = "2.1.0"
@ -1281,6 +1324,7 @@ name = "wgpu_compute_shader"
version = "0.1.0" version = "0.1.0"
dependencies = [ dependencies = [
"bytemuck", "bytemuck",
"chrono",
"env_logger", "env_logger",
"futures-intrusive", "futures-intrusive",
"getset", "getset",

View File

@ -3,7 +3,8 @@ name = "wgpu_compute_shader"
version = "0.1.0" version = "0.1.0"
edition = "2021" edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [lib]
crate-type = ["lib"]
[dependencies] [dependencies]
wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] } wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
@ -14,3 +15,4 @@ bytemuck = { version = "1.12.1", features = ["derive"] }
thiserror = "1.0.51" thiserror = "1.0.51"
tokio = { version = "1.35.1", features = ["sync", "full"] } tokio = { version = "1.35.1", features = ["sync", "full"] }
getset = "0.1.2" getset = "0.1.2"
chrono = "0.4.31"

View File

@ -1,5 +1,4 @@
mod ppu; use wgpu_compute_shader::*;
use ppu::*;
const MM4X4_SHADER: &str = " const MM4X4_SHADER: &str = "
struct Matrix { struct Matrix {
@ -163,5 +162,7 @@ async fn main() -> Result<(), Error> {
println!("{:?}", row); println!("{:?}", row);
}); });
println!("Time elapsed: {} us", results.time_elapsed_us());
Ok(()) Ok(())
} }

2
src/lib.rs Normal file
View File

@ -0,0 +1,2 @@
mod ppu;
pub use ppu::*;

View File

@ -1,5 +1,6 @@
// #![allow(unused)] // #![allow(unused)]
use bytemuck; use bytemuck;
use chrono::Duration;
use getset::{Getters, MutGetters}; use getset::{Getters, MutGetters};
use std::collections::HashMap; use std::collections::HashMap;
use thiserror::Error; use thiserror::Error;
@ -85,16 +86,22 @@ impl ComputeBuffers {
} }
} }
// pub fn add_buffer<T>(&mut self, ppu: &PPU, label: &str, size: usize, usage: wgpu::BufferUsages) { pub fn add_buffer<T>(
// self.buffers.entry(label.to_string()).or_insert_with(|| { &mut self,
// ppu.device.create_buffer(&wgpu::BufferDescriptor { ppu: &PPU,
// label: Some(label), label: &str,
// size: size as u64 * std::mem::size_of::<T>() as u64, size: usize,
// usage, usage: wgpu::BufferUsages,
// mapped_at_creation: false, ) {
// }) self.buffers.entry(label.to_string()).or_insert_with(|| {
// }); ppu.device.create_buffer(&wgpu::BufferDescriptor {
// } label: Some(label),
size: size as u64 * std::mem::size_of::<T>() as u64,
usage,
mapped_at_creation: false,
})
});
}
pub fn add_buffer_init<T: bytemuck::Pod>( pub fn add_buffer_init<T: bytemuck::Pod>(
&mut self, &mut self,
@ -113,21 +120,21 @@ impl ComputeBuffers {
}); });
} }
// pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) { pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
// let buffer = match self.buffers.get(label) { let buffer = match self.buffers.get(label) {
// Some(buffer) => buffer, Some(buffer) => buffer,
// None => return, None => return,
// }; };
// let data_bytes = bytemuck::cast_slice(data); let data_bytes = bytemuck::cast_slice(data);
// let data_len = data_bytes.len() as wgpu::BufferAddress; let data_len = data_bytes.len() as wgpu::BufferAddress;
// if buffer.size() < data_len { if buffer.size() < data_len {
// return; return;
// } }
// ppu.queue.write_buffer(buffer, 0, data_bytes); ppu.queue.write_buffer(buffer, 0, data_bytes);
// } }
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> { pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
self.buffers.get(label) self.buffers.get(label)
@ -136,15 +143,12 @@ impl ComputeBuffers {
pub struct ComputeResult<T> { pub struct ComputeResult<T> {
data: Vec<T>, data: Vec<T>,
time_elapsed_sec: f32, time_elapsed: Duration,
} }
impl<T> ComputeResult<T> { impl<T> ComputeResult<T> {
pub fn new(data: Vec<T>, time_elapsed_sec: f32) -> Self { pub fn new(data: Vec<T>, time_elapsed: Duration) -> Self {
Self { Self { data, time_elapsed }
data,
time_elapsed_sec,
}
} }
pub fn data(&self) -> &Vec<T> { pub fn data(&self) -> &Vec<T> {
@ -152,7 +156,19 @@ impl<T> ComputeResult<T> {
} }
pub fn time_elapsed_sec(&self) -> f32 { pub fn time_elapsed_sec(&self) -> f32 {
self.time_elapsed_sec self.time_elapsed.num_nanoseconds().unwrap() as f32 * 1e-9
}
pub fn time_elapsed_ms(&self) -> f32 {
self.time_elapsed.num_milliseconds() as f32
}
pub fn time_elapsed_us(&self) -> f32 {
self.time_elapsed.num_microseconds().unwrap() as f32
}
pub fn time_elapsed_ns(&self) -> f32 {
self.time_elapsed.num_nanoseconds().unwrap() as f32
} }
} }
@ -234,9 +250,9 @@ impl PPU {
.insert(label.to_string(), bind_group_layout); .insert(label.to_string(), bind_group_layout);
} }
// pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> { pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
// self.bind_group_layouts.get(label) self.bind_group_layouts.get(label)
// } }
/// Create and store a bind group /// Create and store a bind group
pub fn create_bind_group( pub fn create_bind_group(
@ -367,12 +383,21 @@ impl PPU {
); );
} }
// Create a new query set for this task
let query_set = self.device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2, // Two timestamps: start and end
ty: wgpu::QueryType::Timestamp,
label: Some("Timestamp Query Set"),
});
// Record the start timestamp
encoder.write_timestamp(&query_set, 0);
// Copy output to staging buffer // Copy output to staging buffer
let output_buffer = buffers let output_buffer = buffers
.get_buffer(output_buffer_name) .get_buffer(output_buffer_name)
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?; .ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
// let staging_buffer = buffers.staging();
encoder.copy_buffer_to_buffer( encoder.copy_buffer_to_buffer(
output_buffer, output_buffer,
0, 0,
@ -381,35 +406,54 @@ impl PPU {
buffers.staging().size(), buffers.staging().size(),
); );
// Submit the command encoder and wait for it to complete // Record the end timestamp
encoder.write_timestamp(&query_set, 1);
// Resolve timestamp query and write to timestamp query buffer
encoder.resolve_query_set(&query_set, 0..2, buffers.timestamp_query(), 0);
// Copy timestamp query buffer to readback buffer
encoder.copy_buffer_to_buffer(
buffers.timestamp_query(),
0,
buffers.readback(),
0,
buffers.readback().size(),
);
// Submit the command encoder
self.queue.submit(Some(encoder.finish())); self.queue.submit(Some(encoder.finish()));
// Map the staging buffer asynchronously // Wait for the GPU to finish executing before mapping buffers
let buffer_slice = buffers.staging().slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).unwrap();
});
// Poll the device to ensure the mapping is processed
self.device.poll(wgpu::Maintain::Wait); self.device.poll(wgpu::Maintain::Wait);
// Wait for the mapping to complete and retrieve the result // Read the staging buffer data
receiver.receive().await.unwrap().unwrap(); let buffer_slice = buffers.staging().slice(..);
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
self.device.poll(wgpu::Maintain::Wait);
// buffer_slice.unmap();
let data = buffer_slice.get_mapped_range().to_vec(); let data = buffer_slice.get_mapped_range().to_vec();
buffers.staging().unmap();
let result = data let result = data
.chunks_exact(std::mem::size_of::<T>()) .chunks_exact(std::mem::size_of::<T>())
.map(|chunk| bytemuck::from_bytes::<T>(chunk)) .map(|chunk| *bytemuck::from_bytes::<T>(chunk))
.cloned()
.collect::<Vec<_>>(); .collect::<Vec<_>>();
Ok(ComputeResult::new(result, 0.0)) // Read the timestamp query results to the readback buffer
let timestamp_buffer_slice = buffers.readback().slice(..);
timestamp_buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
self.device.poll(wgpu::Maintain::Wait);
// timestamp_buffer_slice.unmap();
let ts_data_raw = timestamp_buffer_slice.get_mapped_range();
let ts_data: &[u64] = bytemuck::cast_slice(&ts_data_raw);
// Calculate the elapsed time in nanoseconds
let elapsed_ns = (ts_data[1] - ts_data[0]) as f64 * self.queue.get_timestamp_period() as f64;
let time_elapsed = Duration::nanoseconds(elapsed_ns as i64);
Ok(ComputeResult::new(result, time_elapsed))
} }
} }