From 2929ec333cee981ef4cbf783c0e33d208c651c4d Mon Sep 17 00:00:00 2001 From: teoxoy <28601907+teoxoy@users.noreply.github.com> Date: Fri, 5 Jan 2024 14:42:07 +0100 Subject: [PATCH] [spv/msl/hlsl-out] support pipeline constant value replacements --- naga/src/arena.rs | 11 ++ naga/src/back/glsl/mod.rs | 6 + naga/src/back/hlsl/mod.rs | 2 + naga/src/back/hlsl/writer.rs | 6 +- naga/src/back/mod.rs | 8 + naga/src/back/msl/mod.rs | 2 + naga/src/back/msl/writer.rs | 4 + naga/src/back/pipeline_constants.rs | 213 +++++++++++++++++++++ naga/src/back/spv/mod.rs | 2 + naga/src/back/spv/writer.rs | 10 + naga/src/back/wgsl/writer.rs | 6 + naga/src/proc/constant_evaluator.rs | 1 + naga/src/valid/mod.rs | 15 +- naga/tests/in/overrides.param.ron | 11 ++ naga/tests/in/overrides.wgsl | 3 + naga/tests/out/analysis/overrides.info.ron | 17 +- naga/tests/out/hlsl/overrides.hlsl | 12 ++ naga/tests/out/hlsl/overrides.ron | 12 ++ naga/tests/out/ir/overrides.compact.ron | 22 ++- naga/tests/out/ir/overrides.ron | 22 ++- naga/tests/out/msl/overrides.msl | 17 ++ naga/tests/out/spv/overrides.main.spvasm | 25 +++ naga/tests/snapshots.rs | 50 ++++- wgpu-hal/src/vulkan/device.rs | 1 + 24 files changed, 463 insertions(+), 15 deletions(-) create mode 100644 naga/src/back/pipeline_constants.rs create mode 100644 naga/tests/in/overrides.param.ron create mode 100644 naga/tests/out/hlsl/overrides.hlsl create mode 100644 naga/tests/out/hlsl/overrides.ron create mode 100644 naga/tests/out/msl/overrides.msl create mode 100644 naga/tests/out/spv/overrides.main.spvasm diff --git a/naga/src/arena.rs b/naga/src/arena.rs index 4e5f5af6e..184102757 100644 --- a/naga/src/arena.rs +++ b/naga/src/arena.rs @@ -297,6 +297,17 @@ impl Arena { .map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) }) } + /// Drains the arena, returning an iterator over the items stored. + pub fn drain(&mut self) -> impl DoubleEndedIterator, T, Span)> { + let arena = std::mem::take(self); + arena + .data + .into_iter() + .zip(arena.span_info) + .enumerate() + .map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) }) + } + /// Returns a iterator over the items stored in this arena, /// returning both the item's handle and a mutable reference to it. pub fn iter_mut(&mut self) -> impl DoubleEndedIterator, &mut T)> { diff --git a/naga/src/back/glsl/mod.rs b/naga/src/back/glsl/mod.rs index 8241b0412..736a3b57b 100644 --- a/naga/src/back/glsl/mod.rs +++ b/naga/src/back/glsl/mod.rs @@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> { pipeline_options: &'a PipelineOptions, policies: proc::BoundsCheckPolicies, ) -> Result { + if !module.overrides.is_empty() { + return Err(Error::Custom( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + // Check if the requested version is supported if !options.version.is_supported() { log::error!("Version {}", options.version); diff --git a/naga/src/back/hlsl/mod.rs b/naga/src/back/hlsl/mod.rs index 37d26bf3b..588c91d69 100644 --- a/naga/src/back/hlsl/mod.rs +++ b/naga/src/back/hlsl/mod.rs @@ -255,6 +255,8 @@ pub enum Error { Unimplemented(String), // TODO: Error used only during development #[error("{0}")] Custom(String), + #[error(transparent)] + PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError), } #[derive(Default)] diff --git a/naga/src/back/hlsl/writer.rs b/naga/src/back/hlsl/writer.rs index 657774d07..0db648984 100644 --- a/naga/src/back/hlsl/writer.rs +++ b/naga/src/back/hlsl/writer.rs @@ -167,8 +167,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> { &mut self, module: &Module, module_info: &valid::ModuleInfo, - _pipeline_options: &PipelineOptions, + pipeline_options: &PipelineOptions, ) -> Result { + let module = + back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let module = module.as_ref(); + self.reset(module); // Write special constants, if needed diff --git a/naga/src/back/mod.rs b/naga/src/back/mod.rs index 61dc4a060..a95328d4f 100644 --- a/naga/src/back/mod.rs +++ b/naga/src/back/mod.rs @@ -16,6 +16,14 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; +#[cfg(any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" +))] +mod pipeline_constants; + /// Names of vector components. pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w']; /// Indent for backends. diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index 7e05be29b..702b373cf 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -143,6 +143,8 @@ pub enum Error { UnsupportedArrayOfType(Handle), #[error("ray tracing is not supported prior to MSL 2.3")] UnsupportedRayTracing, + #[error(transparent)] + PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), } #[derive(Clone, Debug, PartialEq, thiserror::Error)] diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index f8fa2c4da..36d8bc820 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -3223,6 +3223,10 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { + let module = + back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?; + let module = module.as_ref(); + self.names.clear(); self.namer.reset( module, diff --git a/naga/src/back/pipeline_constants.rs b/naga/src/back/pipeline_constants.rs new file mode 100644 index 000000000..5a3cad2a6 --- /dev/null +++ b/naga/src/back/pipeline_constants.rs @@ -0,0 +1,213 @@ +use super::PipelineConstants; +use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner}; +use std::borrow::Cow; +use thiserror::Error; + +#[derive(Error, Debug, Clone)] +#[cfg_attr(test, derive(PartialEq))] +pub enum PipelineConstantError { + #[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")] + MissingValue(String), + #[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")] + SrcNeedsToBeFinite, + #[error("Source f64 value doesn't fit in destination")] + DstRangeTooSmall, +} + +pub(super) fn process_overrides<'a>( + module: &'a Module, + pipeline_constants: &PipelineConstants, +) -> Result, PipelineConstantError> { + if module.overrides.is_empty() { + return Ok(Cow::Borrowed(module)); + } + + let mut module = module.clone(); + + for (_handle, override_, span) in module.overrides.drain() { + let key = if let Some(id) = override_.id { + Cow::Owned(id.to_string()) + } else if let Some(ref name) = override_.name { + Cow::Borrowed(name) + } else { + unreachable!(); + }; + let init = if let Some(value) = pipeline_constants.get::(&key) { + let literal = match module.types[override_.ty].inner { + TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?, + _ => unreachable!(), + }; + module + .const_expressions + .append(Expression::Literal(literal), Span::UNDEFINED) + } else if let Some(init) = override_.init { + init + } else { + return Err(PipelineConstantError::MissingValue(key.to_string())); + }; + let constant = Constant { + name: override_.name, + ty: override_.ty, + init, + }; + module.constants.append(constant, span); + } + + Ok(Cow::Owned(module)) +} + +fn map_value_to_literal(value: f64, scalar: Scalar) -> Result { + // note that in rust 0.0 == -0.0 + match scalar { + Scalar::BOOL => { + // https://webidl.spec.whatwg.org/#js-boolean + let value = value != 0.0 && !value.is_nan(); + Ok(Literal::Bool(value)) + } + Scalar::I32 => { + // https://webidl.spec.whatwg.org/#js-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as i32; + Ok(Literal::I32(value)) + } + Scalar::U32 => { + // https://webidl.spec.whatwg.org/#js-unsigned-long + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value.trunc(); + if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + let value = value as u32; + Ok(Literal::U32(value)) + } + Scalar::F32 => { + // https://webidl.spec.whatwg.org/#js-float + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + let value = value as f32; + if !value.is_finite() { + return Err(PipelineConstantError::DstRangeTooSmall); + } + + Ok(Literal::F32(value)) + } + Scalar::F64 => { + // https://webidl.spec.whatwg.org/#js-double + if !value.is_finite() { + return Err(PipelineConstantError::SrcNeedsToBeFinite); + } + + Ok(Literal::F64(value)) + } + _ => unreachable!(), + } +} + +#[test] +fn test_map_value_to_literal() { + let bool_test_cases = [ + (0.0, false), + (-0.0, false), + (f64::NAN, false), + (1.0, true), + (f64::INFINITY, true), + (f64::NEG_INFINITY, true), + ]; + for (value, out) in bool_test_cases { + let res = Ok(Literal::Bool(out)); + assert_eq!(map_value_to_literal(value, Scalar::BOOL), res); + } + + for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] { + for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] { + let res = Err(PipelineConstantError::SrcNeedsToBeFinite); + assert_eq!(map_value_to_literal(value, scalar), res); + } + } + + // i32 + assert_eq!( + map_value_to_literal(f64::from(i32::MIN), Scalar::I32), + Ok(Literal::I32(i32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX), Scalar::I32), + Ok(Literal::I32(i32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // u32 + assert_eq!( + map_value_to_literal(f64::from(u32::MIN), Scalar::U32), + Ok(Literal::U32(u32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX), Scalar::U32), + Ok(Literal::U32(u32::MAX)) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f32 + assert_eq!( + map_value_to_literal(f64::from(f32::MIN), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from(f32::MAX), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32), + Ok(Literal::F32(f32::MAX)) + ); + assert_eq!( + map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + assert_eq!( + map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32), + Err(PipelineConstantError::DstRangeTooSmall) + ); + + // f64 + assert_eq!( + map_value_to_literal(f64::MIN, Scalar::F64), + Ok(Literal::F64(f64::MIN)) + ); + assert_eq!( + map_value_to_literal(f64::MAX, Scalar::F64), + Ok(Literal::F64(f64::MAX)) + ); +} diff --git a/naga/src/back/spv/mod.rs b/naga/src/back/spv/mod.rs index 087c49bcc..3c0332d59 100644 --- a/naga/src/back/spv/mod.rs +++ b/naga/src/back/spv/mod.rs @@ -70,6 +70,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), + #[error(transparent)] + PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError), } #[derive(Default)] diff --git a/naga/src/back/spv/writer.rs b/naga/src/back/spv/writer.rs index a5065e062..975aa625d 100644 --- a/naga/src/back/spv/writer.rs +++ b/naga/src/back/spv/writer.rs @@ -2029,6 +2029,16 @@ impl Writer { debug_info: &Option, words: &mut Vec, ) -> Result<(), Error> { + let ir_module = if let Some(pipeline_options) = pipeline_options { + crate::back::pipeline_constants::process_overrides( + ir_module, + &pipeline_options.constants, + )? + } else { + std::borrow::Cow::Borrowed(ir_module) + }; + let ir_module = ir_module.as_ref(); + self.reset(); // Try to find the entry point and corresponding index diff --git a/naga/src/back/wgsl/writer.rs b/naga/src/back/wgsl/writer.rs index 607954706..7ca689f48 100644 --- a/naga/src/back/wgsl/writer.rs +++ b/naga/src/back/wgsl/writer.rs @@ -106,6 +106,12 @@ impl Writer { } pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult { + if !module.overrides.is_empty() { + return Err(Error::Unimplemented( + "Pipeline constants are not yet supported for this back-end".to_string(), + )); + } + self.reset(module); // Save all ep result types diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index 8a9da04d3..5617cc770 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -359,6 +359,7 @@ impl ExpressionConstnessTracker { } #[derive(Clone, Debug, thiserror::Error)] +#[cfg_attr(test, derive(PartialEq))] pub enum ConstantEvaluatorError { #[error("Constants cannot access function arguments")] FunctionArg, diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index d54079ac1..be11e8e39 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -186,6 +186,8 @@ pub enum ConstantError { #[derive(Clone, Debug, thiserror::Error)] pub enum OverrideError { + #[error("Override name and ID are missing")] + MissingNameAndID, #[error("The type doesn't match the override")] InvalidType, #[error("The type is not constructible")] @@ -353,6 +355,10 @@ impl Validator { ) -> Result<(), OverrideError> { let o = &gctx.overrides[handle]; + if o.name.is_none() && o.id.is_none() { + return Err(OverrideError::MissingNameAndID); + } + let type_info = &self.types[o.ty.index()]; if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) { return Err(OverrideError::NonConstructibleType); @@ -360,7 +366,14 @@ impl Validator { let decl_ty = &gctx.types[o.ty].inner; match decl_ty { - &crate::TypeInner::Scalar(_) => {} + &crate::TypeInner::Scalar(scalar) => match scalar { + crate::Scalar::BOOL + | crate::Scalar::I32 + | crate::Scalar::U32 + | crate::Scalar::F32 + | crate::Scalar::F64 => {} + _ => return Err(OverrideError::TypeNotScalar), + }, _ => return Err(OverrideError::TypeNotScalar), } diff --git a/naga/tests/in/overrides.param.ron b/naga/tests/in/overrides.param.ron new file mode 100644 index 000000000..5c9b72d31 --- /dev/null +++ b/naga/tests/in/overrides.param.ron @@ -0,0 +1,11 @@ +( + spv: ( + version: (1, 0), + separate_entry_points: true, + ), + pipeline_constants: { + "0": NaN, + "1300": 1.1, + "depth": 2.3, + } +) diff --git a/naga/tests/in/overrides.wgsl b/naga/tests/in/overrides.wgsl index 803269a65..b498a8b52 100644 --- a/naga/tests/in/overrides.wgsl +++ b/naga/tests/in/overrides.wgsl @@ -12,3 +12,6 @@ // overridable constant. override inferred_f32 = 2.718; + +@compute @workgroup_size(1) +fn main() {} \ No newline at end of file diff --git a/naga/tests/out/analysis/overrides.info.ron b/naga/tests/out/analysis/overrides.info.ron index 9ad1b3914..481c3eac9 100644 --- a/naga/tests/out/analysis/overrides.info.ron +++ b/naga/tests/out/analysis/overrides.info.ron @@ -4,7 +4,22 @@ ("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"), ], functions: [], - entry_points: [], + entry_points: [ + ( + flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"), + available_stages: ("VERTEX | FRAGMENT | COMPUTE"), + uniformity: ( + non_uniform_result: None, + requirements: (""), + ), + may_kill: false, + sampling_set: [], + global_uses: [], + expressions: [], + sampling: [], + dual_source_blending: false, + ), + ], const_expression_types: [ Value(Scalar(( kind: Bool, diff --git a/naga/tests/out/hlsl/overrides.hlsl b/naga/tests/out/hlsl/overrides.hlsl new file mode 100644 index 000000000..63b13a5d2 --- /dev/null +++ b/naga/tests/out/hlsl/overrides.hlsl @@ -0,0 +1,12 @@ +static const bool has_point_light = false; +static const float specular_param = 2.3; +static const float gain = 1.1; +static const float width = 0.0; +static const float depth = 2.3; +static const float inferred_f32_ = 2.718; + +[numthreads(1, 1, 1)] +void main() +{ + return; +} diff --git a/naga/tests/out/hlsl/overrides.ron b/naga/tests/out/hlsl/overrides.ron new file mode 100644 index 000000000..a07b03300 --- /dev/null +++ b/naga/tests/out/hlsl/overrides.ron @@ -0,0 +1,12 @@ +( + vertex:[ + ], + fragment:[ + ], + compute:[ + ( + entry_point:"main", + target_profile:"cs_5_1", + ), + ], +) diff --git a/naga/tests/out/ir/overrides.compact.ron b/naga/tests/out/ir/overrides.compact.ron index 5ac9ade6f..af4b31eba 100644 --- a/naga/tests/out/ir/overrides.compact.ron +++ b/naga/tests/out/ir/overrides.compact.ron @@ -67,5 +67,25 @@ Literal(F32(2.718)), ], functions: [], - entry_points: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [], + named_expressions: {}, + body: [ + Return( + value: None, + ), + ], + ), + ), + ], ) \ No newline at end of file diff --git a/naga/tests/out/ir/overrides.ron b/naga/tests/out/ir/overrides.ron index 5ac9ade6f..af4b31eba 100644 --- a/naga/tests/out/ir/overrides.ron +++ b/naga/tests/out/ir/overrides.ron @@ -67,5 +67,25 @@ Literal(F32(2.718)), ], functions: [], - entry_points: [], + entry_points: [ + ( + name: "main", + stage: Compute, + early_depth_test: None, + workgroup_size: (1, 1, 1), + function: ( + name: Some("main"), + arguments: [], + result: None, + local_variables: [], + expressions: [], + named_expressions: {}, + body: [ + Return( + value: None, + ), + ], + ), + ), + ], ) \ No newline at end of file diff --git a/naga/tests/out/msl/overrides.msl b/naga/tests/out/msl/overrides.msl new file mode 100644 index 000000000..419edd890 --- /dev/null +++ b/naga/tests/out/msl/overrides.msl @@ -0,0 +1,17 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +constant bool has_point_light = false; +constant float specular_param = 2.3; +constant float gain = 1.1; +constant float width = 0.0; +constant float depth = 2.3; +constant float inferred_f32_ = 2.718; + +kernel void main_( +) { + return; +} diff --git a/naga/tests/out/spv/overrides.main.spvasm b/naga/tests/out/spv/overrides.main.spvasm new file mode 100644 index 000000000..7dfa6df3e --- /dev/null +++ b/naga/tests/out/spv/overrides.main.spvasm @@ -0,0 +1,25 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 15 +OpCapability Shader +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpEntryPoint GLCompute %12 "main" +OpExecutionMode %12 LocalSize 1 1 1 +%2 = OpTypeVoid +%3 = OpTypeBool +%4 = OpTypeFloat 32 +%5 = OpConstantTrue %3 +%6 = OpConstant %4 2.3 +%7 = OpConstant %4 0.0 +%8 = OpConstant %4 2.718 +%9 = OpConstantFalse %3 +%10 = OpConstant %4 1.1 +%13 = OpTypeFunction %2 +%12 = OpFunction %2 None %13 +%11 = OpLabel +OpBranch %14 +%14 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 1d3734500..e2f6dff25 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -87,6 +87,17 @@ struct Parameters { #[cfg(all(feature = "deserialize", feature = "glsl-out"))] #[serde(default)] glsl_multiview: Option, + #[cfg(all( + feature = "deserialize", + any( + feature = "hlsl-out", + feature = "msl-out", + feature = "spv-out", + feature = "glsl-out" + ) + ))] + #[serde(default)] + pipeline_constants: naga::back::PipelineConstants, } /// Information about a shader input file. @@ -331,18 +342,25 @@ fn check_targets( debug_info, ¶ms.spv, params.bounds_check_policies, + ¶ms.pipeline_constants, ); } } #[cfg(all(feature = "deserialize", feature = "msl-out"))] { if targets.contains(Targets::METAL) { + if !params.msl_pipeline.constants.is_empty() { + panic!("Supply pipeline constants via pipeline_constants instead of msl_pipeline.constants!"); + } + let mut pipeline_options = params.msl_pipeline.clone(); + pipeline_options.constants = params.pipeline_constants.clone(); + write_output_msl( input, module, &info, ¶ms.msl, - ¶ms.msl_pipeline, + &pipeline_options, params.bounds_check_policies, ); } @@ -363,6 +381,7 @@ fn check_targets( ¶ms.glsl, params.bounds_check_policies, params.glsl_multiview, + ¶ms.pipeline_constants, ); } } @@ -377,7 +396,13 @@ fn check_targets( #[cfg(all(feature = "deserialize", feature = "hlsl-out"))] { if targets.contains(Targets::HLSL) { - write_output_hlsl(input, module, &info, ¶ms.hlsl); + write_output_hlsl( + input, + module, + &info, + ¶ms.hlsl, + ¶ms.pipeline_constants, + ); } } #[cfg(all(feature = "deserialize", feature = "wgsl-out"))] @@ -396,6 +421,7 @@ fn write_output_spv( debug_info: Option, params: &SpirvOutParameters, bounds_check_policies: naga::proc::BoundsCheckPolicies, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::spv; use rspirv::binary::Disassemble; @@ -428,7 +454,7 @@ fn write_output_spv( let pipeline_options = spv::PipelineOptions { entry_point: ep.name.clone(), shader_stage: ep.stage, - constants: naga::back::PipelineConstants::default(), + constants: pipeline_constants.clone(), }; write_output_spv_inner( input, @@ -508,6 +534,7 @@ fn write_output_glsl( options: &naga::back::glsl::Options, bounds_check_policies: naga::proc::BoundsCheckPolicies, multiview: Option, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::glsl; @@ -517,7 +544,7 @@ fn write_output_glsl( shader_stage: stage, entry_point: ep_name.to_string(), multiview, - constants: naga::back::PipelineConstants::default(), + constants: pipeline_constants.clone(), }; let mut buffer = String::new(); @@ -542,6 +569,7 @@ fn write_output_hlsl( module: &naga::Module, info: &naga::valid::ModuleInfo, options: &naga::back::hlsl::Options, + pipeline_constants: &naga::back::PipelineConstants, ) { use naga::back::hlsl; use std::fmt::Write as _; @@ -551,7 +579,13 @@ fn write_output_hlsl( let mut buffer = String::new(); let mut writer = hlsl::Writer::new(&mut buffer, options); let reflection_info = writer - .write(module, info, &hlsl::PipelineOptions::default()) + .write( + module, + info, + &hlsl::PipelineOptions { + constants: pipeline_constants.clone(), + }, + ) .expect("HLSL write failed"); input.write_output_file("hlsl", "hlsl", buffer); @@ -817,11 +851,7 @@ fn convert_wgsl() { ), ( "overrides", - Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV - // | Targets::METAL - // | Targets::GLSL - // | Targets::HLSL - // | Targets::WGSL, + Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL, ), ]; diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 2dcded220..989ad60c7 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1588,6 +1588,7 @@ impl crate::Device for super::Device { .shared .workarounds .contains(super::Workarounds::SEPARATE_ENTRY_POINTS) + || !naga_shader.module.overrides.is_empty() { return Ok(super::ShaderModule::Intermediate { naga_shader,