feat!: make ProgrammableStage::entry_point optional in wgpu-core

This commit is contained in:
Erich Gubler 2024-01-29 21:41:55 -05:00
parent 2c66504a59
commit 023b0e063f
11 changed files with 104 additions and 26 deletions

View File

@ -102,6 +102,7 @@ Bottom level categories:
```
- `wgpu::Id` now implements `PartialOrd`/`Ord` allowing it to be put in `BTreeMap`s. By @cwfitzgerald and @9291Sam in [#5176](https://github.com/gfx-rs/wgpu/pull/5176)
- `wgpu::CommandEncoder::write_timestamp` requires now the new `wgpu::Features::TIMESTAMP_QUERY_INSIDE_ENCODERS` feature which is available on all native backends but not on WebGPU (due to a spec change `write_timestamp` is no longer supported on WebGPU). By @wumpf in [#5188](https://github.com/gfx-rs/wgpu/pull/5188)
- Breaking change: [`wgpu_core::pipeline::ProgrammableStageDescriptor`](https://docs.rs/wgpu-core/latest/wgpu_core/pipeline/struct.ProgrammableStageDescriptor.html#structfield.entry_point) is now optional. By @ErichDonGubler in [#5305](https://github.com/gfx-rs/wgpu/pull/5305).
#### GLES

View File

@ -110,7 +110,7 @@ pub fn op_webgpu_create_compute_pipeline(
layout: pipeline_layout,
stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: compute_shader_module_resource.1,
entry_point: Cow::from(compute.entry_point),
entry_point: Some(Cow::from(compute.entry_point)),
// TODO(lucacasonato): support args.compute.constants
},
};
@ -355,7 +355,7 @@ pub fn op_webgpu_create_render_pipeline(
Some(wgpu_core::pipeline::FragmentState {
stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: fragment_shader_module_resource.1,
entry_point: Cow::from(fragment.entry_point),
entry_point: Some(Cow::from(fragment.entry_point)),
},
targets: Cow::from(fragment.targets),
})
@ -377,7 +377,7 @@ pub fn op_webgpu_create_render_pipeline(
vertex: wgpu_core::pipeline::VertexState {
stage: wgpu_core::pipeline::ProgrammableStageDescriptor {
module: vertex_shader_module_resource.1,
entry_point: Cow::Owned(args.vertex.entry_point),
entry_point: Some(Cow::Owned(args.vertex.entry_point)),
},
buffers: Cow::Owned(vertex_buffers),
},

View File

@ -56,7 +56,7 @@
layout: Some(Id(0, 1, Empty)),
stage: (
module: Id(0, 1, Empty),
entry_point: "main",
entry_point: Some("main"),
),
),
),

View File

@ -29,7 +29,7 @@
layout: Some(Id(0, 1, Empty)),
stage: (
module: Id(0, 1, Empty),
entry_point: "main",
entry_point: Some("main"),
),
),
),

View File

@ -57,14 +57,14 @@
vertex: (
stage: (
module: Id(0, 1, Empty),
entry_point: "vs_main",
entry_point: Some("vs_main"),
),
buffers: [],
),
fragment: Some((
stage: (
module: Id(0, 1, Empty),
entry_point: "fs_main",
entry_point: Some("fs_main"),
),
targets: [
Some((

View File

@ -133,7 +133,7 @@
layout: Some(Id(0, 1, Empty)),
stage: (
module: Id(0, 1, Empty),
entry_point: "main",
entry_point: Some("main"),
),
),
),

View File

@ -134,7 +134,7 @@
layout: Some(Id(0, 1, Empty)),
stage: (
module: Id(0, 1, Empty),
entry_point: "main",
entry_point: Some("main"),
),
),
),

View File

@ -2705,14 +2705,21 @@ impl<A: HalApi> Device<A> {
let mut shader_binding_sizes = FastHashMap::default();
let io = validation::StageIo::default();
let final_entry_point_name;
{
let stage = wgt::ShaderStages::COMPUTE;
final_entry_point_name = shader_module.finalize_entry_point_name(
stage,
desc.stage.entry_point.as_ref().map(|ep| ep.as_ref()),
)?;
if let Some(ref interface) = shader_module.interface {
let _ = interface.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&desc.stage.entry_point,
&final_entry_point_name,
stage,
io,
None,
@ -2740,7 +2747,7 @@ impl<A: HalApi> Device<A> {
label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(),
stage: hal::ProgrammableStage {
entry_point: desc.stage.entry_point.as_ref(),
entry_point: final_entry_point_name.as_ref(),
module: shader_module.raw(),
},
};
@ -3115,6 +3122,7 @@ impl<A: HalApi> Device<A> {
};
let vertex_shader_module;
let vertex_entry_point_name;
let vertex_stage = {
let stage_desc = &desc.vertex.stage;
let stage = wgt::ShaderStages::VERTEX;
@ -3131,12 +3139,19 @@ impl<A: HalApi> Device<A> {
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
vertex_entry_point_name = vertex_shader_module
.finalize_entry_point_name(
stage,
stage_desc.entry_point.as_ref().map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if let Some(ref interface) = vertex_shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&stage_desc.entry_point,
&vertex_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@ -3147,11 +3162,12 @@ impl<A: HalApi> Device<A> {
hal::ProgrammableStage {
module: vertex_shader_module.raw(),
entry_point: stage_desc.entry_point.as_ref(),
entry_point: &vertex_entry_point_name,
}
};
let mut fragment_shader_module = None;
let fragment_entry_point_name;
let fragment_stage = match desc.fragment {
Some(ref fragment_state) => {
let stage = wgt::ShaderStages::FRAGMENT;
@ -3167,13 +3183,24 @@ impl<A: HalApi> Device<A> {
let stage_err = |error| pipeline::CreateRenderPipelineError::Stage { stage, error };
fragment_entry_point_name = shader_module
.finalize_entry_point_name(
stage,
fragment_state
.stage
.entry_point
.as_ref()
.map(|ep| ep.as_ref()),
)
.map_err(stage_err)?;
if validated_stages == wgt::ShaderStages::VERTEX {
if let Some(ref interface) = shader_module.interface {
io = interface
.check_stage(
&mut binding_layout_source,
&mut shader_binding_sizes,
&fragment_state.stage.entry_point,
&fragment_entry_point_name,
stage,
io,
desc.depth_stencil.as_ref().map(|d| d.depth_compare),
@ -3185,7 +3212,7 @@ impl<A: HalApi> Device<A> {
if let Some(ref interface) = shader_module.interface {
shader_expects_dual_source_blending = interface
.fragment_uses_dual_source_blending(&fragment_state.stage.entry_point)
.fragment_uses_dual_source_blending(&fragment_entry_point_name)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage,
error,
@ -3194,7 +3221,7 @@ impl<A: HalApi> Device<A> {
Some(hal::ProgrammableStage {
module: shader_module.raw(),
entry_point: fragment_state.stage.entry_point.as_ref(),
entry_point: &fragment_entry_point_name,
})
}
None => None,

View File

@ -92,6 +92,19 @@ impl<A: HalApi> ShaderModule<A> {
pub(crate) fn raw(&self) -> &A::ShaderModule {
self.raw.as_ref().unwrap()
}
pub(crate) fn finalize_entry_point_name(
&self,
stage_bit: wgt::ShaderStages,
entry_point: Option<&str>,
) -> Result<String, validation::StageError> {
match &self.interface {
Some(interface) => interface.finalize_entry_point_name(stage_bit, entry_point),
None => entry_point
.map(|ep| ep.to_string())
.ok_or(validation::StageError::NoEntryPointFound),
}
}
}
#[derive(Clone, Debug)]
@ -213,9 +226,13 @@ impl CreateShaderModuleError {
pub struct ProgrammableStageDescriptor<'a> {
/// The compiled shader module for this stage.
pub module: ShaderModuleId,
/// The name of the entry point in the compiled shader. There must be a function with this name
/// in the shader.
pub entry_point: Cow<'a, str>,
/// The name of the entry point in the compiled shader. The name is selected using the
/// following logic:
///
/// * If `Some(name)` is specified, there must be a function with this name in the shader.
/// * If a single entry point associated with this stage must be in the shader, then proceed as
/// if `Some(…)` was specified with that entry point's name.
pub entry_point: Option<Cow<'a, str>>,
}
/// Number of implicit bind groups derived at pipeline creation.

View File

@ -283,6 +283,16 @@ pub enum StageError {
},
#[error("Location[{location}] is provided by the previous stage output but is not consumed as input by this stage.")]
InputNotConsumed { location: wgt::ShaderLocation },
#[error(
"Unable to select an entry point: no entry point was found in the provided shader module"
)]
NoEntryPointFound,
#[error(
"Unable to select an entry point: \
multiple entry points were found in the provided shader module, \
but no entry point was specified"
)]
MultipleEntryPointsFound,
}
fn map_storage_format_to_naga(format: wgt::TextureFormat) -> Option<naga::StorageFormat> {
@ -971,6 +981,28 @@ impl Interface {
}
}
pub fn finalize_entry_point_name(
&self,
stage_bit: wgt::ShaderStages,
entry_point_name: Option<&str>,
) -> Result<String, StageError> {
let stage = Self::shader_stage_from_stage_bit(stage_bit);
entry_point_name
.map(|ep| ep.to_string())
.map(Ok)
.unwrap_or_else(|| {
let mut entry_points = self
.entry_points
.keys()
.filter_map(|(ep_stage, name)| (ep_stage == &stage).then_some(name));
let first = entry_points.next().ok_or(StageError::NoEntryPointFound)?;
if entry_points.next().is_some() {
return Err(StageError::MultipleEntryPointsFound);
}
Ok(first.clone())
})
}
pub(crate) fn shader_stage_from_stage_bit(stage_bit: wgt::ShaderStages) -> naga::ShaderStage {
match stage_bit {
wgt::ShaderStages::VERTEX => naga::ShaderStage::Vertex,
@ -993,10 +1025,11 @@ impl Interface {
// we need to look for one with the right execution model.
let shader_stage = Self::shader_stage_from_stage_bit(stage_bit);
let pair = (shader_stage, entry_point_name.to_string());
let entry_point = self
.entry_points
.get(&pair)
.ok_or(StageError::MissingEntryPoint(pair.1))?;
let entry_point = match self.entry_points.get(&pair) {
Some(some) => some,
None => return Err(StageError::MissingEntryPoint(pair.1)),
};
let (_stage, entry_point_name) = pair;
// check resources visibility
for &handle in entry_point.resources.iter() {

View File

@ -1102,7 +1102,7 @@ impl crate::Context for ContextWgpuCore {
vertex: pipe::VertexState {
stage: pipe::ProgrammableStageDescriptor {
module: desc.vertex.module.id.into(),
entry_point: Borrowed(desc.vertex.entry_point),
entry_point: Some(Borrowed(desc.vertex.entry_point)),
},
buffers: Borrowed(&vertex_buffers),
},
@ -1112,7 +1112,7 @@ impl crate::Context for ContextWgpuCore {
fragment: desc.fragment.as_ref().map(|frag| pipe::FragmentState {
stage: pipe::ProgrammableStageDescriptor {
module: frag.module.id.into(),
entry_point: Borrowed(frag.entry_point),
entry_point: Some(Borrowed(frag.entry_point)),
},
targets: Borrowed(frag.targets),
}),
@ -1160,7 +1160,7 @@ impl crate::Context for ContextWgpuCore {
layout: desc.layout.map(|l| l.id.into()),
stage: pipe::ProgrammableStageDescriptor {
module: desc.module.id.into(),
entry_point: Borrowed(desc.entry_point),
entry_point: Some(Borrowed(desc.entry_point)),
},
};