Bind compute groups if layout is matching

This commit is contained in:
Dzmitry Malyshau 2019-01-19 23:40:35 -05:00
parent fb6a91589e
commit 3ed4620c1f
4 changed files with 59 additions and 19 deletions

View File

@ -1,6 +1,6 @@
use crate::track::{BufferTracker, TextureTracker};
use crate::{
LifeGuard,
LifeGuard, WeaklyStored,
BindGroupLayoutId, BufferId, SamplerId, TextureViewId,
};
@ -50,6 +50,7 @@ pub struct PipelineLayoutDescriptor {
pub(crate) struct PipelineLayout<B: hal::Backend> {
pub raw: B::PipelineLayout,
pub bind_group_layout_ids: Vec<WeaklyStored<BindGroupLayoutId>>,
}
#[repr(C)]
@ -81,6 +82,7 @@ pub struct BindGroupDescriptor {
pub(crate) struct BindGroup<B: hal::Backend> {
pub raw: B::DescriptorSet,
pub layout_id: WeaklyStored<BindGroupLayoutId>,
pub life_guard: LifeGuard,
pub used_buffers: BufferTracker,
pub used_textures: TextureTracker,

View File

@ -1,5 +1,9 @@
use crate::registry::{Items, HUB};
use crate::{BindGroupId, CommandBufferId, ComputePassId, ComputePipelineId, Stored};
use crate::{
Stored, WeaklyStored,
BindGroupId, CommandBufferId, ComputePassId, ComputePipelineId, PipelineLayoutId,
};
use super::{BindGroupEntry};
use hal::command::RawCommandBuffer;
@ -9,11 +13,18 @@ use std::iter;
pub struct ComputePass<B: hal::Backend> {
raw: B::CommandBuffer,
cmb_id: Stored<CommandBufferId>,
pipeline_layout_id: Option<WeaklyStored<PipelineLayoutId>>, //TODO: strongly `Stored`
bind_groups: Vec<BindGroupEntry>,
}
impl<B: hal::Backend> ComputePass<B> {
pub(crate) fn new(raw: B::CommandBuffer, cmb_id: Stored<CommandBufferId>) -> Self {
ComputePass { raw, cmb_id }
ComputePass {
raw,
cmb_id,
pipeline_layout_id: None,
bind_groups: Vec::new(),
}
}
}
@ -35,22 +46,35 @@ pub extern "C" fn wgpu_compute_pass_set_bind_group(
index: u32,
bind_group_id: BindGroupId,
) {
let mut pass_guard = HUB.compute_passes.write();
let pass = pass_guard.get_mut(pass_id);
let bind_group_guard = HUB.bind_groups.read();
let bind_group = bind_group_guard.get(bind_group_id);
let layout = unimplemented!();
// see https://github.com/gpuweb/gpuweb/pull/93
let pipeline_layout_guard = HUB.pipeline_layouts.read();
unsafe {
HUB.compute_passes
.write()
.get_mut(pass_id)
.raw
.bind_compute_descriptor_sets(
layout,
index as usize,
iter::once(&bind_group.raw),
&[],
);
while pass.bind_groups.len() <= index as usize {
pass.bind_groups.push(BindGroupEntry::default());
}
*pass.bind_groups.get_mut(index as usize).unwrap() = BindGroupEntry {
layout: Some(bind_group.layout_id.clone()),
data: Some(Stored {
value: bind_group_id,
ref_count: bind_group.life_guard.ref_count.clone(),
}),
};
if let Some(ref pipeline_layout_id) = pass.pipeline_layout_id {
let pipeline_layout = pipeline_layout_guard.get(pipeline_layout_id.0);
if pipeline_layout.bind_group_layout_ids[index as usize] == bind_group.layout_id {
unsafe {
pass.raw.bind_compute_descriptor_sets(
&pipeline_layout.raw,
index as usize,
iter::once(&bind_group.raw),
&[],
);
}
}
}
}

View File

@ -13,6 +13,7 @@ use crate::track::{BufferTracker, TextureTracker};
use crate::{conv, resource};
use crate::{
BufferId, CommandBufferId, ComputePassId, DeviceId,
BindGroupId, BindGroupLayoutId,
RenderPassId, TextureId, TextureViewId,
BufferUsageFlags, TextureUsageFlags, Color, Origin3d,
LifeGuard, Stored, WeaklyStored,
@ -29,6 +30,12 @@ use std::slice;
use std::thread::ThreadId;
#[derive(Clone, Default)]
struct BindGroupEntry {
layout: Option<WeaklyStored<BindGroupLayoutId>>,
data: Option<Stored<BindGroupId>>,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)]
pub enum LoadOp {

View File

@ -462,10 +462,11 @@ pub extern "C" fn wgpu_device_create_pipeline_layout(
device_id: DeviceId,
desc: &binding_model::PipelineLayoutDescriptor,
) -> PipelineLayoutId {
let bind_group_layouts =
unsafe { slice::from_raw_parts(desc.bind_group_layouts, desc.bind_group_layouts_length) };
let bind_group_layout_ids = unsafe {
slice::from_raw_parts(desc.bind_group_layouts, desc.bind_group_layouts_length)
};
let bind_group_layout_guard = HUB.bind_group_layouts.read();
let descriptor_set_layouts = bind_group_layouts
let descriptor_set_layouts = bind_group_layout_ids
.iter()
.map(|&id| &bind_group_layout_guard.get(id).raw);
@ -483,6 +484,11 @@ pub extern "C" fn wgpu_device_create_pipeline_layout(
.write()
.register(binding_model::PipelineLayout {
raw: pipeline_layout,
bind_group_layout_ids: bind_group_layout_ids
.iter()
.cloned()
.map(WeaklyStored)
.collect(),
})
}
@ -562,6 +568,7 @@ pub extern "C" fn wgpu_device_create_bind_group(
.write()
.register(binding_model::BindGroup {
raw: desc_set,
layout_id: WeaklyStored(desc.layout),
life_guard: LifeGuard::new(),
used_buffers,
used_textures,