[spv/msl/hlsl-out] support pipeline constant value replacements

This commit is contained in:
teoxoy 2024-01-05 14:42:07 +01:00 committed by Teodor Tanasoaia
parent 7ce422c57a
commit 2929ec333c
24 changed files with 463 additions and 15 deletions

View File

@ -297,6 +297,17 @@ impl<T> Arena<T> {
.map(|(i, v)| unsafe { (Handle::from_usize_unchecked(i), v) })
}
/// Drains the arena, returning an iterator over the items stored.
pub fn drain(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, T, Span)> {
let arena = std::mem::take(self);
arena
.data
.into_iter()
.zip(arena.span_info)
.enumerate()
.map(|(i, (v, span))| unsafe { (Handle::from_usize_unchecked(i), v, span) })
}
/// Returns a iterator over the items stored in this arena,
/// returning both the item's handle and a mutable reference to it.
pub fn iter_mut(&mut self) -> impl DoubleEndedIterator<Item = (Handle<T>, &mut T)> {

View File

@ -567,6 +567,12 @@ impl<'a, W: Write> Writer<'a, W> {
pipeline_options: &'a PipelineOptions,
policies: proc::BoundsCheckPolicies,
) -> Result<Self, Error> {
if !module.overrides.is_empty() {
return Err(Error::Custom(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}
// Check if the requested version is supported
if !options.version.is_supported() {
log::error!("Version {}", options.version);

View File

@ -255,6 +255,8 @@ pub enum Error {
Unimplemented(String), // TODO: Error used only during development
#[error("{0}")]
Custom(String),
#[error(transparent)]
PipelineConstant(#[from] back::pipeline_constants::PipelineConstantError),
}
#[derive(Default)]

View File

@ -167,8 +167,12 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
&mut self,
module: &Module,
module_info: &valid::ModuleInfo,
_pipeline_options: &PipelineOptions,
pipeline_options: &PipelineOptions,
) -> Result<super::ReflectionInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();
self.reset(module);
// Write special constants, if needed

View File

@ -16,6 +16,14 @@ pub mod spv;
#[cfg(feature = "wgsl-out")]
pub mod wgsl;
#[cfg(any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
))]
mod pipeline_constants;
/// Names of vector components.
pub const COMPONENTS: &[char] = &['x', 'y', 'z', 'w'];
/// Indent for backends.

View File

@ -143,6 +143,8 @@ pub enum Error {
UnsupportedArrayOfType(Handle<crate::Type>),
#[error("ray tracing is not supported prior to MSL 2.3")]
UnsupportedRayTracing,
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}
#[derive(Clone, Debug, PartialEq, thiserror::Error)]

View File

@ -3223,6 +3223,10 @@ impl<W: Write> Writer<W> {
options: &Options,
pipeline_options: &PipelineOptions,
) -> Result<TranslationInfo, Error> {
let module =
back::pipeline_constants::process_overrides(module, &pipeline_options.constants)?;
let module = module.as_ref();
self.names.clear();
self.namer.reset(
module,

View File

@ -0,0 +1,213 @@
use super::PipelineConstants;
use crate::{Constant, Expression, Literal, Module, Scalar, Span, TypeInner};
use std::borrow::Cow;
use thiserror::Error;
#[derive(Error, Debug, Clone)]
#[cfg_attr(test, derive(PartialEq))]
pub enum PipelineConstantError {
#[error("Missing value for pipeline-overridable constant with identifier string: '{0}'")]
MissingValue(String),
#[error("Source f64 value needs to be finite (NaNs and Inifinites are not allowed) for number destinations")]
SrcNeedsToBeFinite,
#[error("Source f64 value doesn't fit in destination")]
DstRangeTooSmall,
}
pub(super) fn process_overrides<'a>(
module: &'a Module,
pipeline_constants: &PipelineConstants,
) -> Result<Cow<'a, Module>, PipelineConstantError> {
if module.overrides.is_empty() {
return Ok(Cow::Borrowed(module));
}
let mut module = module.clone();
for (_handle, override_, span) in module.overrides.drain() {
let key = if let Some(id) = override_.id {
Cow::Owned(id.to_string())
} else if let Some(ref name) = override_.name {
Cow::Borrowed(name)
} else {
unreachable!();
};
let init = if let Some(value) = pipeline_constants.get::<str>(&key) {
let literal = match module.types[override_.ty].inner {
TypeInner::Scalar(scalar) => map_value_to_literal(*value, scalar)?,
_ => unreachable!(),
};
module
.const_expressions
.append(Expression::Literal(literal), Span::UNDEFINED)
} else if let Some(init) = override_.init {
init
} else {
return Err(PipelineConstantError::MissingValue(key.to_string()));
};
let constant = Constant {
name: override_.name,
ty: override_.ty,
init,
};
module.constants.append(constant, span);
}
Ok(Cow::Owned(module))
}
fn map_value_to_literal(value: f64, scalar: Scalar) -> Result<Literal, PipelineConstantError> {
// note that in rust 0.0 == -0.0
match scalar {
Scalar::BOOL => {
// https://webidl.spec.whatwg.org/#js-boolean
let value = value != 0.0 && !value.is_nan();
Ok(Literal::Bool(value))
}
Scalar::I32 => {
// https://webidl.spec.whatwg.org/#js-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
let value = value.trunc();
if value < f64::from(i32::MIN) || value > f64::from(i32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}
let value = value as i32;
Ok(Literal::I32(value))
}
Scalar::U32 => {
// https://webidl.spec.whatwg.org/#js-unsigned-long
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
let value = value.trunc();
if value < f64::from(u32::MIN) || value > f64::from(u32::MAX) {
return Err(PipelineConstantError::DstRangeTooSmall);
}
let value = value as u32;
Ok(Literal::U32(value))
}
Scalar::F32 => {
// https://webidl.spec.whatwg.org/#js-float
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
let value = value as f32;
if !value.is_finite() {
return Err(PipelineConstantError::DstRangeTooSmall);
}
Ok(Literal::F32(value))
}
Scalar::F64 => {
// https://webidl.spec.whatwg.org/#js-double
if !value.is_finite() {
return Err(PipelineConstantError::SrcNeedsToBeFinite);
}
Ok(Literal::F64(value))
}
_ => unreachable!(),
}
}
#[test]
fn test_map_value_to_literal() {
let bool_test_cases = [
(0.0, false),
(-0.0, false),
(f64::NAN, false),
(1.0, true),
(f64::INFINITY, true),
(f64::NEG_INFINITY, true),
];
for (value, out) in bool_test_cases {
let res = Ok(Literal::Bool(out));
assert_eq!(map_value_to_literal(value, Scalar::BOOL), res);
}
for scalar in [Scalar::I32, Scalar::U32, Scalar::F32, Scalar::F64] {
for value in [f64::NAN, f64::INFINITY, f64::NEG_INFINITY] {
let res = Err(PipelineConstantError::SrcNeedsToBeFinite);
assert_eq!(map_value_to_literal(value, scalar), res);
}
}
// i32
assert_eq!(
map_value_to_literal(f64::from(i32::MIN), Scalar::I32),
Ok(Literal::I32(i32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX), Scalar::I32),
Ok(Literal::I32(i32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(i32::MIN) - 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(i32::MAX) + 1.0, Scalar::I32),
Err(PipelineConstantError::DstRangeTooSmall)
);
// u32
assert_eq!(
map_value_to_literal(f64::from(u32::MIN), Scalar::U32),
Ok(Literal::U32(u32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX), Scalar::U32),
Ok(Literal::U32(u32::MAX))
);
assert_eq!(
map_value_to_literal(f64::from(u32::MIN) - 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from(u32::MAX) + 1.0, Scalar::U32),
Err(PipelineConstantError::DstRangeTooSmall)
);
// f32
assert_eq!(
map_value_to_literal(f64::from(f32::MIN), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from(f32::MAX), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MIN))
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47efffffefffffff), Scalar::F32),
Ok(Literal::F32(f32::MAX))
);
assert_eq!(
map_value_to_literal(-f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);
assert_eq!(
map_value_to_literal(f64::from_bits(0x47effffff0000000), Scalar::F32),
Err(PipelineConstantError::DstRangeTooSmall)
);
// f64
assert_eq!(
map_value_to_literal(f64::MIN, Scalar::F64),
Ok(Literal::F64(f64::MIN))
);
assert_eq!(
map_value_to_literal(f64::MAX, Scalar::F64),
Ok(Literal::F64(f64::MAX))
);
}

View File

@ -70,6 +70,8 @@ pub enum Error {
FeatureNotImplemented(&'static str),
#[error("module is not validated properly: {0}")]
Validation(&'static str),
#[error(transparent)]
PipelineConstant(#[from] crate::back::pipeline_constants::PipelineConstantError),
}
#[derive(Default)]

View File

@ -2029,6 +2029,16 @@ impl Writer {
debug_info: &Option<DebugInfo>,
words: &mut Vec<Word>,
) -> Result<(), Error> {
let ir_module = if let Some(pipeline_options) = pipeline_options {
crate::back::pipeline_constants::process_overrides(
ir_module,
&pipeline_options.constants,
)?
} else {
std::borrow::Cow::Borrowed(ir_module)
};
let ir_module = ir_module.as_ref();
self.reset();
// Try to find the entry point and corresponding index

View File

@ -106,6 +106,12 @@ impl<W: Write> Writer<W> {
}
pub fn write(&mut self, module: &Module, info: &valid::ModuleInfo) -> BackendResult {
if !module.overrides.is_empty() {
return Err(Error::Unimplemented(
"Pipeline constants are not yet supported for this back-end".to_string(),
));
}
self.reset(module);
// Save all ep result types

View File

@ -359,6 +359,7 @@ impl ExpressionConstnessTracker {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ConstantEvaluatorError {
#[error("Constants cannot access function arguments")]
FunctionArg,

View File

@ -186,6 +186,8 @@ pub enum ConstantError {
#[derive(Clone, Debug, thiserror::Error)]
pub enum OverrideError {
#[error("Override name and ID are missing")]
MissingNameAndID,
#[error("The type doesn't match the override")]
InvalidType,
#[error("The type is not constructible")]
@ -353,6 +355,10 @@ impl Validator {
) -> Result<(), OverrideError> {
let o = &gctx.overrides[handle];
if o.name.is_none() && o.id.is_none() {
return Err(OverrideError::MissingNameAndID);
}
let type_info = &self.types[o.ty.index()];
if !type_info.flags.contains(TypeFlags::CONSTRUCTIBLE) {
return Err(OverrideError::NonConstructibleType);
@ -360,7 +366,14 @@ impl Validator {
let decl_ty = &gctx.types[o.ty].inner;
match decl_ty {
&crate::TypeInner::Scalar(_) => {}
&crate::TypeInner::Scalar(scalar) => match scalar {
crate::Scalar::BOOL
| crate::Scalar::I32
| crate::Scalar::U32
| crate::Scalar::F32
| crate::Scalar::F64 => {}
_ => return Err(OverrideError::TypeNotScalar),
},
_ => return Err(OverrideError::TypeNotScalar),
}

View File

@ -0,0 +1,11 @@
(
spv: (
version: (1, 0),
separate_entry_points: true,
),
pipeline_constants: {
"0": NaN,
"1300": 1.1,
"depth": 2.3,
}
)

View File

@ -12,3 +12,6 @@
// overridable constant.
override inferred_f32 = 2.718;
@compute @workgroup_size(1)
fn main() {}

View File

@ -4,7 +4,22 @@
("DATA | SIZED | COPY | IO_SHAREABLE | HOST_SHAREABLE | ARGUMENT | CONSTRUCTIBLE"),
],
functions: [],
entry_points: [],
entry_points: [
(
flags: ("EXPRESSIONS | BLOCKS | CONTROL_FLOW_UNIFORMITY | STRUCT_LAYOUTS | CONSTANTS | BINDINGS"),
available_stages: ("VERTEX | FRAGMENT | COMPUTE"),
uniformity: (
non_uniform_result: None,
requirements: (""),
),
may_kill: false,
sampling_set: [],
global_uses: [],
expressions: [],
sampling: [],
dual_source_blending: false,
),
],
const_expression_types: [
Value(Scalar((
kind: Bool,

View File

@ -0,0 +1,12 @@
static const bool has_point_light = false;
static const float specular_param = 2.3;
static const float gain = 1.1;
static const float width = 0.0;
static const float depth = 2.3;
static const float inferred_f32_ = 2.718;
[numthreads(1, 1, 1)]
void main()
{
return;
}

View File

@ -0,0 +1,12 @@
(
vertex:[
],
fragment:[
],
compute:[
(
entry_point:"main",
target_profile:"cs_5_1",
),
],
)

View File

@ -67,5 +67,25 @@
Literal(F32(2.718)),
],
functions: [],
entry_points: [],
entry_points: [
(
name: "main",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
body: [
Return(
value: None,
),
],
),
),
],
)

View File

@ -67,5 +67,25 @@
Literal(F32(2.718)),
],
functions: [],
entry_points: [],
entry_points: [
(
name: "main",
stage: Compute,
early_depth_test: None,
workgroup_size: (1, 1, 1),
function: (
name: Some("main"),
arguments: [],
result: None,
local_variables: [],
expressions: [],
named_expressions: {},
body: [
Return(
value: None,
),
],
),
),
],
)

View File

@ -0,0 +1,17 @@
// language: metal1.0
#include <metal_stdlib>
#include <simd/simd.h>
using metal::uint;
constant bool has_point_light = false;
constant float specular_param = 2.3;
constant float gain = 1.1;
constant float width = 0.0;
constant float depth = 2.3;
constant float inferred_f32_ = 2.718;
kernel void main_(
) {
return;
}

View File

@ -0,0 +1,25 @@
; SPIR-V
; Version: 1.0
; Generator: rspirv
; Bound: 15
OpCapability Shader
%1 = OpExtInstImport "GLSL.std.450"
OpMemoryModel Logical GLSL450
OpEntryPoint GLCompute %12 "main"
OpExecutionMode %12 LocalSize 1 1 1
%2 = OpTypeVoid
%3 = OpTypeBool
%4 = OpTypeFloat 32
%5 = OpConstantTrue %3
%6 = OpConstant %4 2.3
%7 = OpConstant %4 0.0
%8 = OpConstant %4 2.718
%9 = OpConstantFalse %3
%10 = OpConstant %4 1.1
%13 = OpTypeFunction %2
%12 = OpFunction %2 None %13
%11 = OpLabel
OpBranch %14
%14 = OpLabel
OpReturn
OpFunctionEnd

View File

@ -87,6 +87,17 @@ struct Parameters {
#[cfg(all(feature = "deserialize", feature = "glsl-out"))]
#[serde(default)]
glsl_multiview: Option<std::num::NonZeroU32>,
#[cfg(all(
feature = "deserialize",
any(
feature = "hlsl-out",
feature = "msl-out",
feature = "spv-out",
feature = "glsl-out"
)
))]
#[serde(default)]
pipeline_constants: naga::back::PipelineConstants,
}
/// Information about a shader input file.
@ -331,18 +342,25 @@ fn check_targets(
debug_info,
&params.spv,
params.bounds_check_policies,
&params.pipeline_constants,
);
}
}
#[cfg(all(feature = "deserialize", feature = "msl-out"))]
{
if targets.contains(Targets::METAL) {
if !params.msl_pipeline.constants.is_empty() {
panic!("Supply pipeline constants via pipeline_constants instead of msl_pipeline.constants!");
}
let mut pipeline_options = params.msl_pipeline.clone();
pipeline_options.constants = params.pipeline_constants.clone();
write_output_msl(
input,
module,
&info,
&params.msl,
&params.msl_pipeline,
&pipeline_options,
params.bounds_check_policies,
);
}
@ -363,6 +381,7 @@ fn check_targets(
&params.glsl,
params.bounds_check_policies,
params.glsl_multiview,
&params.pipeline_constants,
);
}
}
@ -377,7 +396,13 @@ fn check_targets(
#[cfg(all(feature = "deserialize", feature = "hlsl-out"))]
{
if targets.contains(Targets::HLSL) {
write_output_hlsl(input, module, &info, &params.hlsl);
write_output_hlsl(
input,
module,
&info,
&params.hlsl,
&params.pipeline_constants,
);
}
}
#[cfg(all(feature = "deserialize", feature = "wgsl-out"))]
@ -396,6 +421,7 @@ fn write_output_spv(
debug_info: Option<naga::back::spv::DebugInfo>,
params: &SpirvOutParameters,
bounds_check_policies: naga::proc::BoundsCheckPolicies,
pipeline_constants: &naga::back::PipelineConstants,
) {
use naga::back::spv;
use rspirv::binary::Disassemble;
@ -428,7 +454,7 @@ fn write_output_spv(
let pipeline_options = spv::PipelineOptions {
entry_point: ep.name.clone(),
shader_stage: ep.stage,
constants: naga::back::PipelineConstants::default(),
constants: pipeline_constants.clone(),
};
write_output_spv_inner(
input,
@ -508,6 +534,7 @@ fn write_output_glsl(
options: &naga::back::glsl::Options,
bounds_check_policies: naga::proc::BoundsCheckPolicies,
multiview: Option<std::num::NonZeroU32>,
pipeline_constants: &naga::back::PipelineConstants,
) {
use naga::back::glsl;
@ -517,7 +544,7 @@ fn write_output_glsl(
shader_stage: stage,
entry_point: ep_name.to_string(),
multiview,
constants: naga::back::PipelineConstants::default(),
constants: pipeline_constants.clone(),
};
let mut buffer = String::new();
@ -542,6 +569,7 @@ fn write_output_hlsl(
module: &naga::Module,
info: &naga::valid::ModuleInfo,
options: &naga::back::hlsl::Options,
pipeline_constants: &naga::back::PipelineConstants,
) {
use naga::back::hlsl;
use std::fmt::Write as _;
@ -551,7 +579,13 @@ fn write_output_hlsl(
let mut buffer = String::new();
let mut writer = hlsl::Writer::new(&mut buffer, options);
let reflection_info = writer
.write(module, info, &hlsl::PipelineOptions::default())
.write(
module,
info,
&hlsl::PipelineOptions {
constants: pipeline_constants.clone(),
},
)
.expect("HLSL write failed");
input.write_output_file("hlsl", "hlsl", buffer);
@ -817,11 +851,7 @@ fn convert_wgsl() {
),
(
"overrides",
Targets::IR | Targets::ANALYSIS, // | Targets::SPIRV
// | Targets::METAL
// | Targets::GLSL
// | Targets::HLSL
// | Targets::WGSL,
Targets::IR | Targets::ANALYSIS | Targets::SPIRV | Targets::METAL | Targets::HLSL,
),
];

View File

@ -1588,6 +1588,7 @@ impl crate::Device for super::Device {
.shared
.workarounds
.contains(super::Workarounds::SEPARATE_ENTRY_POINTS)
|| !naga_shader.module.overrides.is_empty()
{
return Ok(super::ShaderModule::Intermediate {
naga_shader,