🚀 Initial wgpu matrix multiplication test, fails to build on Linux
Signed-off-by: Julius Koskela <julius.koskela@unikie.com>
This commit is contained in:
commit
e1e9b8dd18
1
.gitignore
vendored
Normal file
1
.gitignore
vendored
Normal file
@ -0,0 +1 @@
|
|||||||
|
/target
|
1284
Cargo.lock
generated
Normal file
1284
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
13
Cargo.toml
Normal file
13
Cargo.toml
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
[package]
|
||||||
|
name = "wgpu_compute_shader"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
wgpu = { version = "0.18", features = ["vulkan-portability"] }
|
||||||
|
env_logger = "0.9.1"
|
||||||
|
pollster = "0.2.5"
|
||||||
|
futures-intrusive = "0.4"
|
||||||
|
bytemuck = { version = "1.12.1", features = ["derive"] }
|
9
build.rs
Normal file
9
build.rs
Normal file
@ -0,0 +1,9 @@
|
|||||||
|
fn main() {
|
||||||
|
if cfg!(target_os = "linux") {
|
||||||
|
// println!("cargo:rustc-link-lib=X11");
|
||||||
|
// println!("cargo:rustc-link-lib=Xcursor");
|
||||||
|
// println!("cargo:rustc-link-lib=Xrandr");
|
||||||
|
// println!("cargo:rustc-link-lib=Xi");
|
||||||
|
println!("cargo:rustc-link-lib=vulkan");
|
||||||
|
}
|
||||||
|
}
|
45
default.nix
Normal file
45
default.nix
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
{pkgs}:
|
||||||
|
with pkgs;
|
||||||
|
rustPlatform.buildRustPackage {
|
||||||
|
pname = "matmul-vshader";
|
||||||
|
version = "0.1.0";
|
||||||
|
|
||||||
|
src = ./.;
|
||||||
|
|
||||||
|
packages = [cmake shaderc];
|
||||||
|
buildInputs = [cmake shaderc];
|
||||||
|
nativeBuildInputs = [
|
||||||
|
vulkan-headers
|
||||||
|
vulkan-loader
|
||||||
|
vulkan-validation-layers
|
||||||
|
vulkan-tools
|
||||||
|
pkg-config
|
||||||
|
git
|
||||||
|
gcc
|
||||||
|
cmake
|
||||||
|
glibc
|
||||||
|
python3
|
||||||
|
shaderc
|
||||||
|
];
|
||||||
|
|
||||||
|
RUST_BACKTRACE = "1";
|
||||||
|
# LD_LIBRARY_PATH = "${vulkan-loader}/lib:${vulkan-validation-layers}/lib:${vulkan-tools}/lib:${vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}";
|
||||||
|
# VK_ICD_FILENAMES = "${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json";
|
||||||
|
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
|
||||||
|
# VK_INSTANCE_LAYERS = "VK_LAYER_KHRONOS_validation";
|
||||||
|
# VK_DEVICE_LAYERS = "VK_LAYER_KHRONOS_validation";
|
||||||
|
# VK_LOADER_DEBUG = "all";
|
||||||
|
# VK_LOADER_DEBUG_FILE = "/tmp/vulkan.log";
|
||||||
|
# VK_INSTANCE_EXTENSIONS = "VK_EXT_debug_utils";
|
||||||
|
# VK_DEVICE_EXTENSIONS = "VK_EXT_debug_utils";
|
||||||
|
# VK_LAYER_ENABLES = "VK_LAYER_KHRONOS_validation";
|
||||||
|
# VK_LAYER_DISABLES = "VK_LAYER_LUNARG_api_dump";
|
||||||
|
# VK_LAYER_PATH = "${vulkan-validation-layers}/share/vulkan/explicit_layer.d";
|
||||||
|
|
||||||
|
cargoBuildFlags = ["--release"];
|
||||||
|
|
||||||
|
cargoLock = {
|
||||||
|
lockFile = ./Cargo.lock;
|
||||||
|
allowBuiltinFetchGit = true;
|
||||||
|
};
|
||||||
|
}
|
27
flake.lock
Normal file
27
flake.lock
Normal file
@ -0,0 +1,27 @@
|
|||||||
|
{
|
||||||
|
"nodes": {
|
||||||
|
"nixpkgs": {
|
||||||
|
"locked": {
|
||||||
|
"lastModified": 1703013332,
|
||||||
|
"narHash": "sha256-+tFNwMvlXLbJZXiMHqYq77z/RfmpfpiI3yjL6o/Zo9M=",
|
||||||
|
"owner": "NixOS",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"rev": "54aac082a4d9bb5bbc5c4e899603abfb76a3f6d6",
|
||||||
|
"type": "github"
|
||||||
|
},
|
||||||
|
"original": {
|
||||||
|
"owner": "NixOS",
|
||||||
|
"ref": "nixos-unstable",
|
||||||
|
"repo": "nixpkgs",
|
||||||
|
"type": "github"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": {
|
||||||
|
"inputs": {
|
||||||
|
"nixpkgs": "nixpkgs"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"root": "root",
|
||||||
|
"version": 7
|
||||||
|
}
|
13
flake.nix
Normal file
13
flake.nix
Normal file
@ -0,0 +1,13 @@
|
|||||||
|
{
|
||||||
|
inputs = {
|
||||||
|
nixpkgs.url = "github:NixOS/nixpkgs/nixos-unstable";
|
||||||
|
};
|
||||||
|
outputs = {
|
||||||
|
self,
|
||||||
|
nixpkgs,
|
||||||
|
}: {
|
||||||
|
packages.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./default.nix {};
|
||||||
|
devShells.x86_64-linux.default = nixpkgs.legacyPackages.x86_64-linux.callPackage ./shell.nix {};
|
||||||
|
formatter.x86_64-linux = nixpkgs.legacyPackages.x86_64-linux.alejandra;
|
||||||
|
};
|
||||||
|
}
|
51
shell.nix
Normal file
51
shell.nix
Normal file
@ -0,0 +1,51 @@
|
|||||||
|
{pkgs}:
|
||||||
|
with pkgs; let
|
||||||
|
build = pkgs.callPackage ./default.nix {};
|
||||||
|
in
|
||||||
|
mkShell {
|
||||||
|
inherit build;
|
||||||
|
packages = [
|
||||||
|
libX11
|
||||||
|
libXcursor
|
||||||
|
libXrandr
|
||||||
|
libXi
|
||||||
|
vulkan-headers
|
||||||
|
vulkan-loader
|
||||||
|
vulkan-validation-layers
|
||||||
|
vulkan-tools
|
||||||
|
pkg-config
|
||||||
|
git
|
||||||
|
gcc
|
||||||
|
gnumake
|
||||||
|
cmake
|
||||||
|
glibc
|
||||||
|
python3
|
||||||
|
shaderc
|
||||||
|
];
|
||||||
|
|
||||||
|
inputsFrom = [
|
||||||
|
cmake
|
||||||
|
shaderc
|
||||||
|
];
|
||||||
|
|
||||||
|
RUST_BACKTRACE = "1";
|
||||||
|
LD_LIBRARY_PATH="${pkgs.libX11}/lib:${pkgs.libXcursor}/lib:${pkgs.libXrandr}/lib:${pkgs.libXi}/lib:${pkgs.vulkan-loader}/lib:${pkgs.vulkan-validation-layers}/lib:${pkgs.vulkan-tools}/lib:${pkgs.vulkan-headers}/lib:${pkgs.stdenv.cc.cc.lib}:${pkgs.stdenv.cc.cc.lib64}:$LD_LIBRARY_PATH";
|
||||||
|
|
||||||
|
cargoBuildFlags = ["--release --features build-from-source"];
|
||||||
|
|
||||||
|
shellHook = ''
|
||||||
|
export VK_ICD_FILENAMES=${vulkan-loader}/share/vulkan/icd.d/radeon_icd64.json
|
||||||
|
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
|
||||||
|
export VK_INSTANCE_LAYERS=VK_LAYER_KHRONOS_validation
|
||||||
|
export VK_DEVICE_LAYERS=VK_LAYER_KHRONOS_validation
|
||||||
|
export VK_LOADER_DEBUG=all
|
||||||
|
export VK_LOADER_DEBUG_FILE=/tmp/vulkan.log
|
||||||
|
export VK_INSTANCE_EXTENSIONS=VK_EXT_debug_utils
|
||||||
|
export VK_DEVICE_EXTENSIONS=VK_EXT_debug_utils
|
||||||
|
export VK_LAYER_ENABLES=VK_LAYER_KHRONOS_validation
|
||||||
|
export VK_LAYER_DISABLES=VK_LAYER_LUNARG_api_dump
|
||||||
|
export VK_LAYER_PATH=${vulkan-validation-layers}/share/vulkan/explicit_layer.d
|
||||||
|
|
||||||
|
pkg-config --cflags fontconfig fontconfig >= 2.11.1 --libs vulkan
|
||||||
|
'';
|
||||||
|
}
|
196
src/main.rs
Normal file
196
src/main.rs
Normal file
@ -0,0 +1,196 @@
|
|||||||
|
// ... existing imports ...
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use wgpu::util::DeviceExt;
|
||||||
|
|
||||||
|
use bytemuck;
|
||||||
|
|
||||||
|
async fn run() {
|
||||||
|
let instance = wgpu::Instance::default();
|
||||||
|
println!("instance {:?}", instance);
|
||||||
|
|
||||||
|
let adapter = instance
|
||||||
|
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||||
|
power_preference: wgpu::PowerPreference::HighPerformance,
|
||||||
|
force_fallback_adapter: false,
|
||||||
|
compatible_surface: None,
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
let adapter = match adapter {
|
||||||
|
Some(adapter) => adapter,
|
||||||
|
None => {
|
||||||
|
println!("No suitable GPU adapters found on the system!");
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
let features = adapter.features();
|
||||||
|
let (device, queue) = adapter
|
||||||
|
.request_device(
|
||||||
|
&wgpu::DeviceDescriptor {
|
||||||
|
label: None,
|
||||||
|
features: features & wgpu::Features::TIMESTAMP_QUERY,
|
||||||
|
limits: Default::default(),
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let query_set = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||||
|
Some(device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||||
|
count: 2,
|
||||||
|
ty: wgpu::QueryType::Timestamp,
|
||||||
|
label: None,
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let start_instant = Instant::now();
|
||||||
|
let cs_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
|
||||||
|
label: None,
|
||||||
|
//source: wgpu::ShaderSource::SpirV(bytes_to_u32(include_bytes!("alu.spv")).into()),
|
||||||
|
source: wgpu::ShaderSource::Wgsl(include_str!("shader.wgsl").into()),
|
||||||
|
});
|
||||||
|
println!("shader compilation {:?}", start_instant.elapsed());
|
||||||
|
|
||||||
|
let matrix_size = std::mem::size_of::<[[f32; 4]; 4]>();
|
||||||
|
let matrix_a_data = [[1.0; 4]; 4]; // Example data
|
||||||
|
let matrix_b_data = [[2.0; 4]; 4]; // Example data
|
||||||
|
|
||||||
|
let matrix_a_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||||
|
label: Some("Matrix A Buffer"),
|
||||||
|
contents: bytemuck::bytes_of(&matrix_a_data),
|
||||||
|
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||||
|
});
|
||||||
|
let matrix_b_buf = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||||
|
label: Some("Matrix B Buffer"),
|
||||||
|
contents: bytemuck::bytes_of(&matrix_b_data),
|
||||||
|
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
|
||||||
|
});
|
||||||
|
let matrix_c_buf = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("Matrix C Buffer"),
|
||||||
|
size: matrix_size as u64,
|
||||||
|
usage: wgpu::BufferUsages::STORAGE
|
||||||
|
| wgpu::BufferUsages::MAP_READ
|
||||||
|
| wgpu::BufferUsages::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let query_buf = if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||||
|
Some(device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("Query Buffer"),
|
||||||
|
size: 16, // Enough for two 64-bit timestamps
|
||||||
|
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
}))
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||||
|
label: None,
|
||||||
|
entries: &[wgpu::BindGroupLayoutEntry {
|
||||||
|
binding: 0,
|
||||||
|
visibility: wgpu::ShaderStages::COMPUTE,
|
||||||
|
ty: wgpu::BindingType::Buffer {
|
||||||
|
ty: wgpu::BufferBindingType::Storage { read_only: false },
|
||||||
|
has_dynamic_offset: false,
|
||||||
|
min_binding_size: None,
|
||||||
|
},
|
||||||
|
count: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
let compute_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||||
|
label: None,
|
||||||
|
bind_group_layouts: &[&bind_group_layout],
|
||||||
|
push_constant_ranges: &[],
|
||||||
|
});
|
||||||
|
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||||
|
label: None,
|
||||||
|
layout: Some(&compute_pipeline_layout),
|
||||||
|
module: &cs_module,
|
||||||
|
entry_point: "main",
|
||||||
|
});
|
||||||
|
|
||||||
|
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||||
|
label: Some("Compute Bind Group"),
|
||||||
|
layout: &bind_group_layout,
|
||||||
|
entries: &[
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 0,
|
||||||
|
resource: matrix_a_buf.as_entire_binding(),
|
||||||
|
},
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 1,
|
||||||
|
resource: matrix_b_buf.as_entire_binding(),
|
||||||
|
},
|
||||||
|
wgpu::BindGroupEntry {
|
||||||
|
binding: 2,
|
||||||
|
resource: matrix_c_buf.as_entire_binding(),
|
||||||
|
},
|
||||||
|
],
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut encoder = device.create_command_encoder(&Default::default());
|
||||||
|
if let Some(query_set) = &query_set {
|
||||||
|
encoder.write_timestamp(query_set, 0);
|
||||||
|
}
|
||||||
|
{
|
||||||
|
let mut cpass = encoder.begin_compute_pass(&Default::default());
|
||||||
|
cpass.set_pipeline(&pipeline);
|
||||||
|
cpass.set_bind_group(0, &bind_group, &[]);
|
||||||
|
cpass.dispatch_workgroups(4, 4, 1); // Dispatch for a 4x4 matrix
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(query_set) = &query_set {
|
||||||
|
encoder.write_timestamp(query_set, 1);
|
||||||
|
}
|
||||||
|
encoder.copy_buffer_to_buffer(&matrix_c_buf, 0, &matrix_c_buf, 0, matrix_size as u64);
|
||||||
|
if let Some(query_set) = &query_set {
|
||||||
|
if let Some(query_buf) = &query_buf {
|
||||||
|
encoder.write_timestamp(query_set, 1);
|
||||||
|
encoder.resolve_query_set(query_set, 0..2, query_buf, 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
queue.submit(Some(encoder.finish()));
|
||||||
|
|
||||||
|
// Assuming query_buf has been properly initialized earlier
|
||||||
|
|
||||||
|
let buf_slice = matrix_c_buf.slice(..);
|
||||||
|
let query_slice = query_buf.as_ref().map(|buf| buf.slice(..)); // Adjust this line
|
||||||
|
|
||||||
|
let (sender, receiver) = futures_intrusive::channel::shared::oneshot_channel();
|
||||||
|
|
||||||
|
// Map the buffer for reading (matrix_c_buf)
|
||||||
|
buf_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
|
||||||
|
|
||||||
|
if let Some(q_slice) = query_slice {
|
||||||
|
// Map the query buffer if it exists
|
||||||
|
let _query_future = q_slice.map_async(wgpu::MapMode::Read, |_| ());
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("pre-poll {:?}", std::time::Instant::now());
|
||||||
|
device.poll(wgpu::Maintain::Wait);
|
||||||
|
println!("post-poll {:?}", std::time::Instant::now());
|
||||||
|
|
||||||
|
if let Some(Ok(())) = receiver.receive().await {
|
||||||
|
let data_raw = &*buf_slice.get_mapped_range();
|
||||||
|
println!("compute shader result: {:?}", data_raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
if features.contains(wgpu::Features::TIMESTAMP_QUERY) {
|
||||||
|
if let Some(q_slice) = query_slice {
|
||||||
|
let ts_period = queue.get_timestamp_period();
|
||||||
|
let ts_data_raw = &*q_slice.get_mapped_range();
|
||||||
|
let ts_data: &[u64] = bytemuck::cast_slice(ts_data_raw);
|
||||||
|
println!(
|
||||||
|
"compute shader elapsed: {:?}ms",
|
||||||
|
(ts_data[1] - ts_data[0]) as f64 * ts_period as f64 * 1e-6
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
pollster::block_on(run());
|
||||||
|
}
|
24
src/shader.wgsl
Normal file
24
src/shader.wgsl
Normal file
@ -0,0 +1,24 @@
|
|||||||
|
struct Matrix {
|
||||||
|
data: array<array<f32, 4>, 4>; // Assuming 4x4 matrices
|
||||||
|
};
|
||||||
|
|
||||||
|
@group(0) @binding(0)
|
||||||
|
var<storage, read> matrixA: Matrix;
|
||||||
|
|
||||||
|
@group(0) @binding(1)
|
||||||
|
var<storage, read> matrixB: Matrix;
|
||||||
|
|
||||||
|
@group(0) @binding(2)
|
||||||
|
var<storage, read_write> matrixC: Matrix;
|
||||||
|
|
||||||
|
@compute @workgroup_size(4, 4)
|
||||||
|
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
|
||||||
|
let row: u32 = global_id.y;
|
||||||
|
let col: u32 = global_id.x;
|
||||||
|
|
||||||
|
var sum: f32 = 0.0;
|
||||||
|
for (var k: u32 = 0; k < 4u; k = k + 1u) {
|
||||||
|
sum = sum + matrixA.data[row][k] * matrixB.data[k][col];
|
||||||
|
}
|
||||||
|
matrixC.data[row][col] = sum;
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user