🚀 4x4 matrix multiplication working

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
Julius Koskela 2023-12-23 02:10:51 +02:00
parent 5b3ef81c84
commit f7b29baf95
Signed by: julius
GPG Key ID: 5A7B7F4897C2914B
4 changed files with 411 additions and 595 deletions

62
Cargo.lock generated
View File

@ -156,7 +156,7 @@ checksum = "965ab7eb5f8f97d2a083c799f3a1b994fc397b2fe2da5d1da1626ce15a39f2b1"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
]
[[package]]
@ -289,7 +289,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
]
[[package]]
@ -334,6 +334,18 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "getset"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e45727250e75cc04ff2846a66397da8ef2b3db8e40e0cef4df67950a07621eb9"
dependencies = [
"proc-macro-error",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "gimli"
version = "0.28.1"
@ -785,6 +797,30 @@ version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8cf8e6a8aa66ce33f63993ffc4ea4271eb5b0530a9002db8455ea6050c77bfa"
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da25490ff9892aab3fcf7c36f08cfb902dd3e71ca0f9f9517bea02a73a5ce38c"
dependencies = [
"proc-macro-error-attr",
"proc-macro2",
"quote",
"syn 1.0.109",
"version_check",
]
[[package]]
name = "proc-macro-error-attr"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1be40180e52ecc98ad80b184934baf3d0d29f979574e439af5a55274b35f869"
dependencies = [
"proc-macro2",
"quote",
"version_check",
]
[[package]]
name = "proc-macro2"
version = "1.0.70"
@ -951,6 +987,17 @@ version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
[[package]]
name = "syn"
version = "1.0.109"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b64191b275b66ffe2469e8af2c1cfe3bafa67b529ead792a6d0160888b4237"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "syn"
version = "2.0.42"
@ -988,7 +1035,7 @@ checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
]
[[package]]
@ -1018,7 +1065,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
]
[[package]]
@ -1072,7 +1119,7 @@ dependencies = [
"once_cell",
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
"wasm-bindgen-shared",
]
@ -1106,7 +1153,7 @@ checksum = "f0eb82fcb7930ae6219a7ecfd55b217f5f0893484b7a13022ebb2b2bf20b5283"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
"wasm-bindgen-backend",
"wasm-bindgen-shared",
]
@ -1236,6 +1283,7 @@ dependencies = [
"bytemuck",
"env_logger",
"futures-intrusive",
"getset",
"pollster",
"thiserror",
"tokio",
@ -1387,5 +1435,5 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [
"proc-macro2",
"quote",
"syn",
"syn 2.0.42",
]

View File

@ -13,3 +13,4 @@ futures-intrusive = "0.4"
bytemuck = { version = "1.12.1", features = ["derive"] }
thiserror = "1.0.51"
tokio = { version = "1.35.1", features = ["sync", "full"] }
getset = "0.1.2"

View File

@ -1,311 +1,84 @@
mod ppu;
use ppu::*;
use bytemuck;
use std::time::Instant;
use wgpu::util::DeviceExt;
const MM4X4_SHADER: &str = "
struct Matrix {
data: array<array<f32, 4>, 4>,
};
async fn run() {
let instance_desc = wgpu::InstanceDescriptor {
backends: wgpu::Backends::VULKAN,
dx12_shader_compiler: wgpu::Dx12Compiler::Fxc,
gles_minor_version: wgpu::Gles3MinorVersion::default(),
flags: wgpu::InstanceFlags::empty(),
};
let instance = wgpu::Instance::new(instance_desc);
println!("instance {:?}", instance);
@group(0) @binding(0)
var<storage, read> matrixA: Matrix;
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::HighPerformance,
force_fallback_adapter: true,
compatible_surface: None,
})
.await;
let adapter = match adapter {
Some(adapter) => adapter,
None => {
println!("No suitable GPU adapters found on the system!");
return;
}
};
let features = adapter.features();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: features & wgpu::Features::TIMESTAMP_QUERY,
limits: Default::default(),
},
None,
)
.await
.unwrap();
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2,
ty: wgpu::QueryType::Timestamp,
label: None,
}))
} else {
None
};
let start_instant = Instant::now();
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: None,
// source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("shader.spv")).into()),
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
});
println!("shader compilation {:?}", start_instant.elapsed());
let matrix_size = std::mem::size_of::<[[f32; 4]; 4]>();
let matrix_a_data = [
[1.0, 2.0, 3.0, 4.0], // Example data
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
let matrix_b_data = [
[1.0, 2.0, 3.0, 4.0], // Example data
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
let matrix_a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Matrix A Buffer"),
contents: bytemuck::bytes_of(&matrix_a_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
let matrix_b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Matrix B Buffer"),
contents: bytemuck::bytes_of(&matrix_b_data),
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
});
// Create Matrix C as a storage buffer
let matrix_c_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Matrix C Storage Buffer"),
size: matrix_size as u64,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, // Changed to COPY_SRC
mapped_at_creation: false,
});
// Create a separate staging buffer for reading back data to CPU
let matrix_c_staging_buf = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Matrix C Staging Buffer"),
size: matrix_size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, // Staging buffer
mapped_at_creation: false,
});
// Query Result Buffer - For GPU to write query results
let query_result_buf = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
Some(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Query Result Buffer"),
size: 16, // Enough for two 64-bit timestamps
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
}))
} else {
None
};
// Readback Buffer - For CPU to read query results
let readback_buf = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
Some(device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Readback Buffer"),
size: 16, // Enough for two 64-bit timestamps
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
}))
} else {
None
};
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true }, // Adjust based on shader needs
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true }, // Adjust based on shader needs
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
// ... any other entries needed ...
],
});
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: None,
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: Some(&compute_pipeline_layout),
module: &cs_module,
entry_point: "main",
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("Compute Bind Group"),
layout: &bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: matrix_a_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: matrix_b_buf.as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: matrix_c_buf.as_entire_binding(),
},
],
});
let mut encoder = device.create_command_encoder(&Default::default());
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 0);
}
{
let mut cpass = encoder.begin_compute_pass(&Default::default());
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, &bind_group, &[]);
cpass.dispatch_workgroups(4, 4, 1); // Dispatch for a 4x4 matrix
}
if let Some(query_set) = &query_set {
encoder.write_timestamp(query_set, 1);
}
encoder.copy_buffer_to_buffer(
&matrix_c_buf,
0,
&matrix_c_staging_buf,
0,
matrix_size as u64,
);
if let Some(query_set) = &query_set {
if let Some(query_result_buf) = &query_result_buf {
encoder.write_timestamp(query_set, 1);
encoder.resolve_query_set(query_set, 0..2, query_result_buf, 0);
// Ensure to copy after the query set is resolved
if let Some(readback_buf) = &readback_buf {
encoder.copy_buffer_to_buffer(query_result_buf, 0, readback_buf, 0, 16);
}
}
}
queue.submit(Some(encoder.finish()));
// Assuming readback_buf has been properly initialized earlier
let buf_slice = matrix_c_staging_buf.slice(..);
let query_slice = readback_buf.as_ref().map(|buf| buf.slice(..)); // Use readback_buf instead of query_buf
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
// Map the buffer for reading (matrix_c_buf)
buf_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
if let Some(q_slice) = query_slice {
// Map the query buffer if it exists
let _query_future = q_slice.map_async(wgpu::MapMode::Read, |_| ());
}
println!("pre-poll {:?}", std::time::Instant::now());
device.poll(wgpu::Maintain::Wait);
println!("post-poll {:?}", std::time::Instant::now());
if let Some(Ok(())) = receiver.receive().await {
let data_raw = &*buf_slice.get_mapped_range();
// map the data to a slice of f64s
let parsed = data_raw
.chunks_exact(std::mem::size_of::<f32>())
.map(|b| {
let a = bytemuck::from_bytes::<f32>(b);
a.clone()
})
.collect::<Vec<f32>>();
println!("matrix_c_staging_buf {:?}", parsed);
}
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
if let Some(q_slice) = query_slice {
let ts_period = queue.get_timestamp_period();
let ts_data_raw = &*q_slice.get_mapped_range();
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
println!(
"compute shader elapsed: {:?}ms",
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
);
@group(0) @binding(1)
var<storage, read> matrixB: Matrix;
@group(0) @binding(2)
var<storage, read_write> matrixC: Matrix;
// Consider setting workgroup size to power of 2 for better efficiency
@compute @workgroup_size(4, 4, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
// Ensure the computation is within the bounds of the 4x4 matrix
if (global_id.x < 4u && global_id.y < 4u) {
let row: u32 = global_id.y;
let col: u32 = global_id.x;
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
}
matrixC.data[row][col] = sum;
}
}
";
use ppu::*;
use std::fs::File;
const MATRIX_A: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
const MATRIX_B: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
const EXPECT: [[f32; 4]; 4] = [
[90.0, 100.0, 110.0, 120.0],
[202.0, 228.0, 254.0, 280.0],
[314.0, 356.0, 398.0, 440.0],
[426.0, 484.0, 542.0, 600.0],
];
#[tokio::main]
async fn main() -> Result<(), Error> {
// Initialize PPU
// Create PPU
let mut ppu = PPU::new().await?;
ppu.load_shader("MM4X4_SHADER", MM4X4_SHADER)?;
// Load Shader from a file
let shader_source = "
struct Matrix {
data: array<array<f32, 4>, 4>, // Corrected: Use a comma instead of a semicolon
};
let mut buffers = ComputeBuffers::new::<f32>(&ppu, 16);
@group(0) @binding(0)
var<storage, read> matrixA: Matrix;
@group(0) @binding(1)
var<storage, read> matrixB: Matrix;
@group(0) @binding(2)
var<storage, read_write> matrixC: Matrix;
@compute @workgroup_size(4, 4)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let row: u32 = global_id.y;
let col: u32 = global_id.x;
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
}
matrixC.data[row][col] = sum;
}
"; // Replace with your shader source
ppu.load_shader("compute_shader", shader_source)?;
buffers.add_buffer_init(
&ppu,
"MATRIX_A",
&MATRIX_A,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_B",
&MATRIX_B,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_C",
&[0.0; 16],
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
);
// Create Bind Group Layout
let bind_group_layout_entries = [
@ -340,139 +113,55 @@ async fn main() -> Result<(), Error> {
count: None,
},
];
ppu.load_bind_group_layout("my_bind_group_layout", &bind_group_layout_entries);
// Create Bind Group
let bind_group_entries = [
wgpu::BindGroupEntry {
binding: 0,
resource: ppu.get_buffer("matrix_a").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: ppu.get_buffer("matrix_b").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: ppu.get_buffer("matrix_c").unwrap().as_entire_binding(),
},
];
ppu.create_bind_group("my_bind_group", "my_bind_group_layout", &bind_group_entries)?;
ppu.load_bind_group_layout("MM4X4_BIND_GROUP_LAYOUT", &bind_group_layout_entries);
// Finally, create the bind group
ppu.create_bind_group(
"MM4X4_BIND_GROUP",
"MM4X4_BIND_GROUP_LAYOUT",
&[
wgpu::BindGroupEntry {
binding: 0,
resource: buffers.get_buffer("MATRIX_A").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buffers.get_buffer("MATRIX_B").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: buffers.get_buffer("MATRIX_C").unwrap().as_entire_binding(),
},
],
)?;
// Load Pipeline
ppu.load_pipeline(
"my_pipeline",
"compute_shader",
"main", // Entry point in the shader
&["my_bind_group_layout"],
)?;
// Create a buffer to store the output data
let output_buffer_label = "output_buffer";
let output_data: Vec<u32> = vec![0; 10]; // Example: buffer to store 10 unsigned integers
ppu.load_buffer(
output_buffer_label,
&output_data,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ,
"MM4X4_PIPELINE",
"MM4X4_SHADER",
"main",
&["MM4X4_BIND_GROUP_LAYOUT"],
)?;
// Execute the compute task
let workgroup_count = (1, 1, 1); // Adjust according to your task's requirements
let results: Vec<u32> = ppu
let workgroup_count = (1, 1, 1);
let results: ComputeResult<f32> = ppu
.execute_compute_task(
"my_pipeline",
"my_bind_group",
"MM4X4_PIPELINE",
"MM4X4_BIND_GROUP",
&buffers,
"MATRIX_C",
workgroup_count,
output_buffer_label,
std::mem::size_of_val(&output_data[..]),
)
.await?;
// Process results
println!("Results: {:?}", results);
println!("Matrix C:");
results.data().chunks(4).for_each(|row| {
println!("{:?}", row);
});
Ok(())
}
// /// Submit a compute task
// pub fn submit_task(&self, pipeline: &str, bind_group: &str, workgroup_count: (u32, u32, u32)) {
// let pipeline = self.get_pipeline(pipeline).unwrap();
// let bind_group = self.get_bind_group(bind_group).unwrap();
// let mut encoder = self
// .device
// .create_command_encoder(&wgpu::CommandEncoderDescriptor {
// label: Some("Compute Encoder"),
// });
// {
// let mut compute_pass = encoder.begin_compute_pass(&Default::default());
// compute_pass.set_pipeline(pipeline);
// compute_pass.set_bind_group(0, bind_group, &[]);
// compute_pass.dispatch_workgroups(
// workgroup_count.0,
// workgroup_count.1,
// workgroup_count.2,
// );
// }
// self.queue.submit(Some(encoder.finish()));
// }
// /// Complete a compute task
// pub async fn complete_task<T>(
// &self,
// output_buffer: &str,
// buffer_size: usize,
// ) -> Result<Vec<T>, Error>
// where
// T: bytemuck::Pod + bytemuck::Zeroable,
// {
// // Create a staging buffer to read back data into
// let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
// label: Some("Staging Buffer"),
// size: buffer_size as u64,
// usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
// mapped_at_creation: false,
// });
// let output_buffer = match self.buffers.get(output_buffer) {
// Some(buffer) => buffer,
// None => return Err(Error::BufferNotFound(output_buffer.to_string())),
// };
// // Copy data from the compute output buffer to the staging buffer
// let mut encoder = self
// .device
// .create_command_encoder(&wgpu::CommandEncoderDescriptor {
// label: Some("Readback Encoder"),
// });
// encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64);
// self.queue.submit(Some(encoder.finish()));
// // Map the buffer asynchronously
// let buffer_slice = staging_buffer.slice(..);
// let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
// buffer_slice.map_async(
// wgpu::MapMode::Read,
// move |v: Result<(), wgpu::BufferAsyncError>| {
// sender.send(v).unwrap();
// },
// );
// // Wait for the buffer to be ready
// self.device.poll(wgpu::Maintain::Wait);
// receiver.receive().await.unwrap()?;
// // Read data from the buffer
// let data = buffer_slice.get_mapped_range();
// let result = bytemuck::cast_slice::<_, T>(&data).to_vec();
// // Unmap the buffer
// drop(data);
// staging_buffer.unmap();
// Ok(result)
// }

View File

@ -1,8 +1,7 @@
#![allow(unused)]
// #![allow(unused)]
use bytemuck;
// use core::slice::SlicePattern;
use std::collections::HashMap;
use std::time::Instant;
use getset::{Getters, MutGetters};
use std::{collections::HashMap, result};
use thiserror::Error;
use wgpu::util::DeviceExt;
@ -20,14 +19,14 @@ pub enum Error {
ShaderModuleNotFound(String),
#[error("No buffer found with name {0}.")]
BufferNotFound(String),
#[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")]
BufferTooSmall(String, usize, usize),
#[error("No bind group layout found with name {0}.")]
BindGroupLayoutNotFound(String),
#[error("No pipeline found with name {0}.")]
PipelineNotFound(String),
#[error("No bind group found with name {0}.")]
BindGroupNotFound(String),
// #[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")]
// BufferTooSmall(String, usize, usize),
#[error("No bind group layout found with name {0}.")]
BindGroupLayoutNotFound(String),
#[error("No pipeline found with name {0}.")]
PipelineNotFound(String),
#[error("No bind group found with name {0}.")]
BindGroupNotFound(String),
}
pub struct PPU {
@ -36,15 +35,139 @@ pub struct PPU {
device: wgpu::Device,
queue: wgpu::Queue,
query_set: wgpu::QuerySet,
buffers: HashMap<String, wgpu::Buffer>,
// buffers: HashMap<String, wgpu::Buffer>,
shader_modules: HashMap<String, wgpu::ShaderModule>,
pipelines: HashMap<String, wgpu::ComputePipeline>,
bind_group_layouts: HashMap<String, wgpu::BindGroupLayout>,
bind_groups: HashMap<String, wgpu::BindGroup>,
bind_groups: HashMap<String, wgpu::BindGroup>,
}
#[derive(Getters, MutGetters)]
pub struct ComputeBuffers {
#[getset(get, get_mut)]
buffers: HashMap<String, wgpu::Buffer>,
#[getset(get)]
staging: wgpu::Buffer,
#[getset(get)]
readback: wgpu::Buffer,
#[getset(get)]
timestamp_query: wgpu::Buffer,
}
impl ComputeBuffers {
pub fn new<T: bytemuck::Pod>(ppu: &PPU, output_size: usize) -> Self {
let staging = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_STAGING_BUFFER"),
size: output_size as u64 * std::mem::size_of::<T>() as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
let timestamp_query = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_TIMESTAMP_QUERY_BUFFER"),
size: 16,
usage: wgpu::BufferUsages::QUERY_RESOLVE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let readback = ppu.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("COMPUTE_READBACK_BUFFER"),
size: 16,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
Self {
buffers: HashMap::new(),
staging,
readback,
timestamp_query,
}
}
// pub fn add_buffer<T>(&mut self, ppu: &PPU, label: &str, size: usize, usage: wgpu::BufferUsages) {
// self.buffers.entry(label.to_string()).or_insert_with(|| {
// ppu.device.create_buffer(&wgpu::BufferDescriptor {
// label: Some(label),
// size: size as u64 * std::mem::size_of::<T>() as u64,
// usage,
// mapped_at_creation: false,
// })
// });
// }
pub fn add_buffer_init<T: bytemuck::Pod>(
&mut self,
ppu: &PPU,
label: &str,
data: &[T],
usage: wgpu::BufferUsages,
) {
self.buffers.entry(label.to_string()).or_insert_with(|| {
ppu.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytemuck::cast_slice(data),
usage,
})
});
}
// pub fn update_buffer<T: bytemuck::Pod>(&self, ppu: &PPU, label: &str, data: &[T]) {
// let buffer = match self.buffers.get(label) {
// Some(buffer) => buffer,
// None => return,
// };
// let data_bytes = bytemuck::cast_slice(data);
// let data_len = data_bytes.len() as wgpu::BufferAddress;
// if buffer.size() < data_len {
// return;
// }
// ppu.queue.write_buffer(buffer, 0, data_bytes);
// }
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
self.buffers.get(label)
}
}
pub struct ComputeResult<T> {
data: Vec<T>,
time_elapsed_sec: f32,
}
impl<T> ComputeResult<T> {
pub fn new(data: Vec<T>, time_elapsed_sec: f32) -> Self {
Self {
data,
time_elapsed_sec,
}
}
pub fn data(&self) -> &Vec<T> {
&self.data
}
pub fn time_elapsed_sec(&self) -> f32 {
self.time_elapsed_sec
}
}
impl PPU {
/// Initialize a new PPU instance
///
/// # Examples
///
/// ```
/// use ppu::PPU;
///
/// #[tokio::main]
/// async fn main() {
/// let ppu = PPU::new().await.unwrap();
/// }
pub async fn new() -> Result<Self, Error> {
let instance = wgpu::Instance::default();
@ -78,8 +201,8 @@ impl PPU {
})
} else {
return Err(Error::QuerySetCreationFailed(
"Timestamp query is not supported".to_string(),
));
"Timestamp query is not supported".to_string(),
));
};
Ok(Self {
@ -88,89 +211,43 @@ impl PPU {
device,
queue,
query_set,
buffers: HashMap::new(),
shader_modules: HashMap::new(),
pipelines: HashMap::new(),
bind_group_layouts: HashMap::new(),
bind_groups: HashMap::new(),
bind_groups: HashMap::new(),
})
}
/// Create a compute buffer
pub fn load_buffer<T: bytemuck::Pod>(
&mut self,
label: &str,
data: &[T],
usage: wgpu::BufferUsages,
) -> Result<(), Error> {
let buffer = self
.device
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some(label),
contents: bytemuck::cast_slice(data),
usage,
});
self.buffers.insert(label.to_string(), buffer);
Ok(())
}
/// Update a compute buffer
pub fn update_buffer<T>(&self, label: &str, data: &[T]) -> Result<(), Error>
where
T: bytemuck::Pod,
{
let buffer = match self.buffers.get(label) {
Some(buffer) => buffer,
None => return Err(Error::BufferNotFound(label.to_string())),
};
let data_bytes = bytemuck::cast_slice(data);
let data_len = data_bytes.len() as wgpu::BufferAddress;
if buffer.size() < data_len {
return Err(Error::BufferTooSmall(
label.to_string(),
data_len as usize,
buffer.size() as usize,
));
}
self.queue.write_buffer(buffer, 0, data_bytes);
Ok(())
}
/// Retrieve a buffer by name
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
self.buffers.get(label)
}
pub fn load_bind_group_layout(
&mut self,
label: &str,
layout_entries: &[wgpu::BindGroupLayoutEntry],
) {
let bind_group_layout = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: layout_entries,
});
) {
let bind_group_layout =
self.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some(label),
entries: layout_entries,
});
self.bind_group_layouts.insert(label.to_string(), bind_group_layout);
self.bind_group_layouts
.insert(label.to_string(), bind_group_layout);
}
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
self.bind_group_layouts.get(label)
}
// pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
// self.bind_group_layouts.get(label)
// }
/// Create and store a bind group
pub fn create_bind_group(
&self,
/// Create and store a bind group
pub fn create_bind_group(
&mut self,
name: &str,
layout_name: &str,
entries: &[wgpu::BindGroupEntry],
) -> Result<wgpu::BindGroup, Error> {
let layout = self.bind_group_layouts.get(layout_name)
) -> Result<(), Error> {
let layout = self
.bind_group_layouts
.get(layout_name)
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))?;
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
@ -178,7 +255,8 @@ impl PPU {
layout,
entries,
});
Ok(bind_group)
self.bind_groups.insert(name.to_string(), bind_group);
Ok(())
}
/// Retrieve a bind group by name
@ -203,135 +281,135 @@ impl PPU {
self.shader_modules.get(name)
}
pub fn load_pipeline(
&mut self,
name: &str,
shader_module_name: &str,
entry_point: &str,
bind_group_layout_names: &[&str], // Use names of bind group layouts
) -> Result<(), Error> {
// Retrieve the shader module
let shader_module = self.get_shader(shader_module_name)
.ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?;
pub fn load_pipeline(
&mut self,
name: &str,
shader_module_name: &str,
entry_point: &str,
bind_group_layout_names: &[&str], // Use names of bind group layouts
) -> Result<(), Error> {
// Retrieve the shader module
let shader_module = self
.get_shader(shader_module_name)
.ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?;
// Retrieve the bind group layouts
let bind_group_layouts = bind_group_layout_names
.iter()
.map(|layout_name| {
self.bind_group_layouts.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
// Retrieve the bind group layouts
let bind_group_layouts = bind_group_layout_names
.iter()
.map(|layout_name| {
self.bind_group_layouts
.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))
})
.collect::<Result<Vec<_>, _>>()?;
// Create the pipeline layout
let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(name),
bind_group_layouts: &bind_group_layouts.iter().map(|layout| *layout).collect::<Vec<_>>(),
push_constant_ranges: &[],
});
// Create the pipeline layout
let pipeline_layout = self
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some(name),
bind_group_layouts: &bind_group_layouts
.iter()
.map(|layout| *layout)
.collect::<Vec<_>>(),
push_constant_ranges: &[],
});
// Create the compute pipeline
let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(name),
layout: Some(&pipeline_layout),
module: shader_module,
entry_point,
});
// Create the compute pipeline
let pipeline = self
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some(name),
layout: Some(&pipeline_layout),
module: shader_module,
entry_point,
});
// Store the pipeline
self.pipelines.insert(name.to_string(), pipeline);
Ok(())
}
// Store the pipeline
self.pipelines.insert(name.to_string(), pipeline);
Ok(())
}
pub fn get_pipeline(&self, name: &str) -> Option<&wgpu::ComputePipeline> {
self.pipelines.get(name)
}
/// Execute a compute task and retrieve the result
/// Execute a compute task and retrieve the result
pub async fn execute_compute_task<T>(
&self,
pipeline_name: &str,
bind_group_name: &str,
workgroup_count: (u32, u32, u32),
buffers: &ComputeBuffers,
output_buffer_name: &str,
buffer_size: usize,
) -> Result<Vec<T>, Error>
workgroup_count: (u32, u32, u32),
) -> Result<ComputeResult<T>, Error>
where
T: bytemuck::Pod + bytemuck::Zeroable,
T: bytemuck::Pod + bytemuck::Zeroable + Send + Sync + std::fmt::Debug, // Added Debug trait
{
// Retrieve the pipeline and bind group
let pipeline = self.get_pipeline(pipeline_name)
let pipeline = self
.get_pipeline(pipeline_name)
.ok_or_else(|| Error::PipelineNotFound(pipeline_name.to_string()))?;
let bind_group = self.get_bind_group(bind_group_name)
let bind_group = self
.get_bind_group(bind_group_name)
.ok_or_else(|| Error::BindGroupNotFound(bind_group_name.to_string()))?;
// Dispatch the compute task
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Compute Task Encoder"),
});
// Create a command encoder and dispatch the compute shader
let mut encoder = self.device.create_command_encoder(&Default::default());
{
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(workgroup_count.0, workgroup_count.1, workgroup_count.2);
compute_pass.dispatch_workgroups(
workgroup_count.0,
workgroup_count.1,
workgroup_count.2,
);
}
// Copy output to staging buffer
let output_buffer = buffers
.get_buffer(output_buffer_name)
.ok_or(Error::BufferNotFound(output_buffer_name.to_string()))?;
// let staging_buffer = buffers.staging();
encoder.copy_buffer_to_buffer(
output_buffer,
0,
buffers.staging(),
0,
buffers.staging().size(),
);
// Submit the command encoder and wait for it to complete
self.queue.submit(Some(encoder.finish()));
// Wait for the task to complete
self.device.poll(wgpu::Maintain::Wait);
// Map the staging buffer asynchronously
// Retrieve and return the data from the output buffer
self.complete_task::<T>(output_buffer_name, buffer_size).await
}
let buffer_slice = buffers.staging().slice(..);
/// Complete a compute task and retrieve the data from an output buffer
pub async fn complete_task<T>(
&self,
output_buffer_name: &str,
buffer_size: usize,
) -> Result<Vec<T>, Error>
where
T: bytemuck::Pod + bytemuck::Zeroable,
{
let output_buffer = self.buffers.get(output_buffer_name)
.ok_or_else(|| Error::BufferNotFound(output_buffer_name.to_string()))?;
// Create a staging buffer to read back data into
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging Buffer"),
size: buffer_size as u64,
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
// Copy data from the compute output buffer to the staging buffer
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Readback Encoder"),
});
encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64);
self.queue.submit(Some(encoder.finish()));
// Map the buffer asynchronously
let buffer_slice = staging_buffer.slice(..);
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
buffer_slice.map_async(wgpu::MapMode::Read, move |v: Result<(), wgpu::BufferAsyncError>| {
sender.send(v).unwrap();
buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
sender.send(result).unwrap();
});
// Wait for the buffer to be ready
// Poll the device to ensure the mapping is processed
self.device.poll(wgpu::Maintain::Wait);
receiver.receive().await.unwrap()?;
let data = buffer_slice.get_mapped_range();
// Read data from the buffer
let result = bytemuck::cast_slice::<_, T>(&data).to_vec();
// Wait for the mapping to complete and retrieve the result
receiver.receive().await.unwrap().unwrap();
// Unmap the buffer
drop(data);
staging_buffer.unmap();
let data = buffer_slice.get_mapped_range().to_vec();
// buffers.staging().unmap();
Ok(result)
let result = data
.chunks_exact(std::mem::size_of::<T>())
.map(|chunk| bytemuck::from_bytes::<T>(chunk))
.cloned()
.collect::<Vec<_>>();
Ok(ComputeResult::new(result, 0.0))
}
}