From 5735f85720abfde3846cac1a332a1d172d6431d8 Mon Sep 17 00:00:00 2001 From: Samson <16504129+sagudev@users.noreply.github.com> Date: Thu, 25 Apr 2024 12:17:00 +0200 Subject: [PATCH] CreateBindGroup validation error on device mismatch (#5596) * Fix cts_runner command invocation in readme * Remove assertDeviceMatch from deno_webgpu in createBindGroup This should be done as verification in wgpu-core. * Add device mismatched check to create_buffer_binding * Extract common logic to create_sampler_binding * Move common logic to create_texture_binding and add device mismatch check --- README.md | 2 +- deno_webgpu/01_webgpu.js | 15 --- wgpu-core/src/device/resource.rs | 205 +++++++++++++++++-------------- 3 files changed, 116 insertions(+), 106 deletions(-) diff --git a/README.md b/README.md index bc0f01b30..c1635042f 100644 --- a/README.md +++ b/README.md @@ -199,7 +199,7 @@ To run a given set of tests: ``` # Must be inside the `cts` folder we just checked out, else this will fail -cargo run --manifest-path ../Cargo.toml --bin cts_runner -- ./tools/run_deno --verbose "" +cargo run --manifest-path ../Cargo.toml -p cts_runner --bin cts_runner -- ./tools/run_deno --verbose "" ``` To find the full list of tests, go to the [online cts viewer](https://gpuweb.github.io/cts/standalone/?runnow=0&worker=0&debug=0&q=webgpu:*). diff --git a/deno_webgpu/01_webgpu.js b/deno_webgpu/01_webgpu.js index dbac07615..369d1cd9b 100644 --- a/deno_webgpu/01_webgpu.js +++ b/deno_webgpu/01_webgpu.js @@ -1289,11 +1289,6 @@ class GPUDevice extends EventTarget { const resource = entry.resource; if (ObjectPrototypeIsPrototypeOf(GPUSamplerPrototype, resource)) { const rid = assertResource(resource, prefix, context); - assertDeviceMatch(device, resource, { - prefix, - resourceContext: context, - selfContext: "this", - }); return { binding: entry.binding, kind: "GPUSampler", @@ -1304,11 +1299,6 @@ class GPUDevice extends EventTarget { ) { const rid = assertResource(resource, prefix, context); assertResource(resource[_texture], prefix, context); - assertDeviceMatch(device, resource[_texture], { - prefix, - resourceContext: context, - selfContext: "this", - }); return { binding: entry.binding, kind: "GPUTextureView", @@ -1318,11 +1308,6 @@ class GPUDevice extends EventTarget { // deno-lint-ignore prefer-primordials const rid = assertResource(resource.buffer, prefix, context); // deno-lint-ignore prefer-primordials - assertDeviceMatch(device, resource.buffer, { - prefix, - resourceContext: context, - selfContext: "this", - }); return { binding: entry.binding, kind: "GPUBufferBinding", diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 403edf87a..7dc57b83a 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -13,6 +13,7 @@ use crate::{ hal_api::HalApi, hal_label, hub::Hub, + id, init_tracker::{ BufferInitTracker, BufferInitTrackerAction, MemoryInitKind, TextureInitRange, TextureInitTracker, TextureInitTrackerAction, @@ -1949,6 +1950,7 @@ impl Device { used: &mut BindGroupStates, storage: &'a Storage>, limits: &wgt::Limits, + device_id: id::Id, snatch_guard: &'a SnatchGuard<'a>, ) -> Result, binding_model::CreateBindGroupError> { use crate::binding_model::CreateBindGroupError as Error; @@ -1967,6 +1969,7 @@ impl Device { }) } }; + let (pub_usage, internal_use, range_limit) = match binding_ty { wgt::BufferBindingType::Uniform => ( wgt::BufferUsages::UNIFORM, @@ -1999,6 +2002,10 @@ impl Device { .add_single(storage, bb.buffer_id, internal_use) .ok_or(Error::InvalidBuffer(bb.buffer_id))?; + if buffer.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice.into()); + } + check_buffer_usage(bb.buffer_id, buffer.usage, pub_usage)?; let raw_buffer = buffer .raw @@ -2077,13 +2084,53 @@ impl Device { }) } - pub(crate) fn create_texture_binding( - view: &TextureView, - internal_use: hal::TextureUses, - pub_usage: wgt::TextureUsages, + fn create_sampler_binding<'a>( + used: &BindGroupStates, + storage: &'a Storage>, + id: id::Id, + device_id: id::Id, + ) -> Result<&'a Sampler, binding_model::CreateBindGroupError> { + use crate::binding_model::CreateBindGroupError as Error; + + let sampler = used + .samplers + .add_single(storage, id) + .ok_or(Error::InvalidSampler(id))?; + + if sampler.device.as_info().id() != device_id { + return Err(DeviceError::WrongDevice.into()); + } + + Ok(sampler) + } + + pub(crate) fn create_texture_binding<'a>( + self: &Arc, + binding: u32, + decl: &wgt::BindGroupLayoutEntry, + storage: &'a Storage>, + id: id::Id, used: &mut BindGroupStates, used_texture_ranges: &mut Vec>, - ) -> Result<(), binding_model::CreateBindGroupError> { + snatch_guard: &'a SnatchGuard<'a>, + ) -> Result, binding_model::CreateBindGroupError> { + use crate::binding_model::CreateBindGroupError as Error; + + let view = used + .views + .add_single(storage, id) + .ok_or(Error::InvalidTextureView(id))?; + + if view.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); + } + + let (pub_usage, internal_use) = self.texture_use_parameters( + binding, + decl, + view, + "SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture", + )?; let texture = &view.parent; let texture_id = texture.as_info().id(); // Careful here: the texture may no longer have its own ref count, @@ -2113,7 +2160,12 @@ impl Device { kind: MemoryInitKind::NeedsInitializedMemory, }); - Ok(()) + Ok(hal::TextureBinding { + view: view + .raw(snatch_guard) + .ok_or(Error::InvalidTextureView(id))?, + usage: internal_use, + }) } // This function expects the provided bind group layout to be resolved @@ -2175,6 +2227,7 @@ impl Device { &mut used, &*buffer_guard, &self.limits, + self.as_info().id(), &snatch_guard, )?; @@ -2198,105 +2251,86 @@ impl Device { &mut used, &*buffer_guard, &self.limits, + self.as_info().id(), &snatch_guard, )?; hal_buffers.push(bb); } (res_index, num_bindings) } - Br::Sampler(id) => { - match decl.ty { - wgt::BindingType::Sampler(ty) => { - let sampler = used - .samplers - .add_single(&*sampler_guard, id) - .ok_or(Error::InvalidSampler(id))?; + Br::Sampler(id) => match decl.ty { + wgt::BindingType::Sampler(ty) => { + let sampler = Self::create_sampler_binding( + &used, + &sampler_guard, + id, + self.as_info().id(), + )?; - if sampler.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } - - // Allowed sampler values for filtering and comparison - let (allowed_filtering, allowed_comparison) = match ty { - wgt::SamplerBindingType::Filtering => (None, false), - wgt::SamplerBindingType::NonFiltering => (Some(false), false), - wgt::SamplerBindingType::Comparison => (None, true), - }; - - if let Some(allowed_filtering) = allowed_filtering { - if allowed_filtering != sampler.filtering { - return Err(Error::WrongSamplerFiltering { - binding, - layout_flt: allowed_filtering, - sampler_flt: sampler.filtering, - }); - } - } - - if allowed_comparison != sampler.comparison { - return Err(Error::WrongSamplerComparison { + let (allowed_filtering, allowed_comparison) = match ty { + wgt::SamplerBindingType::Filtering => (None, false), + wgt::SamplerBindingType::NonFiltering => (Some(false), false), + wgt::SamplerBindingType::Comparison => (None, true), + }; + if let Some(allowed_filtering) = allowed_filtering { + if allowed_filtering != sampler.filtering { + return Err(Error::WrongSamplerFiltering { binding, - layout_cmp: allowed_comparison, - sampler_cmp: sampler.comparison, + layout_flt: allowed_filtering, + sampler_flt: sampler.filtering, }); } - - let res_index = hal_samplers.len(); - hal_samplers.push(sampler.raw()); - (res_index, 1) } - _ => { - return Err(Error::WrongBindingType { + if allowed_comparison != sampler.comparison { + return Err(Error::WrongSamplerComparison { binding, - actual: decl.ty, - expected: "Sampler", - }) + layout_cmp: allowed_comparison, + sampler_cmp: sampler.comparison, + }); } + + let res_index = hal_samplers.len(); + hal_samplers.push(sampler.raw()); + (res_index, 1) } - } + _ => { + return Err(Error::WrongBindingType { + binding, + actual: decl.ty, + expected: "Sampler", + }) + } + }, Br::SamplerArray(ref bindings_array) => { let num_bindings = bindings_array.len(); Self::check_array_binding(self.features, decl.count, num_bindings)?; let res_index = hal_samplers.len(); for &id in bindings_array.iter() { - let sampler = used - .samplers - .add_single(&*sampler_guard, id) - .ok_or(Error::InvalidSampler(id))?; - if sampler.device.as_info().id() != self.as_info().id() { - return Err(DeviceError::WrongDevice.into()); - } + let sampler = Self::create_sampler_binding( + &used, + &sampler_guard, + id, + self.as_info().id(), + )?; + hal_samplers.push(sampler.raw()); } (res_index, num_bindings) } Br::TextureView(id) => { - let view = used - .views - .add_single(&*texture_view_guard, id) - .ok_or(Error::InvalidTextureView(id))?; - let (pub_usage, internal_use) = self.texture_use_parameters( + let tb = self.create_texture_binding( binding, decl, - view, - "SampledTexture, ReadonlyStorageTexture or WriteonlyStorageTexture", - )?; - Self::create_texture_binding( - view, - internal_use, - pub_usage, + &texture_view_guard, + id, &mut used, &mut used_texture_ranges, + &snatch_guard, )?; let res_index = hal_textures.len(); - hal_textures.push(hal::TextureBinding { - view: view - .raw(&snatch_guard) - .ok_or(Error::InvalidTextureView(id))?, - usage: internal_use, - }); + hal_textures.push(tb); (res_index, 1) } Br::TextureViewArray(ref bindings_array) => { @@ -2305,26 +2339,17 @@ impl Device { let res_index = hal_textures.len(); for &id in bindings_array.iter() { - let view = used - .views - .add_single(&*texture_view_guard, id) - .ok_or(Error::InvalidTextureView(id))?; - let (pub_usage, internal_use) = - self.texture_use_parameters(binding, decl, view, - "SampledTextureArray, ReadonlyStorageTextureArray or WriteonlyStorageTextureArray")?; - Self::create_texture_binding( - view, - internal_use, - pub_usage, + let tb = self.create_texture_binding( + binding, + decl, + &texture_view_guard, + id, &mut used, &mut used_texture_ranges, + &snatch_guard, )?; - hal_textures.push(hal::TextureBinding { - view: view - .raw(&snatch_guard) - .ok_or(Error::InvalidTextureView(id))?, - usage: internal_use, - }); + + hal_textures.push(tb); } (res_index, num_bindings)