wgpu_compute_shader/examples/mat_mul_4x4.rs
Julius Koskela 668293d956
🔧 Implement as a library
- Make the crate a lib
- Move main to examples as mat_mul_4x4.rs
- Correctly track elapsed time of compute task

Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
2023-12-23 04:00:45 +02:00

169 lines
4.5 KiB
Rust

use wgpu_compute_shader::*;
const MM4X4_SHADER: &str = "
struct Matrix {
data: array<array<f32, 4>, 4>,
};
@group(0) @binding(0)
var<storage, read> matrixA: Matrix;
@group(0) @binding(1)
var<storage, read> matrixB: Matrix;
@group(0) @binding(2)
var<storage, read_write> matrixC: Matrix;
// Consider setting workgroup size to power of 2 for better efficiency
@compute @workgroup_size(4, 4, 1)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
// Ensure the computation is within the bounds of the 4x4 matrix
if (global_id.x < 4u && global_id.y < 4u) {
let row: u32 = global_id.y;
let col: u32 = global_id.x;
var sum: f32 = 0.0;
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
}
matrixC.data[row][col] = sum;
}
}
";
const MATRIX_A: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
const MATRIX_B: [[f32; 4]; 4] = [
[1.0, 2.0, 3.0, 4.0],
[5.0, 6.0, 7.0, 8.0],
[9.0, 10.0, 11.0, 12.0],
[13.0, 14.0, 15.0, 16.0],
];
// const EXPECT: [[f32; 4]; 4] = [
// [90.0, 100.0, 110.0, 120.0],
// [202.0, 228.0, 254.0, 280.0],
// [314.0, 356.0, 398.0, 440.0],
// [426.0, 484.0, 542.0, 600.0],
// ];
#[tokio::main]
async fn main() -> Result<(), Error> {
// Create PPU
let mut ppu = PPU::new().await?;
ppu.load_shader("MM4X4_SHADER", MM4X4_SHADER)?;
let mut buffers = ComputeBuffers::new::<f32>(&ppu, 16);
buffers.add_buffer_init(
&ppu,
"MATRIX_A",
&MATRIX_A,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_B",
&MATRIX_B,
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
);
buffers.add_buffer_init(
&ppu,
"MATRIX_C",
&[0.0; 16],
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
);
// Create Bind Group Layout
let bind_group_layout_entries = [
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None,
},
count: None,
},
];
ppu.load_bind_group_layout("MM4X4_BIND_GROUP_LAYOUT", &bind_group_layout_entries);
// Finally, create the bind group
ppu.create_bind_group(
"MM4X4_BIND_GROUP",
"MM4X4_BIND_GROUP_LAYOUT",
&[
wgpu::BindGroupEntry {
binding: 0,
resource: buffers.get_buffer("MATRIX_A").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 1,
resource: buffers.get_buffer("MATRIX_B").unwrap().as_entire_binding(),
},
wgpu::BindGroupEntry {
binding: 2,
resource: buffers.get_buffer("MATRIX_C").unwrap().as_entire_binding(),
},
],
)?;
// Load Pipeline
ppu.load_pipeline(
"MM4X4_PIPELINE",
"MM4X4_SHADER",
"main",
&["MM4X4_BIND_GROUP_LAYOUT"],
)?;
// Execute the compute task
let workgroup_count = (1, 1, 1);
let results: ComputeResult<f32> = ppu
.execute_compute_task(
"MM4X4_PIPELINE",
"MM4X4_BIND_GROUP",
&buffers,
"MATRIX_C",
workgroup_count,
)
.await?;
// Process results
println!("Matrix C:");
results.data().chunks(4).for_each(|row| {
println!("{:?}", row);
});
println!("Time elapsed: {} us", results.time_elapsed_us());
Ok(())
}