mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
Collatz Computation (#623)
Build the compute shader for vulkan1.1 as required
This commit is contained in:
parent
a5e9fe751b
commit
cb952562dd
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -401,6 +401,7 @@ dependencies = [
|
||||
name = "compute-shader"
|
||||
version = "0.4.0-alpha.8"
|
||||
dependencies = [
|
||||
"rayon",
|
||||
"spirv-std",
|
||||
]
|
||||
|
||||
|
@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
|
||||
let mut cargo = Command::new("cargo");
|
||||
cargo.args(&[
|
||||
"build",
|
||||
"--lib",
|
||||
"--message-format=json-render-diagnostics",
|
||||
"-Zbuild-std=core",
|
||||
"-Zbuild-std-features=compiler-builtins-mem",
|
||||
|
@ -11,7 +11,7 @@ fn build_shader(
|
||||
) -> Result<(), Box<dyn Error>> {
|
||||
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
|
||||
let path_to_crate = builder_dir.join(path_to_crate);
|
||||
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0");
|
||||
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1");
|
||||
for &cap in caps {
|
||||
builder = builder.capability(cap);
|
||||
}
|
||||
@ -28,7 +28,12 @@ fn build_shader(
|
||||
fn main() -> Result<(), Box<dyn Error>> {
|
||||
build_shader("../../../shaders/sky-shader", true, &[])?;
|
||||
build_shader("../../../shaders/simplest-shader", false, &[])?;
|
||||
build_shader("../../../shaders/compute-shader", false, &[])?;
|
||||
// We need the int8 capability for using `Option`
|
||||
build_shader(
|
||||
"../../../shaders/compute-shader",
|
||||
false,
|
||||
&[Capability::Int8],
|
||||
)?;
|
||||
build_shader("../../../shaders/mouse-shader", false, &[])?;
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1,8 +1,29 @@
|
||||
use super::{shader_module, Options};
|
||||
use core::num::NonZeroU64;
|
||||
use wgpu::util::DeviceExt;
|
||||
|
||||
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
|
||||
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
|
||||
use super::{shader_module, Options};
|
||||
use futures::future::join;
|
||||
use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration};
|
||||
|
||||
fn block_on<T>(future: impl Future<Output = T>) -> T {
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
wasm_bindgen_futures::spawn_local(future)
|
||||
} else {
|
||||
futures::executor::block_on(future)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(options: &Options) {
|
||||
let shader_binary = shader_module(options.shader);
|
||||
|
||||
block_on(start_internal(options, shader_binary))
|
||||
}
|
||||
|
||||
pub async fn start_internal(
|
||||
_options: &Options,
|
||||
shader_binary: wgpu::ShaderModuleDescriptor<'static>,
|
||||
) {
|
||||
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
|
||||
let adapter = instance
|
||||
.request_adapter(&wgpu::RequestAdapterOptions {
|
||||
@ -12,35 +33,33 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
|
||||
.await
|
||||
.expect("Failed to find an appropriate adapter");
|
||||
|
||||
adapter
|
||||
let timestamp_period = adapter.get_timestamp_period();
|
||||
let (device, queue) = adapter
|
||||
.request_device(
|
||||
&wgpu::DeviceDescriptor {
|
||||
label: None,
|
||||
features: wgpu::Features::empty(),
|
||||
features: wgpu::Features::TIMESTAMP_QUERY,
|
||||
limits: wgpu::Limits::default(),
|
||||
},
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("Failed to create device")
|
||||
}
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(target_arch = "wasm32")] {
|
||||
wasm_bindgen_futures::spawn_local(create_device_queue_async())
|
||||
} else {
|
||||
futures::executor::block_on(create_device_queue_async())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn start(options: &Options) {
|
||||
let shader_binary = shader_module(options.shader);
|
||||
|
||||
let (device, queue) = create_device_queue();
|
||||
|
||||
.expect("Failed to create device");
|
||||
drop(instance);
|
||||
drop(adapter);
|
||||
// Load the shaders from disk
|
||||
let module = device.create_shader_module(&shader_binary);
|
||||
|
||||
let top = 2u32.pow(20);
|
||||
let src_range = 1..top;
|
||||
|
||||
let src = src_range
|
||||
.clone()
|
||||
// Not sure which endianness is correct to use here
|
||||
.map(u32::to_ne_bytes)
|
||||
.flat_map(core::array::IntoIter::new)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
label: None,
|
||||
entries: &[
|
||||
@ -72,10 +91,26 @@ pub fn start(options: &Options) {
|
||||
entry_point: "main_cs",
|
||||
});
|
||||
|
||||
let buf = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: None,
|
||||
size: 1,
|
||||
usage: wgpu::BufferUsage::STORAGE,
|
||||
size: src.len() as wgpu::BufferAddress,
|
||||
// Can be read to the CPU, and can be copied from the shader's storage buffer
|
||||
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
|
||||
label: Some("Collatz Conjecture Input"),
|
||||
contents: &src,
|
||||
usage: wgpu::BufferUsage::STORAGE
|
||||
| wgpu::BufferUsage::COPY_DST
|
||||
| wgpu::BufferUsage::COPY_SRC,
|
||||
});
|
||||
|
||||
let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
label: Some("Timestamps buffer"),
|
||||
size: 16,
|
||||
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
|
||||
mapped_at_creation: false,
|
||||
});
|
||||
|
||||
@ -84,14 +119,15 @@ pub fn start(options: &Options) {
|
||||
layout: &bind_group_layout,
|
||||
entries: &[wgpu::BindGroupEntry {
|
||||
binding: 0,
|
||||
resource: wgpu::BindingResource::Buffer {
|
||||
buffer: &buf,
|
||||
offset: 0,
|
||||
size: None,
|
||||
},
|
||||
resource: storage_buffer.as_entire_binding(),
|
||||
}],
|
||||
});
|
||||
|
||||
let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
|
||||
count: 2,
|
||||
ty: wgpu::QueryType::Timestamp,
|
||||
});
|
||||
|
||||
let mut encoder =
|
||||
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
|
||||
|
||||
@ -99,8 +135,58 @@ pub fn start(options: &Options) {
|
||||
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
|
||||
cpass.set_bind_group(0, &bind_group, &[]);
|
||||
cpass.set_pipeline(&compute_pipeline);
|
||||
cpass.dispatch(1, 1, 1);
|
||||
cpass.write_timestamp(&queries, 0);
|
||||
cpass.dispatch(src_range.len() as u32 / 64, 1, 1);
|
||||
cpass.write_timestamp(&queries, 1);
|
||||
}
|
||||
|
||||
encoder.copy_buffer_to_buffer(
|
||||
&storage_buffer,
|
||||
0,
|
||||
&readback_buffer,
|
||||
0,
|
||||
src.len() as wgpu::BufferAddress,
|
||||
);
|
||||
encoder.resolve_query_set(&queries, 0..2, ×tamp_buffer, 0);
|
||||
|
||||
queue.submit(Some(encoder.finish()));
|
||||
let buffer_slice = readback_buffer.slice(..);
|
||||
let timestamp_slice = timestamp_buffer.slice(..);
|
||||
let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read);
|
||||
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
|
||||
device.poll(wgpu::Maintain::Wait);
|
||||
|
||||
if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await {
|
||||
let data = buffer_slice.get_mapped_range();
|
||||
let timing_data = timestamp_slice.get_mapped_range();
|
||||
let result = data
|
||||
.chunks_exact(4)
|
||||
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
let timings = timing_data
|
||||
.chunks_exact(8)
|
||||
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
|
||||
.collect::<Vec<_>>();
|
||||
drop(data);
|
||||
readback_buffer.unmap();
|
||||
drop(timing_data);
|
||||
timestamp_buffer.unmap();
|
||||
let mut max = 0;
|
||||
for (src, out) in src_range.zip(result.iter().copied()) {
|
||||
if out == u32::MAX {
|
||||
println!("{}: overflowed", src);
|
||||
break;
|
||||
} else if out > max {
|
||||
max = out;
|
||||
// Should produce <https://oeis.org/A006877>
|
||||
println!("{}: {}", src, out);
|
||||
}
|
||||
}
|
||||
println!(
|
||||
"Took: {:?}",
|
||||
Duration::from_nanos(
|
||||
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
|
||||
{
|
||||
use spirv_builder::{Capability, SpirvBuilder};
|
||||
use std::borrow::Cow;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::path::PathBuf;
|
||||
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
|
||||
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
|
||||
// have different configurations when being built here vs. when building
|
||||
@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
|
||||
let (crate_name, capabilities): (_, &[Capability]) = match shader {
|
||||
RustGPUShader::Simplest => ("simplest-shader", &[]),
|
||||
RustGPUShader::Sky => ("sky-shader", &[]),
|
||||
RustGPUShader::Compute => ("compute-shader", &[]),
|
||||
RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]),
|
||||
RustGPUShader::Mouse => ("mouse-shader", &[]),
|
||||
};
|
||||
let manifest_dir = env!("CARGO_MANIFEST_DIR");
|
||||
let crate_path = [
|
||||
Path::new(manifest_dir),
|
||||
Path::new(".."),
|
||||
Path::new(".."),
|
||||
Path::new("shaders"),
|
||||
Path::new(crate_name),
|
||||
]
|
||||
let crate_path = [manifest_dir, "..", "..", "shaders", crate_name]
|
||||
.iter()
|
||||
.copied()
|
||||
.collect::<PathBuf>();
|
||||
let mut builder =
|
||||
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false);
|
||||
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false);
|
||||
for &cap in capabilities {
|
||||
builder = builder.capability(cap);
|
||||
}
|
||||
|
@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0"
|
||||
publish = false
|
||||
|
||||
[lib]
|
||||
crate-type = ["dylib"]
|
||||
crate-type = ["dylib", "lib"]
|
||||
|
||||
[dependencies]
|
||||
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }
|
||||
|
||||
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
|
||||
rayon = "1.5"
|
||||
|
@ -1,17 +1,48 @@
|
||||
#![cfg_attr(
|
||||
target_arch = "spirv",
|
||||
no_std,
|
||||
feature(register_attr),
|
||||
register_attr(spirv)
|
||||
register_attr(spirv),
|
||||
no_std
|
||||
)]
|
||||
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
|
||||
#![deny(warnings)]
|
||||
|
||||
extern crate spirv_std;
|
||||
|
||||
use glam::UVec3;
|
||||
use spirv_std::glam;
|
||||
#[cfg(not(target_arch = "spirv"))]
|
||||
use spirv_std::macros::spirv;
|
||||
|
||||
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
|
||||
#[spirv(compute(threads(32)))]
|
||||
pub fn main_cs() {}
|
||||
// Adapted from the wgpu hello-compute example
|
||||
|
||||
pub fn collatz(mut n: u32) -> Option<u32> {
|
||||
let mut i = 0;
|
||||
if n == 0 {
|
||||
return None;
|
||||
}
|
||||
while n != 1 {
|
||||
n = if n % 2 == 0 {
|
||||
n / 2
|
||||
} else {
|
||||
// Overflow? (i.e. 3*n + 1 > 0xffff_ffff)
|
||||
if n >= 0x5555_5555 {
|
||||
return None;
|
||||
}
|
||||
// TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)?
|
||||
3 * n + 1
|
||||
};
|
||||
i += 1;
|
||||
}
|
||||
Some(i)
|
||||
}
|
||||
|
||||
// LocalSize/numthreads of (x = 64, y = 1, z = 1)
|
||||
#[spirv(compute(threads(64)))]
|
||||
pub fn main_cs(
|
||||
#[spirv(global_invocation_id)] id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32],
|
||||
) {
|
||||
let index = id.x as usize;
|
||||
prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX);
|
||||
}
|
||||
|
32
examples/shaders/compute-shader/src/main.rs
Normal file
32
examples/shaders/compute-shader/src/main.rs
Normal file
@ -0,0 +1,32 @@
|
||||
use std::time::Instant;
|
||||
|
||||
use compute_shader::collatz;
|
||||
use rayon::prelude::*;
|
||||
|
||||
fn main() {
|
||||
let top = 2u32.pow(20);
|
||||
let src_range = 1..top;
|
||||
let start = Instant::now();
|
||||
let result = src_range
|
||||
.clone()
|
||||
.into_par_iter()
|
||||
.map(collatz)
|
||||
.collect::<Vec<_>>();
|
||||
let took = start.elapsed();
|
||||
let mut max = 0;
|
||||
for (src, out) in src_range.zip(result.iter().copied()) {
|
||||
match out {
|
||||
Some(out) if out > max => {
|
||||
max = out;
|
||||
// Should produce <https://oeis.org/A006877>
|
||||
println!("{}: {}", src, out);
|
||||
}
|
||||
Some(_) => (),
|
||||
None => {
|
||||
println!("{}: overflowed", src);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
println!("Took: {:?}", took);
|
||||
}
|
Loading…
Reference in New Issue
Block a user