Merge FunctionAnalysisError into FunctionError

This commit is contained in:
Dzmitry Malyshau 2021-03-22 00:45:11 -04:00
parent 1b3e7294fd
commit 1d3f2bbdb1
6 changed files with 88 additions and 111 deletions

View File

@ -1516,7 +1516,9 @@ fn test_stack_size() {
});
let _ = module.functions.append(fun);
// analyse the module
let info = ModuleInfo::new(&module, ValidationFlags::empty()).unwrap();
let info = crate::valid::Validator::new(ValidationFlags::empty())
.validate(&module)
.unwrap();
// process the module
let mut writer = Writer::new(String::new());
writer.write(&module, &info, &Default::default()).unwrap();

View File

@ -6,7 +6,7 @@ Figures out the following properties:
- expression reference counts
!*/
use super::ValidationFlags;
use super::{CallError, FunctionError, ModuleInfo, ValidationFlags};
use crate::arena::{Arena, Handle};
use std::ops;
@ -216,32 +216,6 @@ pub enum UniformityDisruptor {
Discard,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum FunctionAnalysisError {
#[error("Expression {0:?} is not a global variable!")]
ExpectedGlobalVariable(crate::Expression),
#[error("Called function {0:?} that hasn't been declared in the IR yet")]
ForwardCall(Handle<crate::Function>),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
NonUniformControlFlow(
UniformityRequirements,
Handle<crate::Expression>,
UniformityDisruptor,
),
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum AnalysisError {
#[error("Function {0:?} analysis failed")]
Function(Handle<crate::Function>, #[source] FunctionAnalysisError),
#[error("Entry point {0:?}/'{1}' function analysis failed")]
EntryPoint(crate::ShaderStage, String, #[source] FunctionAnalysisError),
}
impl FunctionInfo {
/// Adds a value-type reference to an expression.
#[must_use]
@ -314,7 +288,7 @@ impl FunctionInfo {
arguments: &[crate::FunctionArgument],
global_var_arena: &Arena<crate::GlobalVariable>,
other_functions: &[FunctionInfo],
) -> Result<(), FunctionAnalysisError> {
) -> Result<(), FunctionError> {
use crate::{Expression as E, SampleLevel as Sl};
let mut assignable_global = None;
@ -404,17 +378,13 @@ impl FunctionInfo {
image: match expression_arena[image] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionAnalysisError::ExpectedGlobalVariable(
other.clone(),
))
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
},
sampler: match expression_arena[sampler] {
crate::Expression::GlobalVariable(var) => var,
ref other => {
return Err(FunctionAnalysisError::ExpectedGlobalVariable(
other.clone(),
))
return Err(FunctionError::ExpectedGlobalVariable(other.clone()))
}
},
});
@ -512,9 +482,13 @@ impl FunctionInfo {
requirements: UniformityRequirements::empty(),
},
E::Call(function) => {
let fun = other_functions
let fun =
other_functions
.get(function.index())
.ok_or(FunctionAnalysisError::ForwardCall(function))?;
.ok_or(FunctionError::InvalidCall {
function,
error: CallError::ForwardDeclaredFunction,
})?;
self.process_call(fun).result
}
E::ArrayLength(expr) => Uniformity {
@ -546,7 +520,7 @@ impl FunctionInfo {
statements: &[crate::Statement],
other_functions: &[FunctionInfo],
mut disruptor: Option<UniformityDisruptor>,
) -> Result<FunctionUniformity, FunctionAnalysisError> {
) -> Result<FunctionUniformity, FunctionError> {
use crate::Statement as S;
let mut combined_uniformity = FunctionUniformity::new();
@ -562,9 +536,7 @@ impl FunctionInfo {
&& !req.is_empty()
{
if let Some(cause) = disruptor {
return Err(FunctionAnalysisError::NonUniformControlFlow(
req, expr, cause,
));
return Err(FunctionError::NonUniformControlFlow(req, expr, cause));
}
}
requirements |= req;
@ -670,9 +642,12 @@ impl FunctionInfo {
for &argument in arguments {
let _ = self.add_ref(argument);
}
let info = other_functions
.get(function.index())
.ok_or(FunctionAnalysisError::ForwardCall(function))?;
let info = other_functions.get(function.index()).ok_or(
FunctionError::InvalidCall {
function,
error: CallError::ForwardDeclaredFunction,
},
)?;
self.process_call(info)
}
};
@ -684,22 +659,15 @@ impl FunctionInfo {
}
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
functions: Vec<FunctionInfo>,
entry_points: Vec<FunctionInfo>,
}
impl ModuleInfo {
/// Builds the `FunctionInfo` based on the function, and validates the
/// uniform control flow if required by the expressions of this function.
fn process_function(
pub(super) fn process_function(
&self,
fun: &crate::Function,
global_var_arena: &Arena<crate::GlobalVariable>,
flags: ValidationFlags,
) -> Result<FunctionInfo, FunctionAnalysisError> {
) -> Result<FunctionInfo, FunctionError> {
let mut info = FunctionInfo {
flags,
uniformity: Uniformity::new(),
@ -726,29 +694,6 @@ impl ModuleInfo {
Ok(info)
}
/// Analyze a module and return the `ModuleInfo`, if successful.
pub fn new(module: &crate::Module, flags: ValidationFlags) -> Result<Self, AnalysisError> {
let mut this = ModuleInfo {
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
};
for (fun_handle, fun) in module.functions.iter() {
let info = this
.process_function(fun, &module.global_variables, flags)
.map_err(|source| AnalysisError::Function(fun_handle, source))?;
this.functions.push(info);
}
for ep in module.entry_points.iter() {
let info = this
.process_function(&ep.function, &module.global_variables, flags)
.map_err(|source| AnalysisError::EntryPoint(ep.stage, ep.name.clone(), source))?;
this.entry_points.push(info);
}
Ok(this)
}
pub fn get_entry_point(&self, index: usize) -> &FunctionInfo {
&self.entry_points[index]
}
@ -880,7 +825,7 @@ fn uniform_control_flow() {
};
assert_eq!(
info.process_block(&[stmt_emit2, stmt_if_non_uniform], &[], None),
Err(FunctionAnalysisError::NonUniformControlFlow(
Err(FunctionError::NonUniformControlFlow(
UniformityRequirements::DERIVATIVE,
derivative_expr,
UniformityDisruptor::Expression(non_uniform_global_expr)

View File

@ -4,6 +4,7 @@ use crate::{
};
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum ExpressionError {
#[error("Doesn't exist")]
DoesntExist,

View File

@ -1,10 +1,14 @@
use super::{analyzer::FunctionInfo, ExpressionError, TypeFlags, ValidationFlags};
use super::{
analyzer::{FunctionInfo, UniformityDisruptor, UniformityRequirements},
ExpressionError, ModuleInfo, TypeFlags, ValidationFlags,
};
use crate::{
arena::{Arena, Handle},
proc::{ResolveContext, TypifyError},
};
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum CallError {
#[error("Bad function")]
InvalidFunction,
@ -36,12 +40,14 @@ pub enum CallError {
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum LocalVariableError {
#[error("Initializer doesn't match the variable type")]
InitializerType,
}
#[derive(Clone, Debug, thiserror::Error)]
#[cfg_attr(test, derive(PartialEq))]
pub enum FunctionError {
#[error(transparent)]
Resolve(#[from] TypifyError),
@ -97,6 +103,16 @@ pub enum FunctionError {
#[source]
error: CallError,
},
#[error("Expression {0:?} is not a global variable!")]
ExpectedGlobalVariable(crate::Expression),
#[error(
"Required uniformity of control flow for {0:?} in {1:?} is not fulfilled because of {2:?}"
)]
NonUniformControlFlow(
UniformityRequirements,
Handle<crate::Expression>,
UniformityDisruptor,
),
}
bitflags::bitflags! {
@ -464,9 +480,9 @@ impl super::Validator {
pub(super) fn validate_function(
&mut self,
fun: &crate::Function,
_info: &FunctionInfo,
module: &crate::Module,
) -> Result<(), FunctionError> {
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, FunctionError> {
let resolve_ctx = ResolveContext {
constants: &module.constants,
global_vars: &module.global_variables,
@ -476,6 +492,7 @@ impl super::Validator {
};
self.typifier
.resolve_all(&fun.expressions, &module.types, &resolve_ctx)?;
let info = mod_info.process_function(fun, &module.global_variables, self.flags)?;
for (var_handle, var) in fun.local_variables.iter() {
self.validate_local_var(var, &module.types, &module.constants)
@ -511,9 +528,8 @@ impl super::Validator {
}
if self.flags.contains(ValidationFlags::BLOCKS) {
self.validate_block(&fun.body, &BlockContext::new(fun, module))
} else {
Ok(())
}
self.validate_block(&fun.body, &BlockContext::new(fun, module))?;
}
Ok(info)
}
}

View File

@ -1,6 +1,6 @@
use super::{
analyzer::{FunctionInfo, GlobalUse},
Disalignment, FunctionError, TypeFlags,
Disalignment, FunctionError, ModuleInfo, TypeFlags,
};
use crate::arena::{Arena, Handle};
@ -352,9 +352,9 @@ impl super::Validator {
pub(super) fn validate_entry_point(
&mut self,
ep: &crate::EntryPoint,
info: &FunctionInfo,
module: &crate::Module,
) -> Result<(), EntryPointError> {
mod_info: &ModuleInfo,
) -> Result<FunctionInfo, EntryPointError> {
if ep.early_depth_test.is_some() && ep.stage != crate::ShaderStage::Fragment {
return Err(EntryPointError::UnexpectedEarlyDepthTest);
}
@ -370,6 +370,8 @@ impl super::Validator {
return Err(EntryPointError::UnexpectedWorkgroupSize);
}
let info = self.validate_function(&ep.function, module, &mod_info)?;
self.location_mask.clear();
for (index, fa) in ep.function.arguments.iter().enumerate() {
let ctx = VaryingContext {
@ -439,7 +441,6 @@ impl super::Validator {
}
}
self.validate_function(&ep.function, info, module)?;
Ok(())
Ok(info)
}
}

View File

@ -13,10 +13,7 @@ use bit_set::BitSet;
//TODO: analyze the model at the same time as we validate it,
// merge the corresponding matches over expressions and statements.
pub use analyzer::{
AnalysisError, ExpressionInfo, FunctionInfo, GlobalUse, ModuleInfo, Uniformity,
UniformityRequirements,
};
pub use analyzer::{ExpressionInfo, FunctionInfo, GlobalUse, Uniformity, UniformityRequirements};
pub use expression::ExpressionError;
pub use function::{CallError, FunctionError, LocalVariableError};
pub use interface::{EntryPointError, GlobalVariableError, VaryingError};
@ -33,6 +30,13 @@ bitflags::bitflags! {
}
}
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {
functions: Vec<FunctionInfo>,
entry_points: Vec<FunctionInfo>,
}
#[derive(Debug)]
pub struct Validator {
flags: ValidationFlags,
@ -94,8 +98,6 @@ pub enum ValidationError {
#[source]
error: EntryPointError,
},
#[error(transparent)]
Analysis(#[from] AnalysisError),
#[error("Module is corrupted")]
Corrupted,
}
@ -176,8 +178,10 @@ impl Validator {
pub fn validate(&mut self, module: &crate::Module) -> Result<ModuleInfo, ValidationError> {
self.reset_types(module.types.len());
let mod_info = ModuleInfo::new(module, self.flags)?;
let mut mod_info = ModuleInfo {
functions: Vec::with_capacity(module.functions.len()),
entry_points: Vec::with_capacity(module.entry_points.len()),
};
let layouter = Layouter::new(&module.types, &module.constants);
for (handle, constant) in module.constants.iter() {
@ -211,16 +215,20 @@ impl Validator {
}
for (handle, fun) in module.functions.iter() {
self.validate_function(fun, &mod_info[handle], module)
.map_err(|error| ValidationError::Function {
match self.validate_function(fun, module, &mod_info) {
Ok(info) => mod_info.functions.push(info),
Err(error) => {
return Err(ValidationError::Function {
handle,
name: fun.name.clone().unwrap_or_default(),
error,
})?;
})
}
}
}
let mut ep_map = FastHashSet::default();
for (index, ep) in module.entry_points.iter().enumerate() {
for ep in module.entry_points.iter() {
if !ep_map.insert((ep.stage, &ep.name)) {
return Err(ValidationError::EntryPoint {
stage: ep.stage,
@ -228,13 +236,17 @@ impl Validator {
error: EntryPointError::Conflict,
});
}
let info = mod_info.get_entry_point(index);
self.validate_entry_point(ep, info, module)
.map_err(|error| ValidationError::EntryPoint {
match self.validate_entry_point(ep, module, &mod_info) {
Ok(info) => mod_info.entry_points.push(info),
Err(error) => {
return Err(ValidationError::EntryPoint {
stage: ep.stage,
name: ep.name.clone(),
error,
})?;
})
}
}
}
Ok(mod_info)