🚀 Builds on Linux, work on abstraction
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
parent
bf16edd3aa
commit
5b3ef81c84
96
Cargo.lock
generated
96
Cargo.lock
generated
@ -74,7 +74,7 @@ version = "0.2.14"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"hermit-abi",
|
"hermit-abi 0.1.19",
|
||||||
"libc",
|
"libc",
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
@ -159,6 +159,12 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "bytes"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "cc"
|
name = "cc"
|
||||||
version = "1.0.83"
|
version = "1.0.83"
|
||||||
@ -224,7 +230,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||||||
checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20"
|
checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 2.4.1",
|
"bitflags 2.4.1",
|
||||||
"libloading 0.7.4",
|
"libloading 0.8.1",
|
||||||
"winapi",
|
"winapi",
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -453,6 +459,12 @@ dependencies = [
|
|||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hermit-abi"
|
||||||
|
version = "0.3.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "hexf-parse"
|
name = "hexf-parse"
|
||||||
version = "0.2.1"
|
version = "0.2.1"
|
||||||
@ -591,6 +603,17 @@ dependencies = [
|
|||||||
"adler",
|
"adler",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "mio"
|
||||||
|
version = "0.8.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"wasi",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "naga"
|
name = "naga"
|
||||||
version = "0.14.2"
|
version = "0.14.2"
|
||||||
@ -630,6 +653,16 @@ dependencies = [
|
|||||||
"autocfg",
|
"autocfg",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "num_cpus"
|
||||||
|
version = "1.16.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43"
|
||||||
|
dependencies = [
|
||||||
|
"hermit-abi 0.3.3",
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "objc"
|
name = "objc"
|
||||||
version = "0.2.7"
|
version = "0.2.7"
|
||||||
@ -728,6 +761,12 @@ dependencies = [
|
|||||||
"indexmap",
|
"indexmap",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "pin-project-lite"
|
||||||
|
version = "0.2.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pkg-config"
|
name = "pkg-config"
|
||||||
version = "0.3.28"
|
version = "0.3.28"
|
||||||
@ -853,6 +892,15 @@ version = "1.2.0"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "signal-hook-registry"
|
||||||
|
version = "1.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slotmap"
|
name = "slotmap"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
@ -868,6 +916,16 @@ version = "1.11.2"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
|
checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "socket2"
|
||||||
|
version = "0.5.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
|
||||||
|
dependencies = [
|
||||||
|
"libc",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "spin"
|
name = "spin"
|
||||||
version = "0.9.8"
|
version = "0.9.8"
|
||||||
@ -933,6 +991,36 @@ dependencies = [
|
|||||||
"syn",
|
"syn",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio"
|
||||||
|
version = "1.35.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104"
|
||||||
|
dependencies = [
|
||||||
|
"backtrace",
|
||||||
|
"bytes",
|
||||||
|
"libc",
|
||||||
|
"mio",
|
||||||
|
"num_cpus",
|
||||||
|
"parking_lot 0.12.1",
|
||||||
|
"pin-project-lite",
|
||||||
|
"signal-hook-registry",
|
||||||
|
"socket2",
|
||||||
|
"tokio-macros",
|
||||||
|
"windows-sys",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-macros"
|
||||||
|
version = "2.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.12"
|
version = "1.0.12"
|
||||||
@ -1110,7 +1198,7 @@ dependencies = [
|
|||||||
"js-sys",
|
"js-sys",
|
||||||
"khronos-egl",
|
"khronos-egl",
|
||||||
"libc",
|
"libc",
|
||||||
"libloading 0.7.4",
|
"libloading 0.8.1",
|
||||||
"log",
|
"log",
|
||||||
"metal",
|
"metal",
|
||||||
"naga",
|
"naga",
|
||||||
@ -1149,6 +1237,8 @@ dependencies = [
|
|||||||
"env_logger",
|
"env_logger",
|
||||||
"futures-intrusive",
|
"futures-intrusive",
|
||||||
"pollster",
|
"pollster",
|
||||||
|
"thiserror",
|
||||||
|
"tokio",
|
||||||
"wgpu",
|
"wgpu",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -11,3 +11,5 @@ env_logger = "0.9.1"
|
|||||||
pollster = "0.2.5"
|
pollster = "0.2.5"
|
||||||
futures-intrusive = "0.4"
|
futures-intrusive = "0.4"
|
||||||
bytemuck = { version = "1.12.1", features = ["derive"] }
|
bytemuck = { version = "1.12.1", features = ["derive"] }
|
||||||
|
thiserror = "1.0.51"
|
||||||
|
tokio = { version = "1.35.1", features = ["sync", "full"] }
|
||||||
|
212
src/main.rs
212
src/main.rs
@ -1,8 +1,8 @@
|
|||||||
mod ppu;
|
mod ppu;
|
||||||
|
|
||||||
|
use bytemuck;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use wgpu::util::DeviceExt;
|
use wgpu::util::DeviceExt;
|
||||||
use bytemuck;
|
|
||||||
|
|
||||||
async fn run() {
|
async fn run() {
|
||||||
let instance_desc = wgpu::InstanceDescriptor {
|
let instance_desc = wgpu::InstanceDescriptor {
|
||||||
@ -252,7 +252,8 @@ async fn run() {
|
|||||||
.map(|b| {
|
.map(|b| {
|
||||||
let a = bytemuck::from_bytes::<f32>(b);
|
let a = bytemuck::from_bytes::<f32>(b);
|
||||||
a.clone()
|
a.clone()
|
||||||
}).collect::<Vec<f32>>();
|
})
|
||||||
|
.collect::<Vec<f32>>();
|
||||||
println!("matrix_c_staging_buf {:?}", parsed);
|
println!("matrix_c_staging_buf {:?}", parsed);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -269,6 +270,209 @@ async fn run() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn main() {
|
use ppu::*;
|
||||||
pollster::block_on(run());
|
use std::fs::File;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> Result<(), Error> {
|
||||||
|
// Initialize PPU
|
||||||
|
let mut ppu = PPU::new().await?;
|
||||||
|
|
||||||
|
// Load Shader from a file
|
||||||
|
let shader_source = "
|
||||||
|
struct Matrix {
|
||||||
|
data: array<array<f32, 4>, 4>, // Corrected: Use a comma instead of a semicolon
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read> matrixA: Matrix;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read> matrixB: Matrix;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> matrixC: Matrix;
|
||||||
|
|
||||||
|
@compute @workgroup_size(4, 4)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
let row: u32 = global_id.y;
|
||||||
|
let col: u32 = global_id.x;
|
||||||
|
|
||||||
|
var sum: f32 = 0.0;
|
||||||
|
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
||||||
|
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
|
||||||
|
}
|
||||||
|
matrixC.data[row][col] = sum;
|
||||||
|
}
|
||||||
|
"; // Replace with your shader source
|
||||||
|
ppu.load_shader("compute_shader", shader_source)?;
|
||||||
|
|
||||||
|
// Create Bind Group Layout
|
||||||
|
let bind_group_layout_entries = [
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 0,
|
||||||
|
visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer {
|
||||||
|
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||||
|
has_dynamic_offset: false,
|
||||||
|
min_binding_size: None,
|
||||||
|
},
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 1,
|
||||||
|
visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer {
|
||||||
|
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
||||||
|
has_dynamic_offset: false,
|
||||||
|
min_binding_size: None,
|
||||||
|
},
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 2,
|
||||||
|
visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer {
|
||||||
|
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||||
|
has_dynamic_offset: false,
|
||||||
|
min_binding_size: None,
|
||||||
|
},
|
||||||
|
count: None,
|
||||||
|
},
|
||||||
|
];
|
||||||
|
ppu.load_bind_group_layout("my_bind_group_layout", &bind_group_layout_entries);
|
||||||
|
|
||||||
|
// Create Bind Group
|
||||||
|
let bind_group_entries = [
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 0,
|
||||||
|
resource: ppu.get_buffer("matrix_a").unwrap().as_entire_binding(),
|
||||||
|
},
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 1,
|
||||||
|
resource: ppu.get_buffer("matrix_b").unwrap().as_entire_binding(),
|
||||||
|
},
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 2,
|
||||||
|
resource: ppu.get_buffer("matrix_c").unwrap().as_entire_binding(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
ppu.create_bind_group("my_bind_group", "my_bind_group_layout", &bind_group_entries)?;
|
||||||
|
|
||||||
|
|
||||||
|
// Load Pipeline
|
||||||
|
ppu.load_pipeline(
|
||||||
|
"my_pipeline",
|
||||||
|
"compute_shader",
|
||||||
|
"main", // Entry point in the shader
|
||||||
|
&["my_bind_group_layout"],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Create a buffer to store the output data
|
||||||
|
let output_buffer_label = "output_buffer";
|
||||||
|
let output_data: Vec<u32> = vec![0; 10]; // Example: buffer to store 10 unsigned integers
|
||||||
|
ppu.load_buffer(
|
||||||
|
output_buffer_label,
|
||||||
|
&output_data,
|
||||||
|
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ,
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// Execute the compute task
|
||||||
|
let workgroup_count = (1, 1, 1); // Adjust according to your task's requirements
|
||||||
|
let results: Vec<u32> = ppu
|
||||||
|
.execute_compute_task(
|
||||||
|
"my_pipeline",
|
||||||
|
"my_bind_group",
|
||||||
|
workgroup_count,
|
||||||
|
output_buffer_label,
|
||||||
|
std::mem::size_of_val(&output_data[..]),
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Process results
|
||||||
|
println!("Results: {:?}", results);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// /// Submit a compute task
|
||||||
|
// pub fn submit_task(&self, pipeline: &str, bind_group: &str, workgroup_count: (u32, u32, u32)) {
|
||||||
|
// let pipeline = self.get_pipeline(pipeline).unwrap();
|
||||||
|
// let bind_group = self.get_bind_group(bind_group).unwrap();
|
||||||
|
// let mut encoder = self
|
||||||
|
// .device
|
||||||
|
// .create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
// label: Some("Compute Encoder"),
|
||||||
|
// });
|
||||||
|
|
||||||
|
// {
|
||||||
|
// let mut compute_pass = encoder.begin_compute_pass(&Default::default());
|
||||||
|
// compute_pass.set_pipeline(pipeline);
|
||||||
|
// compute_pass.set_bind_group(0, bind_group, &[]);
|
||||||
|
// compute_pass.dispatch_workgroups(
|
||||||
|
// workgroup_count.0,
|
||||||
|
// workgroup_count.1,
|
||||||
|
// workgroup_count.2,
|
||||||
|
// );
|
||||||
|
// }
|
||||||
|
|
||||||
|
// self.queue.submit(Some(encoder.finish()));
|
||||||
|
// }
|
||||||
|
|
||||||
|
// /// Complete a compute task
|
||||||
|
// pub async fn complete_task<T>(
|
||||||
|
// &self,
|
||||||
|
// output_buffer: &str,
|
||||||
|
// buffer_size: usize,
|
||||||
|
// ) -> Result<Vec<T>, Error>
|
||||||
|
// where
|
||||||
|
// T: bytemuck::Pod + bytemuck::Zeroable,
|
||||||
|
// {
|
||||||
|
// // Create a staging buffer to read back data into
|
||||||
|
// let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
// label: Some("Staging Buffer"),
|
||||||
|
// size: buffer_size as u64,
|
||||||
|
// usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||||
|
// mapped_at_creation: false,
|
||||||
|
// });
|
||||||
|
|
||||||
|
// let output_buffer = match self.buffers.get(output_buffer) {
|
||||||
|
// Some(buffer) => buffer,
|
||||||
|
// None => return Err(Error::BufferNotFound(output_buffer.to_string())),
|
||||||
|
// };
|
||||||
|
|
||||||
|
// // Copy data from the compute output buffer to the staging buffer
|
||||||
|
// let mut encoder = self
|
||||||
|
// .device
|
||||||
|
// .create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
// label: Some("Readback Encoder"),
|
||||||
|
// });
|
||||||
|
|
||||||
|
// encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64);
|
||||||
|
// self.queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
|
// // Map the buffer asynchronously
|
||||||
|
// let buffer_slice = staging_buffer.slice(..);
|
||||||
|
// let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||||
|
|
||||||
|
// buffer_slice.map_async(
|
||||||
|
// wgpu::MapMode::Read,
|
||||||
|
// move |v: Result<(), wgpu::BufferAsyncError>| {
|
||||||
|
// sender.send(v).unwrap();
|
||||||
|
// },
|
||||||
|
// );
|
||||||
|
|
||||||
|
// // Wait for the buffer to be ready
|
||||||
|
// self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
// receiver.receive().await.unwrap()?;
|
||||||
|
|
||||||
|
// // Read data from the buffer
|
||||||
|
// let data = buffer_slice.get_mapped_range();
|
||||||
|
// let result = bytemuck::cast_slice::<_, T>(&data).to_vec();
|
||||||
|
|
||||||
|
// // Unmap the buffer
|
||||||
|
// drop(data);
|
||||||
|
// staging_buffer.unmap();
|
||||||
|
|
||||||
|
// Ok(result)
|
||||||
|
// }
|
||||||
|
317
src/ppu.rs
317
src/ppu.rs
@ -1,34 +1,64 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
use bytemuck;
|
||||||
|
// use core::slice::SlicePattern;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
use thiserror::Error;
|
||||||
use wgpu::util::DeviceExt;
|
use wgpu::util::DeviceExt;
|
||||||
|
|
||||||
use bytemuck;
|
#[derive(Error, Debug)]
|
||||||
|
pub enum Error {
|
||||||
|
#[error("Failed to find an appropriate adapter.")]
|
||||||
|
AdapterNotFound,
|
||||||
|
#[error("Failed to create device.")]
|
||||||
|
DeviceCreationFailed(#[from] wgpu::RequestDeviceError),
|
||||||
|
#[error("Failed to create query set. {0}")]
|
||||||
|
QuerySetCreationFailed(String),
|
||||||
|
#[error("Failed to async map a buffer.")]
|
||||||
|
BufferAsyncError(#[from] wgpu::BufferAsyncError),
|
||||||
|
#[error("Failed to create compute pipeline, shader module {0} was not found.")]
|
||||||
|
ShaderModuleNotFound(String),
|
||||||
|
#[error("No buffer found with name {0}.")]
|
||||||
|
BufferNotFound(String),
|
||||||
|
#[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")]
|
||||||
|
BufferTooSmall(String, usize, usize),
|
||||||
|
#[error("No bind group layout found with name {0}.")]
|
||||||
|
BindGroupLayoutNotFound(String),
|
||||||
|
#[error("No pipeline found with name {0}.")]
|
||||||
|
PipelineNotFound(String),
|
||||||
|
#[error("No bind group found with name {0}.")]
|
||||||
|
BindGroupNotFound(String),
|
||||||
|
}
|
||||||
|
|
||||||
pub struct PPU {
|
pub struct PPU {
|
||||||
instance: wgpu::Instance,
|
instance: wgpu::Instance,
|
||||||
adapter: wgpu::Adapter,
|
adapter: wgpu::Adapter,
|
||||||
device: wgpu::Device,
|
device: wgpu::Device,
|
||||||
queue: wgpu::Queue,
|
queue: wgpu::Queue,
|
||||||
query_set: Option<wgpu::QuerySet>,
|
query_set: wgpu::QuerySet,
|
||||||
// shader: wgpu::ShaderModule,
|
buffers: HashMap<String, wgpu::Buffer>,
|
||||||
|
shader_modules: HashMap<String, wgpu::ShaderModule>,
|
||||||
|
pipelines: HashMap<String, wgpu::ComputePipeline>,
|
||||||
|
bind_group_layouts: HashMap<String, wgpu::BindGroupLayout>,
|
||||||
|
bind_groups: HashMap<String, wgpu::BindGroup>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PPU {
|
impl PPU {
|
||||||
pub async fn init() -> Self {
|
/// Initialize a new PPU instance
|
||||||
let instance_desc = wgpu::InstanceDescriptor {
|
pub async fn new() -> Result<Self, Error> {
|
||||||
backends: wgpu::Backends::VULKAN,
|
let instance = wgpu::Instance::default();
|
||||||
dx12_shader_compiler: wgpu::Dx12Compiler::Fxc,
|
|
||||||
gles_minor_version: wgpu::Gles3MinorVersion::default(),
|
|
||||||
flags: wgpu::InstanceFlags::empty(),
|
|
||||||
};
|
|
||||||
let instance = wgpu::Instance::new(instance_desc);
|
|
||||||
let adapter = instance
|
let adapter = instance
|
||||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||||
power_preference: wgpu::PowerPreference::HighPerformance,
|
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||||
force_fallback_adapter: true,
|
force_fallback_adapter: true,
|
||||||
compatible_surface: None,
|
compatible_surface: None,
|
||||||
}).await.expect("Failed to find an appropriate adapter");
|
})
|
||||||
let features = adapter.features();
|
.await
|
||||||
|
.ok_or(Error::AdapterNotFound)?;
|
||||||
|
|
||||||
|
let features = adapter.features();
|
||||||
|
|
||||||
let (device, queue) = adapter
|
let (device, queue) = adapter
|
||||||
.request_device(
|
.request_device(
|
||||||
&wgpu::DeviceDescriptor {
|
&wgpu::DeviceDescriptor {
|
||||||
@ -38,25 +68,270 @@ impl PPU {
|
|||||||
},
|
},
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await?;
|
||||||
.unwrap();
|
|
||||||
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||||
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
|
device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||||
count: 2,
|
count: 2,
|
||||||
ty: wgpu::QueryType::Timestamp,
|
ty: wgpu::QueryType::Timestamp,
|
||||||
label: None,
|
label: None,
|
||||||
}))
|
})
|
||||||
} else {
|
} else {
|
||||||
None
|
return Err(Error::QuerySetCreationFailed(
|
||||||
|
"Timestamp query is not supported".to_string(),
|
||||||
|
));
|
||||||
};
|
};
|
||||||
println!("PPU initialized");
|
|
||||||
Self {
|
Ok(Self {
|
||||||
instance,
|
instance,
|
||||||
adapter,
|
adapter,
|
||||||
device,
|
device,
|
||||||
queue,
|
queue,
|
||||||
query_set,
|
query_set,
|
||||||
|
buffers: HashMap::new(),
|
||||||
|
shader_modules: HashMap::new(),
|
||||||
|
pipelines: HashMap::new(),
|
||||||
|
bind_group_layouts: HashMap::new(),
|
||||||
|
bind_groups: HashMap::new(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a compute buffer
|
||||||
|
pub fn load_buffer<T: bytemuck::Pod>(
|
||||||
|
&mut self,
|
||||||
|
label: &str,
|
||||||
|
data: &[T],
|
||||||
|
usage: wgpu::BufferUsages,
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
let buffer = self
|
||||||
|
.device
|
||||||
|
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||||
|
label: Some(label),
|
||||||
|
contents: bytemuck::cast_slice(data),
|
||||||
|
usage,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.buffers.insert(label.to_string(), buffer);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Update a compute buffer
|
||||||
|
pub fn update_buffer<T>(&self, label: &str, data: &[T]) -> Result<(), Error>
|
||||||
|
where
|
||||||
|
T: bytemuck::Pod,
|
||||||
|
{
|
||||||
|
let buffer = match self.buffers.get(label) {
|
||||||
|
Some(buffer) => buffer,
|
||||||
|
None => return Err(Error::BufferNotFound(label.to_string())),
|
||||||
|
};
|
||||||
|
|
||||||
|
let data_bytes = bytemuck::cast_slice(data);
|
||||||
|
let data_len = data_bytes.len() as wgpu::BufferAddress;
|
||||||
|
|
||||||
|
if buffer.size() < data_len {
|
||||||
|
return Err(Error::BufferTooSmall(
|
||||||
|
label.to_string(),
|
||||||
|
data_len as usize,
|
||||||
|
buffer.size() as usize,
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self.queue.write_buffer(buffer, 0, data_bytes);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve a buffer by name
|
||||||
|
pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> {
|
||||||
|
self.buffers.get(label)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_bind_group_layout(
|
||||||
|
&mut self,
|
||||||
|
label: &str,
|
||||||
|
layout_entries: &[wgpu::BindGroupLayoutEntry],
|
||||||
|
) {
|
||||||
|
let bind_group_layout = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||||
|
label: Some(label),
|
||||||
|
entries: layout_entries,
|
||||||
|
});
|
||||||
|
|
||||||
|
self.bind_group_layouts.insert(label.to_string(), bind_group_layout);
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> {
|
||||||
|
self.bind_group_layouts.get(label)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create and store a bind group
|
||||||
|
pub fn create_bind_group(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
layout_name: &str,
|
||||||
|
entries: &[wgpu::BindGroupEntry],
|
||||||
|
) -> Result<wgpu::BindGroup, Error> {
|
||||||
|
let layout = self.bind_group_layouts.get(layout_name)
|
||||||
|
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))?;
|
||||||
|
|
||||||
|
let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||||
|
label: Some(name),
|
||||||
|
layout,
|
||||||
|
entries,
|
||||||
|
});
|
||||||
|
Ok(bind_group)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve a bind group by name
|
||||||
|
pub fn get_bind_group(&self, name: &str) -> Option<&wgpu::BindGroup> {
|
||||||
|
self.bind_groups.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Load a shader and store it in the hash map
|
||||||
|
pub fn load_shader(&mut self, name: &str, source: &str) -> Result<(), Error> {
|
||||||
|
let shader_module = self
|
||||||
|
.device
|
||||||
|
.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||||
|
label: Some(name),
|
||||||
|
source: wgpu::ShaderSource::Wgsl(source.into()),
|
||||||
|
});
|
||||||
|
self.shader_modules.insert(name.to_string(), shader_module);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a shader from the hash map
|
||||||
|
pub fn get_shader(&self, name: &str) -> Option<&wgpu::ShaderModule> {
|
||||||
|
self.shader_modules.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn load_pipeline(
|
||||||
|
&mut self,
|
||||||
|
name: &str,
|
||||||
|
shader_module_name: &str,
|
||||||
|
entry_point: &str,
|
||||||
|
bind_group_layout_names: &[&str], // Use names of bind group layouts
|
||||||
|
) -> Result<(), Error> {
|
||||||
|
// Retrieve the shader module
|
||||||
|
let shader_module = self.get_shader(shader_module_name)
|
||||||
|
.ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?;
|
||||||
|
|
||||||
|
// Retrieve the bind group layouts
|
||||||
|
let bind_group_layouts = bind_group_layout_names
|
||||||
|
.iter()
|
||||||
|
.map(|layout_name| {
|
||||||
|
self.bind_group_layouts.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts
|
||||||
|
.ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))
|
||||||
|
})
|
||||||
|
.collect::<Result<Vec<_>, _>>()?;
|
||||||
|
|
||||||
|
// Create the pipeline layout
|
||||||
|
let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||||
|
label: Some(name),
|
||||||
|
bind_group_layouts: &bind_group_layouts.iter().map(|layout| *layout).collect::<Vec<_>>(),
|
||||||
|
push_constant_ranges: &[],
|
||||||
|
});
|
||||||
|
|
||||||
|
// Create the compute pipeline
|
||||||
|
let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||||
|
label: Some(name),
|
||||||
|
layout: Some(&pipeline_layout),
|
||||||
|
module: shader_module,
|
||||||
|
entry_point,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Store the pipeline
|
||||||
|
self.pipelines.insert(name.to_string(), pipeline);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn get_pipeline(&self, name: &str) -> Option<&wgpu::ComputePipeline> {
|
||||||
|
self.pipelines.get(name)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a compute task and retrieve the result
|
||||||
|
pub async fn execute_compute_task<T>(
|
||||||
|
&self,
|
||||||
|
pipeline_name: &str,
|
||||||
|
bind_group_name: &str,
|
||||||
|
workgroup_count: (u32, u32, u32),
|
||||||
|
output_buffer_name: &str,
|
||||||
|
buffer_size: usize,
|
||||||
|
) -> Result<Vec<T>, Error>
|
||||||
|
where
|
||||||
|
T: bytemuck::Pod + bytemuck::Zeroable,
|
||||||
|
{
|
||||||
|
// Retrieve the pipeline and bind group
|
||||||
|
let pipeline = self.get_pipeline(pipeline_name)
|
||||||
|
.ok_or_else(|| Error::PipelineNotFound(pipeline_name.to_string()))?;
|
||||||
|
let bind_group = self.get_bind_group(bind_group_name)
|
||||||
|
.ok_or_else(|| Error::BindGroupNotFound(bind_group_name.to_string()))?;
|
||||||
|
|
||||||
|
// Dispatch the compute task
|
||||||
|
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
label: Some("Compute Task Encoder"),
|
||||||
|
});
|
||||||
|
{
|
||||||
|
let mut compute_pass = encoder.begin_compute_pass(&Default::default());
|
||||||
|
compute_pass.set_pipeline(pipeline);
|
||||||
|
compute_pass.set_bind_group(0, bind_group, &[]);
|
||||||
|
compute_pass.dispatch_workgroups(workgroup_count.0, workgroup_count.1, workgroup_count.2);
|
||||||
|
}
|
||||||
|
self.queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
|
// Wait for the task to complete
|
||||||
|
self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
|
||||||
|
// Retrieve and return the data from the output buffer
|
||||||
|
self.complete_task::<T>(output_buffer_name, buffer_size).await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Complete a compute task and retrieve the data from an output buffer
|
||||||
|
pub async fn complete_task<T>(
|
||||||
|
&self,
|
||||||
|
output_buffer_name: &str,
|
||||||
|
buffer_size: usize,
|
||||||
|
) -> Result<Vec<T>, Error>
|
||||||
|
where
|
||||||
|
T: bytemuck::Pod + bytemuck::Zeroable,
|
||||||
|
{
|
||||||
|
let output_buffer = self.buffers.get(output_buffer_name)
|
||||||
|
.ok_or_else(|| Error::BufferNotFound(output_buffer_name.to_string()))?;
|
||||||
|
|
||||||
|
// Create a staging buffer to read back data into
|
||||||
|
let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("Staging Buffer"),
|
||||||
|
size: buffer_size as u64,
|
||||||
|
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
// Copy data from the compute output buffer to the staging buffer
|
||||||
|
let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
|
||||||
|
label: Some("Readback Encoder"),
|
||||||
|
});
|
||||||
|
encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64);
|
||||||
|
self.queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
|
// Map the buffer asynchronously
|
||||||
|
let buffer_slice = staging_buffer.slice(..);
|
||||||
|
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||||
|
|
||||||
|
buffer_slice.map_async(wgpu::MapMode::Read, move |v: Result<(), wgpu::BufferAsyncError>| {
|
||||||
|
sender.send(v).unwrap();
|
||||||
|
});
|
||||||
|
|
||||||
|
// Wait for the buffer to be ready
|
||||||
|
self.device.poll(wgpu::Maintain::Wait);
|
||||||
|
receiver.receive().await.unwrap()?;
|
||||||
|
let data = buffer_slice.get_mapped_range();
|
||||||
|
|
||||||
|
// Read data from the buffer
|
||||||
|
let result = bytemuck::cast_slice::<_, T>(&data).to_vec();
|
||||||
|
|
||||||
|
// Unmap the buffer
|
||||||
|
drop(data);
|
||||||
|
staging_buffer.unmap();
|
||||||
|
|
||||||
|
Ok(result)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1,25 +0,0 @@
|
|||||||
#version 450
|
|
||||||
|
|
||||||
layout(set = 0, binding = 0) buffer MatrixA {
|
|
||||||
float data[4][4];
|
|
||||||
} matrixA;
|
|
||||||
|
|
||||||
layout(set = 0, binding = 1) buffer MatrixB {
|
|
||||||
float data[4][4];
|
|
||||||
} matrixB;
|
|
||||||
|
|
||||||
layout(set = 0, binding = 2) buffer MatrixC {
|
|
||||||
float data[4][4];
|
|
||||||
} matrixC;
|
|
||||||
|
|
||||||
void main() {
|
|
||||||
uint row = gl_GlobalInvocationID.y;
|
|
||||||
uint col = gl_GlobalInvocationID.x;
|
|
||||||
float sum = 0.0;
|
|
||||||
|
|
||||||
for (uint k = 0; k < 4; ++k) {
|
|
||||||
sum += matrixA.data[row][k] * matrixB.data[k][col];
|
|
||||||
}
|
|
||||||
|
|
||||||
matrixC.data[row][col] = sum;
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user