move out ID to resource mapping code from pipeline creation methods

This commit is contained in:
teoxoy 2024-07-02 11:26:00 +02:00 committed by Teodor Tanasoaia
parent d8b0b5975d
commit 1be51946e3
3 changed files with 280 additions and 80 deletions

View File

@ -12,7 +12,11 @@ use crate::{
init_tracker::TextureInitTracker,
instance::{self, Adapter, Surface},
lock::{rank, RwLock},
pipeline, present,
pipeline::{
self, ResolvedComputePipelineDescriptor, ResolvedFragmentState,
ResolvedProgrammableStageDescriptor, ResolvedRenderPipelineDescriptor, ResolvedVertexState,
},
present,
resource::{
self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError,
Trackable,
@ -1591,6 +1595,8 @@ impl Global {
let fid = hub.render_pipelines.prepare(id_in);
let implicit_context = implicit_pipeline_ids.map(|ipi| ipi.prepare(hub));
let is_auto_layout = desc.layout.is_none();
let error = 'error: {
let device = match hub.devices.get(device_id) {
Ok(device) => device,
@ -1606,12 +1612,107 @@ impl Global {
});
}
let pipeline = match device.create_render_pipeline(&device.adapter, desc, hub) {
let layout = desc
.layout
.map(|layout| {
hub.pipeline_layouts
.get(layout)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)
})
.transpose();
let layout = match layout {
Ok(layout) => layout,
Err(e) => break 'error e,
};
let cache = desc
.cache
.map(|cache| {
hub.pipeline_caches
.get(cache)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidCache)
})
.transpose();
let cache = match cache {
Ok(cache) => cache,
Err(e) => break 'error e,
};
let vertex = {
let module = hub
.shader_modules
.get(desc.vertex.stage.module)
.map_err(|_| pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::VERTEX,
error: crate::validation::StageError::InvalidModule,
});
let module = match module {
Ok(module) => module,
Err(e) => break 'error e,
};
let stage = ResolvedProgrammableStageDescriptor {
module,
entry_point: desc.vertex.stage.entry_point.clone(),
constants: desc.vertex.stage.constants.clone(),
zero_initialize_workgroup_memory: desc
.vertex
.stage
.zero_initialize_workgroup_memory,
vertex_pulling_transform: desc.vertex.stage.vertex_pulling_transform,
};
ResolvedVertexState {
stage,
buffers: desc.vertex.buffers.clone(),
}
};
let fragment = if let Some(ref state) = desc.fragment {
let module = hub.shader_modules.get(state.stage.module).map_err(|_| {
pipeline::CreateRenderPipelineError::Stage {
stage: wgt::ShaderStages::FRAGMENT,
error: crate::validation::StageError::InvalidModule,
}
});
let module = match module {
Ok(module) => module,
Err(e) => break 'error e,
};
let stage = ResolvedProgrammableStageDescriptor {
module,
entry_point: state.stage.entry_point.clone(),
constants: state.stage.constants.clone(),
zero_initialize_workgroup_memory: desc
.vertex
.stage
.zero_initialize_workgroup_memory,
vertex_pulling_transform: state.stage.vertex_pulling_transform,
};
Some(ResolvedFragmentState {
stage,
targets: state.targets.clone(),
})
} else {
None
};
let desc = ResolvedRenderPipelineDescriptor {
label: desc.label.clone(),
layout,
vertex,
primitive: desc.primitive,
depth_stencil: desc.depth_stencil.clone(),
multisample: desc.multisample,
fragment,
multiview: desc.multiview,
cache,
};
let pipeline = match device.create_render_pipeline(&device.adapter, desc) {
Ok(pair) => pair,
Err(e) => break 'error e,
};
if desc.layout.is_none() {
if is_auto_layout {
// TODO: categorize the errors below as API misuse
let ids = if let Some(ids) = implicit_context.as_ref() {
let group_count = pipeline.layout.bind_group_layouts.len();
@ -1655,7 +1756,7 @@ impl Global {
let id = fid.assign_error();
if desc.layout.is_none() {
if is_auto_layout {
// We also need to assign errors to the implicit pipeline layout and the
// implicit bind group layouts.
if let Some(ids) = implicit_context {
@ -1748,6 +1849,8 @@ impl Global {
let fid = hub.compute_pipelines.prepare(id_in);
let implicit_context = implicit_pipeline_ids.map(|ipi| ipi.prepare(hub));
let is_auto_layout = desc.layout.is_none();
let error = 'error: {
let device = match hub.devices.get(device_id) {
Ok(device) => device,
@ -1763,12 +1866,61 @@ impl Global {
});
}
let pipeline = match device.create_compute_pipeline(desc, hub) {
let layout = desc
.layout
.map(|layout| {
hub.pipeline_layouts
.get(layout)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)
})
.transpose();
let layout = match layout {
Ok(layout) => layout,
Err(e) => break 'error e,
};
let cache = desc
.cache
.map(|cache| {
hub.pipeline_caches
.get(cache)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidCache)
})
.transpose();
let cache = match cache {
Ok(cache) => cache,
Err(e) => break 'error e,
};
let module = hub
.shader_modules
.get(desc.stage.module)
.map_err(|_| crate::validation::StageError::InvalidModule);
let module = match module {
Ok(module) => module,
Err(e) => break 'error e.into(),
};
let stage = ResolvedProgrammableStageDescriptor {
module,
entry_point: desc.stage.entry_point.clone(),
constants: desc.stage.constants.clone(),
zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory,
vertex_pulling_transform: desc.stage.vertex_pulling_transform,
};
let desc = ResolvedComputePipelineDescriptor {
label: desc.label.clone(),
layout,
stage,
cache,
};
let pipeline = match device.create_compute_pipeline(desc) {
Ok(pair) => pair,
Err(e) => break 'error e,
};
if desc.layout.is_none() {
if is_auto_layout {
// TODO: categorize the errors below as API misuse
let ids = if let Some(ids) = implicit_context.as_ref() {
let group_count = pipeline.layout.bind_group_layouts.len();
@ -1811,7 +1963,7 @@ impl Global {
let id = fid.assign_error();
if desc.layout.is_none() {
if is_auto_layout {
// We also need to assign errors to the implicit pipeline layout and the
// implicit bind group layouts.
if let Some(ids) = implicit_context {

View File

@ -2586,17 +2586,13 @@ impl<A: HalApi> Device<A> {
pub(crate) fn create_compute_pipeline(
self: &Arc<Self>,
desc: &pipeline::ComputePipelineDescriptor,
hub: &Hub<A>,
desc: pipeline::ResolvedComputePipelineDescriptor<A>,
) -> Result<Arc<pipeline::ComputePipeline<A>>, pipeline::CreateComputePipelineError> {
self.check_is_valid()?;
self.require_downlevel_flags(wgt::DownlevelFlags::COMPUTE_SHADERS)?;
let shader_module = hub
.shader_modules
.get(desc.stage.module)
.map_err(|_| validation::StageError::InvalidModule)?;
let shader_module = desc.stage.module;
shader_module.same_device(self)?;
@ -2604,14 +2600,8 @@ impl<A: HalApi> Device<A> {
// Get the pipeline layout from the desc if it is provided.
let pipeline_layout = match desc.layout {
Some(pipeline_layout_id) => {
let pipeline_layout = hub
.pipeline_layouts
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateComputePipelineError::InvalidLayout)?;
Some(pipeline_layout) => {
pipeline_layout.same_device(self)?;
Some(pipeline_layout)
}
None => None,
@ -2661,16 +2651,12 @@ impl<A: HalApi> Device<A> {
let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout);
let cache = 'cache: {
let Some(cache) = desc.cache else {
break 'cache None;
};
let Ok(cache) = hub.pipeline_caches.get(cache) else {
break 'cache None;
};
let cache = match desc.cache {
Some(cache) => {
cache.same_device(self)?;
Some(cache)
}
None => None,
};
let pipeline_desc = hal::ComputePipelineDescriptor {
@ -2732,8 +2718,7 @@ impl<A: HalApi> Device<A> {
pub(crate) fn create_render_pipeline(
self: &Arc<Self>,
adapter: &Adapter<A>,
desc: &pipeline::RenderPipelineDescriptor,
hub: &Hub<A>,
desc: pipeline::ResolvedRenderPipelineDescriptor<A>,
) -> Result<Arc<pipeline::RenderPipeline<A>>, pipeline::CreateRenderPipelineError> {
use wgt::TextureFormatFeatureFlags as Tfff;
@ -2758,6 +2743,7 @@ impl<A: HalApi> Device<A> {
.map_or(&[][..], |fragment| &fragment.targets);
let depth_stencil_state = desc.depth_stencil.as_ref();
{
let cts: ArrayVec<_, { hal::MAX_COLOR_ATTACHMENTS }> =
color_targets.iter().filter_map(|x| x.as_ref()).collect();
if !cts.is_empty() && {
@ -2769,6 +2755,7 @@ impl<A: HalApi> Device<A> {
log::debug!("Color targets: {:?}", color_targets);
self.require_downlevel_flags(wgt::DownlevelFlags::INDEPENDENT_BLEND)?;
}
}
let mut io = validation::StageIo::default();
let mut validated_stages = wgt::ShaderStages::empty();
@ -3043,14 +3030,8 @@ impl<A: HalApi> Device<A> {
// Get the pipeline layout from the desc if it is provided.
let pipeline_layout = match desc.layout {
Some(pipeline_layout_id) => {
let pipeline_layout = hub
.pipeline_layouts
.get(pipeline_layout_id)
.map_err(|_| pipeline::CreateRenderPipelineError::InvalidLayout)?;
Some(pipeline_layout) => {
pipeline_layout.same_device(self)?;
Some(pipeline_layout)
}
None => None,
@ -3071,19 +3052,12 @@ impl<A: HalApi> Device<A> {
sc
};
let vertex_shader_module;
let vertex_entry_point_name;
let vertex_stage = {
let stage_desc = &desc.vertex.stage;
let stage = wgt::ShaderStages::VERTEX;
vertex_shader_module = hub.shader_modules.get(stage_desc.module).map_err(|_| {
pipeline::CreateRenderPipelineError::Stage {
stage,
error: validation::StageError::InvalidModule,
}
})?;
let vertex_shader_module = &stage_desc.module;
vertex_shader_module.same_device(self)?;
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
@ -3118,20 +3092,12 @@ impl<A: HalApi> Device<A> {
}
};
let mut fragment_shader_module = None;
let fragment_entry_point_name;
let fragment_stage = match desc.fragment {
Some(ref fragment_state) => {
let stage = wgt::ShaderStages::FRAGMENT;
let shader_module = fragment_shader_module.insert(
hub.shader_modules
.get(fragment_state.stage.module)
.map_err(|_| pipeline::CreateRenderPipelineError::Stage {
stage,
error: validation::StageError::InvalidModule,
})?,
);
let shader_module = &fragment_state.stage.module;
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
@ -3227,7 +3193,7 @@ impl<A: HalApi> Device<A> {
Some(_) => wgt::ShaderStages::FRAGMENT,
None => wgt::ShaderStages::VERTEX,
};
if desc.layout.is_none() && !validated_stages.contains(last_stage) {
if is_auto_layout && !validated_stages.contains(last_stage) {
return Err(pipeline::ImplicitLayoutError::ReflectionError(last_stage).into());
}
@ -3265,16 +3231,12 @@ impl<A: HalApi> Device<A> {
let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout);
let pipeline_cache = 'cache: {
let Some(cache) = desc.cache else {
break 'cache None;
};
let Ok(cache) = hub.pipeline_caches.get(cache) else {
break 'cache None;
};
let cache = match desc.cache {
Some(cache) => {
cache.same_device(self)?;
Some(cache)
}
None => None,
};
let pipeline_desc = hal::RenderPipelineDescriptor {
@ -3288,7 +3250,7 @@ impl<A: HalApi> Device<A> {
fragment_stage,
color_targets,
multiview: desc.multiview,
cache: pipeline_cache.as_ref().and_then(|it| it.raw.as_ref()),
cache: cache.as_ref().and_then(|it| it.raw.as_ref()),
};
let raw = unsafe {
self.raw
@ -3346,8 +3308,8 @@ impl<A: HalApi> Device<A> {
let shader_modules = {
let mut shader_modules = ArrayVec::new();
shader_modules.push(vertex_shader_module);
shader_modules.extend(fragment_shader_module);
shader_modules.push(desc.vertex.stage.module);
shader_modules.extend(desc.fragment.map(|f| f.stage.module));
shader_modules
};

View File

@ -151,6 +151,35 @@ pub struct ProgrammableStageDescriptor<'a> {
pub vertex_pulling_transform: bool,
}
/// Describes a programmable pipeline stage.
#[derive(Clone, Debug)]
pub struct ResolvedProgrammableStageDescriptor<'a, A: HalApi> {
/// The compiled shader module for this stage.
pub module: Arc<ShaderModule<A>>,
/// The name of the entry point in the compiled shader. The name is selected using the
/// following logic:
///
/// * If `Some(name)` is specified, there must be a function with this name in the shader.
/// * If a single entry point associated with this stage must be in the shader, then proceed as
/// if `Some(…)` was specified with that entry point's name.
pub entry_point: Option<Cow<'a, str>>,
/// Specifies the values of pipeline-overridable constants in the shader module.
///
/// If an `@id` attribute was specified on the declaration,
/// the key must be the pipeline constant ID as a decimal ASCII number; if not,
/// the key must be the constant's identifier name.
///
/// The value may represent any of WGSL's concrete scalar types.
pub constants: Cow<'a, naga::back::PipelineConstants>,
/// Whether workgroup scoped memory will be initialized with zero values for this stage.
///
/// This is required by the WebGPU spec, but may have overhead which can be avoided
/// for cross-platform applications
pub zero_initialize_workgroup_memory: bool,
/// Should the pipeline attempt to transform vertex shaders to use vertex pulling.
pub vertex_pulling_transform: bool,
}
/// Number of implicit bind groups derived at pipeline creation.
pub type ImplicitBindGroupCount = u8;
@ -180,6 +209,18 @@ pub struct ComputePipelineDescriptor<'a> {
pub cache: Option<PipelineCacheId>,
}
/// Describes a compute pipeline.
#[derive(Clone, Debug)]
pub struct ResolvedComputePipelineDescriptor<'a, A: HalApi> {
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: Option<Arc<PipelineLayout<A>>>,
/// The compiled compute stage and its entry point.
pub stage: ResolvedProgrammableStageDescriptor<'a, A>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<Arc<PipelineCache<A>>>,
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CreateComputePipelineError {
@ -187,6 +228,8 @@ pub enum CreateComputePipelineError {
Device(#[from] DeviceError),
#[error("Pipeline layout is invalid")]
InvalidLayout,
#[error("Cache is invalid")]
InvalidCache,
#[error("Unable to derive an implicit layout")]
Implicit(#[from] ImplicitLayoutError),
#[error("Error matching shader requirements against the pipeline")]
@ -306,6 +349,15 @@ pub struct VertexState<'a> {
pub buffers: Cow<'a, [VertexBufferLayout<'a>]>,
}
/// Describes the vertex process in a render pipeline.
#[derive(Clone, Debug)]
pub struct ResolvedVertexState<'a, A: HalApi> {
/// The compiled vertex stage and its entry point.
pub stage: ResolvedProgrammableStageDescriptor<'a, A>,
/// The format of any vertex buffers used with this pipeline.
pub buffers: Cow<'a, [VertexBufferLayout<'a>]>,
}
/// Describes fragment processing in a render pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -316,6 +368,15 @@ pub struct FragmentState<'a> {
pub targets: Cow<'a, [Option<wgt::ColorTargetState>]>,
}
/// Describes fragment processing in a render pipeline.
#[derive(Clone, Debug)]
pub struct ResolvedFragmentState<'a, A: HalApi> {
/// The compiled fragment stage and its entry point.
pub stage: ResolvedProgrammableStageDescriptor<'a, A>,
/// The effect of draw calls on the color aspect of the output target.
pub targets: Cow<'a, [Option<wgt::ColorTargetState>]>,
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -343,6 +404,29 @@ pub struct RenderPipelineDescriptor<'a> {
pub cache: Option<PipelineCacheId>,
}
/// Describes a render (graphics) pipeline.
#[derive(Clone, Debug)]
pub struct ResolvedRenderPipelineDescriptor<'a, A: HalApi> {
pub label: Label<'a>,
/// The layout of bind groups for this pipeline.
pub layout: Option<Arc<PipelineLayout<A>>>,
/// The vertex processing state for this pipeline.
pub vertex: ResolvedVertexState<'a, A>,
/// The properties of the pipeline at the primitive assembly and rasterization level.
pub primitive: wgt::PrimitiveState,
/// The effect of draw calls on the depth and stencil aspects of the output target, if any.
pub depth_stencil: Option<wgt::DepthStencilState>,
/// The multi-sampling properties of the pipeline.
pub multisample: wgt::MultisampleState,
/// The fragment processing state for this pipeline.
pub fragment: Option<ResolvedFragmentState<'a, A>>,
/// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have.
pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<Arc<PipelineCache<A>>>,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PipelineCacheDescriptor<'a> {
@ -395,6 +479,8 @@ pub enum CreateRenderPipelineError {
Device(#[from] DeviceError),
#[error("Pipeline layout is invalid")]
InvalidLayout,
#[error("Pipeline cache is invalid")]
InvalidCache,
#[error("Unable to derive an implicit layout")]
Implicit(#[from] ImplicitLayoutError),
#[error("Color state [{0}] is invalid")]