🔧 Implement as a library
- Make the crate a lib - Move main to examples as mat_mul_4x4.rs - Correctly track elapsed time of compute task Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
8bf134d3d2
commit
668293d956
44
Cargo.lock
generated
44
Cargo.lock
generated
@ -44,6 +44,12 @@ version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||
|
||||
[[package]]
|
||||
name = "android-tzdata"
|
||||
version = "0.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||
|
||||
[[package]]
|
||||
name = "android_system_properties"
|
||||
version = "0.1.5"
|
||||
@ -180,6 +186,20 @@ version = "1.0.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||
|
||||
[[package]]
|
||||
name = "chrono"
|
||||
version = "0.4.31"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
|
||||
dependencies = [
|
||||
"android-tzdata",
|
||||
"iana-time-zone",
|
||||
"js-sys",
|
||||
"num-traits",
|
||||
"wasm-bindgen",
|
||||
"windows-targets",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codespan-reporting"
|
||||
version = "0.11.1"
|
||||
@ -489,6 +509,29 @@ version = "2.1.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone"
|
||||
version = "0.1.58"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
|
||||
dependencies = [
|
||||
"android_system_properties",
|
||||
"core-foundation-sys",
|
||||
"iana-time-zone-haiku",
|
||||
"js-sys",
|
||||
"wasm-bindgen",
|
||||
"windows-core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "iana-time-zone-haiku"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||
dependencies = [
|
||||
"cc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "indexmap"
|
||||
version = "2.1.0"
|
||||
@ -1281,6 +1324,7 @@ name = "wgpu_compute_shader"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
"bytemuck",
|
||||
"chrono",
|
||||
"env_logger",
|
||||
"futures-intrusive",
|
||||
"getset",
|
||||
|
@ -3,7 +3,8 @@ name = "wgpu_compute_shader"
|
||||
version = "0.1.0"
|
||||
edition = "2021"
|
||||
|
||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||
[lib]
|
||||
crate-type = ["lib"]
|
||||
|
||||
[dependencies]
|
||||
wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
|
||||
@ -14,3 +15,4 @@ bytemuck = { version = "1.12.1", features = ["derive"] }
|
||||
thiserror = "1.0.51"
|
||||
tokio = { version = "1.35.1", features = ["sync", "full"] }
|
||||
getset = "0.1.2"
|
||||
chrono = "0.4.31"
|
||||
|
@ -1,5 +1,4 @@
|
||||
mod ppu;
|
||||
use ppu::*;
|
||||
use wgpu_compute_shader::*;
|
||||
|
||||
const MM4X4_SHADER: &str = "
|
||||
struct Matrix {
|
||||
@ -163,5 +162,7 @@ async fn main() -> Result<(), Error> {
|
||||
println!("{:?}", row);
|
||||
});
|
||||
|
||||
println!("Time elapsed: {} us", results.time_elapsed_us());
|
||||
|
||||
Ok(())
|
||||
}
|
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
||||
mod ppu;
|
||||
pub use ppu::*;
|
148
src/ppu.rs
148
src/ppu.rs
@ -1,5 +1,6 @@
|
||||
// #![allow(unused)]
|
||||
use bytemuck;
|
||||
use chrono::Duration;
|
||||
use getset::{Getters, MutGetters};
|
||||
use std::collections::HashMap;
|
||||
use thiserror::Error;
|
||||
@ -85,16 +86,22 @@ impl ComputeBuffers {
|
||||
}
|
||||
}
|
||||
|
||||
// pub fn add_buffer<T>(&mut self, ppu: &PPU, label: &str, size: usize, usage: wgpu::BufferUsages) {
|
||||
// self.buffers.entry(label.to_string()).or_insert_with(|| {
|
||||
// ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
// label: Some(label),
|
||||
// size: size as u64 * std::mem::size_of::<T>() as u64,
|
||||
// usage,
|
||||
// mapped_at_creation: false,
|
||||
// })
|
||||
// });
|
||||
// }
|
||||
pub fn add_buffer<T>(
|
||||
&mut self,
|
||||
ppu: &PPU,
|
||||
label: &str,
|
||||
size: usize,
|
||||
usage: wgpu::BufferUsages,
|
||||
) {
|
||||
self.buffers.entry(label.to_string()).or_insert_with(|| {
|
||||
ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some(label),
|
||||
size: size as u64 * std::mem::size_of::<T>() as u64,
|
||||
usage,
|
||||
mapped_at_creation: false,
|
||||
})
|
||||
});
|
||||
}
|
||||
|
||||
pub fn add_buffer_init<T: bytemuck::Pod>(
|
||||
&mut self,
|
||||
@ -113,21 +120,21 @@ impl ComputeBuffers {
|
||||
});
|
||||
}
|
||||
|
||||
// pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
|
||||
// let buffer = match self.buffers.get(label) {
|
||||
// Some(buffer) => buffer,
|
||||
// None => return,
|
||||
// };
|
||||
pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
|
||||
let buffer = match self.buffers.get(label) {
|
||||
Some(buffer) => buffer,
|
||||
None => return,
|
||||
};
|
||||
|
||||
// let data_bytes = bytemuck::cast_slice(data);
|
||||
// let data_len = data_bytes.len() as wgpu::BufferAddress;
|
||||
let data_bytes = bytemuck::cast_slice(data);
|
||||
let data_len = data_bytes.len() as wgpu::BufferAddress;
|
||||
|
||||
// if buffer.size() < data_len {
|
||||
// return;
|
||||
// }
|
||||
if buffer.size() < data_len {
|
||||
return;
|
||||
}
|
||||
|
||||
// ppu.queue.write_buffer(buffer, 0, data_bytes);
|
||||
// }
|
||||
ppu.queue.write_buffer(buffer, 0, data_bytes);
|
||||
}
|
||||
|
||||
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
|
||||
self.buffers.get(label)
|
||||
@ -136,15 +143,12 @@ impl ComputeBuffers {
|
||||
|
||||
pub struct ComputeResult<T> {
|
||||
data: Vec<T>,
|
||||
time_elapsed_sec: f32,
|
||||
time_elapsed: Duration,
|
||||
}
|
||||
|
||||
impl<T> ComputeResult<T> {
|
||||
pub fn new(data: Vec<T>, time_elapsed_sec: f32) -> Self {
|
||||
Self {
|
||||
data,
|
||||
time_elapsed_sec,
|
||||
}
|
||||
pub fn new(data: Vec<T>, time_elapsed: Duration) -> Self {
|
||||
Self { data, time_elapsed }
|
||||
}
|
||||
|
||||
pub fn data(&self) -> &Vec<T> {
|
||||
@ -152,7 +156,19 @@ impl<T> ComputeResult<T> {
|
||||
}
|
||||
|
||||
pub fn time_elapsed_sec(&self) -> f32 {
|
||||
self.time_elapsed_sec
|
||||
self.time_elapsed.num_nanoseconds().unwrap() as f32 * 1e-9
|
||||
}
|
||||
|
||||
pub fn time_elapsed_ms(&self) -> f32 {
|
||||
self.time_elapsed.num_milliseconds() as f32
|
||||
}
|
||||
|
||||
pub fn time_elapsed_us(&self) -> f32 {
|
||||
self.time_elapsed.num_microseconds().unwrap() as f32
|
||||
}
|
||||
|
||||
pub fn time_elapsed_ns(&self) -> f32 {
|
||||
self.time_elapsed.num_nanoseconds().unwrap() as f32
|
||||
}
|
||||
}
|
||||
|
||||
@ -234,9 +250,9 @@ impl PPU {
|
||||
.insert(label.to_string(), bind_group_layout);
|
||||
}
|
||||
|
||||
// pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
||||
// self.bind_group_layouts.get(label)
|
||||
// }
|
||||
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
||||
self.bind_group_layouts.get(label)
|
||||
}
|
||||
|
||||
/// Create and store a bind group
|
||||
pub fn create_bind_group(
|
||||
@ -367,12 +383,21 @@ impl PPU {
|
||||
);
|
||||
}
|
||||
|
||||
// Create a new query set for this task
|
||||
let query_set = self.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
count: 2, // Two timestamps: start and end
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
label: Some("Timestamp Query Set"),
|
||||
});
|
||||
|
||||
// Record the start timestamp
|
||||
encoder.write_timestamp(&query_set, 0);
|
||||
|
||||
// Copy output to staging buffer
|
||||
let output_buffer = buffers
|
||||
.get_buffer(output_buffer_name)
|
||||
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
|
||||
|
||||
// let staging_buffer = buffers.staging();
|
||||
encoder.copy_buffer_to_buffer(
|
||||
output_buffer,
|
||||
0,
|
||||
@ -381,35 +406,54 @@ impl PPU {
|
||||
buffers.staging().size(),
|
||||
);
|
||||
|
||||
// Submit the command encoder and wait for it to complete
|
||||
// Record the end timestamp
|
||||
encoder.write_timestamp(&query_set, 1);
|
||||
|
||||
// Resolve timestamp query and write to timestamp query buffer
|
||||
encoder.resolve_query_set(&query_set, 0..2, buffers.timestamp_query(), 0);
|
||||
|
||||
// Copy timestamp query buffer to readback buffer
|
||||
encoder.copy_buffer_to_buffer(
|
||||
buffers.timestamp_query(),
|
||||
0,
|
||||
buffers.readback(),
|
||||
0,
|
||||
buffers.readback().size(),
|
||||
);
|
||||
|
||||
// Submit the command encoder
|
||||
self.queue.submit(Some(encoder.finish()));
|
||||
|
||||
// Map the staging buffer asynchronously
|
||||
|
||||
let buffer_slice = buffers.staging().slice(..);
|
||||
|
||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
|
||||
sender.send(result).unwrap();
|
||||
});
|
||||
|
||||
// Poll the device to ensure the mapping is processed
|
||||
// Wait for the GPU to finish executing before mapping buffers
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
// Wait for the mapping to complete and retrieve the result
|
||||
receiver.receive().await.unwrap().unwrap();
|
||||
// Read the staging buffer data
|
||||
let buffer_slice = buffers.staging().slice(..);
|
||||
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
// buffer_slice.unmap();
|
||||
|
||||
let data = buffer_slice.get_mapped_range().to_vec();
|
||||
buffers.staging().unmap();
|
||||
|
||||
let result = data
|
||||
.chunks_exact(std::mem::size_of::<T>())
|
||||
.map(|chunk| bytemuck::from_bytes::<T>(chunk))
|
||||
.cloned()
|
||||
.map(|chunk| *bytemuck::from_bytes::<T>(chunk))
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok(ComputeResult::new(result, 0.0))
|
||||
// Read the timestamp query results to the readback buffer
|
||||
|
||||
let timestamp_buffer_slice = buffers.readback().slice(..);
|
||||
timestamp_buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||
self.device.poll(wgpu::Maintain::Wait);
|
||||
// timestamp_buffer_slice.unmap();
|
||||
|
||||
let ts_data_raw = timestamp_buffer_slice.get_mapped_range();
|
||||
let ts_data: &[u64] = bytemuck::cast_slice(&ts_data_raw);
|
||||
|
||||
// Calculate the elapsed time in nanoseconds
|
||||
let elapsed_ns = (ts_data[1] - ts_data[0]) as f64 * self.queue.get_timestamp_period() as f64;
|
||||
let time_elapsed = Duration::nanoseconds(elapsed_ns as i64);
|
||||
|
||||
Ok(ComputeResult::new(result, time_elapsed))
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user