mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
Collatz Computation (#623)
Build the compute shader for vulkan1.1 as required
This commit is contained in:
parent
a5e9fe751b
commit
cb952562dd
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -401,6 +401,7 @@ dependencies = [
|
|||||||
name = "compute-shader"
|
name = "compute-shader"
|
||||||
version = "0.4.0-alpha.8"
|
version = "0.4.0-alpha.8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"rayon",
|
||||||
"spirv-std",
|
"spirv-std",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
|
|||||||
let mut cargo = Command::new("cargo");
|
let mut cargo = Command::new("cargo");
|
||||||
cargo.args(&[
|
cargo.args(&[
|
||||||
"build",
|
"build",
|
||||||
|
"--lib",
|
||||||
"--message-format=json-render-diagnostics",
|
"--message-format=json-render-diagnostics",
|
||||||
"-Zbuild-std=core",
|
"-Zbuild-std=core",
|
||||||
"-Zbuild-std-features=compiler-builtins-mem",
|
"-Zbuild-std-features=compiler-builtins-mem",
|
||||||
|
@ -11,7 +11,7 @@ fn build_shader(
|
|||||||
) -> Result<(), Box<dyn Error>> {
|
) -> Result<(), Box<dyn Error>> {
|
||||||
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
|
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
|
||||||
let path_to_crate = builder_dir.join(path_to_crate);
|
let path_to_crate = builder_dir.join(path_to_crate);
|
||||||
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0");
|
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1");
|
||||||
for &cap in caps {
|
for &cap in caps {
|
||||||
builder = builder.capability(cap);
|
builder = builder.capability(cap);
|
||||||
}
|
}
|
||||||
@ -28,7 +28,12 @@ fn build_shader(
|
|||||||
fn main() -> Result<(), Box<dyn Error>> {
|
fn main() -> Result<(), Box<dyn Error>> {
|
||||||
build_shader("../../../shaders/sky-shader", true, &[])?;
|
build_shader("../../../shaders/sky-shader", true, &[])?;
|
||||||
build_shader("../../../shaders/simplest-shader", false, &[])?;
|
build_shader("../../../shaders/simplest-shader", false, &[])?;
|
||||||
build_shader("../../../shaders/compute-shader", false, &[])?;
|
// We need the int8 capability for using `Option`
|
||||||
|
build_shader(
|
||||||
|
"../../../shaders/compute-shader",
|
||||||
|
false,
|
||||||
|
&[Capability::Int8],
|
||||||
|
)?;
|
||||||
build_shader("../../../shaders/mouse-shader", false, &[])?;
|
build_shader("../../../shaders/mouse-shader", false, &[])?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1,34 +1,15 @@
|
|||||||
|
use wgpu::util::DeviceExt;
|
||||||
|
|
||||||
use super::{shader_module, Options};
|
use super::{shader_module, Options};
|
||||||
use core::num::NonZeroU64;
|
use futures::future::join;
|
||||||
|
use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration};
|
||||||
|
|
||||||
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
|
fn block_on<T>(future: impl Future<Output = T>) -> T {
|
||||||
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
|
|
||||||
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
|
|
||||||
let adapter = instance
|
|
||||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
|
||||||
power_preference: wgpu::PowerPreference::default(),
|
|
||||||
compatible_surface: None,
|
|
||||||
})
|
|
||||||
.await
|
|
||||||
.expect("Failed to find an appropriate adapter");
|
|
||||||
|
|
||||||
adapter
|
|
||||||
.request_device(
|
|
||||||
&wgpu::DeviceDescriptor {
|
|
||||||
label: None,
|
|
||||||
features: wgpu::Features::empty(),
|
|
||||||
limits: wgpu::Limits::default(),
|
|
||||||
},
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
.expect("Failed to create device")
|
|
||||||
}
|
|
||||||
cfg_if::cfg_if! {
|
cfg_if::cfg_if! {
|
||||||
if #[cfg(target_arch = "wasm32")] {
|
if #[cfg(target_arch = "wasm32")] {
|
||||||
wasm_bindgen_futures::spawn_local(create_device_queue_async())
|
wasm_bindgen_futures::spawn_local(future)
|
||||||
} else {
|
} else {
|
||||||
futures::executor::block_on(create_device_queue_async())
|
futures::executor::block_on(future)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -36,11 +17,49 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
|
|||||||
pub fn start(options: &Options) {
|
pub fn start(options: &Options) {
|
||||||
let shader_binary = shader_module(options.shader);
|
let shader_binary = shader_module(options.shader);
|
||||||
|
|
||||||
let (device, queue) = create_device_queue();
|
block_on(start_internal(options, shader_binary))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn start_internal(
|
||||||
|
_options: &Options,
|
||||||
|
shader_binary: wgpu::ShaderModuleDescriptor<'static>,
|
||||||
|
) {
|
||||||
|
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
|
||||||
|
let adapter = instance
|
||||||
|
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||||
|
power_preference: wgpu::PowerPreference::default(),
|
||||||
|
compatible_surface: None,
|
||||||
|
})
|
||||||
|
.await
|
||||||
|
.expect("Failed to find an appropriate adapter");
|
||||||
|
|
||||||
|
let timestamp_period = adapter.get_timestamp_period();
|
||||||
|
let (device, queue) = adapter
|
||||||
|
.request_device(
|
||||||
|
&wgpu::DeviceDescriptor {
|
||||||
|
label: None,
|
||||||
|
features: wgpu::Features::TIMESTAMP_QUERY,
|
||||||
|
limits: wgpu::Limits::default(),
|
||||||
|
},
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.expect("Failed to create device");
|
||||||
|
drop(instance);
|
||||||
|
drop(adapter);
|
||||||
// Load the shaders from disk
|
// Load the shaders from disk
|
||||||
let module = device.create_shader_module(&shader_binary);
|
let module = device.create_shader_module(&shader_binary);
|
||||||
|
|
||||||
|
let top = 2u32.pow(20);
|
||||||
|
let src_range = 1..top;
|
||||||
|
|
||||||
|
let src = src_range
|
||||||
|
.clone()
|
||||||
|
// Not sure which endianness is correct to use here
|
||||||
|
.map(u32::to_ne_bytes)
|
||||||
|
.flat_map(core::array::IntoIter::new)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||||
label: None,
|
label: None,
|
||||||
entries: &[
|
entries: &[
|
||||||
@ -72,10 +91,26 @@ pub fn start(options: &Options) {
|
|||||||
entry_point: "main_cs",
|
entry_point: "main_cs",
|
||||||
});
|
});
|
||||||
|
|
||||||
let buf = device.create_buffer(&wgpu::BufferDescriptor {
|
let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
label: None,
|
label: None,
|
||||||
size: 1,
|
size: src.len() as wgpu::BufferAddress,
|
||||||
usage: wgpu::BufferUsage::STORAGE,
|
// Can be read to the CPU, and can be copied from the shader's storage buffer
|
||||||
|
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
|
||||||
|
mapped_at_creation: false,
|
||||||
|
});
|
||||||
|
|
||||||
|
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||||
|
label: Some("Collatz Conjecture Input"),
|
||||||
|
contents: &src,
|
||||||
|
usage: wgpu::BufferUsage::STORAGE
|
||||||
|
| wgpu::BufferUsage::COPY_DST
|
||||||
|
| wgpu::BufferUsage::COPY_SRC,
|
||||||
|
});
|
||||||
|
|
||||||
|
let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||||
|
label: Some("Timestamps buffer"),
|
||||||
|
size: 16,
|
||||||
|
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
|
||||||
mapped_at_creation: false,
|
mapped_at_creation: false,
|
||||||
});
|
});
|
||||||
|
|
||||||
@ -84,14 +119,15 @@ pub fn start(options: &Options) {
|
|||||||
layout: &bind_group_layout,
|
layout: &bind_group_layout,
|
||||||
entries: &[wgpu::BindGroupEntry {
|
entries: &[wgpu::BindGroupEntry {
|
||||||
binding: 0,
|
binding: 0,
|
||||||
resource: wgpu::BindingResource::Buffer {
|
resource: storage_buffer.as_entire_binding(),
|
||||||
buffer: &buf,
|
|
||||||
offset: 0,
|
|
||||||
size: None,
|
|
||||||
},
|
|
||||||
}],
|
}],
|
||||||
});
|
});
|
||||||
|
|
||||||
|
let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||||
|
count: 2,
|
||||||
|
ty: wgpu::QueryType::Timestamp,
|
||||||
|
});
|
||||||
|
|
||||||
let mut encoder =
|
let mut encoder =
|
||||||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||||
|
|
||||||
@ -99,8 +135,58 @@ pub fn start(options: &Options) {
|
|||||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||||
cpass.set_bind_group(0, &bind_group, &[]);
|
cpass.set_bind_group(0, &bind_group, &[]);
|
||||||
cpass.set_pipeline(&compute_pipeline);
|
cpass.set_pipeline(&compute_pipeline);
|
||||||
cpass.dispatch(1, 1, 1);
|
cpass.write_timestamp(&queries, 0);
|
||||||
|
cpass.dispatch(src_range.len() as u32 / 64, 1, 1);
|
||||||
|
cpass.write_timestamp(&queries, 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
encoder.copy_buffer_to_buffer(
|
||||||
|
&storage_buffer,
|
||||||
|
0,
|
||||||
|
&readback_buffer,
|
||||||
|
0,
|
||||||
|
src.len() as wgpu::BufferAddress,
|
||||||
|
);
|
||||||
|
encoder.resolve_query_set(&queries, 0..2, ×tamp_buffer, 0);
|
||||||
|
|
||||||
queue.submit(Some(encoder.finish()));
|
queue.submit(Some(encoder.finish()));
|
||||||
|
let buffer_slice = readback_buffer.slice(..);
|
||||||
|
let timestamp_slice = timestamp_buffer.slice(..);
|
||||||
|
let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read);
|
||||||
|
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
|
||||||
|
device.poll(wgpu::Maintain::Wait);
|
||||||
|
|
||||||
|
if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await {
|
||||||
|
let data = buffer_slice.get_mapped_range();
|
||||||
|
let timing_data = timestamp_slice.get_mapped_range();
|
||||||
|
let result = data
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let timings = timing_data
|
||||||
|
.chunks_exact(8)
|
||||||
|
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
drop(data);
|
||||||
|
readback_buffer.unmap();
|
||||||
|
drop(timing_data);
|
||||||
|
timestamp_buffer.unmap();
|
||||||
|
let mut max = 0;
|
||||||
|
for (src, out) in src_range.zip(result.iter().copied()) {
|
||||||
|
if out == u32::MAX {
|
||||||
|
println!("{}: overflowed", src);
|
||||||
|
break;
|
||||||
|
} else if out > max {
|
||||||
|
max = out;
|
||||||
|
// Should produce <https://oeis.org/A006877>
|
||||||
|
println!("{}: {}", src, out);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!(
|
||||||
|
"Took: {:?}",
|
||||||
|
Duration::from_nanos(
|
||||||
|
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
|
||||||
|
)
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
|
|||||||
{
|
{
|
||||||
use spirv_builder::{Capability, SpirvBuilder};
|
use spirv_builder::{Capability, SpirvBuilder};
|
||||||
use std::borrow::Cow;
|
use std::borrow::Cow;
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::PathBuf;
|
||||||
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
|
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
|
||||||
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
|
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
|
||||||
// have different configurations when being built here vs. when building
|
// have different configurations when being built here vs. when building
|
||||||
@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
|
|||||||
let (crate_name, capabilities): (_, &[Capability]) = match shader {
|
let (crate_name, capabilities): (_, &[Capability]) = match shader {
|
||||||
RustGPUShader::Simplest => ("simplest-shader", &[]),
|
RustGPUShader::Simplest => ("simplest-shader", &[]),
|
||||||
RustGPUShader::Sky => ("sky-shader", &[]),
|
RustGPUShader::Sky => ("sky-shader", &[]),
|
||||||
RustGPUShader::Compute => ("compute-shader", &[]),
|
RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]),
|
||||||
RustGPUShader::Mouse => ("mouse-shader", &[]),
|
RustGPUShader::Mouse => ("mouse-shader", &[]),
|
||||||
};
|
};
|
||||||
let manifest_dir = env!("CARGO_MANIFEST_DIR");
|
let manifest_dir = env!("CARGO_MANIFEST_DIR");
|
||||||
let crate_path = [
|
let crate_path = [manifest_dir, "..", "..", "shaders", crate_name]
|
||||||
Path::new(manifest_dir),
|
.iter()
|
||||||
Path::new(".."),
|
.copied()
|
||||||
Path::new(".."),
|
.collect::<PathBuf>();
|
||||||
Path::new("shaders"),
|
|
||||||
Path::new(crate_name),
|
|
||||||
]
|
|
||||||
.iter()
|
|
||||||
.copied()
|
|
||||||
.collect::<PathBuf>();
|
|
||||||
let mut builder =
|
let mut builder =
|
||||||
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false);
|
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false);
|
||||||
for &cap in capabilities {
|
for &cap in capabilities {
|
||||||
builder = builder.capability(cap);
|
builder = builder.capability(cap);
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0"
|
|||||||
publish = false
|
publish = false
|
||||||
|
|
||||||
[lib]
|
[lib]
|
||||||
crate-type = ["dylib"]
|
crate-type = ["dylib", "lib"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }
|
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }
|
||||||
|
|
||||||
|
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
|
||||||
|
rayon = "1.5"
|
||||||
|
@ -1,17 +1,48 @@
|
|||||||
#![cfg_attr(
|
#![cfg_attr(
|
||||||
target_arch = "spirv",
|
target_arch = "spirv",
|
||||||
no_std,
|
|
||||||
feature(register_attr),
|
feature(register_attr),
|
||||||
register_attr(spirv)
|
register_attr(spirv),
|
||||||
|
no_std
|
||||||
)]
|
)]
|
||||||
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
|
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
|
||||||
#![deny(warnings)]
|
#![deny(warnings)]
|
||||||
|
|
||||||
extern crate spirv_std;
|
extern crate spirv_std;
|
||||||
|
|
||||||
|
use glam::UVec3;
|
||||||
|
use spirv_std::glam;
|
||||||
#[cfg(not(target_arch = "spirv"))]
|
#[cfg(not(target_arch = "spirv"))]
|
||||||
use spirv_std::macros::spirv;
|
use spirv_std::macros::spirv;
|
||||||
|
|
||||||
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
|
// Adapted from the wgpu hello-compute example
|
||||||
#[spirv(compute(threads(32)))]
|
|
||||||
pub fn main_cs() {}
|
pub fn collatz(mut n: u32) -> Option<u32> {
|
||||||
|
let mut i = 0;
|
||||||
|
if n == 0 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
while n != 1 {
|
||||||
|
n = if n % 2 == 0 {
|
||||||
|
n / 2
|
||||||
|
} else {
|
||||||
|
// Overflow? (i.e. 3*n + 1 > 0xffff_ffff)
|
||||||
|
if n >= 0x5555_5555 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
// TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)?
|
||||||
|
3 * n + 1
|
||||||
|
};
|
||||||
|
i += 1;
|
||||||
|
}
|
||||||
|
Some(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// LocalSize/numthreads of (x = 64, y = 1, z = 1)
|
||||||
|
#[spirv(compute(threads(64)))]
|
||||||
|
pub fn main_cs(
|
||||||
|
#[spirv(global_invocation_id)] id: UVec3,
|
||||||
|
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32],
|
||||||
|
) {
|
||||||
|
let index = id.x as usize;
|
||||||
|
prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX);
|
||||||
|
}
|
||||||
|
32
examples/shaders/compute-shader/src/main.rs
Normal file
32
examples/shaders/compute-shader/src/main.rs
Normal file
@ -0,0 +1,32 @@
|
|||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
use compute_shader::collatz;
|
||||||
|
use rayon::prelude::*;
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
let top = 2u32.pow(20);
|
||||||
|
let src_range = 1..top;
|
||||||
|
let start = Instant::now();
|
||||||
|
let result = src_range
|
||||||
|
.clone()
|
||||||
|
.into_par_iter()
|
||||||
|
.map(collatz)
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let took = start.elapsed();
|
||||||
|
let mut max = 0;
|
||||||
|
for (src, out) in src_range.zip(result.iter().copied()) {
|
||||||
|
match out {
|
||||||
|
Some(out) if out > max => {
|
||||||
|
max = out;
|
||||||
|
// Should produce <https://oeis.org/A006877>
|
||||||
|
println!("{}: {}", src, out);
|
||||||
|
}
|
||||||
|
Some(_) => (),
|
||||||
|
None => {
|
||||||
|
println!("{}: overflowed", src);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
println!("Took: {:?}", took);
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user