Collatz Computation (#623)

Build the compute shader for vulkan1.1 as required
This commit is contained in:
Daniel McNab 2021-06-03 13:20:42 +01:00 committed by GitHub
parent a5e9fe751b
commit cb952562dd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 210 additions and 57 deletions

1
Cargo.lock generated
View File

@ -401,6 +401,7 @@ dependencies = [
name = "compute-shader"
version = "0.4.0-alpha.8"
dependencies = [
"rayon",
"spirv-std",
]

View File

@ -367,6 +367,7 @@ fn invoke_rustc(builder: &SpirvBuilder) -> Result<PathBuf, SpirvBuilderError> {
let mut cargo = Command::new("cargo");
cargo.args(&[
"build",
"--lib",
"--message-format=json-render-diagnostics",
"-Zbuild-std=core",
"-Zbuild-std-features=compiler-builtins-mem",

View File

@ -11,7 +11,7 @@ fn build_shader(
) -> Result<(), Box<dyn Error>> {
let builder_dir = &Path::new(env!("CARGO_MANIFEST_DIR"));
let path_to_crate = builder_dir.join(path_to_crate);
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.0");
let mut builder = SpirvBuilder::new(path_to_crate, "spirv-unknown-vulkan1.1");
for &cap in caps {
builder = builder.capability(cap);
}
@ -28,7 +28,12 @@ fn build_shader(
fn main() -> Result<(), Box<dyn Error>> {
build_shader("../../../shaders/sky-shader", true, &[])?;
build_shader("../../../shaders/simplest-shader", false, &[])?;
build_shader("../../../shaders/compute-shader", false, &[])?;
// We need the int8 capability for using `Option`
build_shader(
"../../../shaders/compute-shader",
false,
&[Capability::Int8],
)?;
build_shader("../../../shaders/mouse-shader", false, &[])?;
Ok(())
}

View File

@ -1,34 +1,15 @@
use wgpu::util::DeviceExt;
use super::{shader_module, Options};
use core::num::NonZeroU64;
use futures::future::join;
use std::{convert::TryInto, future::Future, num::NonZeroU64, time::Duration};
fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
async fn create_device_queue_async() -> (wgpu::Device, wgpu::Queue) {
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: None,
})
.await
.expect("Failed to find an appropriate adapter");
adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: wgpu::Features::empty(),
limits: wgpu::Limits::default(),
},
None,
)
.await
.expect("Failed to create device")
}
fn block_on<T>(future: impl Future<Output = T>) -> T {
cfg_if::cfg_if! {
if #[cfg(target_arch = "wasm32")] {
wasm_bindgen_futures::spawn_local(create_device_queue_async())
wasm_bindgen_futures::spawn_local(future)
} else {
futures::executor::block_on(create_device_queue_async())
futures::executor::block_on(future)
}
}
}
@ -36,11 +17,49 @@ fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
pub fn start(options: &Options) {
let shader_binary = shader_module(options.shader);
let (device, queue) = create_device_queue();
block_on(start_internal(options, shader_binary))
}
pub async fn start_internal(
_options: &Options,
shader_binary: wgpu::ShaderModuleDescriptor<'static>,
) {
let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
let adapter = instance
.request_adapter(&wgpu::RequestAdapterOptions {
power_preference: wgpu::PowerPreference::default(),
compatible_surface: None,
})
.await
.expect("Failed to find an appropriate adapter");
let timestamp_period = adapter.get_timestamp_period();
let (device, queue) = adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
features: wgpu::Features::TIMESTAMP_QUERY,
limits: wgpu::Limits::default(),
},
None,
)
.await
.expect("Failed to create device");
drop(instance);
drop(adapter);
// Load the shaders from disk
let module = device.create_shader_module(&shader_binary);
let top = 2u32.pow(20);
let src_range = 1..top;
let src = src_range
.clone()
// Not sure which endianness is correct to use here
.map(u32::to_ne_bytes)
.flat_map(core::array::IntoIter::new)
.collect::<Vec<_>>();
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: None,
entries: &[
@ -72,10 +91,26 @@ pub fn start(options: &Options) {
entry_point: "main_cs",
});
let buf = device.create_buffer(&wgpu::BufferDescriptor {
let readback_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size: 1,
usage: wgpu::BufferUsage::STORAGE,
size: src.len() as wgpu::BufferAddress,
// Can be read to the CPU, and can be copied from the shader's storage buffer
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Collatz Conjecture Input"),
contents: &src,
usage: wgpu::BufferUsage::STORAGE
| wgpu::BufferUsage::COPY_DST
| wgpu::BufferUsage::COPY_SRC,
});
let timestamp_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Timestamps buffer"),
size: 16,
usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
mapped_at_creation: false,
});
@ -84,14 +119,15 @@ pub fn start(options: &Options) {
layout: &bind_group_layout,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::Buffer {
buffer: &buf,
offset: 0,
size: None,
},
resource: storage_buffer.as_entire_binding(),
}],
});
let queries = device.create_query_set(&wgpu::QuerySetDescriptor {
count: 2,
ty: wgpu::QueryType::Timestamp,
});
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
@ -99,8 +135,58 @@ pub fn start(options: &Options) {
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { label: None });
cpass.set_bind_group(0, &bind_group, &[]);
cpass.set_pipeline(&compute_pipeline);
cpass.dispatch(1, 1, 1);
cpass.write_timestamp(&queries, 0);
cpass.dispatch(src_range.len() as u32 / 64, 1, 1);
cpass.write_timestamp(&queries, 1);
}
encoder.copy_buffer_to_buffer(
&storage_buffer,
0,
&readback_buffer,
0,
src.len() as wgpu::BufferAddress,
);
encoder.resolve_query_set(&queries, 0..2, &timestamp_buffer, 0);
queue.submit(Some(encoder.finish()));
let buffer_slice = readback_buffer.slice(..);
let timestamp_slice = timestamp_buffer.slice(..);
let timestamp_future = timestamp_slice.map_async(wgpu::MapMode::Read);
let buffer_future = buffer_slice.map_async(wgpu::MapMode::Read);
device.poll(wgpu::Maintain::Wait);
if let (Ok(()), Ok(())) = join(buffer_future, timestamp_future).await {
let data = buffer_slice.get_mapped_range();
let timing_data = timestamp_slice.get_mapped_range();
let result = data
.chunks_exact(4)
.map(|b| u32::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
let timings = timing_data
.chunks_exact(8)
.map(|b| u64::from_ne_bytes(b.try_into().unwrap()))
.collect::<Vec<_>>();
drop(data);
readback_buffer.unmap();
drop(timing_data);
timestamp_buffer.unmap();
let mut max = 0;
for (src, out) in src_range.zip(result.iter().copied()) {
if out == u32::MAX {
println!("{}: overflowed", src);
break;
} else if out > max {
max = out;
// Should produce <https://oeis.org/A006877>
println!("{}: {}", src, out);
}
}
println!(
"Took: {:?}",
Duration::from_nanos(
((timings[1] - timings[0]) as f64 * f64::from(timestamp_period)) as u64
)
);
}
}

View File

@ -61,7 +61,7 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
{
use spirv_builder::{Capability, SpirvBuilder};
use std::borrow::Cow;
use std::path::{Path, PathBuf};
use std::path::PathBuf;
// Hack: spirv_builder builds into a custom directory if running under cargo, to not
// deadlock, and the default target directory if not. However, packages like `proc-macro2`
// have different configurations when being built here vs. when building
@ -73,22 +73,16 @@ fn shader_module(shader: RustGPUShader) -> wgpu::ShaderModuleDescriptor<'static>
let (crate_name, capabilities): (_, &[Capability]) = match shader {
RustGPUShader::Simplest => ("simplest-shader", &[]),
RustGPUShader::Sky => ("sky-shader", &[]),
RustGPUShader::Compute => ("compute-shader", &[]),
RustGPUShader::Compute => ("compute-shader", &[Capability::Int8]),
RustGPUShader::Mouse => ("mouse-shader", &[]),
};
let manifest_dir = env!("CARGO_MANIFEST_DIR");
let crate_path = [
Path::new(manifest_dir),
Path::new(".."),
Path::new(".."),
Path::new("shaders"),
Path::new(crate_name),
]
.iter()
.copied()
.collect::<PathBuf>();
let crate_path = [manifest_dir, "..", "..", "shaders", crate_name]
.iter()
.copied()
.collect::<PathBuf>();
let mut builder =
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.0").print_metadata(false);
SpirvBuilder::new(crate_path, "spirv-unknown-vulkan1.1").print_metadata(false);
for &cap in capabilities {
builder = builder.capability(cap);
}

View File

@ -7,7 +7,10 @@ license = "MIT OR Apache-2.0"
publish = false
[lib]
crate-type = ["dylib"]
crate-type = ["dylib", "lib"]
[dependencies]
spirv-std = { path = "../../../crates/spirv-std", features = ["glam"] }
[target.'cfg(not(target_arch = "spirv"))'.dependencies]
rayon = "1.5"

View File

@ -1,17 +1,48 @@
#![cfg_attr(
target_arch = "spirv",
no_std,
feature(register_attr),
register_attr(spirv)
register_attr(spirv),
no_std
)]
// HACK(eddyb) can't easily see warnings otherwise from `spirv-builder` builds.
#![deny(warnings)]
extern crate spirv_std;
use glam::UVec3;
use spirv_std::glam;
#[cfg(not(target_arch = "spirv"))]
use spirv_std::macros::spirv;
// LocalSize/numthreads of (x = 32, y = 1, z = 1)
#[spirv(compute(threads(32)))]
pub fn main_cs() {}
// Adapted from the wgpu hello-compute example
pub fn collatz(mut n: u32) -> Option<u32> {
let mut i = 0;
if n == 0 {
return None;
}
while n != 1 {
n = if n % 2 == 0 {
n / 2
} else {
// Overflow? (i.e. 3*n + 1 > 0xffff_ffff)
if n >= 0x5555_5555 {
return None;
}
// TODO: Use this instead when/if checked add/mul can work: n.checked_mul(3)?.checked_add(1)?
3 * n + 1
};
i += 1;
}
Some(i)
}
// LocalSize/numthreads of (x = 64, y = 1, z = 1)
#[spirv(compute(threads(64)))]
pub fn main_cs(
#[spirv(global_invocation_id)] id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] prime_indices: &mut [u32],
) {
let index = id.x as usize;
prime_indices[index] = collatz(prime_indices[index]).unwrap_or(u32::MAX);
}

View File

@ -0,0 +1,32 @@
use std::time::Instant;
use compute_shader::collatz;
use rayon::prelude::*;
fn main() {
let top = 2u32.pow(20);
let src_range = 1..top;
let start = Instant::now();
let result = src_range
.clone()
.into_par_iter()
.map(collatz)
.collect::<Vec<_>>();
let took = start.elapsed();
let mut max = 0;
for (src, out) in src_range.zip(result.iter().copied()) {
match out {
Some(out) if out > max => {
max = out;
// Should produce <https://oeis.org/A006877>
println!("{}: {}", src, out);
}
Some(_) => (),
None => {
println!("{}: overflowed", src);
break;
}
}
}
println!("Took: {:?}", took);
}