First bits of expression validation

This commit is contained in:
Dzmitry Malyshau 2021-03-18 01:12:06 -04:00
parent 72ede02888
commit 67ca0e7e7f

View File

@ -149,10 +149,30 @@ pub enum VaryingError {
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
pub enum ExpressionError { pub enum ExpressionError {
#[error("Is invalid")] #[error("Doesn't exist")]
Invalid, DoesntExist,
#[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")] #[error("Used by a statement before it was introduced into the scope by any of the dominating blocks")]
NotInScope, NotInScope,
#[error("Depends on {0:?}, which has not been processed yet")]
ForwardDependency(Handle<crate::Expression>),
#[error("Base type {0:?} is not compatible with this expression")]
InvalidBaseType(Handle<crate::Expression>),
#[error("Accessing with index {0:?} can't be done")]
InvalidIndexType(Handle<crate::Expression>),
#[error("Accessing index {1} is out of {0:?} bounds")]
IndexOutOfBounds(Handle<crate::Expression>, u32),
#[error("Function argument {0:?} doesn't exist")]
FunctionArgumentDoesntExist(u32),
#[error("Constant {0:?} doesn't exist")]
ConstantDoesntExist(Handle<crate::Constant>),
#[error("Global variable {0:?} doesn't exist")]
GlobalVarDoesntExist(Handle<crate::GlobalVariable>),
#[error("Local variable {0:?} doesn't exist")]
LocalVarDoesntExist(Handle<crate::LocalVariable>),
#[error("Loading of {0:?} can't be done")]
InvalidPointerType(Handle<crate::Expression>),
#[error("Array length of {0:?} can't be done")]
InvalidArrayType(Handle<crate::Expression>),
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
@ -499,6 +519,25 @@ impl VaryingContext<'_> {
} }
} }
struct ExpressionTypeResolver<'a> {
root: Handle<crate::Expression>,
types: &'a Arena<crate::Type>,
typifier: &'a Typifier,
}
impl<'a> ExpressionTypeResolver<'a> {
fn resolve(
&self,
handle: Handle<crate::Expression>,
) -> Result<&'a crate::TypeInner, ExpressionError> {
if handle < self.root {
Ok(self.typifier.get(handle, self.types))
} else {
Err(ExpressionError::ForwardDependency(handle))
}
}
}
impl Validator { impl Validator {
/// Construct a new validator instance. /// Construct a new validator instance.
pub fn new() -> Self { pub fn new() -> Self {
@ -793,7 +832,7 @@ impl Validator {
} }
for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() { for (index, (arg, &expr)) in fun.arguments.iter().zip(arguments).enumerate() {
let ty = self let ty = self
.resolve_type_impl(expr, context.types) .resolve_statement_type_impl(expr, context.types)
.map_err(|error| CallError::Argument { index, error })?; .map_err(|error| CallError::Argument { index, error })?;
if ty != &context.types[arg.ty].inner { if ty != &context.types[arg.ty].inner {
return Err(CallError::ArgumentType { return Err(CallError::ArgumentType {
@ -813,7 +852,7 @@ impl Validator {
} }
let result_ty = result let result_ty = result
.map(|expr| self.resolve_type_impl(expr, context.types)) .map(|expr| self.resolve_statement_type_impl(expr, context.types))
.transpose() .transpose()
.map_err(CallError::ResultValue)?; .map_err(CallError::ResultValue)?;
let expected_ty = fun.result.as_ref().map(|fr| &context.types[fr.ty].inner); let expected_ty = fun.result.as_ref().map(|fr| &context.types[fr.ty].inner);
@ -831,7 +870,174 @@ impl Validator {
Ok(()) Ok(())
} }
fn resolve_type_impl<'a>( #[allow(unused)]
fn validate_expression(
&self,
root: Handle<crate::Expression>,
expression: &crate::Expression,
function: &crate::Function,
stage: Option<crate::ShaderStage>,
module: &crate::Module,
) -> Result<(), ExpressionError> {
use crate::{Expression as E, TypeInner as Ti};
use std::convert::TryInto;
let resolver = ExpressionTypeResolver {
root,
types: &module.types,
typifier: &self.typifier,
};
match *expression {
E::Access { base, index } => {
match *resolver.resolve(base)? {
Ti::Vector { .. }
| Ti::Matrix { .. }
| Ti::Array { .. }
| Ti::Pointer { .. } => {}
ref other => {
log::error!("Indexing of {:?}", other);
return Err(ExpressionError::InvalidBaseType(base));
}
}
match *resolver.resolve(index)? {
//TODO: only allow one of these
Ti::Scalar {
kind: crate::ScalarKind::Sint,
width: _,
}
| Ti::Scalar {
kind: crate::ScalarKind::Uint,
width: _,
} => {}
ref other => {
log::error!("Indexing by {:?}", other);
return Err(ExpressionError::InvalidIndexType(index));
}
}
}
E::AccessIndex { base, index } => {
let limit = match *resolver.resolve(base)? {
Ti::Vector { size, .. } => size as u32,
Ti::Matrix { columns, .. } => columns as u32,
Ti::Array {
size: crate::ArraySize::Constant(handle),
..
} => {
match module
.constants
.try_get(handle)
.ok_or(ExpressionError::ConstantDoesntExist(handle))?
.inner
{
crate::ConstantInner::Scalar { value, width: _ } => match value {
crate::ScalarValue::Sint(value) => {
value.max(0).try_into().unwrap_or(!0)
}
crate::ScalarValue::Uint(value) => value.try_into().unwrap_or(!0),
_ => unreachable!(),
},
// caught by type validation
crate::ConstantInner::Composite { .. } => unreachable!(),
}
}
Ti::Array { .. } => !0, // can't statically know, but need run-time checks
Ti::Pointer { .. } => !0, //TODO
Ti::Struct {
ref members,
block: _,
} => members.len() as u32,
ref other => {
log::error!("Indexing of {:?}", other);
return Err(ExpressionError::InvalidBaseType(base));
}
};
if index >= limit {
return Err(ExpressionError::IndexOutOfBounds(base, index));
}
}
E::Constant(handle) => {
let _ = module
.constants
.try_get(handle)
.ok_or(ExpressionError::ConstantDoesntExist(handle))?;
}
E::Compose { ref components, ty } => {
//TODO
}
E::FunctionArgument(index) => {
if index >= function.arguments.len() as u32 {
return Err(ExpressionError::FunctionArgumentDoesntExist(index));
}
}
E::GlobalVariable(handle) => {
let _ = module
.global_variables
.try_get(handle)
.ok_or(ExpressionError::GlobalVarDoesntExist(handle))?;
}
E::LocalVariable(handle) => {
let _ = function
.local_variables
.try_get(handle)
.ok_or(ExpressionError::LocalVarDoesntExist(handle))?;
}
E::Load { pointer } => match *resolver.resolve(pointer)? {
Ti::Pointer { .. } | Ti::ValuePointer { .. } => {}
ref other => {
log::error!("Loading {:?}", other);
return Err(ExpressionError::InvalidPointerType(pointer));
}
},
E::ImageSample {
image,
sampler,
coordinate,
array_index,
offset,
level,
depth_ref,
} => {}
E::ImageLoad {
image,
coordinate,
array_index,
index,
} => {}
E::ImageQuery { image, query } => {}
E::Unary { op, expr } => {}
E::Binary { op, left, right } => {}
E::Select {
condition,
accept,
reject,
} => {}
E::Derivative { axis, expr } => {}
E::Relational { argument, .. } => {}
E::Math {
fun,
arg,
arg1,
arg2,
} => {}
E::As {
expr,
kind,
convert,
} => {}
E::Call(function) => {}
E::ArrayLength(expr) => match *resolver.resolve(expr)? {
Ti::Array { .. } => {}
ref other => {
log::error!("Array length of {:?}", other);
return Err(ExpressionError::InvalidArrayType(expr));
}
},
}
Ok(())
}
fn resolve_statement_type_impl<'a>(
&'a self, &'a self,
handle: Handle<crate::Expression>, handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>, types: &'a Arena<crate::Type>,
@ -841,15 +1047,15 @@ impl Validator {
} }
self.typifier self.typifier
.try_get(handle, types) .try_get(handle, types)
.ok_or(ExpressionError::Invalid) .ok_or(ExpressionError::DoesntExist)
} }
fn resolve_type<'a>( fn resolve_statement_type<'a>(
&'a self, &'a self,
handle: Handle<crate::Expression>, handle: Handle<crate::Expression>,
types: &'a Arena<crate::Type>, types: &'a Arena<crate::Type>,
) -> Result<&'a crate::TypeInner, FunctionError> { ) -> Result<&'a crate::TypeInner, FunctionError> {
self.resolve_type_impl(handle, types) self.resolve_statement_type_impl(handle, types)
.map_err(|error| FunctionError::Expression { handle, error }) .map_err(|error| FunctionError::Expression { handle, error })
} }
@ -880,7 +1086,7 @@ impl Validator {
ref accept, ref accept,
ref reject, ref reject,
} => { } => {
match *self.resolve_type(condition, context.types)? { match *self.resolve_statement_type(condition, context.types)? {
Ti::Scalar { Ti::Scalar {
kind: crate::ScalarKind::Bool, kind: crate::ScalarKind::Bool,
width: _, width: _,
@ -895,7 +1101,7 @@ impl Validator {
ref cases, ref cases,
ref default, ref default,
} => { } => {
match *self.resolve_type(selector, context.types)? { match *self.resolve_statement_type(selector, context.types)? {
Ti::Scalar { Ti::Scalar {
kind: crate::ScalarKind::Sint, kind: crate::ScalarKind::Sint,
width: _, width: _,
@ -940,7 +1146,7 @@ impl Validator {
return Err(FunctionError::InvalidReturnSpot); return Err(FunctionError::InvalidReturnSpot);
} }
let value_ty = value let value_ty = value
.map(|expr| self.resolve_type(expr, context.types)) .map(|expr| self.resolve_statement_type(expr, context.types))
.transpose()?; .transpose()?;
let expected_ty = context.return_type.map(|ty| &context.types[ty].inner); let expected_ty = context.return_type.map(|ty| &context.types[ty].inner);
if value_ty != expected_ty { if value_ty != expected_ty {
@ -962,7 +1168,7 @@ impl Validator {
self.typifier.try_get(current, context.types).ok_or( self.typifier.try_get(current, context.types).ok_or(
FunctionError::Expression { FunctionError::Expression {
handle: current, handle: current,
error: ExpressionError::Invalid, error: ExpressionError::DoesntExist,
}, },
)?; )?;
match context.expressions[current] { match context.expressions[current] {
@ -975,7 +1181,7 @@ impl Validator {
} }
} }
let value_ty = self.resolve_type(value, context.types)?; let value_ty = self.resolve_statement_type(value, context.types)?;
match *value_ty { match *value_ty {
Ti::Image { .. } | Ti::Sampler { .. } => { Ti::Image { .. } | Ti::Sampler { .. } => {
return Err(FunctionError::InvalidStoreValue(value)); return Err(FunctionError::InvalidStoreValue(value));
@ -1063,6 +1269,7 @@ impl Validator {
fun: &crate::Function, fun: &crate::Function,
_info: &FunctionInfo, _info: &FunctionInfo,
module: &crate::Module, module: &crate::Module,
stage: Option<crate::ShaderStage>,
) -> Result<(), FunctionError> { ) -> Result<(), FunctionError> {
let resolve_ctx = ResolveContext { let resolve_ctx = ResolveContext {
constants: &module.constants, constants: &module.constants,
@ -1097,6 +1304,9 @@ impl Validator {
if expr.needs_pre_emit() { if expr.needs_pre_emit() {
self.valid_expression_set.insert(handle.index()); self.valid_expression_set.insert(handle.index());
} }
if let Err(error) = self.validate_expression(handle, expr, fun, stage, module) {
return Err(FunctionError::Expression { handle, error });
}
} }
self.validate_block( self.validate_block(
@ -1201,7 +1411,7 @@ impl Validator {
} }
} }
self.validate_function(&ep.function, info, module)?; self.validate_function(&ep.function, info, module, Some(ep.stage))?;
Ok(()) Ok(())
} }
@ -1245,7 +1455,7 @@ impl Validator {
} }
for (handle, fun) in module.functions.iter() { for (handle, fun) in module.functions.iter() {
self.validate_function(fun, &analysis[handle], module) self.validate_function(fun, &analysis[handle], module, None)
.map_err(|error| ValidationError::Function { .map_err(|error| ValidationError::Function {
handle, handle,
name: fun.name.clone().unwrap_or_default(), name: fun.name.clone().unwrap_or_default(),