🔧 Implement as a library
- Make the crate a lib - Move main to examples as mat_mul_4x4.rs - Correctly track elapsed time of compute task Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
8bf134d3d2
commit
668293d956
44
Cargo.lock
generated
44
Cargo.lock
generated
@ -44,6 +44,12 @@ version = "0.2.16"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "android-tzdata"
|
||||||
|
version = "0.1.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "android_system_properties"
|
name = "android_system_properties"
|
||||||
version = "0.1.5"
|
version = "0.1.5"
|
||||||
@ -180,6 +186,20 @@ version = "1.0.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "chrono"
|
||||||
|
version = "0.4.31"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7f2c685bad3eb3d45a01354cedb7d5faa66194d1d58ba6e267a8de788f79db38"
|
||||||
|
dependencies = [
|
||||||
|
"android-tzdata",
|
||||||
|
"iana-time-zone",
|
||||||
|
"js-sys",
|
||||||
|
"num-traits",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"windows-targets",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "codespan-reporting"
|
name = "codespan-reporting"
|
||||||
version = "0.11.1"
|
version = "0.11.1"
|
||||||
@ -489,6 +509,29 @@ version = "2.1.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iana-time-zone"
|
||||||
|
version = "0.1.58"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8326b86b6cff230b97d0d312a6c40a60726df3332e721f72a1b035f451663b20"
|
||||||
|
dependencies = [
|
||||||
|
"android_system_properties",
|
||||||
|
"core-foundation-sys",
|
||||||
|
"iana-time-zone-haiku",
|
||||||
|
"js-sys",
|
||||||
|
"wasm-bindgen",
|
||||||
|
"windows-core",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "iana-time-zone-haiku"
|
||||||
|
version = "0.1.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f"
|
||||||
|
dependencies = [
|
||||||
|
"cc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "indexmap"
|
name = "indexmap"
|
||||||
version = "2.1.0"
|
version = "2.1.0"
|
||||||
@ -1281,6 +1324,7 @@ name = "wgpu_compute_shader"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bytemuck",
|
"bytemuck",
|
||||||
|
"chrono",
|
||||||
"env_logger",
|
"env_logger",
|
||||||
"futures-intrusive",
|
"futures-intrusive",
|
||||||
"getset",
|
"getset",
|
||||||
|
@ -3,7 +3,8 @@ name = "wgpu_compute_shader"
|
|||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
edition = "2021"
|
edition = "2021"
|
||||||
|
|
||||||
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
[lib]
|
||||||
|
crate-type = ["lib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
|
wgpu = { version = "0.18", features = ["vulkan-portability", "spirv" ] }
|
||||||
@ -14,3 +15,4 @@ bytemuck = { version = "1.12.1", features = ["derive"] }
|
|||||||
thiserror = "1.0.51"
|
thiserror = "1.0.51"
|
||||||
tokio = { version = "1.35.1", features = ["sync", "full"] }
|
tokio = { version = "1.35.1", features = ["sync", "full"] }
|
||||||
getset = "0.1.2"
|
getset = "0.1.2"
|
||||||
|
chrono = "0.4.31"
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
mod ppu;
|
use wgpu_compute_shader::*;
|
||||||
use ppu::*;
|
|
||||||
|
|
||||||
const MM4X4_SHADER: &str = "
|
const MM4X4_SHADER: &str = "
|
||||||
struct Matrix {
|
struct Matrix {
|
||||||
@ -163,5 +162,7 @@ async fn main() -> Result<(), Error> {
|
|||||||
println!("{:?}", row);
|
println!("{:?}", row);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
println!("Time elapsed: {} us", results.time_elapsed_us());
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
2
src/lib.rs
Normal file
2
src/lib.rs
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
mod ppu;
|
||||||
|
pub use ppu::*;
|
148
src/ppu.rs
148
src/ppu.rs
@ -1,5 +1,6 @@
|
|||||||
// #![allow(unused)]
|
// #![allow(unused)]
|
||||||
use bytemuck;
|
use bytemuck;
|
||||||
|
use chrono::Duration;
|
||||||
use getset::{Getters, MutGetters};
|
use getset::{Getters, MutGetters};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use thiserror::Error;
|
use thiserror::Error;
|
||||||
@ -85,16 +86,22 @@ impl ComputeBuffers {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn add_buffer<T>(&mut self, ppu: &PPU, label: &str, size: usize, usage: wgpu::BufferUsages) {
|
pub fn add_buffer<T>(
|
||||||
// self.buffers.entry(label.to_string()).or_insert_with(|| {
|
&mut self,
|
||||||
// ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
ppu: &PPU,
|
||||||
// label: Some(label),
|
label: &str,
|
||||||
// size: size as u64 * std::mem::size_of::<T>() as u64,
|
size: usize,
|
||||||
// usage,
|
usage: wgpu::BufferUsages,
|
||||||
// mapped_at_creation: false,
|
) {
|
||||||
// })
|
self.buffers.entry(label.to_string()).or_insert_with(|| {
|
||||||
// });
|
ppu.device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
// }
|
label: Some(label),
|
||||||
|
size: size as u64 * std::mem::size_of::<T>() as u64,
|
||||||
|
usage,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
})
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
pub fn add_buffer_init<T: bytemuck::Pod>(
|
pub fn add_buffer_init<T: bytemuck::Pod>(
|
||||||
&mut self,
|
&mut self,
|
||||||
@ -113,21 +120,21 @@ impl ComputeBuffers {
|
|||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
|
pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
|
||||||
// let buffer = match self.buffers.get(label) {
|
let buffer = match self.buffers.get(label) {
|
||||||
// Some(buffer) => buffer,
|
Some(buffer) => buffer,
|
||||||
// None => return,
|
None => return,
|
||||||
// };
|
};
|
||||||
|
|
||||||
// let data_bytes = bytemuck::cast_slice(data);
|
let data_bytes = bytemuck::cast_slice(data);
|
||||||
// let data_len = data_bytes.len() as wgpu::BufferAddress;
|
let data_len = data_bytes.len() as wgpu::BufferAddress;
|
||||||
|
|
||||||
// if buffer.size() < data_len {
|
if buffer.size() < data_len {
|
||||||
// return;
|
return;
|
||||||
// }
|
}
|
||||||
|
|
||||||
// ppu.queue.write_buffer(buffer, 0, data_bytes);
|
ppu.queue.write_buffer(buffer, 0, data_bytes);
|
||||||
// }
|
}
|
||||||
|
|
||||||
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
|
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
|
||||||
self.buffers.get(label)
|
self.buffers.get(label)
|
||||||
@ -136,15 +143,12 @@ impl ComputeBuffers {
|
|||||||
|
|
||||||
pub struct ComputeResult<T> {
|
pub struct ComputeResult<T> {
|
||||||
data: Vec<T>,
|
data: Vec<T>,
|
||||||
time_elapsed_sec: f32,
|
time_elapsed: Duration,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T> ComputeResult<T> {
|
impl<T> ComputeResult<T> {
|
||||||
pub fn new(data: Vec<T>, time_elapsed_sec: f32) -> Self {
|
pub fn new(data: Vec<T>, time_elapsed: Duration) -> Self {
|
||||||
Self {
|
Self { data, time_elapsed }
|
||||||
data,
|
|
||||||
time_elapsed_sec,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn data(&self) -> &Vec<T> {
|
pub fn data(&self) -> &Vec<T> {
|
||||||
@ -152,7 +156,19 @@ impl<T> ComputeResult<T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn time_elapsed_sec(&self) -> f32 {
|
pub fn time_elapsed_sec(&self) -> f32 {
|
||||||
self.time_elapsed_sec
|
self.time_elapsed.num_nanoseconds().unwrap() as f32 * 1e-9
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn time_elapsed_ms(&self) -> f32 {
|
||||||
|
self.time_elapsed.num_milliseconds() as f32
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn time_elapsed_us(&self) -> f32 {
|
||||||
|
self.time_elapsed.num_microseconds().unwrap() as f32
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn time_elapsed_ns(&self) -> f32 {
|
||||||
|
self.time_elapsed.num_nanoseconds().unwrap() as f32
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -234,9 +250,9 @@ impl PPU {
|
|||||||
.insert(label.to_string(), bind_group_layout);
|
.insert(label.to_string(), bind_group_layout);
|
||||||
}
|
}
|
||||||
|
|
||||||
// pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
||||||
// self.bind_group_layouts.get(label)
|
self.bind_group_layouts.get(label)
|
||||||
// }
|
}
|
||||||
|
|
||||||
/// Create and store a bind group
|
/// Create and store a bind group
|
||||||
pub fn create_bind_group(
|
pub fn create_bind_group(
|
||||||
@ -367,12 +383,21 @@ impl PPU {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Create a new query set for this task
|
||||||
|
let query_set = self.device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||||
|
count: 2, // Two timestamps: start and end
|
||||||
|
ty: wgpu::QueryType::Timestamp,
|
||||||
|
label: Some("Timestamp Query Set"),
|
||||||
|
});
|
||||||
|
|
||||||
|
// Record the start timestamp
|
||||||
|
encoder.write_timestamp(&query_set, 0);
|
||||||
|
|
||||||
// Copy output to staging buffer
|
// Copy output to staging buffer
|
||||||
let output_buffer = buffers
|
let output_buffer = buffers
|
||||||
.get_buffer(output_buffer_name)
|
.get_buffer(output_buffer_name)
|
||||||
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
|
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
|
||||||
|
|
||||||
// let staging_buffer = buffers.staging();
|
|
||||||
encoder.copy_buffer_to_buffer(
|
encoder.copy_buffer_to_buffer(
|
||||||
output_buffer,
|
output_buffer,
|
||||||
0,
|
0,
|
||||||
@ -381,35 +406,54 @@ impl PPU {
|
|||||||
buffers.staging().size(),
|
buffers.staging().size(),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Submit the command encoder and wait for it to complete
|
// Record the end timestamp
|
||||||
|
encoder.write_timestamp(&query_set, 1);
|
||||||
|
|
||||||
|
// Resolve timestamp query and write to timestamp query buffer
|
||||||
|
encoder.resolve_query_set(&query_set, 0..2, buffers.timestamp_query(), 0);
|
||||||
|
|
||||||
|
// Copy timestamp query buffer to readback buffer
|
||||||
|
encoder.copy_buffer_to_buffer(
|
||||||
|
buffers.timestamp_query(),
|
||||||
|
0,
|
||||||
|
buffers.readback(),
|
||||||
|
0,
|
||||||
|
buffers.readback().size(),
|
||||||
|
);
|
||||||
|
|
||||||
|
// Submit the command encoder
|
||||||
self.queue.submit(Some(encoder.finish()));
|
self.queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
// Map the staging buffer asynchronously
|
// Wait for the GPU to finish executing before mapping buffers
|
||||||
|
|
||||||
let buffer_slice = buffers.staging().slice(..);
|
|
||||||
|
|
||||||
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
|
||||||
|
|
||||||
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
|
|
||||||
sender.send(result).unwrap();
|
|
||||||
});
|
|
||||||
|
|
||||||
// Poll the device to ensure the mapping is processed
|
|
||||||
self.device.poll(wgpu::Maintain::Wait);
|
self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
|
||||||
// Wait for the mapping to complete and retrieve the result
|
// Read the staging buffer data
|
||||||
receiver.receive().await.unwrap().unwrap();
|
let buffer_slice = buffers.staging().slice(..);
|
||||||
|
buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||||
|
self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
// buffer_slice.unmap();
|
||||||
|
|
||||||
let data = buffer_slice.get_mapped_range().to_vec();
|
let data = buffer_slice.get_mapped_range().to_vec();
|
||||||
buffers.staging().unmap();
|
|
||||||
|
|
||||||
let result = data
|
let result = data
|
||||||
.chunks_exact(std::mem::size_of::<T>())
|
.chunks_exact(std::mem::size_of::<T>())
|
||||||
.map(|chunk| bytemuck::from_bytes::<T>(chunk))
|
.map(|chunk| *bytemuck::from_bytes::<T>(chunk))
|
||||||
.cloned()
|
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
Ok(ComputeResult::new(result, 0.0))
|
// Read the timestamp query results to the readback buffer
|
||||||
|
|
||||||
|
let timestamp_buffer_slice = buffers.readback().slice(..);
|
||||||
|
timestamp_buffer_slice.map_async(wgpu::MapMode::Read, |_| {});
|
||||||
|
self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
// timestamp_buffer_slice.unmap();
|
||||||
|
|
||||||
|
let ts_data_raw = timestamp_buffer_slice.get_mapped_range();
|
||||||
|
let ts_data: &[u64] = bytemuck::cast_slice(&ts_data_raw);
|
||||||
|
|
||||||
|
// Calculate the elapsed time in nanoseconds
|
||||||
|
let elapsed_ns = (ts_data[1] - ts_data[0]) as f64 * self.queue.get_timestamp_period() as f64;
|
||||||
|
let time_elapsed = Duration::nanoseconds(elapsed_ns as i64);
|
||||||
|
|
||||||
|
Ok(ComputeResult::new(result, time_elapsed))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user