diff --git a/wgpu-native/src/binding_model.rs b/wgpu-native/src/binding_model.rs index c0706a771..29975d5b5 100644 --- a/wgpu-native/src/binding_model.rs +++ b/wgpu-native/src/binding_model.rs @@ -1,6 +1,6 @@ use crate::track::{BufferTracker, TextureTracker}; use crate::{ - LifeGuard, + LifeGuard, WeaklyStored, BindGroupLayoutId, BufferId, SamplerId, TextureViewId, }; @@ -50,6 +50,7 @@ pub struct PipelineLayoutDescriptor { pub(crate) struct PipelineLayout { pub raw: B::PipelineLayout, + pub bind_group_layout_ids: Vec>, } #[repr(C)] @@ -81,6 +82,7 @@ pub struct BindGroupDescriptor { pub(crate) struct BindGroup { pub raw: B::DescriptorSet, + pub layout_id: WeaklyStored, pub life_guard: LifeGuard, pub used_buffers: BufferTracker, pub used_textures: TextureTracker, diff --git a/wgpu-native/src/command/compute.rs b/wgpu-native/src/command/compute.rs index ccbda31ac..551996733 100644 --- a/wgpu-native/src/command/compute.rs +++ b/wgpu-native/src/command/compute.rs @@ -1,5 +1,9 @@ use crate::registry::{Items, HUB}; -use crate::{BindGroupId, CommandBufferId, ComputePassId, ComputePipelineId, Stored}; +use crate::{ + Stored, WeaklyStored, + BindGroupId, CommandBufferId, ComputePassId, ComputePipelineId, PipelineLayoutId, +}; +use super::{BindGroupEntry}; use hal::command::RawCommandBuffer; @@ -9,11 +13,18 @@ use std::iter; pub struct ComputePass { raw: B::CommandBuffer, cmb_id: Stored, + pipeline_layout_id: Option>, //TODO: strongly `Stored` + bind_groups: Vec, } impl ComputePass { pub(crate) fn new(raw: B::CommandBuffer, cmb_id: Stored) -> Self { - ComputePass { raw, cmb_id } + ComputePass { + raw, + cmb_id, + pipeline_layout_id: None, + bind_groups: Vec::new(), + } } } @@ -35,22 +46,35 @@ pub extern "C" fn wgpu_compute_pass_set_bind_group( index: u32, bind_group_id: BindGroupId, ) { + let mut pass_guard = HUB.compute_passes.write(); + let pass = pass_guard.get_mut(pass_id); let bind_group_guard = HUB.bind_groups.read(); let bind_group = bind_group_guard.get(bind_group_id); - let layout = unimplemented!(); - // see https://github.com/gpuweb/gpuweb/pull/93 + let pipeline_layout_guard = HUB.pipeline_layouts.read(); - unsafe { - HUB.compute_passes - .write() - .get_mut(pass_id) - .raw - .bind_compute_descriptor_sets( - layout, - index as usize, - iter::once(&bind_group.raw), - &[], - ); + while pass.bind_groups.len() <= index as usize { + pass.bind_groups.push(BindGroupEntry::default()); + } + *pass.bind_groups.get_mut(index as usize).unwrap() = BindGroupEntry { + layout: Some(bind_group.layout_id.clone()), + data: Some(Stored { + value: bind_group_id, + ref_count: bind_group.life_guard.ref_count.clone(), + }), + }; + + if let Some(ref pipeline_layout_id) = pass.pipeline_layout_id { + let pipeline_layout = pipeline_layout_guard.get(pipeline_layout_id.0); + if pipeline_layout.bind_group_layout_ids[index as usize] == bind_group.layout_id { + unsafe { + pass.raw.bind_compute_descriptor_sets( + &pipeline_layout.raw, + index as usize, + iter::once(&bind_group.raw), + &[], + ); + } + } } } diff --git a/wgpu-native/src/command/mod.rs b/wgpu-native/src/command/mod.rs index 298ae77e9..2c10174db 100644 --- a/wgpu-native/src/command/mod.rs +++ b/wgpu-native/src/command/mod.rs @@ -13,6 +13,7 @@ use crate::track::{BufferTracker, TextureTracker}; use crate::{conv, resource}; use crate::{ BufferId, CommandBufferId, ComputePassId, DeviceId, + BindGroupId, BindGroupLayoutId, RenderPassId, TextureId, TextureViewId, BufferUsageFlags, TextureUsageFlags, Color, Origin3d, LifeGuard, Stored, WeaklyStored, @@ -29,6 +30,12 @@ use std::slice; use std::thread::ThreadId; +#[derive(Clone, Default)] +struct BindGroupEntry { + layout: Option>, + data: Option>, +} + #[repr(C)] #[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] pub enum LoadOp { diff --git a/wgpu-native/src/device.rs b/wgpu-native/src/device.rs index e55801fa3..bf594909e 100644 --- a/wgpu-native/src/device.rs +++ b/wgpu-native/src/device.rs @@ -462,10 +462,11 @@ pub extern "C" fn wgpu_device_create_pipeline_layout( device_id: DeviceId, desc: &binding_model::PipelineLayoutDescriptor, ) -> PipelineLayoutId { - let bind_group_layouts = - unsafe { slice::from_raw_parts(desc.bind_group_layouts, desc.bind_group_layouts_length) }; + let bind_group_layout_ids = unsafe { + slice::from_raw_parts(desc.bind_group_layouts, desc.bind_group_layouts_length) + }; let bind_group_layout_guard = HUB.bind_group_layouts.read(); - let descriptor_set_layouts = bind_group_layouts + let descriptor_set_layouts = bind_group_layout_ids .iter() .map(|&id| &bind_group_layout_guard.get(id).raw); @@ -483,6 +484,11 @@ pub extern "C" fn wgpu_device_create_pipeline_layout( .write() .register(binding_model::PipelineLayout { raw: pipeline_layout, + bind_group_layout_ids: bind_group_layout_ids + .iter() + .cloned() + .map(WeaklyStored) + .collect(), }) } @@ -562,6 +568,7 @@ pub extern "C" fn wgpu_device_create_bind_group( .write() .register(binding_model::BindGroup { raw: desc_set, + layout_id: WeaklyStored(desc.layout), life_guard: LifeGuard::new(), used_buffers, used_textures,