mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-25 00:03:29 +00:00
Merge #56
56: Initial compute pipeline support r=kvark a=swiftcoder Hey, I'd love to help out with this effort, but please let me know if I'm only going to be slowing you down :) Sample is a port of https://github.com/gfx-rs/gfx/blob/master/examples/compute/main.rs - as yet incomplete because we don't have a read-back API defined. I'm happy to do the implementation of the read-back API, but I wanted to touch base about how you envision `mapReadAsync`/`mapWriteAsync` working in Rust. Do we introduce futures to match the webgpu spec? Or expose it in a more direct fashion. Co-authored-by: Tristam MacDonald <swiftcoder@gmail.com>
This commit is contained in:
commit
c9f4936df4
@ -11,6 +11,10 @@ publish = false
|
||||
name = "hello_triangle"
|
||||
path = "hello_triangle_rust/main.rs"
|
||||
|
||||
[[bin]]
|
||||
name = "hello_compute"
|
||||
path = "hello_compute_rust/main.rs"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
remote = ["wgpu-native/remote"]
|
||||
|
31
examples/data/collatz.comp
Normal file
31
examples/data/collatz.comp
Normal file
@ -0,0 +1,31 @@
|
||||
#version 450
|
||||
layout(local_size_x = 1) in;
|
||||
|
||||
layout(set = 0, binding = 0) buffer PrimeIndices {
|
||||
uint[] indices;
|
||||
}; // this is used as both input and output for convenience
|
||||
|
||||
// The Collatz Conjecture states that for any integer n:
|
||||
// If n is even, n = n/2
|
||||
// If n is odd, n = 3n+1
|
||||
// And repeat this process for each new n, you will always eventually reach 1.
|
||||
// Though the conjecture has not been proven, no counterexample has ever been found.
|
||||
// This function returns how many times this recurrence needs to be applied to reach 1.
|
||||
uint collatz_iterations(uint n) {
|
||||
uint i = 0;
|
||||
while(n != 1) {
|
||||
if (mod(n, 2) == 0) {
|
||||
n = n / 2;
|
||||
}
|
||||
else {
|
||||
n = (3 * n) + 1;
|
||||
}
|
||||
i++;
|
||||
}
|
||||
return i;
|
||||
}
|
||||
|
||||
void main() {
|
||||
uint index = gl_GlobalInvocationID.x;
|
||||
indices[index] = collatz_iterations(indices[index]);
|
||||
}
|
BIN
examples/data/collatz.comp.spv
Normal file
BIN
examples/data/collatz.comp.spv
Normal file
Binary file not shown.
107
examples/hello_compute_rust/main.rs
Normal file
107
examples/hello_compute_rust/main.rs
Normal file
@ -0,0 +1,107 @@
|
||||
extern crate env_logger;
|
||||
extern crate wgpu;
|
||||
extern crate wgpu_native;
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
// TODO: deduplicate this with the copy in gfx-examples/framework
|
||||
pub fn cast_slice<T>(data: &[T]) -> &[u8] {
|
||||
use std::mem::size_of;
|
||||
use std::slice::from_raw_parts;
|
||||
|
||||
unsafe { from_raw_parts(data.as_ptr() as *const u8, data.len() * size_of::<T>()) }
|
||||
}
|
||||
|
||||
fn main() {
|
||||
env_logger::init();
|
||||
|
||||
// For now this just panics if you didn't pass numbers. Could add proper error handling.
|
||||
if std::env::args().len() == 1 {
|
||||
panic!("You must pass a list of positive integers!")
|
||||
}
|
||||
let numbers: Vec<u32> = std::env::args()
|
||||
.skip(1)
|
||||
.map(|s| u32::from_str(&s).expect("You must pass a list of positive integers!"))
|
||||
.collect();
|
||||
|
||||
let size = (numbers.len() * std::mem::size_of::<u32>()) as u32;
|
||||
|
||||
let instance = wgpu::Instance::new();
|
||||
let adapter = instance.get_adapter(&wgpu::AdapterDescriptor {
|
||||
power_preference: wgpu::PowerPreference::LowPower,
|
||||
});
|
||||
let mut device = adapter.create_device(&wgpu::DeviceDescriptor {
|
||||
extensions: wgpu::Extensions {
|
||||
anisotropic_filtering: false,
|
||||
},
|
||||
});
|
||||
|
||||
let cs_bytes = include_bytes!("./../data/collatz.comp.spv");
|
||||
let cs_module = device.create_shader_module(cs_bytes);
|
||||
|
||||
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
size,
|
||||
usage: wgpu::BufferUsageFlags::MAP_READ
|
||||
| wgpu::BufferUsageFlags::TRANSFER_DST
|
||||
| wgpu::BufferUsageFlags::TRANSFER_SRC,
|
||||
});
|
||||
staging_buffer.set_sub_data(0, cast_slice(&numbers));
|
||||
|
||||
let storage_buffer = device.create_buffer(&wgpu::BufferDescriptor {
|
||||
size: (numbers.len() * std::mem::size_of::<u32>()) as u32,
|
||||
usage: wgpu::BufferUsageFlags::STORAGE
|
||||
| wgpu::BufferUsageFlags::TRANSFER_DST
|
||||
| wgpu::BufferUsageFlags::TRANSFER_SRC,
|
||||
});
|
||||
|
||||
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
|
||||
bindings: &[wgpu::BindGroupLayoutBinding {
|
||||
binding: 0,
|
||||
visibility: wgpu::ShaderStageFlags::COMPUTE,
|
||||
ty: wgpu::BindingType::StorageBuffer,
|
||||
}],
|
||||
});
|
||||
|
||||
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
|
||||
layout: &bind_group_layout,
|
||||
bindings: &[wgpu::Binding {
|
||||
binding: 0,
|
||||
resource: wgpu::BindingResource::Buffer {
|
||||
buffer: &storage_buffer,
|
||||
range: 0..(numbers.len() as u32),
|
||||
},
|
||||
}],
|
||||
});
|
||||
|
||||
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
|
||||
bind_group_layouts: &[&bind_group_layout],
|
||||
});
|
||||
|
||||
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
|
||||
layout: &pipeline_layout,
|
||||
compute_stage: wgpu::PipelineStageDescriptor {
|
||||
module: &cs_module,
|
||||
stage: wgpu::ShaderStage::Compute,
|
||||
entry_point: "main",
|
||||
},
|
||||
});
|
||||
|
||||
let mut cmd_buf = device.create_command_buffer(&wgpu::CommandBufferDescriptor { todo: 0 });
|
||||
{
|
||||
cmd_buf.copy_buffer_tobuffer(&staging_buffer, 0, &storage_buffer, 0, size);
|
||||
}
|
||||
{
|
||||
let mut cpass = cmd_buf.begin_compute_pass();
|
||||
cpass.set_pipeline(&compute_pipeline);
|
||||
cpass.set_bind_group(0, &bind_group);
|
||||
cpass.dispatch(numbers.len() as u32, 1, 1);
|
||||
cpass.end_pass();
|
||||
}
|
||||
{
|
||||
cmd_buf.copy_buffer_tobuffer(&storage_buffer, 0, &staging_buffer, 0, size);
|
||||
}
|
||||
|
||||
// TODO: read the results back out of the staging buffer
|
||||
|
||||
device.get_queue().submit(&[cmd_buf]);
|
||||
}
|
@ -2,12 +2,10 @@ use crate::{back, binding_model, command, conv, pipeline, resource, swap_chain};
|
||||
use crate::registry::{HUB, Items};
|
||||
use crate::track::{BufferTracker, TextureTracker, TrackPermit};
|
||||
use crate::{
|
||||
LifeGuard, RefCount, Stored, SubmissionIndex, WeaklyStored,
|
||||
BindGroupLayoutId, BindGroupId,
|
||||
BlendStateId, BufferId, CommandBufferId, DepthStencilStateId,
|
||||
AdapterId, DeviceId, PipelineLayoutId, QueueId, RenderPipelineId, ShaderModuleId,
|
||||
SamplerId, TextureId, TextureViewId,
|
||||
SurfaceId, SwapChainId,
|
||||
AdapterId, BindGroupId, BindGroupLayoutId, BlendStateId, BufferId, CommandBufferId,
|
||||
ComputePipelineId, DepthStencilStateId, DeviceId, LifeGuard, PipelineLayoutId, QueueId,
|
||||
RefCount, RenderPipelineId, SamplerId, ShaderModuleId, Stored, SubmissionIndex, SurfaceId,
|
||||
SwapChainId, TextureId, TextureViewId, WeaklyStored,
|
||||
};
|
||||
|
||||
use hal::command::RawCommandBuffer;
|
||||
@ -1143,6 +1141,55 @@ pub extern "C" fn wgpu_device_create_render_pipeline(
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_device_create_compute_pipeline(
|
||||
device_id: DeviceId,
|
||||
desc: &pipeline::ComputePipelineDescriptor,
|
||||
) -> ComputePipelineId {
|
||||
let device_guard = HUB.devices.read();
|
||||
let device = device_guard.get(device_id);
|
||||
let pipeline_layout_guard = HUB.pipeline_layouts.read();
|
||||
let layout = &pipeline_layout_guard.get(desc.layout).raw;
|
||||
let pipeline_stage = &desc.compute_stage;
|
||||
let shader_module_guard = HUB.shader_modules.read();
|
||||
|
||||
assert!(pipeline_stage.stage == pipeline::ShaderStage::Compute); // TODO
|
||||
|
||||
let shader = hal::pso::EntryPoint::<back::Backend> {
|
||||
entry: unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap(), // TODO
|
||||
module: &shader_module_guard.get(pipeline_stage.module).raw,
|
||||
specialization: hal::pso::Specialization {
|
||||
// TODO
|
||||
constants: &[],
|
||||
data: &[],
|
||||
},
|
||||
};
|
||||
|
||||
// TODO
|
||||
let flags = hal::pso::PipelineCreationFlags::empty();
|
||||
// TODO
|
||||
let parent = hal::pso::BasePipeline::None;
|
||||
|
||||
let pipeline_desc = hal::pso::ComputePipelineDesc {
|
||||
shader,
|
||||
layout,
|
||||
flags,
|
||||
parent,
|
||||
};
|
||||
|
||||
let pipeline = unsafe { device.raw.create_compute_pipeline(&pipeline_desc, None) }.unwrap();
|
||||
|
||||
HUB.compute_pipelines
|
||||
.write()
|
||||
.register(pipeline::ComputePipeline {
|
||||
raw: pipeline,
|
||||
layout_id: WeaklyStored(desc.layout),
|
||||
})
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn wgpu_device_create_swap_chain(
|
||||
device_id: DeviceId,
|
||||
|
@ -1,12 +1,10 @@
|
||||
use crate::resource;
|
||||
use crate::{
|
||||
ByteArray, WeaklyStored,
|
||||
BlendStateId, DepthStencilStateId, PipelineLayoutId, ShaderModuleId,
|
||||
BlendStateId, ByteArray, DepthStencilStateId, PipelineLayoutId, ShaderModuleId, WeaklyStored,
|
||||
};
|
||||
|
||||
use bitflags::bitflags;
|
||||
|
||||
|
||||
pub type ShaderAttributeIndex = u32;
|
||||
|
||||
#[repr(C)]
|
||||
@ -208,7 +206,7 @@ pub struct PipelineStageDescriptor {
|
||||
#[repr(C)]
|
||||
pub struct ComputePipelineDescriptor {
|
||||
pub layout: PipelineLayoutId,
|
||||
pub stages: *const PipelineStageDescriptor,
|
||||
pub compute_stage: PipelineStageDescriptor,
|
||||
}
|
||||
|
||||
pub(crate) struct ComputePipeline<B: hal::Backend> {
|
||||
|
@ -9,15 +9,15 @@ use std::ops::Range;
|
||||
use std::ptr;
|
||||
|
||||
pub use wgn::{
|
||||
AdapterDescriptor, Attachment, BindGroupLayoutBinding, BindingType, BlendStateDescriptor,
|
||||
BufferDescriptor, BufferUsageFlags,
|
||||
IndexFormat, VertexFormat, InputStepMode, ShaderAttributeIndex, VertexAttributeDescriptor,
|
||||
Color, ColorWriteFlags, CommandBufferDescriptor, DepthStencilStateDescriptor,
|
||||
DeviceDescriptor, Extensions, Extent3d, LoadOp, Origin3d, PowerPreference, PrimitiveTopology,
|
||||
RenderPassColorAttachmentDescriptor, RenderPassDepthStencilAttachmentDescriptor,
|
||||
AdapterDescriptor, AddressMode, Attachment, BindGroupLayoutBinding, BindingType,
|
||||
BlendStateDescriptor, BorderColor, BufferDescriptor, BufferUsageFlags, Color, ColorWriteFlags,
|
||||
CommandBufferDescriptor, CompareFunction, DepthStencilStateDescriptor, DeviceDescriptor,
|
||||
Extensions, Extent3d, FilterMode, IndexFormat, InputStepMode, LoadOp, Origin3d,
|
||||
PowerPreference, PrimitiveTopology, RenderPassColorAttachmentDescriptor,
|
||||
RenderPassDepthStencilAttachmentDescriptor, SamplerDescriptor, ShaderAttributeIndex,
|
||||
ShaderModuleDescriptor, ShaderStage, ShaderStageFlags, StoreOp, SwapChainDescriptor,
|
||||
TextureDescriptor, TextureDimension, TextureFormat, TextureUsageFlags, TextureViewDescriptor,
|
||||
SamplerDescriptor, AddressMode, FilterMode, CompareFunction, BorderColor,
|
||||
VertexAttributeDescriptor, VertexFormat,
|
||||
};
|
||||
|
||||
pub struct Instance {
|
||||
@ -162,6 +162,11 @@ pub struct RenderPipelineDescriptor<'a> {
|
||||
pub vertex_buffers: &'a [VertexBufferDescriptor<'a>],
|
||||
}
|
||||
|
||||
pub struct ComputePipelineDescriptor<'a> {
|
||||
pub layout: &'a PipelineLayout,
|
||||
pub compute_stage: PipelineStageDescriptor<'a>,
|
||||
}
|
||||
|
||||
pub struct RenderPassDescriptor<'a> {
|
||||
pub color_attachments: &'a [RenderPassColorAttachmentDescriptor<&'a TextureView>],
|
||||
pub depth_stencil_attachment:
|
||||
@ -210,7 +215,6 @@ impl<'a> TextureCopyView<'a> {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
impl Instance {
|
||||
pub fn new() -> Self {
|
||||
Instance {
|
||||
@ -273,14 +277,17 @@ impl Device {
|
||||
.map(|binding| wgn::Binding {
|
||||
binding: binding.binding,
|
||||
resource: match binding.resource {
|
||||
BindingResource::Buffer { ref buffer, ref range } => {
|
||||
wgn::BindingResource::Buffer(wgn::BufferBinding {
|
||||
buffer: buffer.id,
|
||||
offset: range.start,
|
||||
size: range.end,
|
||||
})
|
||||
BindingResource::Buffer {
|
||||
ref buffer,
|
||||
ref range,
|
||||
} => wgn::BindingResource::Buffer(wgn::BufferBinding {
|
||||
buffer: buffer.id,
|
||||
offset: range.start,
|
||||
size: range.end,
|
||||
}),
|
||||
BindingResource::Sampler(ref sampler) => {
|
||||
wgn::BindingResource::Sampler(sampler.id)
|
||||
}
|
||||
BindingResource::Sampler(ref sampler) => wgn::BindingResource::Sampler(sampler.id),
|
||||
BindingResource::TextureView(ref texture_view) => {
|
||||
wgn::BindingResource::TextureView(texture_view.id)
|
||||
}
|
||||
@ -362,7 +369,8 @@ impl Device {
|
||||
.collect::<ArrayVec<[_; 2]>>();
|
||||
|
||||
let temp_blend_states = desc.blend_states.iter().map(|bs| bs.id).collect::<Vec<_>>();
|
||||
let temp_vertex_buffers = desc.vertex_buffers
|
||||
let temp_vertex_buffers = desc
|
||||
.vertex_buffers
|
||||
.iter()
|
||||
.map(|vbuf| wgn::VertexBufferDescriptor {
|
||||
stride: vbuf.stride,
|
||||
@ -403,6 +411,25 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_compute_pipeline(&self, desc: &ComputePipelineDescriptor) -> ComputePipeline {
|
||||
let entry_point = CString::new(desc.compute_stage.entry_point).unwrap();
|
||||
let compute_stage = wgn::PipelineStageDescriptor {
|
||||
module: desc.compute_stage.module.id,
|
||||
stage: desc.compute_stage.stage,
|
||||
entry_point: entry_point.as_ptr(),
|
||||
};
|
||||
|
||||
ComputePipeline {
|
||||
id: wgn::wgpu_device_create_compute_pipeline(
|
||||
self.id,
|
||||
&wgn::ComputePipelineDescriptor {
|
||||
layout: desc.layout.id,
|
||||
compute_stage,
|
||||
},
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn create_buffer(&self, desc: &BufferDescriptor) -> Buffer {
|
||||
Buffer {
|
||||
id: wgn::wgpu_device_create_buffer(self.id, desc),
|
||||
@ -651,7 +678,9 @@ impl SwapChain {
|
||||
pub fn get_next_texture(&mut self) -> SwapChainOutput {
|
||||
let output = wgn::wgpu_swap_chain_get_next_texture(self.id);
|
||||
SwapChainOutput {
|
||||
texture: Texture { id: output.texture_id },
|
||||
texture: Texture {
|
||||
id: output.texture_id,
|
||||
},
|
||||
view: TextureView { id: output.view_id },
|
||||
swap_chain_id: &self.id,
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user