move the burden of evaluating override-expressions to users of naga's API

This commit is contained in:
teoxoy 2024-04-04 11:58:28 +02:00 committed by Teodor Tanasoaia
parent 7bed9e8bce
commit 7df0aa6364
20 changed files with 175 additions and 168 deletions

View File

@ -597,17 +597,18 @@ fn write_output(
let mut options = params.msl.clone();
options.bounds_check_policies = params.bounds_check_policies;
let info = info.as_ref().ok_or(CliError(
"Generating metal output requires validation to \
succeed, and it failed in a previous step",
))?;
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let pipeline_options = msl::PipelineOptions::default();
let (msl, _) = msl::write_string(
module,
info.as_ref().ok_or(CliError(
"Generating metal output requires validation to \
succeed, and it failed in a previous step",
))?,
&options,
&pipeline_options,
)
.unwrap_pretty();
let (msl, _) =
msl::write_string(&module, &info, &options, &pipeline_options).unwrap_pretty();
fs::write(output_path, msl)?;
}
"spv" => {
@ -624,23 +625,23 @@ fn write_output(
pipeline_options_owned = spv::PipelineOptions {
entry_point: name.clone(),
shader_stage: module.entry_points[ep_index].stage,
constants: naga::back::PipelineConstants::default(),
};
Some(&pipeline_options_owned)
}
None => None,
};
let spv = spv::write_vec(
module,
info.as_ref().ok_or(CliError(
"Generating SPIR-V output requires validation to \
succeed, and it failed in a previous step",
))?,
&params.spv_out,
pipeline_options,
)
.unwrap_pretty();
let info = info.as_ref().ok_or(CliError(
"Generating SPIR-V output requires validation to \
succeed, and it failed in a previous step",
))?;
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let spv =
spv::write_vec(&module, &info, &params.spv_out, pipeline_options).unwrap_pretty();
let bytes = spv
.iter()
.fold(Vec::with_capacity(spv.len() * 4), |mut v, w| {
@ -665,17 +666,22 @@ fn write_output(
_ => unreachable!(),
},
multiview: None,
constants: naga::back::PipelineConstants::default(),
};
let info = info.as_ref().ok_or(CliError(
"Generating glsl output requires validation to \
succeed, and it failed in a previous step",
))?;
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let mut buffer = String::new();
let mut writer = glsl::Writer::new(
&mut buffer,
module,
info.as_ref().ok_or(CliError(
"Generating glsl output requires validation to \
succeed, and it failed in a previous step",
))?,
&module,
&info,
&params.glsl,
&pipeline_options,
params.bounds_check_policies,
@ -692,20 +698,19 @@ fn write_output(
}
"hlsl" => {
use naga::back::hlsl;
let info = info.as_ref().ok_or(CliError(
"Generating hlsl output requires validation to \
succeed, and it failed in a previous step",
))?;
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, &params.overrides)
.unwrap_pretty();
let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, &params.hlsl);
writer
.write(
module,
info.as_ref().ok_or(CliError(
"Generating hlsl output requires validation to \
succeed, and it failed in a previous step",
))?,
&hlsl::PipelineOptions {
constants: params.overrides.clone(),
},
)
.unwrap_pretty();
writer.write(&module, &info).unwrap_pretty();
fs::write(output_path, buffer)?;
}
"wgsl" => {

View File

@ -193,7 +193,6 @@ fn backends(c: &mut Criterion) {
let pipeline_options = naga::back::spv::PipelineOptions {
shader_stage: ep.stage,
entry_point: ep.name.clone(),
constants: naga::back::PipelineConstants::default(),
};
writer
.write(module, info, Some(&pipeline_options), &None, &mut data)
@ -224,11 +223,10 @@ fn backends(c: &mut Criterion) {
group.bench_function("hlsl", |b| {
b.iter(|| {
let options = naga::back::hlsl::Options::default();
let pipeline_options = naga::back::hlsl::PipelineOptions::default();
let mut string = String::new();
for &(ref module, ref info) in inputs.iter() {
let mut writer = naga::back::hlsl::Writer::new(&mut string, &options);
let _ = writer.write(module, info, &pipeline_options); // may fail on unimplemented things
let _ = writer.write(module, info); // may fail on unimplemented things
string.clear();
}
});
@ -250,7 +248,6 @@ fn backends(c: &mut Criterion) {
shader_stage: ep.stage,
entry_point: ep.name.clone(),
multiview: None,
constants: naga::back::PipelineConstants::default(),
};
// might be `Err` if missing features

View File

@ -294,8 +294,6 @@ pub struct PipelineOptions {
pub entry_point: String,
/// How many views to render to, if doing multiview rendering.
pub multiview: Option<std::num::NonZeroU32>,
/// Pipeline constants.
pub constants: back::PipelineConstants,
}
#[derive(Debug)]
@ -499,6 +497,8 @@ pub enum Error {
ImageMultipleSamplers,
#[error("{0}")]
Custom(String),
#[error("overrides should not be present at this stage")]
Override,
}
/// Binary operation with a different logic on the GLSL side.
@ -568,9 +568,7 @@ impl<'a, W: Write> Writer<'a, W> {
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
if !module.overrides.is_empty() {
return Err(Error::Custom(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
return Err(Error::Override);
}
// Check if the requested version is supported
@ -2544,7 +2542,7 @@ impl<'a, W: Write> Writer<'a, W> {
|writer, expr| writer.write_expr(expr, ctx),
)?;
}
Expression::Override(_) => return Err(Error::Custom("overrides are WIP".into())),
Expression::Override(_) => return Err(Error::Override),
// `Access` is applied to arrays, vectors and matrices and is written as indexing
Expression::Access { base, index } => {
self.write_expr(base, ctx)?;

View File

@ -195,14 +195,6 @@ pub struct Options {
pub zero_initialize_workgroup_memory: bool,
}
#[derive(Clone, Debug, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct PipelineOptions {
/// Pipeline constants.
pub constants: back::PipelineConstants,
}
impl Default for Options {
fn default() -> Self {
Options {
@ -255,8 +247,8 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] Box<back::pipeline_constants::PipelineConstantError>),
#[error("overrides should not be present at this stage")]
Override,
}
#[derive(Default)]

View File

@ -1,7 +1,7 @@
use super::{
help::{WrappedArrayLength, WrappedConstructor, WrappedImageQuery, WrappedStructMatrixAccess},
storage::StoreValue,
BackendResult, Error, Options, PipelineOptions,
BackendResult, Error, Options,
};
use crate::{
back,
@ -167,16 +167,10 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&mut self,
module: &Module,
module_info: &valid::ModuleInfo,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let (module, module_info) = back::pipeline_constants::process_overrides(
module,
module_info,
&pipeline_options.constants,
)
.map_err(Box::new)?;
let module = module.as_ref();
let module_info = module_info.as_ref();
if !module.overrides.is_empty() {
return Err(Error::Override);
}
self.reset(module);
@ -2150,9 +2144,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
Expression::Override(_) => return Err(Error::Override),
// All of the multiplication can be expressed as `mul`,
// except vector * vector, which needs to use the "*" operator.
Expression::Binary {

View File

@ -22,7 +22,7 @@ pub mod wgsl;
feature = "spv-out",
feature = "glsl-out"
))]
mod pipeline_constants;
pub mod pipeline_constants;
/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];

View File

@ -143,8 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
#[error("overrides should not be present at this stage")]
Override,
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]
@ -234,8 +234,6 @@ pub struct PipelineOptions {
///
/// Enable this for vertex shaders with point primitive topologies.
pub allow_and_force_point_size: bool,
/// Pipeline constants.
pub constants: crate::back::PipelineConstants,
}
impl Options {

View File

@ -1431,9 +1431,7 @@ impl<W: Write> Writer<W> {
|writer, context, expr| writer.put_expression(expr, context, true),
)?;
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP".into()))
}
crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::Access { base, .. }
| crate::Expression::AccessIndex { base, .. } => {
// This is an acceptable place to generate a `ReadZeroSkipWrite` check.
@ -3223,11 +3221,9 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let (module, info) =
back::pipeline_constants::process_overrides(module, info, &pipeline_options.constants)
.map_err(Box::new)?;
let module = module.as_ref();
let info = info.as_ref();
if !module.overrides.is_empty() {
return Err(Error::Override);
}
self.names.clear();
self.namer.reset(

View File

@ -36,7 +36,7 @@ pub enum PipelineConstantError {
/// fully-evaluated expressions.
///
/// [`global_expressions`]: Module::global_expressions
pub(super) fn process_overrides<'a>(
pub fn process_overrides<'a>(
module: &'a Module,
module_info: &'a ModuleInfo,
pipeline_constants: &PipelineConstants,

View File

@ -239,9 +239,7 @@ impl<'w> BlockContext<'w> {
let init = self.ir_module.constants[handle].init;
self.writer.constant_ids[init.index()]
}
crate::Expression::Override(_) => {
return Err(Error::FeatureNotImplemented("overrides are WIP"))
}
crate::Expression::Override(_) => return Err(Error::Override),
crate::Expression::ZeroValue(_) => self.writer.get_constant_null(result_type_id),
crate::Expression::Compose { ty, ref components } => {
self.temp_list.clear();

View File

@ -70,8 +70,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] Box<crate::back::pipeline_constants::PipelineConstantError>),
#[error("overrides should not be present at this stage")]
Override,
}
#[derive(Default)]
@ -773,8 +773,6 @@ pub struct PipelineOptions {
///
/// If no entry point that matches is found while creating a [`Writer`], a error will be thrown.
pub entry_point: String,
/// Pipeline constants.
pub constants: crate::back::PipelineConstants,
}
pub fn write_vec(

View File

@ -2029,21 +2029,9 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let (ir_module, info) = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
info,
&pipeline_options.constants,
)
.map_err(Box::new)?
} else {
(
std::borrow::Cow::Borrowed(ir_module),
std::borrow::Cow::Borrowed(info),
)
};
let ir_module = ir_module.as_ref();
let info = info.as_ref();
if !ir_module.overrides.is_empty() {
return Err(Error::Override);
}
self.reset();

View File

@ -1205,9 +1205,7 @@ impl<W: Write> Writer<W> {
|writer, expr| writer.write_expr(module, expr, func_ctx),
)?;
}
Expression::Override(_) => {
return Err(Error::Unimplemented("overrides are WIP".into()))
}
Expression::Override(_) => unreachable!(),
Expression::FunctionArgument(pos) => {
let name_key = func_ctx.argument_key(pos);
let name = &self.names[&name_key];

View File

@ -27,6 +27,5 @@
),
msl_pipeline: (
allow_and_force_point_size: true,
constants: {},
),
)

View File

@ -0,0 +1,29 @@
#version 310 es
precision highp float;
precision highp int;
layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
const bool has_point_light = false;
const float specular_param = 2.3;
const float gain = 1.1;
const float width = 0.0;
const float depth = 2.3;
const float height = 4.6;
const float inferred_f32_ = 2.718;
float gain_x_10_ = 11.0;
void main() {
float t = 0.0;
bool x = false;
float gain_x_100_ = 0.0;
t = 23.0;
x = true;
float _e10 = gain_x_10_;
gain_x_100_ = (_e10 * 10.0);
return;
}

View File

@ -349,19 +349,14 @@ fn check_targets(
#[cfg(all(feature = "deserialize", feature = "msl-out"))]
{
if targets.contains(Targets::METAL) {
if !params.msl_pipeline.constants.is_empty() {
panic!("Supply pipeline constants via pipeline_constants instead of msl_pipeline.constants!");
}
let mut pipeline_options = params.msl_pipeline.clone();
pipeline_options.constants = params.pipeline_constants.clone();
write_output_msl(
input,
module,
&info,
&params.msl,
&pipeline_options,
&params.msl_pipeline,
params.bounds_check_policies,
&params.pipeline_constants,
);
}
}
@ -449,25 +444,27 @@ fn write_output_spv(
debug_info,
};
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants)
.expect("override evaluation failed");
if params.separate_entry_points {
for ep in module.entry_points.iter() {
let pipeline_options = spv::PipelineOptions {
entry_point: ep.name.clone(),
shader_stage: ep.stage,
constants: pipeline_constants.clone(),
};
write_output_spv_inner(
input,
module,
info,
&module,
&info,
&options,
Some(&pipeline_options),
&format!("{}.spvasm", ep.name),
);
}
} else {
assert!(pipeline_constants.is_empty());
write_output_spv_inner(input, module, info, &options, None, "spvasm");
write_output_spv_inner(input, &module, &info, &options, None, "spvasm");
}
}
@ -505,14 +502,19 @@ fn write_output_msl(
options: &naga::back::msl::Options,
pipeline_options: &naga::back::msl::PipelineOptions,
bounds_check_policies: naga::proc::BoundsCheckPolicies,
pipeline_constants: &naga::back::PipelineConstants,
) {
use naga::back::msl;
println!("generating MSL");
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants)
.expect("override evaluation failed");
let mut options = options.clone();
options.bounds_check_policies = bounds_check_policies;
let (string, tr_info) = msl::write_string(module, info, &options, pipeline_options)
let (string, tr_info) = msl::write_string(&module, &info, &options, pipeline_options)
.unwrap_or_else(|err| panic!("Metal write failed: {err}"));
for (ep, result) in module.entry_points.iter().zip(tr_info.entry_point_names) {
@ -545,14 +547,16 @@ fn write_output_glsl(
shader_stage: stage,
entry_point: ep_name.to_string(),
multiview,
constants: pipeline_constants.clone(),
};
let mut buffer = String::new();
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants)
.expect("override evaluation failed");
let mut writer = glsl::Writer::new(
&mut buffer,
module,
info,
&module,
&info,
options,
&pipeline_options,
bounds_check_policies,
@ -577,17 +581,13 @@ fn write_output_hlsl(
println!("generating HLSL");
let (module, info) =
naga::back::pipeline_constants::process_overrides(module, info, pipeline_constants)
.expect("override evaluation failed");
let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, options);
let reflection_info = writer
.write(
module,
info,
&hlsl::PipelineOptions {
constants: pipeline_constants.clone(),
},
)
.expect("HLSL write failed");
let reflection_info = writer.write(&module, &info).expect("HLSL write failed");
input.write_output_file("hlsl", "hlsl", buffer);
@ -852,7 +852,12 @@ fn convert_wgsl() {
),
(
"overrides",
Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL,
Targets::IR
| Targets::ANALYSIS
| Targets::SPIRV
| Targets::METAL
| Targets::HLSL
| Targets::GLSL,
),
(
"overrides-atomicCompareExchangeWeak",

View File

@ -218,17 +218,21 @@ impl super::Device {
use naga::back::hlsl;
let stage_bit = crate::auxil::map_naga_stage(naga_stage);
let module = &stage.module.naga.module;
let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
stage.constants,
)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?;
//TODO: reuse the writer
let mut source = String::new();
let mut writer = hlsl::Writer::new(&mut source, &layout.naga_options);
let pipeline_options = hlsl::PipelineOptions {
constants: stage.constants.to_owned(),
};
let reflection_info = {
profiling::scope!("naga::back::hlsl::write");
writer
.write(module, &stage.module.naga.info, &pipeline_options)
.write(&module, &info)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("HLSL: {e:?}")))?
};

View File

@ -218,12 +218,19 @@ impl super::Device {
shader_stage: naga_stage,
entry_point: stage.entry_point.to_string(),
multiview: context.multiview,
constants: stage.constants.to_owned(),
};
let shader = &stage.module.naga;
let entry_point_index = shader
.module
let (module, info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
stage.constants,
)
.map_err(|e| {
let msg = format!("{e}");
crate::PipelineError::Linkage(map_naga_stage(naga_stage), msg)
})?;
let entry_point_index = module
.entry_points
.iter()
.position(|ep| ep.name.as_str() == stage.entry_point)
@ -250,8 +257,8 @@ impl super::Device {
let mut output = String::new();
let mut writer = glsl::Writer::new(
&mut output,
&shader.module,
&shader.info,
&module,
&info,
&context.layout.naga_options,
&pipeline_options,
policies,
@ -270,8 +277,8 @@ impl super::Device {
context.consume_reflection(
gl,
&shader.module,
shader.info.get_entry_point(entry_point_index),
&module,
info.get_entry_point(entry_point_index),
reflection_info,
naga_stage,
program,

View File

@ -69,7 +69,13 @@ impl super::Device {
) -> Result<CompiledShader, crate::PipelineError> {
let stage_bit = map_naga_stage(naga_stage);
let module = &stage.module.naga.module;
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
stage.constants,
)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?;
let ep_resources = &layout.per_stage_map[naga_stage];
let bounds_check_policy = if stage.module.runtime_checks {
@ -112,16 +118,11 @@ impl super::Device {
metal::MTLPrimitiveTopologyClass::Point => true,
_ => false,
},
constants: stage.constants.to_owned(),
};
let (source, info) = naga::back::msl::write_string(
module,
&stage.module.naga.info,
&options,
&pipeline_options,
)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?;
let (source, info) =
naga::back::msl::write_string(&module, &module_info, &options, &pipeline_options)
.map_err(|e| crate::PipelineError::Linkage(stage_bit, format!("MSL: {:?}", e)))?;
log::debug!(
"Naga generated shader for entry point '{}' and stage {:?}\n{}",
@ -169,7 +170,7 @@ impl super::Device {
})?;
// collect sizes indices, immutable buffers, and work group memory sizes
let ep_info = &stage.module.naga.info.get_entry_point(ep_index);
let ep_info = &module_info.get_entry_point(ep_index);
let mut wg_memory_sizes = Vec::new();
let mut sized_bindings = Vec::new();
let mut immutable_buffer_mask = 0;

View File

@ -734,7 +734,6 @@ impl super::Device {
let pipeline_options = naga::back::spv::PipelineOptions {
entry_point: stage.entry_point.to_string(),
shader_stage: naga_stage,
constants: stage.constants.to_owned(),
};
let needs_temp_options = !runtime_checks
|| !binding_map.is_empty()
@ -766,14 +765,17 @@ impl super::Device {
} else {
&self.naga_options
};
let (module, info) = naga::back::pipeline_constants::process_overrides(
&naga_shader.module,
&naga_shader.info,
stage.constants,
)
.map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{e}")))?;
let spv = {
profiling::scope!("naga::spv::write_vec");
naga::back::spv::write_vec(
&naga_shader.module,
&naga_shader.info,
options,
Some(&pipeline_options),
)
naga::back::spv::write_vec(&module, &info, options, Some(&pipeline_options))
}
.map_err(|e| crate::PipelineError::Linkage(stage_flags, format!("{e}")))?;
self.create_shader_module_impl(&spv)?