mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-02-14 16:02:47 +00:00
Merge #704
704: Pipeline layout validation r=cwfitzgerald a=kvark **Connections** Implements a solid part of #269 Starts converting the function to return results, related to #638 cc @GabrielMajeri **Description** This change matches shader bindings against the pipeline layout. It's *mostly* complete, minus some bugs and not handling the `storage_texture_format` properly. The risk here is that Naga reflection may have bugs, or our validation may have bugs, and we don't want to break the user content while this is in flux. So the PR introduces an internal `WGPU_SHADER_VALIDATION` environment variable. Switching it to "0" skips Naga shader parsing completely and allows the users to unsafely use the API. Another aspect of the PR is that some of the functions now return `Result`. The way I see us proceeding is that any errors that we don't expect users to handle should result in panics when `wgpu` is used natively (i.e. not from a browser). These panics would happen in the "direct" backend of wgpu-rs (as well as in wgpu-native), but the `Result` would not be exposed to wgpu-rs, so that it matches the Web behavior. At the same time, browser implementations (Gecko and Servo) will check the result on their GPU process and implement the WebGPU error model accordingly. This means `wgpu-core` can be super Rusty and safe. **Testing** Running on wgpu-rs examples. Most of them fail to get parsed by Naga, but `boids` succeeds and passes validation 🎉 Co-authored-by: Dzmitry Malyshau <kvarkus@gmail.com>
This commit is contained in:
commit
417ea69b45
@ -263,7 +263,8 @@ impl GlobalExt for wgc::hub::Global<IdentityPassThroughFactory> {
|
||||
entries_length: entries.len(),
|
||||
},
|
||||
id,
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
A::DestroyBindGroupLayout(id) => {
|
||||
self.bind_group_layout_destroy::<B>(id);
|
||||
@ -280,7 +281,8 @@ impl GlobalExt for wgc::hub::Global<IdentityPassThroughFactory> {
|
||||
bind_group_layouts_length: bind_group_layouts.len(),
|
||||
},
|
||||
id,
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
A::DestroyPipelineLayout(id) => {
|
||||
self.pipeline_layout_destroy::<B>(id);
|
||||
@ -353,7 +355,8 @@ impl GlobalExt for wgc::hub::Global<IdentityPassThroughFactory> {
|
||||
compute_stage: cs_stage.desc,
|
||||
},
|
||||
id,
|
||||
);
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
A::DestroyComputePipeline(id) => {
|
||||
self.compute_pipeline_destroy::<B>(id);
|
||||
|
@ -47,6 +47,38 @@ pub struct BindGroupLayoutEntry {
|
||||
pub storage_texture_format: wgt::TextureFormat,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum BindGroupLayoutEntryError {
|
||||
NoVisibility,
|
||||
UnexpectedHasDynamicOffset,
|
||||
UnexpectedMultisampled,
|
||||
}
|
||||
|
||||
impl BindGroupLayoutEntry {
|
||||
pub(crate) fn validate(&self) -> Result<(), BindGroupLayoutEntryError> {
|
||||
if self.visibility.is_empty() {
|
||||
return Err(BindGroupLayoutEntryError::NoVisibility);
|
||||
}
|
||||
match self.ty {
|
||||
BindingType::UniformBuffer | BindingType::StorageBuffer => {}
|
||||
_ => {
|
||||
if self.has_dynamic_offset {
|
||||
return Err(BindGroupLayoutEntryError::UnexpectedHasDynamicOffset);
|
||||
}
|
||||
}
|
||||
}
|
||||
match self.ty {
|
||||
BindingType::SampledTexture => {}
|
||||
_ => {
|
||||
if self.multisampled {
|
||||
return Err(BindGroupLayoutEntryError::UnexpectedMultisampled);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
pub struct BindGroupLayoutDescriptor {
|
||||
@ -55,12 +87,20 @@ pub struct BindGroupLayoutDescriptor {
|
||||
pub entries_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum BindGroupLayoutError {
|
||||
ConflictBinding(u32),
|
||||
Entry(u32, BindGroupLayoutEntryError),
|
||||
}
|
||||
|
||||
pub(crate) type BindEntryMap = FastHashMap<u32, BindGroupLayoutEntry>;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct BindGroupLayout<B: hal::Backend> {
|
||||
pub(crate) raw: B::DescriptorSetLayout,
|
||||
pub(crate) device_id: Stored<DeviceId>,
|
||||
pub(crate) life_guard: LifeGuard,
|
||||
pub(crate) entries: FastHashMap<u32, BindGroupLayoutEntry>,
|
||||
pub(crate) entries: BindEntryMap,
|
||||
pub(crate) desc_counts: DescriptorCounts,
|
||||
pub(crate) dynamic_count: usize,
|
||||
}
|
||||
@ -72,6 +112,11 @@ pub struct PipelineLayoutDescriptor {
|
||||
pub bind_group_layouts_length: usize,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum PipelineLayoutError {
|
||||
TooManyGroups(usize),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PipelineLayout<B: hal::Backend> {
|
||||
pub(crate) raw: B::PipelineLayout,
|
||||
|
@ -371,14 +371,14 @@ pub(crate) fn map_texture_format(
|
||||
// Depth and stencil formats
|
||||
Tf::Depth32Float => H::D32Sfloat,
|
||||
Tf::Depth24Plus => {
|
||||
if private_features.supports_texture_d24_s8 {
|
||||
if private_features.texture_d24_s8 {
|
||||
H::D24UnormS8Uint
|
||||
} else {
|
||||
H::D32Sfloat
|
||||
}
|
||||
}
|
||||
Tf::Depth24PlusStencil8 => {
|
||||
if private_features.supports_texture_d24_s8 {
|
||||
if private_features.texture_d24_s8 {
|
||||
H::D24UnormS8Uint
|
||||
} else {
|
||||
H::D32SfloatS8Uint
|
||||
|
@ -7,7 +7,7 @@ use crate::{
|
||||
hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Input, Token},
|
||||
id, pipeline, resource, swap_chain,
|
||||
track::{BufferState, TextureState, TrackerSet},
|
||||
FastHashMap, LifeGuard, PrivateFeatures, Stored,
|
||||
FastHashMap, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS,
|
||||
};
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
@ -196,7 +196,7 @@ impl<B: GfxBackend> Device<B> {
|
||||
queue_group: hal::queue::QueueGroup<B>,
|
||||
mem_props: hal::adapter::MemoryProperties,
|
||||
hal_limits: hal::Limits,
|
||||
supports_texture_d24_s8: bool,
|
||||
private_features: PrivateFeatures,
|
||||
desc: &wgt::DeviceDescriptor,
|
||||
trace_path: Option<&std::path::Path>,
|
||||
) -> Self {
|
||||
@ -253,9 +253,7 @@ impl<B: GfxBackend> Device<B> {
|
||||
}
|
||||
}),
|
||||
hal_limits,
|
||||
private_features: PrivateFeatures {
|
||||
supports_texture_d24_s8,
|
||||
},
|
||||
private_features,
|
||||
limits: desc.limits.clone(),
|
||||
extensions: desc.extensions.clone(),
|
||||
pending_writes: queue::PendingWrites::new(),
|
||||
@ -1085,12 +1083,21 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
device_id: id::DeviceId,
|
||||
desc: &binding_model::BindGroupLayoutDescriptor,
|
||||
id_in: Input<G, id::BindGroupLayoutId>,
|
||||
) -> id::BindGroupLayoutId {
|
||||
) -> Result<id::BindGroupLayoutId, binding_model::BindGroupLayoutError> {
|
||||
let mut token = Token::root();
|
||||
let hub = B::hub(self);
|
||||
let entries = unsafe { slice::from_raw_parts(desc.entries, desc.entries_length) };
|
||||
let entry_map: FastHashMap<_, _> =
|
||||
entries.iter().cloned().map(|b| (b.binding, b)).collect();
|
||||
let mut entry_map = FastHashMap::default();
|
||||
for entry in entries {
|
||||
if let Err(e) = entry.validate() {
|
||||
return Err(binding_model::BindGroupLayoutError::Entry(entry.binding, e));
|
||||
}
|
||||
if entry_map.insert(entry.binding, entry.clone()).is_some() {
|
||||
return Err(binding_model::BindGroupLayoutError::ConflictBinding(
|
||||
entry.binding,
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: deduplicate the bind group layouts at some level.
|
||||
// We can't do it right here, because in the remote scenario
|
||||
@ -1102,7 +1109,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
.find(|(_, bgl)| bgl.entries == entry_map);
|
||||
|
||||
if let Some((id, _)) = bind_group_layout_id {
|
||||
return id;
|
||||
return Ok(id);
|
||||
}
|
||||
}
|
||||
|
||||
@ -1157,7 +1164,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}),
|
||||
None => (),
|
||||
};
|
||||
id
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
pub fn bind_group_layout_destroy<B: GfxBackend>(
|
||||
@ -1191,7 +1198,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
device_id: id::DeviceId,
|
||||
desc: &binding_model::PipelineLayoutDescriptor,
|
||||
id_in: Input<G, id::PipelineLayoutId>,
|
||||
) -> id::PipelineLayoutId {
|
||||
) -> Result<id::PipelineLayoutId, binding_model::PipelineLayoutError> {
|
||||
let hub = B::hub(self);
|
||||
let mut token = Token::root();
|
||||
|
||||
@ -1201,12 +1208,11 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
slice::from_raw_parts(desc.bind_group_layouts, desc.bind_group_layouts_length)
|
||||
};
|
||||
|
||||
assert!(
|
||||
desc.bind_group_layouts_length <= (device.limits.max_bind_groups as usize),
|
||||
"Cannot set more bind groups ({}) than the `max_bind_groups` limit requested on device creation ({})",
|
||||
desc.bind_group_layouts_length,
|
||||
device.limits.max_bind_groups
|
||||
);
|
||||
if desc.bind_group_layouts_length > (device.limits.max_bind_groups as usize) {
|
||||
return Err(binding_model::PipelineLayoutError::TooManyGroups(
|
||||
desc.bind_group_layouts_length,
|
||||
));
|
||||
}
|
||||
|
||||
// TODO: push constants
|
||||
let pipeline_layout = {
|
||||
@ -1252,7 +1258,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}),
|
||||
None => (),
|
||||
};
|
||||
id
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
pub fn pipeline_layout_destroy<B: GfxBackend>(&self, pipeline_layout_id: id::PipelineLayoutId) {
|
||||
@ -1570,7 +1576,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let spv = unsafe { slice::from_raw_parts(desc.code.bytes, desc.code.length) };
|
||||
let raw = unsafe { device.raw.create_shader_module(spv).unwrap() };
|
||||
|
||||
let module = {
|
||||
let module = if device.private_features.shader_validation {
|
||||
// Parse the given shader code and store its representation.
|
||||
let spv_iter = spv.into_iter().cloned();
|
||||
let mut parser = naga::front::spirv::Parser::new(spv_iter);
|
||||
@ -1581,6 +1587,8 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
log::warn!("Shader module will not be validated");
|
||||
})
|
||||
.ok()
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let shader = pipeline::ShaderModule {
|
||||
raw,
|
||||
@ -1851,7 +1859,14 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let device = &device_guard[device_id];
|
||||
let (raw_pipeline, layout_ref_count) = {
|
||||
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
|
||||
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
|
||||
let layout = &pipeline_layout_guard[desc.layout];
|
||||
let group_layouts = layout
|
||||
.bind_group_layout_ids
|
||||
.iter()
|
||||
.map(|id| &bgl_guard[id.value].entries)
|
||||
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
|
||||
|
||||
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
|
||||
|
||||
let rp_key = RenderPassKey {
|
||||
@ -1901,9 +1916,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let shader_module = &shader_module_guard[desc.vertex_stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) =
|
||||
validate_shader(module, entry_point_name, ExecutionModel::Vertex)
|
||||
{
|
||||
if let Err(e) = pipeline::validate_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
entry_point_name,
|
||||
ExecutionModel::Vertex,
|
||||
) {
|
||||
log::error!("Failed validating vertex shader module: {:?}", e);
|
||||
}
|
||||
}
|
||||
@ -1926,9 +1944,12 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let shader_module = &shader_module_guard[stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) =
|
||||
validate_shader(module, entry_point_name, ExecutionModel::Fragment)
|
||||
{
|
||||
if let Err(e) = pipeline::validate_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
entry_point_name,
|
||||
ExecutionModel::Fragment,
|
||||
) {
|
||||
log::error!("Failed validating fragment shader module: {:?}", e);
|
||||
}
|
||||
}
|
||||
@ -2098,7 +2119,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
device_id: id::DeviceId,
|
||||
desc: &pipeline::ComputePipelineDescriptor,
|
||||
id_in: Input<G, id::ComputePipelineId>,
|
||||
) -> id::ComputePipelineId {
|
||||
) -> Result<id::ComputePipelineId, pipeline::ComputePipelineError> {
|
||||
let hub = B::hub(self);
|
||||
let mut token = Token::root();
|
||||
|
||||
@ -2106,7 +2127,14 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let device = &device_guard[device_id];
|
||||
let (raw_pipeline, layout_ref_count) = {
|
||||
let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
|
||||
let (bgl_guard, mut token) = hub.bind_group_layouts.read(&mut token);
|
||||
let layout = &pipeline_layout_guard[desc.layout];
|
||||
let group_layouts = layout
|
||||
.bind_group_layout_ids
|
||||
.iter()
|
||||
.map(|id| &bgl_guard[id.value].entries)
|
||||
.collect::<ArrayVec<[&binding_model::BindEntryMap; MAX_BIND_GROUPS]>>();
|
||||
|
||||
let pipeline_stage = &desc.compute_stage;
|
||||
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
|
||||
|
||||
@ -2118,9 +2146,13 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let shader_module = &shader_module_guard[pipeline_stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) = validate_shader(module, entry_point_name, ExecutionModel::GLCompute)
|
||||
{
|
||||
log::error!("Failed validating compute shader module: {:?}", e);
|
||||
if let Err(e) = pipeline::validate_stage(
|
||||
module,
|
||||
&group_layouts,
|
||||
entry_point_name,
|
||||
ExecutionModel::GLCompute,
|
||||
) {
|
||||
return Err(pipeline::ComputePipelineError::Stage(e));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2178,7 +2210,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}),
|
||||
None => (),
|
||||
};
|
||||
id
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
pub fn compute_pipeline_destroy<B: GfxBackend>(
|
||||
@ -2580,27 +2612,3 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors produced when validating the shader modules of a pipeline.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
enum ShaderValidationError {
|
||||
/// Unable to find an entry point matching the specified execution model.
|
||||
MissingEntryPoint(ExecutionModel),
|
||||
}
|
||||
|
||||
fn validate_shader(
|
||||
module: &naga::Module,
|
||||
entry_point_name: &str,
|
||||
execution_model: ExecutionModel,
|
||||
) -> Result<(), ShaderValidationError> {
|
||||
// Since a shader module can have multiple entry points with the same name,
|
||||
// we need to look for one with the right execution model.
|
||||
let entry_point = module.entry_points.iter().find(|entry_point| {
|
||||
entry_point.name == entry_point_name && entry_point.exec_model == execution_model
|
||||
});
|
||||
|
||||
match entry_point {
|
||||
Some(_) => Ok(()),
|
||||
None => Err(ShaderValidationError::MissingEntryPoint(execution_model)),
|
||||
}
|
||||
}
|
||||
|
@ -177,6 +177,7 @@ impl<B: hal::Backend> Access<PipelineLayout<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<PipelineLayout<B>> for CommandBuffer<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<BindGroupLayout<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroupLayout<B>> for PipelineLayout<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroup<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<BindGroup<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<BindGroup<B>> for BindGroupLayout<B> {}
|
||||
@ -191,7 +192,7 @@ impl<B: hal::Backend> Access<RenderPipeline<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<RenderPipeline<B>> for BindGroup<B> {}
|
||||
impl<B: hal::Backend> Access<RenderPipeline<B>> for ComputePipeline<B> {}
|
||||
impl<B: hal::Backend> Access<ShaderModule<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<ShaderModule<B>> for PipelineLayout<B> {}
|
||||
impl<B: hal::Backend> Access<ShaderModule<B>> for BindGroupLayout<B> {}
|
||||
impl<B: hal::Backend> Access<Buffer<B>> for Root {}
|
||||
impl<B: hal::Backend> Access<Buffer<B>> for Device<B> {}
|
||||
impl<B: hal::Backend> Access<Buffer<B>> for BindGroupLayout<B> {}
|
||||
|
@ -7,7 +7,7 @@ use crate::{
|
||||
device::Device,
|
||||
hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Input, Token},
|
||||
id::{AdapterId, DeviceId, SurfaceId},
|
||||
power, LifeGuard, Stored, MAX_BIND_GROUPS,
|
||||
power, LifeGuard, PrivateFeatures, Stored, MAX_BIND_GROUPS,
|
||||
};
|
||||
|
||||
use wgt::{Backend, BackendBit, DeviceDescriptor, PowerPreference, BIND_BUFFER_ALIGNMENT};
|
||||
@ -677,10 +677,29 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}
|
||||
|
||||
let mem_props = phd.memory_properties();
|
||||
let supports_texture_d24_s8 = phd
|
||||
.format_properties(Some(hal::format::Format::D24UnormS8Uint))
|
||||
.optimal_tiling
|
||||
.contains(hal::format::ImageFeature::DEPTH_STENCIL_ATTACHMENT);
|
||||
let private_features = PrivateFeatures {
|
||||
shader_validation: match std::env::var("WGPU_SHADER_VALIDATION") {
|
||||
Ok(var) => match var.as_str() {
|
||||
"0" => {
|
||||
log::info!("Shader validation is disabled");
|
||||
false
|
||||
}
|
||||
"1" => {
|
||||
log::info!("Shader validation is enabled");
|
||||
true
|
||||
}
|
||||
_ => {
|
||||
log::warn!("Unknown shader validation setting: {:?}", var);
|
||||
true
|
||||
}
|
||||
},
|
||||
_ => true,
|
||||
},
|
||||
texture_d24_s8: phd
|
||||
.format_properties(Some(hal::format::Format::D24UnormS8Uint))
|
||||
.optimal_tiling
|
||||
.contains(hal::format::ImageFeature::DEPTH_STENCIL_ATTACHMENT),
|
||||
};
|
||||
|
||||
Device::new(
|
||||
gpu.device,
|
||||
@ -691,7 +710,7 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
gpu.queue_groups.swap_remove(0),
|
||||
mem_props,
|
||||
limits,
|
||||
supports_texture_d24_s8,
|
||||
private_features,
|
||||
desc,
|
||||
trace_path,
|
||||
)
|
||||
|
@ -171,7 +171,8 @@ pub struct U32Array {
|
||||
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
struct PrivateFeatures {
|
||||
pub supports_texture_d24_s8: bool,
|
||||
shader_validation: bool,
|
||||
texture_d24_s8: bool,
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
|
@ -3,10 +3,12 @@
|
||||
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
|
||||
|
||||
use crate::{
|
||||
binding_model::{BindEntryMap, BindGroupLayoutEntry, BindingType},
|
||||
device::RenderPassContext,
|
||||
id::{DeviceId, PipelineLayoutId, ShaderModuleId},
|
||||
LifeGuard, RawString, RefCount, Stored, U32Array,
|
||||
};
|
||||
use spirv_headers as spirv;
|
||||
use std::borrow::Borrow;
|
||||
use wgt::{
|
||||
BufferAddress, ColorStateDescriptor, DepthStencilStateDescriptor, IndexFormat, InputStepMode,
|
||||
@ -50,6 +52,184 @@ pub struct ProgrammableStageDescriptor {
|
||||
pub entry_point: RawString,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum BindingError {
|
||||
/// The binding is missing from the pipeline layout.
|
||||
Missing,
|
||||
/// The visibility flags don't include the shader stage.
|
||||
Invisible,
|
||||
/// The load/store access flags don't match the shader.
|
||||
WrongUsage(naga::GlobalUse),
|
||||
/// The type on the shader side does not match the pipeline binding.
|
||||
WrongType,
|
||||
/// The view dimension doesn't match the shader.
|
||||
WrongTextureViewDimension { dim: spirv::Dim, is_array: bool },
|
||||
/// The component type of a sampled texture doesn't match the shader.
|
||||
WrongTextureComponentType(Option<naga::ScalarKind>),
|
||||
/// Texture sampling capability doesn't match with the shader.
|
||||
WrongTextureSampled,
|
||||
/// The multisampled flag doesn't match.
|
||||
WrongTextureMultisampled,
|
||||
}
|
||||
|
||||
/// Errors produced when validating a programmable stage of a pipeline.
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ProgrammableStageError {
|
||||
/// Unable to find an entry point matching the specified execution model.
|
||||
MissingEntryPoint(spirv::ExecutionModel),
|
||||
/// Error matching a global binding to the pipeline layout.
|
||||
Binding {
|
||||
set: u32,
|
||||
binding: u32,
|
||||
error: BindingError,
|
||||
},
|
||||
}
|
||||
|
||||
fn validate_binding(
|
||||
module: &naga::Module,
|
||||
var: &naga::GlobalVariable,
|
||||
entry: &BindGroupLayoutEntry,
|
||||
usage: naga::GlobalUse,
|
||||
) -> Result<(), BindingError> {
|
||||
let allowed_usage = match module.types[var.ty].inner {
|
||||
naga::TypeInner::Struct { .. } => match entry.ty {
|
||||
BindingType::UniformBuffer => naga::GlobalUse::LOAD,
|
||||
BindingType::StorageBuffer => naga::GlobalUse::all(),
|
||||
BindingType::ReadonlyStorageBuffer => naga::GlobalUse::LOAD,
|
||||
_ => return Err(BindingError::WrongType),
|
||||
},
|
||||
naga::TypeInner::Sampler => match entry.ty {
|
||||
BindingType::Sampler | BindingType::ComparisonSampler => naga::GlobalUse::empty(),
|
||||
_ => return Err(BindingError::WrongType),
|
||||
},
|
||||
naga::TypeInner::Image { base, dim, flags } => {
|
||||
if entry.multisampled != flags.contains(naga::ImageFlags::MULTISAMPLED) {
|
||||
return Err(BindingError::WrongTextureMultisampled);
|
||||
}
|
||||
if flags.contains(naga::ImageFlags::ARRAYED) {
|
||||
match (dim, entry.view_dimension) {
|
||||
(spirv::Dim::Dim2D, wgt::TextureViewDimension::D2Array) => (),
|
||||
(spirv::Dim::DimCube, wgt::TextureViewDimension::CubeArray) => (),
|
||||
_ => {
|
||||
return Err(BindingError::WrongTextureViewDimension {
|
||||
dim,
|
||||
is_array: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
match (dim, entry.view_dimension) {
|
||||
(spirv::Dim::Dim1D, wgt::TextureViewDimension::D1) => (),
|
||||
(spirv::Dim::Dim2D, wgt::TextureViewDimension::D2) => (),
|
||||
(spirv::Dim::Dim3D, wgt::TextureViewDimension::D3) => (),
|
||||
(spirv::Dim::DimCube, wgt::TextureViewDimension::Cube) => (),
|
||||
_ => {
|
||||
return Err(BindingError::WrongTextureViewDimension {
|
||||
dim,
|
||||
is_array: false,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
let (allowed_usage, is_sampled) = match entry.ty {
|
||||
BindingType::SampledTexture => {
|
||||
let expected_scalar_kind = match entry.texture_component_type {
|
||||
wgt::TextureComponentType::Float => naga::ScalarKind::Float,
|
||||
wgt::TextureComponentType::Sint => naga::ScalarKind::Sint,
|
||||
wgt::TextureComponentType::Uint => naga::ScalarKind::Uint,
|
||||
};
|
||||
match module.types[base].inner {
|
||||
naga::TypeInner::Scalar { kind, .. }
|
||||
| naga::TypeInner::Vector { kind, .. }
|
||||
if kind == expected_scalar_kind =>
|
||||
{
|
||||
()
|
||||
}
|
||||
naga::TypeInner::Scalar { kind, .. }
|
||||
| naga::TypeInner::Vector { kind, .. } => {
|
||||
return Err(BindingError::WrongTextureComponentType(Some(kind)))
|
||||
}
|
||||
_ => return Err(BindingError::WrongTextureComponentType(None)),
|
||||
};
|
||||
(naga::GlobalUse::LOAD, true)
|
||||
}
|
||||
BindingType::ReadonlyStorageTexture => {
|
||||
//TODO: check entry.storage_texture_format
|
||||
(naga::GlobalUse::LOAD, false)
|
||||
}
|
||||
BindingType::WriteonlyStorageTexture => (naga::GlobalUse::STORE, false),
|
||||
_ => return Err(BindingError::WrongType),
|
||||
};
|
||||
if is_sampled != flags.contains(naga::ImageFlags::SAMPLED) {
|
||||
return Err(BindingError::WrongTextureSampled);
|
||||
}
|
||||
allowed_usage
|
||||
}
|
||||
_ => return Err(BindingError::WrongType),
|
||||
};
|
||||
if allowed_usage.contains(usage) {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(BindingError::WrongUsage(usage))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn validate_stage(
|
||||
module: &naga::Module,
|
||||
group_layouts: &[&BindEntryMap],
|
||||
entry_point_name: &str,
|
||||
execution_model: spirv::ExecutionModel,
|
||||
) -> Result<(), ProgrammableStageError> {
|
||||
// Since a shader module can have multiple entry points with the same name,
|
||||
// we need to look for one with the right execution model.
|
||||
let entry_point = module
|
||||
.entry_points
|
||||
.iter()
|
||||
.find(|entry_point| {
|
||||
entry_point.name == entry_point_name && entry_point.exec_model == execution_model
|
||||
})
|
||||
.ok_or(ProgrammableStageError::MissingEntryPoint(execution_model))?;
|
||||
let stage_bit = match execution_model {
|
||||
spirv::ExecutionModel::Vertex => wgt::ShaderStage::VERTEX,
|
||||
spirv::ExecutionModel::Fragment => wgt::ShaderStage::FRAGMENT,
|
||||
spirv::ExecutionModel::GLCompute => wgt::ShaderStage::COMPUTE,
|
||||
// the entry point wouldn't match otherwise
|
||||
_ => unreachable!(),
|
||||
};
|
||||
|
||||
let function = &module.functions[entry_point.function];
|
||||
for ((_, var), &usage) in module.global_variables.iter().zip(&function.global_usage) {
|
||||
if usage.is_empty() {
|
||||
continue;
|
||||
}
|
||||
match var.binding {
|
||||
Some(naga::Binding::Descriptor { set, binding }) => {
|
||||
let result = group_layouts
|
||||
.get(set as usize)
|
||||
.and_then(|map| map.get(&binding))
|
||||
.ok_or(BindingError::Missing)
|
||||
.and_then(|entry| {
|
||||
if entry.visibility.contains(stage_bit) {
|
||||
Ok(entry)
|
||||
} else {
|
||||
Err(BindingError::Invisible)
|
||||
}
|
||||
})
|
||||
.and_then(|entry| validate_binding(module, var, entry, usage));
|
||||
if let Err(error) = result {
|
||||
return Err(ProgrammableStageError::Binding {
|
||||
set,
|
||||
binding,
|
||||
error,
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {} //TODO
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
#[derive(Debug)]
|
||||
pub struct ComputePipelineDescriptor {
|
||||
@ -57,6 +237,11 @@ pub struct ComputePipelineDescriptor {
|
||||
pub compute_stage: ProgrammableStageDescriptor,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum ComputePipelineError {
|
||||
Stage(ProgrammableStageError),
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ComputePipeline<B: hal::Backend> {
|
||||
pub(crate) raw: B::ComputePipeline,
|
||||
|
Loading…
Reference in New Issue
Block a user