From 5b3ef81c84267de6995a8258604e5224d8a3589b Mon Sep 17 00:00:00 2001 From: Julius Koskela Date: Fri, 22 Dec 2023 16:01:40 +0200 Subject: [PATCH] =?UTF-8?q?=F0=9F=9A=80=20Builds=20on=20Linux,=20work=20on?= =?UTF-8?q?=20abstraction?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Julius Koskela --- Cargo.lock | 96 ++++++++++++++- Cargo.toml | 2 + src/main.rs | 212 +++++++++++++++++++++++++++++++- src/ppu.rs | 317 ++++++++++++++++++++++++++++++++++++++++++++---- src/shader.spv | 25 ---- src/shader.wgsl | 2 +- 6 files changed, 600 insertions(+), 54 deletions(-) delete mode 100644 src/shader.spv diff --git a/Cargo.lock b/Cargo.lock index 9429c52..c873b9d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -74,7 +74,7 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9b39be18770d11421cdb1b9947a45dd3f37e93092cbf377614828a319d5fee8" dependencies = [ - "hermit-abi", + "hermit-abi 0.1.19", "libc", "winapi", ] @@ -159,6 +159,12 @@ dependencies = [ "syn", ] +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + [[package]] name = "cc" version = "1.0.83" @@ -224,7 +230,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e16e44ab292b1dddfdaf7be62cfd8877df52f2f3fde5858d95bab606be259f20" dependencies = [ "bitflags 2.4.1", - "libloading 0.7.4", + "libloading 0.8.1", "winapi", ] @@ -453,6 +459,12 @@ dependencies = [ "libc", ] +[[package]] +name = "hermit-abi" +version = "0.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" + [[package]] name = "hexf-parse" version = "0.2.1" @@ -591,6 +603,17 @@ dependencies = [ "adler", ] +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "wasi", + "windows-sys", +] + [[package]] name = "naga" version = "0.14.2" @@ -630,6 +653,16 @@ dependencies = [ "autocfg", ] +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi 0.3.3", + "libc", +] + [[package]] name = "objc" version = "0.2.7" @@ -728,6 +761,12 @@ dependencies = [ "indexmap", ] +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + [[package]] name = "pkg-config" version = "0.3.28" @@ -853,6 +892,15 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "signal-hook-registry" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" +dependencies = [ + "libc", +] + [[package]] name = "slotmap" version = "1.0.7" @@ -868,6 +916,16 @@ version = "1.11.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4dccd0940a2dcdf68d092b8cbab7dc0ad8fa938bf95787e1b916b0e3d0e8e970" +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + [[package]] name = "spin" version = "0.9.8" @@ -933,6 +991,36 @@ dependencies = [ "syn", ] +[[package]] +name = "tokio" +version = "1.35.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c89b4efa943be685f629b149f53829423f8f5531ea21249408e8e2f8671ec104" +dependencies = [ + "backtrace", + "bytes", + "libc", + "mio", + "num_cpus", + "parking_lot 0.12.1", + "pin-project-lite", + "signal-hook-registry", + "socket2", + "tokio-macros", + "windows-sys", +] + +[[package]] +name = "tokio-macros" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "unicode-ident" version = "1.0.12" @@ -1110,7 +1198,7 @@ dependencies = [ "js-sys", "khronos-egl", "libc", - "libloading 0.7.4", + "libloading 0.8.1", "log", "metal", "naga", @@ -1149,6 +1237,8 @@ dependencies = [ "env_logger", "futures-intrusive", "pollster", + "thiserror", + "tokio", "wgpu", ] diff --git a/Cargo.toml b/Cargo.toml index b72d07e..1fd1135 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,3 +11,5 @@ env_logger = "0.9.1" pollster = "0.2.5" futures-intrusive = "0.4" bytemuck = { version = "1.12.1", features = ["derive"] } +thiserror = "1.0.51" +tokio = { version = "1.35.1", features = ["sync", "full"] } diff --git a/src/main.rs b/src/main.rs index 4e81e24..107b02a 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,8 @@ mod ppu; +use bytemuck; use std::time::Instant; use wgpu::util::DeviceExt; -use bytemuck; async fn run() { let instance_desc = wgpu::InstanceDescriptor { @@ -252,7 +252,8 @@ async fn run() { .map(|b| { let a = bytemuck::from_bytes::(b); a.clone() - }).collect::>(); + }) + .collect::>(); println!("matrix_c_staging_buf {:?}", parsed); } @@ -269,6 +270,209 @@ async fn run() { } } -fn main() { - pollster::block_on(run()); +use ppu::*; +use std::fs::File; + +#[tokio::main] +async fn main() -> Result<(), Error> { + // Initialize PPU + let mut ppu = PPU::new().await?; + + // Load Shader from a file + let shader_source = " + struct Matrix { + data: array, 4>, // Corrected: Use a comma instead of a semicolon + }; + + @group(0) @binding(0) + var matrixA: Matrix; + + @group(0) @binding(1) + var matrixB: Matrix; + + @group(0) @binding(2) + var matrixC: Matrix; + + @compute @workgroup_size(4, 4) + fn main(@builtin(global_invocation_id) global_id: vec3) { + let row: u32 = global_id.y; + let col: u32 = global_id.x; + + var sum: f32 = 0.0; + for (var k: u32 = 0u; k < 4u; k = k + 1u) { + sum = sum + matrixA.data[row][k] * matrixB.data[k][col]; + } + matrixC.data[row][col] = sum; + } + "; // Replace with your shader source + ppu.load_shader("compute_shader", shader_source)?; + + // Create Bind Group Layout + let bind_group_layout_entries = [ + wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 1, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: true }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + wgpu::BindGroupLayoutEntry { + binding: 2, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: None, + }, + count: None, + }, + ]; + ppu.load_bind_group_layout("my_bind_group_layout", &bind_group_layout_entries); + + // Create Bind Group +let bind_group_entries = [ + wgpu::BindGroupEntry { + binding: 0, + resource: ppu.get_buffer("matrix_a").unwrap().as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 1, + resource: ppu.get_buffer("matrix_b").unwrap().as_entire_binding(), + }, + wgpu::BindGroupEntry { + binding: 2, + resource: ppu.get_buffer("matrix_c").unwrap().as_entire_binding(), + }, +]; +ppu.create_bind_group("my_bind_group", "my_bind_group_layout", &bind_group_entries)?; + + + // Load Pipeline + ppu.load_pipeline( + "my_pipeline", + "compute_shader", + "main", // Entry point in the shader + &["my_bind_group_layout"], + )?; + + // Create a buffer to store the output data + let output_buffer_label = "output_buffer"; + let output_data: Vec = vec![0; 10]; // Example: buffer to store 10 unsigned integers + ppu.load_buffer( + output_buffer_label, + &output_data, + wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::MAP_READ, + )?; + + // Execute the compute task + let workgroup_count = (1, 1, 1); // Adjust according to your task's requirements + let results: Vec = ppu + .execute_compute_task( + "my_pipeline", + "my_bind_group", + workgroup_count, + output_buffer_label, + std::mem::size_of_val(&output_data[..]), + ) + .await?; + + // Process results + println!("Results: {:?}", results); + + Ok(()) } + +// /// Submit a compute task +// pub fn submit_task(&self, pipeline: &str, bind_group: &str, workgroup_count: (u32, u32, u32)) { +// let pipeline = self.get_pipeline(pipeline).unwrap(); +// let bind_group = self.get_bind_group(bind_group).unwrap(); +// let mut encoder = self +// .device +// .create_command_encoder(&wgpu::CommandEncoderDescriptor { +// label: Some("Compute Encoder"), +// }); + +// { +// let mut compute_pass = encoder.begin_compute_pass(&Default::default()); +// compute_pass.set_pipeline(pipeline); +// compute_pass.set_bind_group(0, bind_group, &[]); +// compute_pass.dispatch_workgroups( +// workgroup_count.0, +// workgroup_count.1, +// workgroup_count.2, +// ); +// } + +// self.queue.submit(Some(encoder.finish())); +// } + +// /// Complete a compute task +// pub async fn complete_task( +// &self, +// output_buffer: &str, +// buffer_size: usize, +// ) -> Result, Error> +// where +// T: bytemuck::Pod + bytemuck::Zeroable, +// { +// // Create a staging buffer to read back data into +// let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { +// label: Some("Staging Buffer"), +// size: buffer_size as u64, +// usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, +// mapped_at_creation: false, +// }); + +// let output_buffer = match self.buffers.get(output_buffer) { +// Some(buffer) => buffer, +// None => return Err(Error::BufferNotFound(output_buffer.to_string())), +// }; + +// // Copy data from the compute output buffer to the staging buffer +// let mut encoder = self +// .device +// .create_command_encoder(&wgpu::CommandEncoderDescriptor { +// label: Some("Readback Encoder"), +// }); + +// encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64); +// self.queue.submit(Some(encoder.finish())); + +// // Map the buffer asynchronously +// let buffer_slice = staging_buffer.slice(..); +// let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + +// buffer_slice.map_async( +// wgpu::MapMode::Read, +// move |v: Result<(), wgpu::BufferAsyncError>| { +// sender.send(v).unwrap(); +// }, +// ); + +// // Wait for the buffer to be ready +// self.device.poll(wgpu::Maintain::Wait); +// receiver.receive().await.unwrap()?; + +// // Read data from the buffer +// let data = buffer_slice.get_mapped_range(); +// let result = bytemuck::cast_slice::<_, T>(&data).to_vec(); + +// // Unmap the buffer +// drop(data); +// staging_buffer.unmap(); + +// Ok(result) +// } diff --git a/src/ppu.rs b/src/ppu.rs index 6a61ef3..69a86e5 100644 --- a/src/ppu.rs +++ b/src/ppu.rs @@ -1,34 +1,64 @@ +#![allow(unused)] +use bytemuck; +// use core::slice::SlicePattern; +use std::collections::HashMap; use std::time::Instant; - +use thiserror::Error; use wgpu::util::DeviceExt; -use bytemuck; +#[derive(Error, Debug)] +pub enum Error { + #[error("Failed to find an appropriate adapter.")] + AdapterNotFound, + #[error("Failed to create device.")] + DeviceCreationFailed(#[from] wgpu::RequestDeviceError), + #[error("Failed to create query set. {0}")] + QuerySetCreationFailed(String), + #[error("Failed to async map a buffer.")] + BufferAsyncError(#[from] wgpu::BufferAsyncError), + #[error("Failed to create compute pipeline, shader module {0} was not found.")] + ShaderModuleNotFound(String), + #[error("No buffer found with name {0}.")] + BufferNotFound(String), + #[error("Buffer {0} is too small, expected {1} bytes found {2} bytes.")] + BufferTooSmall(String, usize, usize), + #[error("No bind group layout found with name {0}.")] + BindGroupLayoutNotFound(String), + #[error("No pipeline found with name {0}.")] + PipelineNotFound(String), + #[error("No bind group found with name {0}.")] + BindGroupNotFound(String), +} pub struct PPU { instance: wgpu::Instance, adapter: wgpu::Adapter, device: wgpu::Device, queue: wgpu::Queue, - query_set: Option, - // shader: wgpu::ShaderModule, + query_set: wgpu::QuerySet, + buffers: HashMap, + shader_modules: HashMap, + pipelines: HashMap, + bind_group_layouts: HashMap, + bind_groups: HashMap, } impl PPU { - pub async fn init() -> Self { - let instance_desc = wgpu::InstanceDescriptor { - backends: wgpu::Backends::VULKAN, - dx12_shader_compiler: wgpu::Dx12Compiler::Fxc, - gles_minor_version: wgpu::Gles3MinorVersion::default(), - flags: wgpu::InstanceFlags::empty(), - }; - let instance = wgpu::Instance::new(instance_desc); + /// Initialize a new PPU instance + pub async fn new() -> Result { + let instance = wgpu::Instance::default(); + let adapter = instance .request_adapter(&wgpu::RequestAdapterOptions { power_preference: wgpu::PowerPreference::HighPerformance, force_fallback_adapter: true, compatible_surface: None, - }).await.expect("Failed to find an appropriate adapter"); - let features = adapter.features(); + }) + .await + .ok_or(Error::AdapterNotFound)?; + + let features = adapter.features(); + let (device, queue) = adapter .request_device( &wgpu::DeviceDescriptor { @@ -38,25 +68,270 @@ impl PPU { }, None, ) - .await - .unwrap(); + .await?; + let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) { - Some(device.create_query_set(&wgpu::QuerySetDescriptor { + device.create_query_set(&wgpu::QuerySetDescriptor { count: 2, ty: wgpu::QueryType::Timestamp, label: None, - })) + }) } else { - None + return Err(Error::QuerySetCreationFailed( + "Timestamp query is not supported".to_string(), + )); }; - println!("PPU initialized"); - Self { + + Ok(Self { instance, adapter, device, queue, query_set, + buffers: HashMap::new(), + shader_modules: HashMap::new(), + pipelines: HashMap::new(), + bind_group_layouts: HashMap::new(), + bind_groups: HashMap::new(), + }) + } + + /// Create a compute buffer + pub fn load_buffer( + &mut self, + label: &str, + data: &[T], + usage: wgpu::BufferUsages, + ) -> Result<(), Error> { + let buffer = self + .device + .create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: Some(label), + contents: bytemuck::cast_slice(data), + usage, + }); + + self.buffers.insert(label.to_string(), buffer); + Ok(()) + } + + /// Update a compute buffer + pub fn update_buffer(&self, label: &str, data: &[T]) -> Result<(), Error> + where + T: bytemuck::Pod, + { + let buffer = match self.buffers.get(label) { + Some(buffer) => buffer, + None => return Err(Error::BufferNotFound(label.to_string())), + }; + + let data_bytes = bytemuck::cast_slice(data); + let data_len = data_bytes.len() as wgpu::BufferAddress; + + if buffer.size() < data_len { + return Err(Error::BufferTooSmall( + label.to_string(), + data_len as usize, + buffer.size() as usize, + )); } + + self.queue.write_buffer(buffer, 0, data_bytes); + + Ok(()) + } + + /// Retrieve a buffer by name + pub fn get_buffer(&self, label: &str) -> Option<&wgpu::Buffer> { + self.buffers.get(label) + } + + pub fn load_bind_group_layout( + &mut self, + label: &str, + layout_entries: &[wgpu::BindGroupLayoutEntry], + ) { + let bind_group_layout = self.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some(label), + entries: layout_entries, + }); + + self.bind_group_layouts.insert(label.to_string(), bind_group_layout); + } + + pub fn get_bind_group_layout(&self, label: &str) -> Option<&wgpu::BindGroupLayout> { + self.bind_group_layouts.get(label) + } + + /// Create and store a bind group + pub fn create_bind_group( + &self, + name: &str, + layout_name: &str, + entries: &[wgpu::BindGroupEntry], + ) -> Result { + let layout = self.bind_group_layouts.get(layout_name) + .ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string()))?; + + let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some(name), + layout, + entries, + }); + Ok(bind_group) + } + + /// Retrieve a bind group by name + pub fn get_bind_group(&self, name: &str) -> Option<&wgpu::BindGroup> { + self.bind_groups.get(name) + } + + /// Load a shader and store it in the hash map + pub fn load_shader(&mut self, name: &str, source: &str) -> Result<(), Error> { + let shader_module = self + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some(name), + source: wgpu::ShaderSource::Wgsl(source.into()), + }); + self.shader_modules.insert(name.to_string(), shader_module); + Ok(()) + } + + /// Get a shader from the hash map + pub fn get_shader(&self, name: &str) -> Option<&wgpu::ShaderModule> { + self.shader_modules.get(name) + } + + pub fn load_pipeline( + &mut self, + name: &str, + shader_module_name: &str, + entry_point: &str, + bind_group_layout_names: &[&str], // Use names of bind group layouts + ) -> Result<(), Error> { + // Retrieve the shader module + let shader_module = self.get_shader(shader_module_name) + .ok_or(Error::ShaderModuleNotFound(shader_module_name.to_string()))?; + + // Retrieve the bind group layouts + let bind_group_layouts = bind_group_layout_names + .iter() + .map(|layout_name| { + self.bind_group_layouts.get(*layout_name) // Assuming you have a HashMap for BindGroupLayouts + .ok_or_else(|| Error::BindGroupLayoutNotFound(layout_name.to_string())) + }) + .collect::, _>>()?; + + // Create the pipeline layout + let pipeline_layout = self.device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some(name), + bind_group_layouts: &bind_group_layouts.iter().map(|layout| *layout).collect::>(), + push_constant_ranges: &[], + }); + + // Create the compute pipeline + let pipeline = self.device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some(name), + layout: Some(&pipeline_layout), + module: shader_module, + entry_point, + }); + + // Store the pipeline + self.pipelines.insert(name.to_string(), pipeline); + Ok(()) + } + + pub fn get_pipeline(&self, name: &str) -> Option<&wgpu::ComputePipeline> { + self.pipelines.get(name) + } + + /// Execute a compute task and retrieve the result + pub async fn execute_compute_task( + &self, + pipeline_name: &str, + bind_group_name: &str, + workgroup_count: (u32, u32, u32), + output_buffer_name: &str, + buffer_size: usize, + ) -> Result, Error> + where + T: bytemuck::Pod + bytemuck::Zeroable, + { + // Retrieve the pipeline and bind group + let pipeline = self.get_pipeline(pipeline_name) + .ok_or_else(|| Error::PipelineNotFound(pipeline_name.to_string()))?; + let bind_group = self.get_bind_group(bind_group_name) + .ok_or_else(|| Error::BindGroupNotFound(bind_group_name.to_string()))?; + + // Dispatch the compute task + let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Compute Task Encoder"), + }); + { + let mut compute_pass = encoder.begin_compute_pass(&Default::default()); + compute_pass.set_pipeline(pipeline); + compute_pass.set_bind_group(0, bind_group, &[]); + compute_pass.dispatch_workgroups(workgroup_count.0, workgroup_count.1, workgroup_count.2); + } + self.queue.submit(Some(encoder.finish())); + + // Wait for the task to complete + self.device.poll(wgpu::Maintain::Wait); + + // Retrieve and return the data from the output buffer + self.complete_task::(output_buffer_name, buffer_size).await + } + + /// Complete a compute task and retrieve the data from an output buffer + pub async fn complete_task( + &self, + output_buffer_name: &str, + buffer_size: usize, + ) -> Result, Error> + where + T: bytemuck::Pod + bytemuck::Zeroable, + { + let output_buffer = self.buffers.get(output_buffer_name) + .ok_or_else(|| Error::BufferNotFound(output_buffer_name.to_string()))?; + + // Create a staging buffer to read back data into + let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("Staging Buffer"), + size: buffer_size as u64, + usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST, + mapped_at_creation: false, + }); + + // Copy data from the compute output buffer to the staging buffer + let mut encoder = self.device.create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("Readback Encoder"), + }); + encoder.copy_buffer_to_buffer(output_buffer, 0, &staging_buffer, 0, buffer_size as u64); + self.queue.submit(Some(encoder.finish())); + + // Map the buffer asynchronously + let buffer_slice = staging_buffer.slice(..); + let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel(); + + buffer_slice.map_async(wgpu::MapMode::Read, move |v: Result<(), wgpu::BufferAsyncError>| { + sender.send(v).unwrap(); + }); + + // Wait for the buffer to be ready + self.device.poll(wgpu::Maintain::Wait); + receiver.receive().await.unwrap()?; + let data = buffer_slice.get_mapped_range(); + + // Read data from the buffer + let result = bytemuck::cast_slice::<_, T>(&data).to_vec(); + + // Unmap the buffer + drop(data); + staging_buffer.unmap(); + + Ok(result) } } diff --git a/src/shader.spv b/src/shader.spv deleted file mode 100644 index 5c9af3a..0000000 --- a/src/shader.spv +++ /dev/null @@ -1,25 +0,0 @@ -#version 450 - -layout(set = 0, binding = 0) buffer MatrixA { - float data[4][4]; -} matrixA; - -layout(set = 0, binding = 1) buffer MatrixB { - float data[4][4]; -} matrixB; - -layout(set = 0, binding = 2) buffer MatrixC { - float data[4][4]; -} matrixC; - -void main() { - uint row = gl_GlobalInvocationID.y; - uint col = gl_GlobalInvocationID.x; - float sum = 0.0; - - for (uint k = 0; k < 4; ++k) { - sum += matrixA.data[row][k] * matrixB.data[k][col]; - } - - matrixC.data[row][col] = sum; -} diff --git a/src/shader.wgsl b/src/shader.wgsl index a5a6091..b458d0f 100644 --- a/src/shader.wgsl +++ b/src/shader.wgsl @@ -21,4 +21,4 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { sum = sum + matrixA.data[row][k] * matrixB.data[k][col]; } matrixC.data[row][col] = sum; -} +} \ No newline at end of file