native: bind groups for both compute and graphics

This commit is contained in:
Dzmitry Malyshau 2019-01-20 17:05:44 -05:00
parent 6699f4bed1
commit dee685aa0c
5 changed files with 143 additions and 73 deletions

View File

@ -1,12 +1,10 @@
use crate::registry::{HUB, Items, ConcreteItems};
use crate::registry::{HUB, Items};
use crate::{
B, Stored, WeaklyStored,
BindGroup, PipelineLayout,
BindGroupId, BindGroupLayoutId, PipelineLayoutId,
};
use hal;
use parking_lot::RwLockReadGuard;
#[derive(Clone, Default)]
@ -21,35 +19,20 @@ pub struct Binder {
entries: Vec<BindGroupEntry>,
}
pub struct NewBind<'a, B: hal::Backend> {
pipeline_layout_guard: RwLockReadGuard<'a, ConcreteItems<PipelineLayout<B>>>,
pipeline_layout_id: PipelineLayoutId,
bind_group_guard: RwLockReadGuard<'a, ConcreteItems<BindGroup<B>>>,
bind_group_id: BindGroupId,
}
impl<'a, B: hal::Backend> NewBind<'a, B> {
pub fn pipeline_layout(&self) -> &B::PipelineLayout {
&self.pipeline_layout_guard.get(self.pipeline_layout_id).raw
}
pub fn descriptor_set(&self) -> &B::DescriptorSet {
&self.bind_group_guard.get(self.bind_group_id).raw
}
}
//Note: we can probably make this much better than passing an `FnMut`
impl Binder {
//Note: `'a` is need to avoid inheriting the lifetime from `self`
pub fn bind_group<'a>(
&mut self, index: u32, bind_group_id: BindGroupId
) -> Option<NewBind<'a, B>> {
pub fn bind_group<F>(&mut self, index: usize, bind_group_id: BindGroupId, mut fun: F)
where
F: FnMut(&<B as hal::Backend>::PipelineLayout, &<B as hal::Backend>::DescriptorSet),
{
let bind_group_guard = HUB.bind_groups.read();
let bind_group = bind_group_guard.get(bind_group_id);
while self.entries.len() <= index as usize {
while self.entries.len() <= index {
self.entries.push(BindGroupEntry::default());
}
*self.entries.get_mut(index as usize).unwrap() = BindGroupEntry {
*self.entries.get_mut(index).unwrap() = BindGroupEntry {
layout: Some(bind_group.layout_id.clone()),
data: Some(Stored {
value: bind_group_id,
@ -61,16 +44,41 @@ impl Binder {
//TODO: we can cache the group layout ids of the current pipeline in `Binder` itself
let pipeline_layout_guard = HUB.pipeline_layouts.read();
let pipeline_layout = pipeline_layout_guard.get(pipeline_layout_id);
if pipeline_layout.bind_group_layout_ids[index as usize] == bind_group.layout_id {
return Some(NewBind {
pipeline_layout_guard,
pipeline_layout_id,
bind_group_guard,
bind_group_id,
})
if pipeline_layout.bind_group_layout_ids[index] == bind_group.layout_id {
fun(&pipeline_layout.raw, &bind_group.raw);
}
}
}
None
pub fn change_layout<F>(&mut self, pipeline_layout_id: PipelineLayoutId, mut fun: F)
where
F: FnMut(&<B as hal::Backend>::PipelineLayout, usize, &<B as hal::Backend>::DescriptorSet),
{
if self.pipeline_layout_id == Some(WeaklyStored(pipeline_layout_id)) {
return
}
self.pipeline_layout_id = Some(WeaklyStored(pipeline_layout_id));
let pipeline_layout_guard = HUB.pipeline_layouts.read();
let pipeline_layout = pipeline_layout_guard.get(pipeline_layout_id);
let bing_group_guard = HUB.bind_groups.read();
while self.entries.len() < pipeline_layout.bind_group_layout_ids.len() {
self.entries.push(BindGroupEntry::default());
}
for (index, (entry, bgl_id)) in self.entries
.iter_mut()
.zip(&pipeline_layout.bind_group_layout_ids)
.enumerate()
{
if entry.layout == Some(bgl_id.clone()) {
continue
}
entry.layout = Some(bgl_id.clone());
if let Some(ref bg_id) = entry.data {
let bind_group = bing_group_guard.get(bg_id.value);
fun(&pipeline_layout.raw, index, &bind_group.raw);
}
}
}
}

View File

@ -38,43 +38,6 @@ pub extern "C" fn wgpu_compute_pass_end_pass(pass_id: ComputePassId) -> CommandB
pass.cmb_id.value
}
#[no_mangle]
pub extern "C" fn wgpu_compute_pass_set_bind_group(
pass_id: ComputePassId,
index: u32,
bind_group_id: BindGroupId,
) {
let mut pass_guard = HUB.compute_passes.write();
let pass = pass_guard.get_mut(pass_id);
if let Some(bind) = pass.binder.bind_group(index, bind_group_id) {
unsafe {
pass.raw.bind_compute_descriptor_sets(
bind.pipeline_layout(),
index as usize,
iter::once(bind.descriptor_set()),
&[],
);
}
}
}
#[no_mangle]
pub extern "C" fn wgpu_compute_pass_set_pipeline(
pass_id: ComputePassId,
pipeline_id: ComputePipelineId,
) {
let pipeline_guard = HUB.compute_pipelines.read();
let pipeline = &pipeline_guard.get(pipeline_id).raw;
unsafe {
HUB.compute_passes
.write()
.get_mut(pass_id)
.raw
.bind_compute_pipeline(pipeline);
}
}
#[no_mangle]
pub extern "C" fn wgpu_compute_pass_dispatch(pass_id: ComputePassId, x: u32, y: u32, z: u32) {
unsafe {
@ -85,3 +48,46 @@ pub extern "C" fn wgpu_compute_pass_dispatch(pass_id: ComputePassId, x: u32, y:
.dispatch([x, y, z]);
}
}
#[no_mangle]
pub extern "C" fn wgpu_compute_pass_set_bind_group(
pass_id: ComputePassId,
index: u32,
bind_group_id: BindGroupId,
) {
let mut pass_guard = HUB.compute_passes.write();
let ComputePass { ref mut raw, ref mut binder, .. } = *pass_guard.get_mut(pass_id);
binder.bind_group(index as usize, bind_group_id, |pipeline_layout, desc_set| unsafe {
raw.bind_compute_descriptor_sets(
pipeline_layout,
index as usize,
iter::once(desc_set),
&[],
);
});
}
#[no_mangle]
pub extern "C" fn wgpu_compute_pass_set_pipeline(
pass_id: ComputePassId,
pipeline_id: ComputePipelineId,
) {
let mut pass_guard = HUB.compute_passes.write();
let ComputePass { ref mut raw, ref mut binder, .. } = *pass_guard.get_mut(pass_id);
let pipeline_guard = HUB.compute_pipelines.read();
let pipeline = pipeline_guard.get(pipeline_id);
unsafe {
raw.bind_compute_pipeline(&pipeline.raw);
}
binder.change_layout(pipeline.layout_id.0, |pipeline_layout, index, desc_set| unsafe {
raw.bind_compute_descriptor_sets(
pipeline_layout,
index,
iter::once(desc_set),
&[],
);
});
}

View File

@ -1,17 +1,21 @@
use crate::command::bind::Binder;
use crate::resource::BufferUsageFlags;
use crate::registry::{Items, HUB};
use crate::track::{BufferTracker, TextureTracker, TrackPermit};
use crate::{
CommandBuffer, Stored,
BufferId, CommandBufferId, RenderPassId,
BindGroupId, BufferId, CommandBufferId, RenderPassId, RenderPipelineId,
};
use hal::command::RawCommandBuffer;
use std::iter;
pub struct RenderPass<B: hal::Backend> {
raw: B::CommandBuffer,
cmb_id: Stored<CommandBufferId>,
binder: Binder,
buffer_tracker: BufferTracker,
texture_tracker: TextureTracker,
}
@ -21,6 +25,7 @@ impl<B: hal::Backend> RenderPass<B> {
RenderPass {
raw,
cmb_id,
binder: Binder::default(),
buffer_tracker: BufferTracker::new(),
texture_tracker: TextureTracker::new(),
}
@ -152,3 +157,46 @@ pub extern "C" fn wgpu_render_pass_draw_indexed(
);
}
}
#[no_mangle]
pub extern "C" fn wgpu_render_pass_set_bind_group(
pass_id: RenderPassId,
index: u32,
bind_group_id: BindGroupId,
) {
let mut pass_guard = HUB.render_passes.write();
let RenderPass { ref mut raw, ref mut binder, .. } = *pass_guard.get_mut(pass_id);
binder.bind_group(index as usize, bind_group_id, |pipeline_layout, desc_set| unsafe {
raw.bind_compute_descriptor_sets(
pipeline_layout,
index as usize,
iter::once(desc_set),
&[],
);
});
}
#[no_mangle]
pub extern "C" fn wgpu_render_pass_set_pipeline(
pass_id: RenderPassId,
pipeline_id: RenderPipelineId,
) {
let mut pass_guard = HUB.render_passes.write();
let RenderPass { ref mut raw, ref mut binder, .. } = *pass_guard.get_mut(pass_id);
let pipeline_guard = HUB.render_pipelines.read();
let pipeline = pipeline_guard.get(pipeline_id);
unsafe {
raw.bind_graphics_pipeline(&pipeline.raw);
}
binder.change_layout(pipeline.layout_id.0, |pipeline_layout, index, desc_set| unsafe {
raw.bind_graphics_descriptor_sets(
pipeline_layout,
index,
iter::once(desc_set),
&[],
);
});
}

View File

@ -986,7 +986,10 @@ pub extern "C" fn wgpu_device_create_render_pipeline(
HUB.render_pipelines
.write()
.register(pipeline::RenderPipeline { raw: pipeline })
.register(pipeline::RenderPipeline {
raw: pipeline,
layout_id: WeaklyStored(desc.layout),
})
}
#[no_mangle]

View File

@ -1,5 +1,8 @@
use crate::resource;
use crate::{BlendStateId, ByteArray, DepthStencilStateId, PipelineLayoutId, ShaderModuleId};
use crate::{
ByteArray, WeaklyStored,
BlendStateId, DepthStencilStateId, PipelineLayoutId, ShaderModuleId,
};
use bitflags::bitflags;
@ -212,6 +215,7 @@ pub struct ComputePipelineDescriptor {
pub(crate) struct ComputePipeline<B: hal::Backend> {
pub raw: B::ComputePipeline,
pub layout_id: WeaklyStored<PipelineLayoutId>,
}
#[repr(C)]
@ -251,4 +255,5 @@ pub struct RenderPipelineDescriptor {
pub(crate) struct RenderPipeline<B: hal::Backend> {
pub raw: B::GraphicsPipeline,
pub layout_id: WeaklyStored<PipelineLayoutId>,
}