[wgpu-core] introduce Registry .strict_get() & .strict_unregister() and use them for adapters

This works because we never assign errors to adapters (they are never invalid).
This commit is contained in:
teoxoy 2024-09-05 14:30:41 +02:00 committed by Teodor Tanasoaia
parent 70a9c01b48
commit 98426329a4
9 changed files with 81 additions and 105 deletions

View File

@ -414,9 +414,9 @@ pub fn op_webgpu_request_adapter(
})
}
};
let adapter_features = instance.adapter_features(adapter)?;
let adapter_features = instance.adapter_features(adapter);
let features = deserialize_features(&adapter_features);
let adapter_limits = instance.adapter_limits(adapter)?;
let adapter_limits = instance.adapter_limits(adapter);
let instance = instance.clone();
@ -705,7 +705,7 @@ pub fn op_webgpu_request_adapter_info(
let adapter = adapter_resource.1;
let instance = state.borrow::<Instance>();
let info = instance.adapter_get_info(adapter)?;
let info = instance.adapter_get_info(adapter);
adapter_resource.close();
Ok(GPUAdapterInfo {

View File

@ -78,7 +78,7 @@ fn main() {
)
.expect("Unable to find an adapter for selected backend");
let info = global.adapter_get_info(adapter).unwrap();
let info = global.adapter_get_info(adapter);
log::info!("Picked '{}'", info.name);
let device_id = wgc::id::Id::zip(1, 0, backend);
let queue_id = wgc::id::Id::zip(1, 0, backend);

View File

@ -244,8 +244,8 @@ impl Corpus {
};
println!("\tBackend {:?}", backend);
let supported_features = global.adapter_features(adapter).unwrap();
let downlevel_caps = global.adapter_downlevel_capabilities(adapter).unwrap();
let supported_features = global.adapter_features(adapter);
let downlevel_caps = global.adapter_downlevel_capabilities(adapter);
let test = Test::load(dir.join(test_path), adapter.backend());
if !supported_features.contains(test.features) {

View File

@ -43,10 +43,7 @@ impl Global {
let hub = &self.hub;
let surface_guard = self.surfaces.read();
let adapter_guard = hub.adapters.read();
let adapter = adapter_guard
.get(adapter_id)
.map_err(|_| instance::IsSurfaceSupportedError::InvalidAdapter)?;
let adapter = hub.adapters.strict_get(adapter_id);
let surface = surface_guard
.get(surface_id)
.map_err(|_| instance::IsSurfaceSupportedError::InvalidSurface)?;
@ -87,15 +84,13 @@ impl Global {
let hub = &self.hub;
let surface_guard = self.surfaces.read();
let adapter_guard = hub.adapters.read();
let adapter = adapter_guard
.get(adapter_id)
.map_err(|_| instance::GetSurfaceSupportError::InvalidAdapter)?;
let adapter = hub.adapters.strict_get(adapter_id);
let surface = surface_guard
.get(surface_id)
.map_err(|_| instance::GetSurfaceSupportError::InvalidSurface)?;
get_supported_callback(adapter, surface)
get_supported_callback(&adapter, surface)
}
pub fn device_features(&self, device_id: DeviceId) -> Result<wgt::Features, DeviceError> {

View File

@ -349,8 +349,6 @@ crate::impl_storage_item!(Adapter);
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum IsSurfaceSupportedError {
#[error("Invalid adapter")]
InvalidAdapter,
#[error("Invalid surface")]
InvalidSurface,
}
@ -358,8 +356,6 @@ pub enum IsSurfaceSupportedError {
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum GetSurfaceSupportError {
#[error("Invalid adapter")]
InvalidAdapter,
#[error("Invalid surface")]
InvalidSurface,
#[error("Surface is not supported by the adapter")]
@ -373,8 +369,6 @@ pub enum GetSurfaceSupportError {
pub enum RequestDeviceError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Parent adapter is invalid")]
InvalidAdapter,
#[error(transparent)]
LimitsExceeded(#[from] FailedLimit),
#[error("Device has no queue supporting graphics")]
@ -403,10 +397,6 @@ impl<M: Marker> AdapterInputs<'_, M> {
}
}
#[derive(Clone, Debug, Error)]
#[error("Adapter is invalid")]
pub struct InvalidAdapter;
#[derive(Clone, Debug, Error)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
@ -869,73 +859,51 @@ impl Global {
id
}
pub fn adapter_get_info(
&self,
adapter_id: AdapterId,
) -> Result<wgt::AdapterInfo, InvalidAdapter> {
self.hub
.adapters
.get(adapter_id)
.map(|adapter| adapter.raw.info.clone())
.map_err(|_| InvalidAdapter)
pub fn adapter_get_info(&self, adapter_id: AdapterId) -> wgt::AdapterInfo {
let adapter = self.hub.adapters.strict_get(adapter_id);
adapter.raw.info.clone()
}
pub fn adapter_get_texture_format_features(
&self,
adapter_id: AdapterId,
format: wgt::TextureFormat,
) -> Result<wgt::TextureFormatFeatures, InvalidAdapter> {
self.hub
.adapters
.get(adapter_id)
.map(|adapter| adapter.get_texture_format_features(format))
.map_err(|_| InvalidAdapter)
) -> wgt::TextureFormatFeatures {
let adapter = self.hub.adapters.strict_get(adapter_id);
adapter.get_texture_format_features(format)
}
pub fn adapter_features(&self, adapter_id: AdapterId) -> Result<wgt::Features, InvalidAdapter> {
self.hub
.adapters
.get(adapter_id)
.map(|adapter| adapter.raw.features)
.map_err(|_| InvalidAdapter)
pub fn adapter_features(&self, adapter_id: AdapterId) -> wgt::Features {
let adapter = self.hub.adapters.strict_get(adapter_id);
adapter.raw.features
}
pub fn adapter_limits(&self, adapter_id: AdapterId) -> Result<wgt::Limits, InvalidAdapter> {
self.hub
.adapters
.get(adapter_id)
.map(|adapter| adapter.raw.capabilities.limits.clone())
.map_err(|_| InvalidAdapter)
pub fn adapter_limits(&self, adapter_id: AdapterId) -> wgt::Limits {
let adapter = self.hub.adapters.strict_get(adapter_id);
adapter.raw.capabilities.limits.clone()
}
pub fn adapter_downlevel_capabilities(
&self,
adapter_id: AdapterId,
) -> Result<wgt::DownlevelCapabilities, InvalidAdapter> {
self.hub
.adapters
.get(adapter_id)
.map(|adapter| adapter.raw.capabilities.downlevel.clone())
.map_err(|_| InvalidAdapter)
) -> wgt::DownlevelCapabilities {
let adapter = self.hub.adapters.strict_get(adapter_id);
adapter.raw.capabilities.downlevel.clone()
}
pub fn adapter_get_presentation_timestamp(
&self,
adapter_id: AdapterId,
) -> Result<wgt::PresentationTimestamp, InvalidAdapter> {
let hub = &self.hub;
let adapter = hub.adapters.get(adapter_id).map_err(|_| InvalidAdapter)?;
Ok(unsafe { adapter.raw.adapter.get_presentation_timestamp() })
) -> wgt::PresentationTimestamp {
let adapter = self.hub.adapters.strict_get(adapter_id);
unsafe { adapter.raw.adapter.get_presentation_timestamp() }
}
pub fn adapter_drop(&self, adapter_id: AdapterId) {
profiling::scope!("Adapter::drop");
api_log!("Adapter::drop {adapter_id:?}");
let hub = &self.hub;
hub.adapters.unregister(adapter_id);
self.hub.adapters.strict_unregister(adapter_id);
}
}
@ -956,10 +924,7 @@ impl Global {
let queue_fid = self.hub.queues.prepare(backend, queue_id_in);
let error = 'error: {
let adapter = match self.hub.adapters.get(adapter_id) {
Ok(adapter) => adapter,
Err(_) => break 'error RequestDeviceError::InvalidAdapter,
};
let adapter = self.hub.adapters.strict_get(adapter_id);
let (device, queue) =
match adapter.create_device_and_queue(desc, self.instance.flags, trace_path) {
Ok((device, queue)) => (device, queue),
@ -1000,10 +965,7 @@ impl Global {
let queues_fid = self.hub.queues.prepare(backend, queue_id_in);
let error = 'error: {
let adapter = match self.hub.adapters.get(adapter_id) {
Ok(adapter) => adapter,
Err(_) => break 'error RequestDeviceError::InvalidAdapter,
};
let adapter = self.hub.adapters.strict_get(adapter_id);
let (device, queue) = match adapter.create_device_and_queue_from_hal(
hal_device,
desc,

View File

@ -120,6 +120,15 @@ impl<T: StorageItem> Registry<T> {
//Returning None is legal if it's an error ID
value
}
pub(crate) fn strict_unregister(&self, id: Id<T::Marker>) -> T {
let value = self.storage.write().strict_remove(id);
// This needs to happen *after* removing it from the storage, to maintain the
// invariant that `self.identity` only contains ids which are actually available
// See https://github.com/gfx-rs/wgpu/issues/5372
self.identity.free(id);
//Returning None is legal if it's an error ID
value
}
pub(crate) fn generate_report(&self) -> RegistryReport {
let storage = self.storage.read();
@ -143,6 +152,10 @@ impl<T: StorageItem + Clone> Registry<T> {
pub(crate) fn get(&self, id: Id<T::Marker>) -> Result<T, InvalidId> {
self.read().get_owned(id)
}
pub(crate) fn strict_get(&self, id: Id<T::Marker>) -> T {
self.read().strict_get(id)
}
}
#[cfg(test)]

View File

@ -1272,11 +1272,8 @@ impl Global {
profiling::scope!("Adapter::as_hal");
let hub = &self.hub;
let adapter = hub.adapters.get(id).ok();
let hal_adapter = adapter
.as_ref()
.map(|adapter| &adapter.raw.adapter)
.and_then(|adapter| adapter.as_any().downcast_ref());
let adapter = hub.adapters.strict_get(id);
let hal_adapter = adapter.raw.adapter.as_any().downcast_ref();
hal_adapter_callback(hal_adapter)
}

View File

@ -157,6 +157,18 @@ where
}
}
pub(crate) fn strict_remove(&mut self, id: Id<T::Marker>) -> T {
let (index, epoch, _) = id.unzip();
match std::mem::replace(&mut self.map[index as usize], Element::Vacant) {
Element::Occupied(value, storage_epoch) => {
assert_eq!(epoch, storage_epoch);
value
}
Element::Error(_) => unreachable!(),
Element::Vacant => panic!("Cannot remove a vacant resource"),
}
}
pub(crate) fn iter(&self, backend: Backend) -> impl Iterator<Item = (Id<T::Marker>, &T)> {
self.map
.iter()
@ -183,4 +195,21 @@ where
pub(crate) fn get_owned(&self, id: Id<T::Marker>) -> Result<T, InvalidId> {
Ok(self.get(id)?.clone())
}
/// Get an owned reference to an item.
/// Panics if there is an epoch mismatch, the entry is empty or in error.
pub(crate) fn strict_get(&self, id: Id<T::Marker>) -> T {
let (index, epoch, _) = id.unzip();
let (result, storage_epoch) = match self.map.get(index as usize) {
Some(&Element::Occupied(ref v, epoch)) => (v.clone(), epoch),
None | Some(&Element::Vacant) => panic!("{}[{:?}] does not exist", self.kind, id),
Some(&Element::Error(_)) => unreachable!(),
};
assert_eq!(
epoch, storage_epoch,
"{}[{:?}] is no longer alive",
self.kind, id
);
result
}
}

View File

@ -651,34 +651,22 @@ impl crate::Context for ContextWgpuCore {
}
fn adapter_features(&self, adapter_data: &Self::AdapterData) -> Features {
match self.0.adapter_features(*adapter_data) {
Ok(features) => features,
Err(err) => self.handle_error_fatal(err, "Adapter::features"),
}
self.0.adapter_features(*adapter_data)
}
fn adapter_limits(&self, adapter_data: &Self::AdapterData) -> Limits {
match self.0.adapter_limits(*adapter_data) {
Ok(limits) => limits,
Err(err) => self.handle_error_fatal(err, "Adapter::limits"),
}
self.0.adapter_limits(*adapter_data)
}
fn adapter_downlevel_capabilities(
&self,
adapter_data: &Self::AdapterData,
) -> DownlevelCapabilities {
match self.0.adapter_downlevel_capabilities(*adapter_data) {
Ok(downlevel) => downlevel,
Err(err) => self.handle_error_fatal(err, "Adapter::downlevel_properties"),
}
self.0.adapter_downlevel_capabilities(*adapter_data)
}
fn adapter_get_info(&self, adapter_data: &Self::AdapterData) -> AdapterInfo {
match self.0.adapter_get_info(*adapter_data) {
Ok(info) => info,
Err(err) => self.handle_error_fatal(err, "Adapter::get_info"),
}
self.0.adapter_get_info(*adapter_data)
}
fn adapter_get_texture_format_features(
@ -686,23 +674,15 @@ impl crate::Context for ContextWgpuCore {
adapter_data: &Self::AdapterData,
format: wgt::TextureFormat,
) -> wgt::TextureFormatFeatures {
match self
.0
self.0
.adapter_get_texture_format_features(*adapter_data, format)
{
Ok(info) => info,
Err(err) => self.handle_error_fatal(err, "Adapter::get_texture_format_features"),
}
}
fn adapter_get_presentation_timestamp(
&self,
adapter_data: &Self::AdapterData,
) -> wgt::PresentationTimestamp {
match self.0.adapter_get_presentation_timestamp(*adapter_data) {
Ok(timestamp) => timestamp,
Err(err) => self.handle_error_fatal(err, "Adapter::correlate_presentation_timestamp"),
}
self.0.adapter_get_presentation_timestamp(*adapter_data)
}
fn surface_get_capabilities(