168 lines
4.4 KiB
Rust
168 lines
4.4 KiB
Rust
mod ppu;
|
|
use ppu::*;
|
|
|
|
const MM4X4_SHADER: &str = "
|
|
struct Matrix {
|
|
data: array<array<f32, 4>, 4>,
|
|
};
|
|
|
|
@group(0) @binding(0)
|
|
var<storage, read> matrixA: Matrix;
|
|
|
|
@group(0) @binding(1)
|
|
var<storage, read> matrixB: Matrix;
|
|
|
|
@group(0) @binding(2)
|
|
var<storage, read_write> matrixC: Matrix;
|
|
|
|
// Consider setting workgroup size to power of 2 for better efficiency
|
|
@compute @workgroup_size(4, 4, 1)
|
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
|
// Ensure the computation is within the bounds of the 4x4 matrix
|
|
if (global_id.x < 4u && global_id.y < 4u) {
|
|
let row: u32 = global_id.y;
|
|
let col: u32 = global_id.x;
|
|
|
|
var sum: f32 = 0.0;
|
|
for (var k: u32 = 0u; k < 4u; k = k + 1u) {
|
|
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
|
|
}
|
|
matrixC.data[row][col] = sum;
|
|
}
|
|
}
|
|
";
|
|
|
|
const MATRIX_A: [[f32; 4]; 4] = [
|
|
[1.0, 2.0, 3.0, 4.0],
|
|
[5.0, 6.0, 7.0, 8.0],
|
|
[9.0, 10.0, 11.0, 12.0],
|
|
[13.0, 14.0, 15.0, 16.0],
|
|
];
|
|
|
|
const MATRIX_B: [[f32; 4]; 4] = [
|
|
[1.0, 2.0, 3.0, 4.0],
|
|
[5.0, 6.0, 7.0, 8.0],
|
|
[9.0, 10.0, 11.0, 12.0],
|
|
[13.0, 14.0, 15.0, 16.0],
|
|
];
|
|
|
|
// const EXPECT: [[f32; 4]; 4] = [
|
|
// [90.0, 100.0, 110.0, 120.0],
|
|
// [202.0, 228.0, 254.0, 280.0],
|
|
// [314.0, 356.0, 398.0, 440.0],
|
|
// [426.0, 484.0, 542.0, 600.0],
|
|
// ];
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Error> {
|
|
// Create PPU
|
|
let mut ppu = PPU::new().await?;
|
|
ppu.load_shader("MM4X4_SHADER", MM4X4_SHADER)?;
|
|
|
|
let mut buffers = ComputeBuffers::new::<f32>(&ppu, 16);
|
|
|
|
buffers.add_buffer_init(
|
|
&ppu,
|
|
"MATRIX_A",
|
|
&MATRIX_A,
|
|
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
|
);
|
|
buffers.add_buffer_init(
|
|
&ppu,
|
|
"MATRIX_B",
|
|
&MATRIX_B,
|
|
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
|
);
|
|
buffers.add_buffer_init(
|
|
&ppu,
|
|
"MATRIX_C",
|
|
&[0.0; 16],
|
|
wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
|
|
);
|
|
|
|
// Create Bind Group Layout
|
|
let bind_group_layout_entries = [
|
|
wgpu::BindGroupLayoutEntry {
|
|
binding: 0,
|
|
visibility: wgpu::ShaderStages::COMPUTE,
|
|
ty: wgpu::BindingType::Buffer {
|
|
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
|
has_dynamic_offset: false,
|
|
min_binding_size: None,
|
|
},
|
|
count: None,
|
|
},
|
|
wgpu::BindGroupLayoutEntry {
|
|
binding: 1,
|
|
visibility: wgpu::ShaderStages::COMPUTE,
|
|
ty: wgpu::BindingType::Buffer {
|
|
ty: wgpu::BufferBindingType::Storage { read_only: true },
|
|
has_dynamic_offset: false,
|
|
min_binding_size: None,
|
|
},
|
|
count: None,
|
|
},
|
|
wgpu::BindGroupLayoutEntry {
|
|
binding: 2,
|
|
visibility: wgpu::ShaderStages::COMPUTE,
|
|
ty: wgpu::BindingType::Buffer {
|
|
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
|
has_dynamic_offset: false,
|
|
min_binding_size: None,
|
|
},
|
|
count: None,
|
|
},
|
|
];
|
|
|
|
ppu.load_bind_group_layout("MM4X4_BIND_GROUP_LAYOUT", &bind_group_layout_entries);
|
|
|
|
// Finally, create the bind group
|
|
ppu.create_bind_group(
|
|
"MM4X4_BIND_GROUP",
|
|
"MM4X4_BIND_GROUP_LAYOUT",
|
|
&[
|
|
wgpu::BindGroupEntry {
|
|
binding: 0,
|
|
resource: buffers.get_buffer("MATRIX_A").unwrap().as_entire_binding(),
|
|
},
|
|
wgpu::BindGroupEntry {
|
|
binding: 1,
|
|
resource: buffers.get_buffer("MATRIX_B").unwrap().as_entire_binding(),
|
|
},
|
|
wgpu::BindGroupEntry {
|
|
binding: 2,
|
|
resource: buffers.get_buffer("MATRIX_C").unwrap().as_entire_binding(),
|
|
},
|
|
],
|
|
)?;
|
|
|
|
// Load Pipeline
|
|
ppu.load_pipeline(
|
|
"MM4X4_PIPELINE",
|
|
"MM4X4_SHADER",
|
|
"main",
|
|
&["MM4X4_BIND_GROUP_LAYOUT"],
|
|
)?;
|
|
|
|
// Execute the compute task
|
|
let workgroup_count = (1, 1, 1);
|
|
let results: ComputeResult<f32> = ppu
|
|
.execute_compute_task(
|
|
"MM4X4_PIPELINE",
|
|
"MM4X4_BIND_GROUP",
|
|
&buffers,
|
|
"MATRIX_C",
|
|
workgroup_count,
|
|
)
|
|
.await?;
|
|
|
|
// Process results
|
|
println!("Matrix C:");
|
|
|
|
results.data().chunks(4).for_each(|row| {
|
|
println!("{:?}", row);
|
|
});
|
|
|
|
Ok(())
|
|
}
|