56: Initial compute pipeline support r=kvark a=swiftcoder

Hey, I'd love to help out with this effort, but please let me know if I'm only going to be slowing you down :)

Sample is a port of https://github.com/gfx-rs/gfx/blob/master/examples/compute/main.rs - as yet incomplete because we don't have a read-back API defined.

I'm happy to do the implementation of the read-back API, but I wanted to touch base about how you envision `mapReadAsync`/`mapWriteAsync` working in Rust. Do we introduce futures to match the webgpu spec? Or expose it in a more direct fashion.

Co-authored-by: Tristam MacDonald <swiftcoder@gmail.com>
This commit is contained in:
bors[bot] 2019-02-12 04:39:37 +00:00
commit c9f4936df4
7 changed files with 243 additions and 27 deletions

View File

@ -11,6 +11,10 @@ publish = false
name = "hello_triangle"
path = "hello_triangle_rust/main.rs"
[[bin]]
name = "hello_compute"
path = "hello_compute_rust/main.rs"
[features]
default = []
remote = ["wgpu-native/remote"]

View File

@ -0,0 +1,31 @@
#version 450
layout(local_size_x = 1) in;
layout(set = 0, binding = 0) buffer PrimeIndices {
uint[] indices;
}; // this is used as both input and output for convenience
// The Collatz Conjecture states that for any integer n:
// If n is even, n = n/2
// If n is odd, n = 3n+1
// And repeat this process for each new n, you will always eventually reach 1.
// Though the conjecture has not been proven, no counterexample has ever been found.
// This function returns how many times this recurrence needs to be applied to reach 1.
uint collatz_iterations(uint n) {
uint i = 0;
while(n != 1) {
if (mod(n, 2) == 0) {
n = n / 2;
}
else {
n = (3 * n) + 1;
}
i++;
}
return i;
}
void main() {
uint index = gl_GlobalInvocationID.x;
indices[index] = collatz_iterations(indices[index]);
}

Binary file not shown.

View File

@ -0,0 +1,107 @@
extern crate env_logger;
extern crate wgpu;
extern crate wgpu_native;
use std::str::FromStr;
// TODO: deduplicate this with the copy in gfx-examples/framework
pub fn cast_slice<T>(data: &[T]) -> &[u8] {
use std::mem::size_of;
use std::slice::from_raw_parts;
unsafe { from_raw_parts(data.as_ptr() as *const u8, data.len() * size_of::<T>()) }
}
fn main() {
env_logger::init();
// For now this just panics if you didn't pass numbers. Could add proper error handling.
if std::env::args().len() == 1 {
panic!("You must pass a list of positive integers!")
}
let numbers: Vec<u32> = std::env::args()
.skip(1)
.map(|s| u32::from_str(&s).expect("You must pass a list of positive integers!"))
.collect();
let size = (numbers.len() * std::mem::size_of::<u32>()) as u32;
let instance = wgpu::Instance::new();
let adapter = instance.get_adapter(&wgpu::AdapterDescriptor {
power_preference: wgpu::PowerPreference::LowPower,
});
let mut device = adapter.create_device(&wgpu::DeviceDescriptor {
extensions: wgpu::Extensions {
anisotropic_filtering: false,
},
});
let cs_bytes = include_bytes!("./../data/collatz.comp.spv");
let cs_module = device.create_shader_module(cs_bytes);
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
size,
usage: wgpu::BufferUsageFlags::MAP_READ
| wgpu::BufferUsageFlags::TRANSFER_DST
| wgpu::BufferUsageFlags::TRANSFER_SRC,
});
staging_buffer.set_sub_data(0, cast_slice(&numbers));
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
size: (numbers.len() * std::mem::size_of::<u32>()) as u32,
usage: wgpu::BufferUsageFlags::STORAGE
| wgpu::BufferUsageFlags::TRANSFER_DST
| wgpu::BufferUsageFlags::TRANSFER_SRC,
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
bindings: &[wgpu::BindGroupLayoutBinding {
binding: 0,
visibility: wgpu::ShaderStageFlags::COMPUTE,
ty: wgpu::BindingType::StorageBuffer,
}],
});
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &bind_group_layout,
bindings: &[wgpu::Binding {
binding: 0,
resource: wgpu::BindingResource::Buffer {
buffer: &storage_buffer,
range: 0..(numbers.len() as u32),
},
}],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
bind_group_layouts: &[&bind_group_layout],
});
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
layout: &pipeline_layout,
compute_stage: wgpu::PipelineStageDescriptor {
module: &cs_module,
stage: wgpu::ShaderStage::Compute,
entry_point: "main",
},
});
let mut cmd_buf = device.create_command_buffer(&wgpu::CommandBufferDescriptor { todo: 0 });
{
cmd_buf.copy_buffer_tobuffer(&staging_buffer, 0, &storage_buffer, 0, size);
}
{
let mut cpass = cmd_buf.begin_compute_pass();
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(0, &bind_group);
cpass.dispatch(numbers.len() as u32, 1, 1);
cpass.end_pass();
}
{
cmd_buf.copy_buffer_tobuffer(&storage_buffer, 0, &staging_buffer, 0, size);
}
// TODO: read the results back out of the staging buffer
device.get_queue().submit(&[cmd_buf]);
}

View File

@ -2,12 +2,10 @@ use crate::{back, binding_model, command, conv, pipeline, resource, swap_chain};
use crate::registry::{HUB, Items};
use crate::track::{BufferTracker, TextureTracker, TrackPermit};
use crate::{
LifeGuard, RefCount, Stored, SubmissionIndex, WeaklyStored,
BindGroupLayoutId, BindGroupId,
BlendStateId, BufferId, CommandBufferId, DepthStencilStateId,
AdapterId, DeviceId, PipelineLayoutId, QueueId, RenderPipelineId, ShaderModuleId,
SamplerId, TextureId, TextureViewId,
SurfaceId, SwapChainId,
AdapterId, BindGroupId, BindGroupLayoutId, BlendStateId, BufferId, CommandBufferId,
ComputePipelineId, DepthStencilStateId, DeviceId, LifeGuard, PipelineLayoutId, QueueId,
RefCount, RenderPipelineId, SamplerId, ShaderModuleId, Stored, SubmissionIndex, SurfaceId,
SwapChainId, TextureId, TextureViewId, WeaklyStored,
};
use hal::command::RawCommandBuffer;
@ -1143,6 +1141,55 @@ pub extern "C" fn wgpu_device_create_render_pipeline(
})
}
#[no_mangle]
pub extern "C" fn wgpu_device_create_compute_pipeline(
device_id: DeviceId,
desc: &pipeline::ComputePipelineDescriptor,
) -> ComputePipelineId {
let device_guard = HUB.devices.read();
let device = device_guard.get(device_id);
let pipeline_layout_guard = HUB.pipeline_layouts.read();
let layout = &pipeline_layout_guard.get(desc.layout).raw;
let pipeline_stage = &desc.compute_stage;
let shader_module_guard = HUB.shader_modules.read();
assert!(pipeline_stage.stage == pipeline::ShaderStage::Compute); // TODO
let shader = hal::pso::EntryPoint::<back::Backend> {
entry: unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
.to_str()
.to_owned()
.unwrap(), // TODO
module: &shader_module_guard.get(pipeline_stage.module).raw,
specialization: hal::pso::Specialization {
// TODO
constants: &[],
data: &[],
},
};
// TODO
let flags = hal::pso::PipelineCreationFlags::empty();
// TODO
let parent = hal::pso::BasePipeline::None;
let pipeline_desc = hal::pso::ComputePipelineDesc {
shader,
layout,
flags,
parent,
};
let pipeline = unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) }.unwrap();
HUB.compute_pipelines
.write()
.register(pipeline::ComputePipeline {
raw: pipeline,
layout_id: WeaklyStored(desc.layout),
})
}
#[no_mangle]
pub extern "C" fn wgpu_device_create_swap_chain(
device_id: DeviceId,

View File

@ -1,12 +1,10 @@
use crate::resource;
use crate::{
ByteArray, WeaklyStored,
BlendStateId, DepthStencilStateId, PipelineLayoutId, ShaderModuleId,
BlendStateId, ByteArray, DepthStencilStateId, PipelineLayoutId, ShaderModuleId, WeaklyStored,
};
use bitflags::bitflags;
pub type ShaderAttributeIndex = u32;
#[repr(C)]
@ -208,7 +206,7 @@ pub struct PipelineStageDescriptor {
#[repr(C)]
pub struct ComputePipelineDescriptor {
pub layout: PipelineLayoutId,
pub stages: *const PipelineStageDescriptor,
pub compute_stage: PipelineStageDescriptor,
}
pub(crate) struct ComputePipeline<B: hal::Backend> {

View File

@ -9,15 +9,15 @@ use std::ops::Range;
use std::ptr;
pub use wgn::{
AdapterDescriptor, Attachment, BindGroupLayoutBinding, BindingType, BlendStateDescriptor,
BufferDescriptor, BufferUsageFlags,
IndexFormat, VertexFormat, InputStepMode, ShaderAttributeIndex, VertexAttributeDescriptor,
Color, ColorWriteFlags, CommandBufferDescriptor, DepthStencilStateDescriptor,
DeviceDescriptor, Extensions, Extent3d, LoadOp, Origin3d, PowerPreference, PrimitiveTopology,
RenderPassColorAttachmentDescriptor, RenderPassDepthStencilAttachmentDescriptor,
AdapterDescriptor, AddressMode, Attachment, BindGroupLayoutBinding, BindingType,
BlendStateDescriptor, BorderColor, BufferDescriptor, BufferUsageFlags, Color, ColorWriteFlags,
CommandBufferDescriptor, CompareFunction, DepthStencilStateDescriptor, DeviceDescriptor,
Extensions, Extent3d, FilterMode, IndexFormat, InputStepMode, LoadOp, Origin3d,
PowerPreference, PrimitiveTopology, RenderPassColorAttachmentDescriptor,
RenderPassDepthStencilAttachmentDescriptor, SamplerDescriptor, ShaderAttributeIndex,
ShaderModuleDescriptor, ShaderStage, ShaderStageFlags, StoreOp, SwapChainDescriptor,
TextureDescriptor, TextureDimension, TextureFormat, TextureUsageFlags, TextureViewDescriptor,
SamplerDescriptor, AddressMode, FilterMode, CompareFunction, BorderColor,
VertexAttributeDescriptor, VertexFormat,
};
pub struct Instance {
@ -162,6 +162,11 @@ pub struct RenderPipelineDescriptor<'a> {
pub vertex_buffers: &'a [VertexBufferDescriptor<'a>],
}
pub struct ComputePipelineDescriptor<'a> {
pub layout: &'a PipelineLayout,
pub compute_stage: PipelineStageDescriptor<'a>,
}
pub struct RenderPassDescriptor<'a> {
pub color_attachments: &'a [RenderPassColorAttachmentDescriptor<&'a TextureView>],
pub depth_stencil_attachment:
@ -210,7 +215,6 @@ impl<'a> TextureCopyView<'a> {
}
}
impl Instance {
pub fn new() -> Self {
Instance {
@ -273,14 +277,17 @@ impl Device {
.map(|binding| wgn::Binding {
binding: binding.binding,
resource: match binding.resource {
BindingResource::Buffer { ref buffer, ref range } => {
wgn::BindingResource::Buffer(wgn::BufferBinding {
buffer: buffer.id,
offset: range.start,
size: range.end,
})
BindingResource::Buffer {
ref buffer,
ref range,
} => wgn::BindingResource::Buffer(wgn::BufferBinding {
buffer: buffer.id,
offset: range.start,
size: range.end,
}),
BindingResource::Sampler(ref sampler) => {
wgn::BindingResource::Sampler(sampler.id)
}
BindingResource::Sampler(ref sampler) => wgn::BindingResource::Sampler(sampler.id),
BindingResource::TextureView(ref texture_view) => {
wgn::BindingResource::TextureView(texture_view.id)
}
@ -362,7 +369,8 @@ impl Device {
.collect::<ArrayVec<[_; 2]>>();
let temp_blend_states = desc.blend_states.iter().map(|bs| bs.id).collect::<Vec<_>>();
let temp_vertex_buffers = desc.vertex_buffers
let temp_vertex_buffers = desc
.vertex_buffers
.iter()
.map(|vbuf| wgn::VertexBufferDescriptor {
stride: vbuf.stride,
@ -403,6 +411,25 @@ impl Device {
}
}
pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor) -> ComputePipeline {
let entry_point = CString::new(desc.compute_stage.entry_point).unwrap();
let compute_stage = wgn::PipelineStageDescriptor {
module: desc.compute_stage.module.id,
stage: desc.compute_stage.stage,
entry_point: entry_point.as_ptr(),
};
ComputePipeline {
id: wgn::wgpu_device_create_compute_pipeline(
self.id,
&wgn::ComputePipelineDescriptor {
layout: desc.layout.id,
compute_stage,
},
),
}
}
pub fn create_buffer(&self, desc: &BufferDescriptor) -> Buffer {
Buffer {
id: wgn::wgpu_device_create_buffer(self.id, desc),
@ -651,7 +678,9 @@ impl SwapChain {
pub fn get_next_texture(&mut self) -> SwapChainOutput {
let output = wgn::wgpu_swap_chain_get_next_texture(self.id);
SwapChainOutput {
texture: Texture { id: output.texture_id },
texture: Texture {
id: output.texture_id,
},
view: TextureView { id: output.view_id },
swap_chain_id: &self.id,
}