From c16f2097add7705f2439d45fa4334efbf495ff14 Mon Sep 17 00:00:00 2001 From: Jim Blandy Date: Thu, 27 May 2021 15:02:38 -0700 Subject: [PATCH] [spv-out]: Ensure array subscripts are in bounds. --- cli/src/main.rs | 21 ++ src/back/mod.rs | 43 +++ src/back/spv/index.rs | 509 +++++++++++++++++++++++++++ src/back/spv/instructions.rs | 15 + src/back/spv/mod.rs | 26 +- src/back/spv/writer.rs | 291 ++++++++++----- src/proc/index.rs | 79 +++++ src/proc/mod.rs | 24 +- src/valid/expression.rs | 46 ++- src/valid/type.rs | 16 +- tests/in/bounds-check-zero.param.ron | 4 + tests/in/bounds-check-zero.wgsl | 46 +++ tests/in/pointer-access.param.ron | 5 + tests/in/pointer-access.spv | Bin 0 -> 664 bytes tests/in/pointer-access.spvasm | 57 +++ tests/out/bounds-check-zero.spvasm | 212 +++++++++++ tests/out/pointer-access.spvasm | 55 +++ tests/snapshots.rs | 27 ++ tests/wgsl-errors.rs | 43 ++- 19 files changed, 1415 insertions(+), 104 deletions(-) create mode 100644 src/back/spv/index.rs create mode 100644 src/proc/index.rs create mode 100644 tests/in/bounds-check-zero.param.ron create mode 100644 tests/in/bounds-check-zero.wgsl create mode 100644 tests/in/pointer-access.param.ron create mode 100644 tests/in/pointer-access.spv create mode 100644 tests/in/pointer-access.spvasm create mode 100644 tests/out/bounds-check-zero.spvasm create mode 100644 tests/out/pointer-access.spvasm diff --git a/cli/src/main.rs b/cli/src/main.rs index 3e04b055b..edf8f3e98 100644 --- a/cli/src/main.rs +++ b/cli/src/main.rs @@ -6,6 +6,7 @@ use std::{env, error::Error, path::Path}; #[derive(Default)] struct Parameters { validation_flags: naga::valid::ValidationFlags, + index_bounds_check_policy: naga::back::IndexBoundsCheckPolicy, spv_adjust_coordinate_space: bool, spv_flow_dump_prefix: Option, spv: naga::back::spv::Options, @@ -69,6 +70,24 @@ fn main() { params.validation_flags = naga::valid::ValidationFlags::from_bits(value).unwrap(); } + "index-bounds-check-policy" => { + let value = args.next().unwrap(); + params.index_bounds_check_policy = match value.as_str() { + "Restrict" => naga::back::IndexBoundsCheckPolicy::Restrict, + "ReadZeroSkipWrite" => { + naga::back::IndexBoundsCheckPolicy::ReadZeroSkipWrite + } + "UndefinedBehavior" => { + naga::back::IndexBoundsCheckPolicy::UndefinedBehavior + } + other => { + panic!( + "Unrecognized '--index-bounds-check-policy' value: {:?}", + other + ); + } + }; + } "flow-dir" => params.spv_flow_dump_prefix = args.next(), "entry-point" => params.glsl.entry_point = args.next().unwrap(), "profile" => { @@ -241,6 +260,8 @@ fn main() { "spv" => { use naga::back::spv; + params.spv.index_bounds_check_policy = params.index_bounds_check_policy; + let spv = spv::write_vec(&module, info.as_ref().unwrap(), ¶ms.spv).unwrap_pretty(); let bytes = spv diff --git a/src/back/mod.rs b/src/back/mod.rs index 49c7c5689..4874cadde 100644 --- a/src/back/mod.rs +++ b/src/back/mod.rs @@ -13,6 +13,49 @@ pub mod spv; #[cfg(feature = "wgsl-out")] pub mod wgsl; +/// How should code generated by Naga do indexing bounds checks? +/// +/// When a vector, matrix, or array index is out of bounds—either negative, or +/// greater than or equal to the number of elements in the type—WGSL requires +/// that some other index of the implementation's choice that is in bounds is +/// used instead. (There are no types with zero elements.) +/// +/// Different users of Naga will prefer different defaults: +/// +/// - When used as part of a WebGPU implementation, the WGSL specification +/// requires the `Restrict` behavior. +/// +/// - When used by the `wgpu` crate for native development, `wgpu` selects +/// `ReadZeroSkipWrite` as its default. +/// +/// - Naga's own default is `UndefinedBehavior`, so that shader translations +/// are as faithful to the original as possible. +#[derive(Clone, Copy, Debug)] +pub enum IndexBoundsCheckPolicy { + /// Replace out-of-bounds indexes with some arbitrary in-bounds index. + /// + /// (This does not necessarily mean clamping. For example, interpreting the + /// index as unsigned and taking the minimum with the largest valid index + /// would also be a valid implementation. That would map negative indices to + /// the last element, not the first.) + Restrict, + + /// Out-of-bounds reads return zero, and writes have no effect. + ReadZeroSkipWrite, + + /// Out-of-bounds indexes are undefined behavior. Generate the fastest code + /// possible. This is the default for Naga, as a translator, but consumers + /// should consider defaulting to a safer behavior. + UndefinedBehavior, +} + +/// The default `IndexBoundsCheckPolicy` is `UndefinedBehavior`. +impl Default for IndexBoundsCheckPolicy { + fn default() -> Self { + IndexBoundsCheckPolicy::UndefinedBehavior + } +} + impl crate::Expression { /// Returns the ref count, upon reaching which this expression /// should be considered for baking. diff --git a/src/back/spv/index.rs b/src/back/spv/index.rs new file mode 100644 index 000000000..02495bbef --- /dev/null +++ b/src/back/spv/index.rs @@ -0,0 +1,509 @@ +//! Bounds-checking for SPIR-V output. + +use super::{Block, Error, Function, IdGenerator, Instruction, Word, Writer}; +use crate::{arena::Handle, back::IndexBoundsCheckPolicy, valid::FunctionInfo}; + +/// The results of emitting code for a left-hand-side expression. +/// +/// On success, `write_expression_pointer` returns one of these. +pub(super) enum ExpressionPointer { + /// The pointer to the expression's value is available, as the value of the + /// expression with the given id. + Ready { pointer_id: Word }, + + /// The access expression must be conditional on the value of `condition`, a boolean + /// expression that is true if all indices are in bounds. If `condition` is true, then + /// `access` is an `OpAccessChain` instruction that will compute a pointer to the + /// expression's value. If `condition` is false, then executing `access` would be + /// undefined behavior. + Conditional { + condition: Word, + access: Instruction, + }, +} + +/// The results of performing a bounds check. +/// +/// On success, `write_bounds_check` returns a value of this type. +pub enum BoundsCheckResult { + /// The index is statically known and in bounds, with the given value. + KnownInBounds(u32), + + /// The given instruction computes the index to be used. + Computed(Word), + + /// The given instruction computes a boolean condition which is true + /// if the index is in bounds. + Conditional(Word), +} + +/// A value that we either know at translation time, or need to compute at runtime. +pub enum MaybeKnown { + /// The value is known at shader translation time. + Known(T), + + /// The value is computed by the instruction with the given id. + Computed(Word), +} + +impl Writer { + /// Emit code to compute the length of a run-time array. + /// + /// Given `array`, an expression referring to the final member of a struct, + /// where the member in question is a runtime-sized array, return the + /// instruction id for the array's length. + pub(super) fn write_runtime_array_length( + &mut self, + array: Handle, + ir_function: &crate::Function, + function: &Function, + block: &mut Block, + ) -> Result { + // Look into the expression to find the value and type of the struct + // holding the dynamically-sized array. + let (structure_id, last_member_index) = match ir_function.expressions[array] { + crate::Expression::AccessIndex { base, index } => match ir_function.expressions[base] { + crate::Expression::GlobalVariable(handle) => { + (self.global_variables[handle.index()].id, index) + } + crate::Expression::FunctionArgument(index) => { + let parameter_id = function.parameter_id(index); + (parameter_id, index) + } + _ => return Err(Error::Validation("array length expression")), + }, + _ => return Err(Error::Validation("array length expression")), + }; + + let length_id = self.id_gen.next(); + block.body.push(Instruction::array_length( + self.get_uint_type_id()?, + length_id, + structure_id, + last_member_index, + )); + + Ok(length_id) + } + + /// Compute the length of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its length. The result may either be computed by SPIR-V instructions, or + /// known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_length( + &mut self, + sequence: Handle, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &Function, + block: &mut Block, + ) -> Result, Error> { + let sequence_ty = fun_info[sequence].ty.inner_with(&ir_module.types); + match sequence_ty.indexable_length(ir_module)? { + crate::proc::IndexableLength::Known(known_length) => { + Ok(MaybeKnown::Known(known_length)) + } + crate::proc::IndexableLength::Dynamic => { + let length_id = + self.write_runtime_array_length(sequence, ir_function, function, block)?; + Ok(MaybeKnown::Computed(length_id)) + } + crate::proc::IndexableLength::Specializable(constant) => { + let length_id = self.constant_ids[constant.index()]; + Ok(MaybeKnown::Computed(length_id)) + } + } + } + + /// Compute the maximum valid index of a subscriptable value. + /// + /// Given `sequence`, an expression referring to some indexable type, return + /// its maximum valid index - one less than its length. The result may + /// either be computed, or known at shader translation time. + /// + /// `sequence` may be a `Vector`, `Matrix`, or `Array`, a `Pointer` to any + /// of those, or a `ValuePointer`. An array may be fixed-size, dynamically + /// sized, or use a specializable constant as its length. + fn write_sequence_max_index( + &mut self, + sequence: Handle, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &Function, + block: &mut Block, + ) -> Result, Error> { + match self.write_sequence_length( + sequence, + ir_module, + ir_function, + fun_info, + function, + block, + )? { + MaybeKnown::Known(known_length) => { + // We should have thrown out all attempts to subscript zero-length + // sequences during validation, so the following subtraction should never + // underflow. + assert!(known_length > 0); + // Compute the max index from the length now. + Ok(MaybeKnown::Known(known_length - 1)) + } + MaybeKnown::Computed(length_id) => { + // Emit code to compute the max index from the length. + let const_one_id = self.get_index_constant(1)?; + let max_index_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::ISub, + self.get_uint_type_id()?, + max_index_id, + length_id, + const_one_id, + )); + Ok(MaybeKnown::Computed(max_index_id)) + } + } + } + + /// Restrict an index to be in range for a vector, matrix, or array. + /// + /// This is used to implement `IndexBoundsCheckPolicy::Restrict`. An + /// in-bounds index is left unchanged. An out-of-bounds index is replaced + /// with some arbitrary in-bounds index. Note,this is not necessarily + /// clamping; for example, negative indices might be changed to refer to the + /// last element of the sequence, not the first, as clamping would do. + /// + /// Either return the restricted index value, if known, or add instructions + /// to `block` to compute it, and return the id of the result. See the + /// documentation for `BoundsCheckResult` for details. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_restricted_index( + &mut self, + sequence: Handle, + index: Handle, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &Function, + block: &mut Block, + ) -> Result { + let index_id = self.cached[index]; + + // Get the sequence's maximum valid index. Return early if we've already + // done the bounds check. + let max_index_id = match self.write_sequence_max_index( + sequence, + ir_module, + ir_function, + fun_info, + function, + block, + )? { + MaybeKnown::Known(known_max_index) => { + if let crate::Expression::Constant(index_k) = ir_function.expressions[index] { + if let Some(known_index) = ir_module.constants[index_k].to_array_length() { + // Both the index and length are known at compile time. + // + // In strict WGSL compliance mode, out-of-bounds indices cannot be + // reported at shader translation time, and must be replaced with + // in-bounds indices at run time. So we cannot assume that + // validation ensured the index was in bounds. Restrict now. + let restricted = std::cmp::min(known_index, known_max_index); + return Ok(BoundsCheckResult::KnownInBounds(restricted)); + } + } + + self.get_index_constant(known_max_index)? + } + MaybeKnown::Computed(max_index_id) => max_index_id, + }; + + // One or the other of the index or length is dynamic, so emit code for + // IndexBoundsCheckPolicy::Restrict. + let restricted_index_id = self.id_gen.next(); + block.body.push(Instruction::ext_inst( + self.gl450_ext_inst_id, + spirv::GLOp::UMin, + self.get_uint_type_id()?, + restricted_index_id, + &[index_id, max_index_id], + )); + Ok(BoundsCheckResult::Computed(restricted_index_id)) + } + + /// Write an index bounds comparison to `block`, if needed. + /// + /// If we're able to determine statically that `index` is in bounds for + /// `sequence`, return `KnownInBounds(value)`, where `value` is the actual + /// value of the index. (In principle, one could know that the index is in + /// bounds without knowing its specific value, but in our simple-minded + /// situation, we always know it.) + /// + /// If instead we must generate code to perform the comparison at run time, + /// return `Conditional(comparison_id)`, where `comparison_id` is an + /// instruction producing a boolean value that is true if `index` is in + /// bounds for `sequence`. + /// + /// The `sequence` expression may be a `Vector`, `Matrix`, or `Array`, a + /// `Pointer` to any of those, or a `ValuePointer`. An array may be + /// fixed-size, dynamically sized, or use a specializable constant as its + /// length. + #[allow(clippy::too_many_arguments)] + fn write_index_comparison( + &mut self, + sequence: Handle, + index: Handle, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &mut Function, + block: &mut Block, + ) -> Result { + let index_id = self.cached[index]; + + // Get the sequence's length. Return early if we've already done the + // bounds check. + let length_id = match self.write_sequence_length( + sequence, + ir_module, + ir_function, + fun_info, + function, + block, + )? { + MaybeKnown::Known(known_length) => { + if let crate::Expression::Constant(index_k) = ir_function.expressions[index] { + if let Some(known_index) = ir_module.constants[index_k].to_array_length() { + // Both the index and length are known at compile time. + // + // It would be nice to assume that, since we are using the + // `ReadZeroSkipWrite` policy, we are not in strict WGSL + // compliance mode, and thus we can count on the validator to have + // rejected any programs with known out-of-bounds indices, and + // thus just return `KnownInBounds` here without actually + // checking. + // + // But it's also reasonable to expect that bounds check policies + // and error reporting policies should be able to vary + // independently without introducing security holes. So, we should + // support the case where bad indices do not cause validation + // errors, and are handled via `ReadZeroSkipWrite`. + // + // In theory, when `known_index` is bad, we could return a new + // `KnownOutOfBounds` variant here. But it's simpler just to fall + // through and let the bounds check take place. The shader is + // broken anyway, so it doesn't make sense to invest in emitting + // the ideal code for it. + if known_index < known_length { + return Ok(BoundsCheckResult::KnownInBounds(known_index)); + } + } + } + + self.get_index_constant(known_length)? + } + MaybeKnown::Computed(length_id) => length_id, + }; + + // Compare the index against the length. + let condition_id = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::ULessThan, + self.get_bool_type_id()?, + condition_id, + index_id, + length_id, + )); + + // Indicate that we did generate the check. + Ok(BoundsCheckResult::Conditional(condition_id)) + } + + /// Emit a conditional load for `IndexBoundsCheckPolicy::ReadZeroSkipWrite`. + /// + /// Generate code to load a value of `result_type` if `condition` is true, + /// and generate a null value of that type if it is false. Call `emit_load` + /// to emit the instructions to perform the load. Return the id of the + /// merged value of the two branches. + pub(super) fn write_conditional_indexed_load( + &mut self, + result_type: Word, + condition: Word, + function: &mut Function, + block: &mut Block, + emit_load: F, + ) -> Word + where + F: FnOnce(&mut IdGenerator, &mut Block) -> Word, + { + let header_block = block.label_id; + let merge_block = self.id_gen.next(); + let in_bounds_block = self.id_gen.next(); + + // Branch based on whether the index was in bounds. + // + // As it turns out, our out-of-bounds branch block would contain no + // instructions: it just produces a constant zero, whose instruction is + // in the module's declarations section at the front. In this case, + // SPIR-V lets us omit the empty 'else' block, and branch directly to + // the merge block. The phi instruction in the merge block can cite the + // header block as its CFG predecessor. + block.body.push(Instruction::selection_merge( + merge_block, + spirv::SelectionControl::NONE, + )); + function.consume( + std::mem::replace(block, Block::new(in_bounds_block)), + Instruction::branch_conditional(condition, in_bounds_block, merge_block), + ); + + // The in-bounds path. Perform the access and the load. + let value_id = emit_load(&mut self.id_gen, block); + + // Finish the in-bounds block and start the merge block. This + // is the block we'll leave current on return. + function.consume( + std::mem::replace(block, Block::new(merge_block)), + Instruction::branch(merge_block), + ); + + // For the out-of-bounds case, produce a zero value. + let null_id = self.write_constant_null(result_type); + + // Merge the results from the two paths. + let result_id = self.id_gen.next(); + block.body.push(Instruction::phi( + result_type, + result_id, + &[(value_id, in_bounds_block), (null_id, header_block)], + )); + + result_id + } + + /// Emit code for bounds checks, per self.index_bounds_check_policy. + /// + /// Return a `BoundsCheckResult` indicating how the index should be + /// consumed. See that type's documentation for details. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_bounds_check( + &mut self, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &mut Function, + base: Handle, + index: Handle, + block: &mut Block, + ) -> Result { + Ok(match self.index_bounds_check_policy { + IndexBoundsCheckPolicy::Restrict => self.write_restricted_index( + base, + index, + ir_module, + ir_function, + fun_info, + function, + block, + )?, + IndexBoundsCheckPolicy::ReadZeroSkipWrite => self.write_index_comparison( + base, + index, + ir_module, + ir_function, + fun_info, + function, + block, + )?, + IndexBoundsCheckPolicy::UndefinedBehavior => { + BoundsCheckResult::Computed(self.cached[index]) + } + }) + } + + /// Emit code to subscript a vector by value with a computed index. + /// + /// Return the id of the element value. + #[allow(clippy::too_many_arguments)] + pub(super) fn write_vector_access( + &mut self, + ir_module: &crate::Module, + ir_function: &crate::Function, + fun_info: &FunctionInfo, + function: &mut Function, + expr_handle: Handle, + base: Handle, + index: Handle, + block: &mut Block, + ) -> Result { + let result_type_id = self.get_expression_type_id(&fun_info[expr_handle].ty)?; + + let base_id = self.cached[base]; + let index_id = self.cached[index]; + + let result_id = match self.write_bounds_check( + ir_module, + ir_function, + fun_info, + function, + base, + index, + block, + )? { + BoundsCheckResult::KnownInBounds(known_index) => { + let result_id = self.id_gen.next(); + block.body.push(Instruction::composite_extract( + result_type_id, + result_id, + base_id, + &[known_index], + )); + result_id + } + BoundsCheckResult::Computed(computed_index_id) => { + let result_id = self.id_gen.next(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + result_id, + base_id, + computed_index_id, + )); + result_id + } + BoundsCheckResult::Conditional(comparison_id) => { + // Run-time bounds checks were required. Emit + // conditional load. + self.write_conditional_indexed_load( + result_type_id, + comparison_id, + function, + block, + |id_gen, block| { + // The in-bounds path. Generate the access. + let element_id = id_gen.next(); + block.body.push(Instruction::vector_extract_dynamic( + result_type_id, + element_id, + base_id, + index_id, + )); + element_id + }, + ) + } + }; + + Ok(result_id) + } +} diff --git a/src/back/spv/instructions.rs b/src/back/spv/instructions.rs index 21948d1bc..97fe602dd 100644 --- a/src/back/spv/instructions.rs +++ b/src/back/spv/instructions.rs @@ -758,6 +758,21 @@ impl super::Instruction { // Control-Flow Instructions // + pub(super) fn phi( + result_type_id: Word, + result_id: Word, + var_parent_pairs: &[(Word, Word)], + ) -> Self { + let mut instruction = Self::new(Op::Phi); + instruction.add_operand(result_type_id); + instruction.add_operand(result_id); + for &(variable, parent) in var_parent_pairs { + instruction.add_operand(variable); + instruction.add_operand(parent); + } + instruction + } + pub(super) fn selection_merge( merge_id: Word, selection_control: spirv::SelectionControl, diff --git a/src/back/spv/mod.rs b/src/back/spv/mod.rs index c579f0546..daabc4c4e 100644 --- a/src/back/spv/mod.rs +++ b/src/back/spv/mod.rs @@ -2,13 +2,14 @@ !*/ mod helpers; +mod index; mod instructions; mod layout; mod writer; pub use spirv::Capability; -use crate::arena::Handle; +use crate::{arena::Handle, back::IndexBoundsCheckPolicy}; use spirv::Word; use std::ops; @@ -57,6 +58,8 @@ pub enum Error { FeatureNotImplemented(&'static str), #[error("module is not validated properly: {0}")] Validation(&'static str), + #[error(transparent)] + Proc(#[from] crate::proc::ProcError), } #[derive(Default)] @@ -110,6 +113,20 @@ struct Function { entry_point_context: Option, } +impl Function { + fn consume(&mut self, mut block: Block, termination: Instruction) { + block.termination = Some(termination); + self.blocks.push(block); + } + + fn parameter_id(&self, index: u32) -> Word { + match self.entry_point_context { + Some(ref context) => context.argument_ids[index as usize], + None => self.parameters[index as usize].result_id.unwrap(), + } + } +} + #[derive(Debug, PartialEq, Hash, Eq, Copy, Clone)] enum LocalType { Value { @@ -211,7 +228,8 @@ pub struct Writer { debugs: Vec, annotations: Vec, flags: WriterFlags, - void_type: u32, + index_bounds_check_policy: IndexBoundsCheckPolicy, + void_type: Word, //TODO: convert most of these into vectors, addressable by handle indices lookup_type: crate::FastHashMap, lookup_function: crate::FastHashMap, Word>, @@ -245,6 +263,9 @@ pub struct Options { // Note: there is a major bug currently associated with deriving the capabilities. // We are calling `required_capabilities`, but the semantics of this is broken. pub capabilities: Option>, + /// How should the generated code handle array, vector, or matrix indices + /// that are out of range? + pub index_bounds_check_policy: IndexBoundsCheckPolicy, } impl Default for Options { @@ -257,6 +278,7 @@ impl Default for Options { lang_version: (1, 0), flags, capabilities: None, + index_bounds_check_policy: super::IndexBoundsCheckPolicy::default(), } } } diff --git a/src/back/spv/writer.rs b/src/back/spv/writer.rs index abcce32e2..1053fb94e 100644 --- a/src/back/spv/writer.rs +++ b/src/back/spv/writer.rs @@ -1,5 +1,6 @@ use super::{ helpers::{contains_builtin, map_storage_class}, + index::{BoundsCheckResult, ExpressionPointer}, Block, CachedExpressions, Dimension, EntryPointContext, Error, Function, GlobalVariable, IdGenerator, Instruction, LocalType, LocalVariable, LogicalLayout, LookupFunctionType, LookupType, Options, PhysicalLayout, ResultMember, Writer, WriterFlags, BITS_PER_BYTE, @@ -49,11 +50,6 @@ impl Function { block.termination.as_ref().unwrap().to_words(sink); } } - - fn consume(&mut self, mut block: Block, termination: Instruction) { - block.termination = Some(termination); - self.blocks.push(block); - } } impl PhysicalLayout { @@ -146,6 +142,7 @@ impl Writer { debugs: vec![], annotations: vec![], flags: options.flags, + index_bounds_check_policy: options.index_bounds_check_policy, void_type, lookup_type: crate::FastHashMap::default(), lookup_function: crate::FastHashMap::default(), @@ -189,7 +186,7 @@ impl Writer { } } - fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Result { + pub(super) fn get_expression_type_id(&mut self, tr: &TypeResolution) -> Result { let lookup_ty = match *tr { TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), TypeResolution::Value(ref inner) => { @@ -224,6 +221,26 @@ impl Writer { }) } + pub(super) fn get_uint_type_id(&mut self) -> Result { + let local_type = LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Uint, + width: 4, + pointer_class: None, + }; + self.get_type_id(local_type.into()) + } + + pub(super) fn get_bool_type_id(&mut self) -> Result { + let local_type = LocalType::Value { + vector_size: None, + kind: crate::ScalarKind::Bool, + width: 1, + pointer_class: None, + }; + self.get_type_id(local_type.into()) + } + fn decorate(&mut self, id: Word, decoration: spirv::Decoration, operands: &[Word]) { self.annotations .push(Instruction::decorate(id, decoration, operands)); @@ -798,7 +815,7 @@ impl Writer { Ok(id) } - fn get_index_constant(&mut self, index: Word) -> Result { + pub(super) fn get_index_constant(&mut self, index: Word) -> Result { self.get_constant_scalar(crate::ScalarValue::Uint(index as _), 4) } @@ -905,7 +922,7 @@ impl Writer { Ok(()) } - fn write_constant_null(&mut self, type_id: Word) -> Word { + pub(super) fn write_constant_null(&mut self, type_id: Word) -> Word { let null_id = self.id_gen.next(); Instruction::constant_null(type_id, null_id) .to_words(&mut self.logical_layout.declarations); @@ -1199,33 +1216,31 @@ impl Writer { 0 } crate::Expression::Access { base, index } => { - let index_id = self.cached[index]; - let base_id = self.cached[base]; - match *fun_info[base].ty.inner_with(&ir_module.types) { - crate::TypeInner::Vector { .. } => { - let id = self.id_gen.next(); - block.body.push(Instruction::vector_extract_dynamic( - result_type_id, - id, - base_id, - index_id, - )); - id - } - crate::TypeInner::Array { .. } => { - return Err(Error::Validation( - "dynamic indexing of arrays not permitted", - )); - } + let base_ty = fun_info[base].ty.inner_with(&ir_module.types); + match *base_ty { + crate::TypeInner::Vector { .. } => (), ref other => { log::error!( "Unable to access base {:?} of type {:?}", ir_function.expressions[base], other ); - return Err(Error::FeatureNotImplemented("access for type")); + return Err(Error::Validation( + "only vectors may be dynamically indexed by value", + )); } - } + }; + + self.write_vector_access( + ir_module, + ir_function, + fun_info, + function, + expr_handle, + base, + index, + block, + )? } crate::Expression::AccessIndex { base, index: _ } if self.is_intermediate(base, ir_function, &ir_module.types) => @@ -1609,19 +1624,45 @@ impl Writer { } crate::Expression::LocalVariable(variable) => function.variables[&variable].id, crate::Expression::Load { pointer } => { - let pointer_id = - self.write_expression_pointer(ir_function, fun_info, pointer, block, function)?; - - let id = self.id_gen.next(); - block - .body - .push(Instruction::load(result_type_id, id, pointer_id, None)); - id + match self.write_expression_pointer( + ir_module, + ir_function, + fun_info, + pointer, + block, + function, + )? { + ExpressionPointer::Ready { pointer_id } => { + let id = self.id_gen.next(); + block + .body + .push(Instruction::load(result_type_id, id, pointer_id, None)); + id + } + ExpressionPointer::Conditional { condition, access } => { + self.write_conditional_indexed_load( + result_type_id, + condition, + function, + block, + move |id_gen, block| { + // The in-bounds path. Perform the access and the load. + let pointer_id = access.result_id.unwrap(); + let value_id = id_gen.next(); + block.body.push(access); + block.body.push(Instruction::load( + result_type_id, + value_id, + pointer_id, + None, + )); + value_id + }, + ) + } + } } - crate::Expression::FunctionArgument(index) => match function.entry_point_context { - Some(ref context) => context.argument_ids[index as usize], - None => function.parameters[index as usize].result_id.unwrap(), - }, + crate::Expression::FunctionArgument(index) => function.parameter_id(index), crate::Expression::Call(_function) => self.lookup_function_call[&expr_handle], crate::Expression::As { expr, @@ -2079,36 +2120,7 @@ impl Writer { id } crate::Expression::ArrayLength(expr) => { - let (structure_id, member_idx) = match ir_function.expressions[expr] { - crate::Expression::AccessIndex { base, .. } => { - match ir_function.expressions[base] { - crate::Expression::GlobalVariable(handle) => { - let global = &ir_module.global_variables[handle]; - let last_idx = match ir_module.types[global.ty].inner { - crate::TypeInner::Struct { ref members, .. } => { - members.len() as u32 - 1 - } - _ => return Err(Error::Validation("array length expression")), - }; - - (self.global_variables[handle.index()].id, last_idx) - } - _ => return Err(Error::Validation("array length expression")), - } - } - _ => return Err(Error::Validation("array length expression")), - }; - - // let structure_id = self.get_expression_global(ir_function, global); - let id = self.id_gen.next(); - - block.body.push(Instruction::array_length( - result_type_id, - id, - structure_id, - member_idx, - )); - id + self.write_runtime_array_length(expr, ir_function, function, block)? } }; @@ -2116,15 +2128,21 @@ impl Writer { Ok(()) } - /// Write a left-hand-side expression, returning an `id` of the pointer. + /// Build an `OpAccessChain` instruction for a left-hand-side expression. + /// + /// Emit any needed bounds-checking expressions to `block`. + /// + /// On success, the return value is an [`ExpressionPointer`] value; see the + /// documentation for that type. fn write_expression_pointer( &mut self, + ir_module: &crate::Module, ir_function: &crate::Function, fun_info: &FunctionInfo, mut expr_handle: Handle, block: &mut Block, function: &mut Function, - ) -> Result { + ) -> Result { let result_lookup_ty = match fun_info[expr_handle].ty { TypeResolution::Handle(ty_handle) => LookupType::Handle(ty_handle), TypeResolution::Value(ref inner) => { @@ -2133,12 +2151,60 @@ impl Writer { }; let result_type_id = self.get_type_id(result_lookup_ty)?; + // The id of the boolean `and` of all dynamic bounds checks up to this point. If + // `None`, then we haven't done any dynamic bounds checks yet. + // + // When we have a chain of bounds checks, we combine them with `OpLogicalAnd`, not + // a short-circuit branch. This means we might do comparisons we don't need to, + // but we expect these checks to almost always succeed, and keeping branches to a + // minimum is essential. + let mut accumulated_checks = None; + self.temp_list.clear(); let root_id = loop { expr_handle = match ir_function.expressions[expr_handle] { crate::Expression::Access { base, index } => { - let index_id = self.cached[index]; + let index_id = match self.write_bounds_check( + ir_module, + ir_function, + fun_info, + function, + base, + index, + block, + )? { + BoundsCheckResult::KnownInBounds(known_index) => { + // Even if the index is known, `OpAccessIndex` + // requires expression operands, not literals. + let scalar = crate::ScalarValue::Uint(known_index as u64); + self.get_constant_scalar(scalar, 4)? + } + BoundsCheckResult::Computed(computed_index_id) => computed_index_id, + BoundsCheckResult::Conditional(comparison_id) => { + match accumulated_checks { + Some(prior_checks) => { + let combined = self.id_gen.next(); + block.body.push(Instruction::binary( + spirv::Op::LogicalAnd, + self.get_bool_type_id()?, + combined, + prior_checks, + comparison_id, + )); + accumulated_checks = Some(combined); + } + None => { + // Start a fresh chain of checks. + accumulated_checks = Some(comparison_id); + } + } + + // Either way, the index to use is unchanged. + self.cached[index] + } + }; self.temp_list.push(index_id); + base } crate::Expression::AccessIndex { base, index } => { @@ -2162,20 +2228,30 @@ impl Writer { } }; - let id = if self.temp_list.is_empty() { - root_id + let pointer = if self.temp_list.is_empty() { + ExpressionPointer::Ready { + pointer_id: root_id, + } } else { self.temp_list.reverse(); - let id = self.id_gen.next(); - block.body.push(Instruction::access_chain( - result_type_id, - id, - root_id, - &self.temp_list, - )); - id + let pointer_id = self.id_gen.next(); + let access = + Instruction::access_chain(result_type_id, pointer_id, root_id, &self.temp_list); + + // If we generated some bounds checks, we need to leave it to our + // caller to generate the branch, the access, the load or store, and + // the zero value (for loads). Otherwise, we can emit the access + // ourselves, and just hand them the id of the pointer. + match accumulated_checks { + Some(condition) => ExpressionPointer::Conditional { condition, access }, + None => { + block.body.push(access); + ExpressionPointer::Ready { pointer_id } + } + } }; - Ok(id) + + Ok(pointer) } fn get_expression_global( @@ -2539,18 +2615,53 @@ impl Writer { )); } crate::Statement::Store { pointer, value } => { - let pointer_id = self.write_expression_pointer( + let value_id = self.cached[value]; + match self.write_expression_pointer( + ir_module, ir_function, fun_info, pointer, &mut block, function, - )?; - let value_id = self.cached[value]; + )? { + ExpressionPointer::Ready { pointer_id } => { + block + .body + .push(Instruction::store(pointer_id, value_id, None)); + } + ExpressionPointer::Conditional { condition, access } => { + let merge_block = self.id_gen.next(); + let in_bounds_block = self.id_gen.next(); - block - .body - .push(Instruction::store(pointer_id, value_id, None)); + // Emit the conditional branch. + block.body.push(Instruction::selection_merge( + merge_block, + spirv::SelectionControl::NONE, + )); + function.consume( + std::mem::replace(&mut block, Block::new(in_bounds_block)), + Instruction::branch_conditional( + condition, + in_bounds_block, + merge_block, + ), + ); + + // The in-bounds path. Perform the access and the store. + let pointer_id = access.result_id.unwrap(); + block.body.push(access); + block + .body + .push(Instruction::store(pointer_id, value_id, None)); + + // Finish the in-bounds block and start the merge block. This + // is the block we'll leave current on return. + function.consume( + std::mem::replace(&mut block, Block::new(merge_block)), + Instruction::branch(merge_block), + ); + } + }; } crate::Statement::ImageStore { image, diff --git a/src/proc/index.rs b/src/proc/index.rs new file mode 100644 index 000000000..69dd6c2bb --- /dev/null +++ b/src/proc/index.rs @@ -0,0 +1,79 @@ +//! Definitions for index bounds checking. + +use super::ProcError; + +impl crate::TypeInner { + /// Return the length of a subscriptable type. + /// + /// The `self` parameter should be a handle to a vector, matrix, or array + /// type, a pointer to one of those, or a value pointer. Arrays may be + /// fixed-size, dynamically sized, or sized by a specializable constant. + /// + /// The value returned is appropriate for bounds checks on subscripting. + /// + /// Return an error if `self` does not describe a subscriptable type at all. + pub fn indexable_length(&self, module: &crate::Module) -> Result { + use crate::TypeInner as Ti; + let known_length = match *self { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } => { + return size.to_indexable_length(module); + } + Ti::ValuePointer { + size: Some(size), .. + } => size as _, + Ti::Pointer { base, .. } => { + // When assigning types to expressions, ResolveContext::Resolve + // does a separate sub-match here instead of a full recursion, + // so we'll do the same. + let base_inner = &module.types[base].inner; + match *base_inner { + Ti::Vector { size, .. } => size as _, + Ti::Matrix { columns, .. } => columns as _, + Ti::Array { size, .. } => return size.to_indexable_length(module), + _ => return Err(ProcError::TypeNotIndexable), + } + } + _ => return Err(ProcError::TypeNotIndexable), + }; + Ok(IndexableLength::Known(known_length)) + } +} + +/// The number of elements in an indexable type. +/// +/// This summarizes the length of vectors, matrices, and arrays in a way that is +/// convenient for indexing and bounds-checking code. +pub enum IndexableLength { + /// Values of this type always have the given number of elements. + Known(u32), + + /// The value of the given specializable constant is the number of elements. + /// (Non-specializable constants are reported as `Known`.) + Specializable(crate::Handle), + + /// The number of elements is determined at runtime. + Dynamic, +} + +impl crate::ArraySize { + pub fn to_indexable_length(self, module: &crate::Module) -> Result { + use crate::Constant as K; + Ok(match self { + Self::Constant(k) => match module.constants[k] { + K { + specialization: Some(_), + .. + } => IndexableLength::Specializable(k), + ref unspecialized => { + let length = unspecialized + .to_array_length() + .ok_or(ProcError::InvalidArraySizeConstant(k))?; + IndexableLength::Known(length) + } + }, + Self::Dynamic => IndexableLength::Dynamic, + }) + } +} diff --git a/src/proc/mod.rs b/src/proc/mod.rs index a524981f5..b19fe6a47 100644 --- a/src/proc/mod.rs +++ b/src/proc/mod.rs @@ -1,16 +1,26 @@ //! Module processing functionality. +mod index; mod interpolator; mod layouter; mod namer; mod terminator; mod typifier; +pub use index::IndexableLength; pub use layouter::{Alignment, InvalidBaseType, Layouter, TypeLayout}; pub use namer::{EntryPointIndex, NameKey, Namer}; pub use terminator::ensure_block_returns; pub use typifier::{ResolveContext, ResolveError, TypeResolution}; +#[derive(Clone, Debug, thiserror::Error, PartialEq)] +pub enum ProcError { + #[error("type is not indexable, and has no length (validation error)")] + TypeNotIndexable, + #[error("array length is wrong kind of constant (validation error)")] + InvalidArraySizeConstant(crate::Handle), +} + impl From for super::ScalarKind { fn from(format: super::StorageFormat) -> Self { use super::{ScalarKind as Sk, StorageFormat as Sf}; @@ -217,7 +227,19 @@ impl crate::SampleLevel { } impl crate::Constant { - pub fn to_array_length(&self) -> Option { + /// Interpret this constant as an array length, and return it as a `u32`. + /// + /// Ignore any specialization available for this constant; return its + /// unspecialized value. + /// + /// If the constant has an inappropriate kind (non-scalar or non-integer) or + /// value (negative, out of range for u32), return `None`. This usually + /// indicates an error, but only the caller has enough information to report + /// the error helpfully: in back ends, it's a validation error, but in front + /// ends, it may indicate ill-formed input (for example, a SPIR-V + /// `OpArrayType` referring to an inappropriate `OpConstant`). So we return + /// `Option` and let the caller sort things out. + pub(crate) fn to_array_length(&self) -> Option { use std::convert::TryInto; match self.inner { crate::ConstantInner::Scalar { value, width: _ } => match value { diff --git a/src/valid/expression.rs b/src/valid/expression.rs index 561a088a7..a4035fede 100644 --- a/src/valid/expression.rs +++ b/src/valid/expression.rs @@ -1,7 +1,7 @@ use super::{compose::validate_compose, ComposeError, FunctionInfo, ShaderStages, TypeFlags}; use crate::{ arena::{Arena, Handle}, - proc::ResolveError, + proc::{ProcError, ResolveError}, }; #[derive(Clone, Debug, thiserror::Error)] @@ -17,8 +17,8 @@ pub enum ExpressionError { InvalidBaseType(Handle), #[error("Accessing with index {0:?} can't be done")] InvalidIndexType(Handle), - #[error("Accessing index {1} is out of {0:?} bounds")] - IndexOutOfBounds(Handle, u32), + #[error("Accessing index {1:?} is out of {0:?} bounds")] + IndexOutOfBounds(Handle, crate::ScalarValue), #[error("The expression {0:?} may only be indexed by a constant")] IndexMustBeConstant(Handle), #[error("Function argument {0:?} doesn't exist")] @@ -41,6 +41,8 @@ pub enum ExpressionError { InvalidSwizzleComponent(crate::SwizzleComponent, crate::VectorSize), #[error(transparent)] Compose(#[from] ComposeError), + #[error(transparent)] + Proc(#[from] ProcError), #[error("Operation {0:?} can't work with {1:?}")] InvalidUnaryOperandType(crate::UnaryOperator, Handle), #[error("Operation {0:?} can't work with {1:?} and {2:?}")] @@ -144,8 +146,9 @@ impl super::Validator { let stages = match *expression { E::Access { base, index } => { + let base_type = resolver.resolve(base)?; // See the documentation for `Expression::Access`. - let dynamic_indexing_restricted = match *resolver.resolve(base)? { + let dynamic_indexing_restricted = match *base_type { Ti::Vector { .. } => false, Ti::Matrix { .. } | Ti::Array { .. } => true, Ti::Pointer { .. } | Ti::ValuePointer { size: Some(_), .. } => false, @@ -174,6 +177,36 @@ impl super::Validator { { return Err(ExpressionError::IndexMustBeConstant(base)); } + + // If we know both the length and the index, we can do the + // bounds check now. + if let crate::proc::IndexableLength::Known(known_length) = + base_type.indexable_length(module)? + { + if let E::Constant(k) = function.expressions[index] { + if let crate::Constant { + // We must treat specializable constants as unknown. + specialization: None, + // Non-scalar indices should have been caught above. + inner: crate::ConstantInner::Scalar { value, .. }, + .. + } = module.constants[k] + { + match value { + crate::ScalarValue::Uint(u) if u >= known_length as u64 => { + return Err(ExpressionError::IndexOutOfBounds(base, value)); + } + crate::ScalarValue::Sint(s) + if s < 0 || s >= known_length as i64 => + { + return Err(ExpressionError::IndexOutOfBounds(base, value)); + } + _ => (), + } + } + } + } + ShaderStages::all() } E::AccessIndex { base, index } => { @@ -208,7 +241,10 @@ impl super::Validator { let limit = resolve_index_limit(module, base, resolver.resolve(base)?, true)?; if index >= limit { - return Err(ExpressionError::IndexOutOfBounds(base, index)); + return Err(ExpressionError::IndexOutOfBounds( + base, + crate::ScalarValue::Uint(limit as _), + )); } ShaderStages::all() } diff --git a/src/valid/type.rs b/src/valid/type.rs index abcd268c1..7ddc0f6aa 100644 --- a/src/valid/type.rs +++ b/src/valid/type.rs @@ -52,6 +52,8 @@ pub enum TypeError { InvalidArrayBaseType(Handle), #[error("The constant {0:?} can not be used for an array size")] InvalidArraySizeConstant(Handle), + #[error("Array type {0:?} must have a length of one or more")] + NonPositiveArrayLength(Handle), #[error("Array stride {stride} is smaller than the base element size {base_size}")] InsufficientArrayStride { stride: u32, base_size: u32 }, #[error("Field '{0}' can't be dynamically-sized, has type {1:?}")] @@ -288,15 +290,15 @@ impl super::Validator { let sized_flag = match size { crate::ArraySize::Constant(const_handle) => { - match constants.try_get(const_handle) { + let length_is_positive = match constants.try_get(const_handle) { Some(&crate::Constant { inner: crate::ConstantInner::Scalar { width: _, - value: crate::ScalarValue::Uint(_), + value: crate::ScalarValue::Uint(length), }, .. - }) => {} + }) => length > 0, // Accept a signed integer size to avoid // requiring an explicit uint // literal. Type inference should make @@ -305,14 +307,18 @@ impl super::Validator { inner: crate::ConstantInner::Scalar { width: _, - value: crate::ScalarValue::Sint(_), + value: crate::ScalarValue::Sint(length), }, .. - }) => {} + }) => length > 0, other => { log::warn!("Array size {:?}", other); return Err(TypeError::InvalidArraySizeConstant(const_handle)); } + }; + + if !length_is_positive { + return Err(TypeError::NonPositiveArrayLength(const_handle)); } TypeFlags::SIZED diff --git a/tests/in/bounds-check-zero.param.ron b/tests/in/bounds-check-zero.param.ron new file mode 100644 index 000000000..e6fc09bf4 --- /dev/null +++ b/tests/in/bounds-check-zero.param.ron @@ -0,0 +1,4 @@ +( + bounds_check_read_zero_skip_write: true, + spv_version: (1, 1), +) diff --git a/tests/in/bounds-check-zero.wgsl b/tests/in/bounds-check-zero.wgsl new file mode 100644 index 000000000..773a1788f --- /dev/null +++ b/tests/in/bounds-check-zero.wgsl @@ -0,0 +1,46 @@ +// Tests for `naga::back::IndexBoundsCheckPolicy::ReadZeroSkipWrite`. + +[[block]] +struct Globals { + a: array; + v: vec4; + m: mat3x4; +}; + +[[group(0), binding(0)]] var globals: Globals; + +fn index_array(i: i32) -> f32 { + return globals.a[i]; +} + +fn index_vector(i: i32) -> f32 { + return globals.v[i]; +} + +fn index_vector_by_value(v: vec4, i: i32) -> f32 { + return v[i]; +} + +fn index_matrix(i: i32) -> vec4 { + return globals.m[i]; +} + +fn index_twice(i: i32, j: i32) -> f32 { + return globals.m[i][j]; +} + +fn set_array(i: i32, v: f32) { + globals.a[i] = v; +} + +fn set_vector(i: i32, v: f32) { + globals.v[i] = v; +} + +fn set_matrix(i: i32, v: vec4) { + globals.m[i] = v; +} + +fn set_index_twice(i: i32, j: i32, v: f32) { + globals.m[i][j] = v; +} diff --git a/tests/in/pointer-access.param.ron b/tests/in/pointer-access.param.ron new file mode 100644 index 000000000..a049e71aa --- /dev/null +++ b/tests/in/pointer-access.param.ron @@ -0,0 +1,5 @@ +( + spv_version: (1, 0), + spv_debug: true, + spv_adjust_coordinate_space: true, +) diff --git a/tests/in/pointer-access.spv b/tests/in/pointer-access.spv new file mode 100644 index 0000000000000000000000000000000000000000..ba5664c3a7ad4b4d88069e858c69699126167a69 GIT binary patch literal 664 zcmYk3%}c{T5XGly(`c*Js`Ud@T6^@af_Uhu1uvrDHN++@f@s;Kp#Qli@%@sGxG;G$ zZ|3dW-86%7WMRY1B0JHS)wjT4z);w+u01_G$M>_xI4?e2vdrR_^|G?u=IMEI8I8rdf`_Zp` Lj+N8Hha8_@_)arN literal 0 HcmV?d00001 diff --git a/tests/in/pointer-access.spvasm b/tests/in/pointer-access.spvasm new file mode 100644 index 000000000..bdf572c78 --- /dev/null +++ b/tests/in/pointer-access.spvasm @@ -0,0 +1,57 @@ +;;; Indexing into composite values passed by pointer. + +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +OpMemoryModel Logical GLSL450 +OpName %dpa_arg_array "dpa_arg_array" +OpName %dpa_arg_index "dpa_arg_index" +OpName %dpra_arg_struct "dpra_arg_struct" +OpName %dpra_arg_index "dpra_arg_index" +OpDecorate %run_array ArrayStride 4 +OpDecorate %run_struct Block +OpDecorate %dummy_var DescriptorSet 0 +OpDecorate %dummy_var Binding 0 +OpMemberDecorate %run_struct 0 Offset 0 + +%uint = OpTypeInt 32 0 +%uint_ptr = OpTypePointer StorageBuffer %uint +%const_0 = OpConstant %uint 0 +%const_7 = OpConstant %uint 7 + +%array = OpTypeArray %uint %const_7 +%array_ptr = OpTypePointer StorageBuffer %array +%dpa_type = OpTypeFunction %uint %array_ptr %uint + +%run_array = OpTypeRuntimeArray %uint +%run_struct = OpTypeStruct %run_array +%run_struct_ptr = OpTypePointer StorageBuffer %run_struct +%dpra_type = OpTypeFunction %uint %run_struct_ptr %uint + +;;; This makes Naga request SPV_KHR_storage_buffer_storage_class in the output, +;;; too. +%dummy_var = OpVariable %run_struct_ptr StorageBuffer + +;;; Take a pointer to an array of unsigned integers, and an index, +;;; and return the given element of the array. +%dpa = OpFunction %uint None %dpa_type +%dpa_arg_array = OpFunctionParameter %array_ptr +%dpa_arg_index = OpFunctionParameter %uint + + %dpa_label = OpLabel + %elt_ptr = OpAccessChain %uint_ptr %dpa_arg_array %dpa_arg_index + %elt_value = OpLoad %uint %elt_ptr + OpReturnValue %elt_value + OpFunctionEnd + +;;; Take a pointer to a struct containing a run-time array, and an index, and +;;; return the given element of the array. +%dpra = OpFunction %uint None %dpra_type +%dpra_arg_struct = OpFunctionParameter %run_struct_ptr +%dpra_arg_index = OpFunctionParameter %uint + + %dpra_label = OpLabel + %elt_ptr2 = OpAccessChain %uint_ptr %dpra_arg_struct %const_0 %dpra_arg_index + %elt_value2 = OpLoad %uint %elt_ptr2 + OpReturnValue %elt_value2 + OpFunctionEnd diff --git a/tests/out/bounds-check-zero.spvasm b/tests/out/bounds-check-zero.spvasm new file mode 100644 index 000000000..da2a09ec6 --- /dev/null +++ b/tests/out/bounds-check-zero.spvasm @@ -0,0 +1,212 @@ +; SPIR-V +; Version: 1.1 +; Generator: rspirv +; Bound: 135 +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpDecorate %6 ArrayStride 4 +OpDecorate %9 Block +OpMemberDecorate %9 0 Offset 0 +OpMemberDecorate %9 1 Offset 48 +OpMemberDecorate %9 2 Offset 64 +OpMemberDecorate %9 2 ColMajor +OpMemberDecorate %9 2 MatrixStride 16 +OpDecorate %10 DescriptorSet 0 +OpDecorate %10 Binding 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpConstant %4 10 +%5 = OpTypeFloat 32 +%6 = OpTypeArray %5 %3 +%7 = OpTypeVector %5 4 +%8 = OpTypeMatrix %7 3 +%9 = OpTypeStruct %6 %7 %8 +%11 = OpTypePointer StorageBuffer %9 +%10 = OpVariable %11 StorageBuffer +%15 = OpTypeFunction %5 %4 +%17 = OpTypePointer StorageBuffer %6 +%18 = OpTypePointer StorageBuffer %5 +%20 = OpTypeInt 32 0 +%19 = OpConstant %20 10 +%22 = OpTypeBool +%23 = OpConstant %20 0 +%28 = OpConstantNull %5 +%34 = OpTypePointer StorageBuffer %7 +%35 = OpConstant %20 1 +%38 = OpConstant %20 4 +%43 = OpConstantNull %5 +%49 = OpTypeFunction %5 %7 %4 +%55 = OpConstantNull %5 +%60 = OpTypeFunction %7 %4 +%62 = OpTypePointer StorageBuffer %8 +%63 = OpTypePointer StorageBuffer %7 +%64 = OpConstant %20 3 +%66 = OpConstant %20 2 +%71 = OpConstantNull %7 +%77 = OpTypeFunction %5 %4 %4 +%84 = OpConstantNull %7 +%90 = OpConstantNull %5 +%96 = OpTypeFunction %2 %4 %5 +%107 = OpTypePointer StorageBuffer %5 +%116 = OpTypeFunction %2 %4 %7 +%127 = OpTypeFunction %2 %4 %4 %5 +%14 = OpFunction %5 None %15 +%13 = OpFunctionParameter %4 +%12 = OpLabel +OpBranch %16 +%16 = OpLabel +%21 = OpULessThan %22 %13 %19 +OpSelectionMerge %25 None +OpBranchConditional %21 %26 %25 +%26 = OpLabel +%24 = OpAccessChain %18 %10 %23 %13 +%27 = OpLoad %5 %24 +OpBranch %25 +%25 = OpLabel +%29 = OpPhi %5 %27 %26 %28 %16 +OpReturnValue %29 +OpFunctionEnd +%32 = OpFunction %5 None %15 +%31 = OpFunctionParameter %4 +%30 = OpLabel +OpBranch %33 +%33 = OpLabel +%36 = OpAccessChain %34 %10 %35 +%37 = OpLoad %7 %36 +%39 = OpULessThan %22 %31 %38 +OpSelectionMerge %40 None +OpBranchConditional %39 %41 %40 +%41 = OpLabel +%42 = OpVectorExtractDynamic %5 %37 %31 +OpBranch %40 +%40 = OpLabel +%44 = OpPhi %5 %42 %41 %43 %33 +OpReturnValue %44 +OpFunctionEnd +%48 = OpFunction %5 None %49 +%46 = OpFunctionParameter %7 +%47 = OpFunctionParameter %4 +%45 = OpLabel +OpBranch %50 +%50 = OpLabel +%51 = OpULessThan %22 %47 %38 +OpSelectionMerge %52 None +OpBranchConditional %51 %53 %52 +%53 = OpLabel +%54 = OpVectorExtractDynamic %5 %46 %47 +OpBranch %52 +%52 = OpLabel +%56 = OpPhi %5 %54 %53 %55 %50 +OpReturnValue %56 +OpFunctionEnd +%59 = OpFunction %7 None %60 +%58 = OpFunctionParameter %4 +%57 = OpLabel +OpBranch %61 +%61 = OpLabel +%65 = OpULessThan %22 %58 %64 +OpSelectionMerge %68 None +OpBranchConditional %65 %69 %68 +%69 = OpLabel +%67 = OpAccessChain %63 %10 %66 %58 +%70 = OpLoad %7 %67 +OpBranch %68 +%68 = OpLabel +%72 = OpPhi %7 %70 %69 %71 %61 +OpReturnValue %72 +OpFunctionEnd +%76 = OpFunction %5 None %77 +%74 = OpFunctionParameter %4 +%75 = OpFunctionParameter %4 +%73 = OpLabel +OpBranch %78 +%78 = OpLabel +%79 = OpULessThan %22 %74 %64 +OpSelectionMerge %81 None +OpBranchConditional %79 %82 %81 +%82 = OpLabel +%80 = OpAccessChain %63 %10 %66 %74 +%83 = OpLoad %7 %80 +OpBranch %81 +%81 = OpLabel +%85 = OpPhi %7 %83 %82 %84 %78 +%86 = OpULessThan %22 %75 %38 +OpSelectionMerge %87 None +OpBranchConditional %86 %88 %87 +%88 = OpLabel +%89 = OpVectorExtractDynamic %5 %85 %75 +OpBranch %87 +%87 = OpLabel +%91 = OpPhi %5 %89 %88 %90 %81 +OpReturnValue %91 +OpFunctionEnd +%95 = OpFunction %2 None %96 +%93 = OpFunctionParameter %4 +%94 = OpFunctionParameter %5 +%92 = OpLabel +OpBranch %97 +%97 = OpLabel +%98 = OpULessThan %22 %93 %19 +OpSelectionMerge %100 None +OpBranchConditional %98 %101 %100 +%101 = OpLabel +%99 = OpAccessChain %18 %10 %23 %93 +OpStore %99 %94 +OpBranch %100 +%100 = OpLabel +OpReturn +OpFunctionEnd +%105 = OpFunction %2 None %96 +%103 = OpFunctionParameter %4 +%104 = OpFunctionParameter %5 +%102 = OpLabel +OpBranch %106 +%106 = OpLabel +%108 = OpULessThan %22 %103 %38 +OpSelectionMerge %110 None +OpBranchConditional %108 %111 %110 +%111 = OpLabel +%109 = OpAccessChain %107 %10 %35 %103 +OpStore %109 %104 +OpBranch %110 +%110 = OpLabel +OpReturn +OpFunctionEnd +%115 = OpFunction %2 None %116 +%113 = OpFunctionParameter %4 +%114 = OpFunctionParameter %7 +%112 = OpLabel +OpBranch %117 +%117 = OpLabel +%118 = OpULessThan %22 %113 %64 +OpSelectionMerge %120 None +OpBranchConditional %118 %121 %120 +%121 = OpLabel +%119 = OpAccessChain %63 %10 %66 %113 +OpStore %119 %114 +OpBranch %120 +%120 = OpLabel +OpReturn +OpFunctionEnd +%126 = OpFunction %2 None %127 +%123 = OpFunctionParameter %4 +%124 = OpFunctionParameter %4 +%125 = OpFunctionParameter %5 +%122 = OpLabel +OpBranch %128 +%128 = OpLabel +%129 = OpULessThan %22 %124 %38 +%130 = OpULessThan %22 %123 %64 +%131 = OpLogicalAnd %22 %129 %130 +OpSelectionMerge %133 None +OpBranchConditional %131 %134 %133 +%134 = OpLabel +%132 = OpAccessChain %107 %10 %66 %123 %124 +OpStore %132 %125 +OpBranch %133 +%133 = OpLabel +OpReturn +OpFunctionEnd \ No newline at end of file diff --git a/tests/out/pointer-access.spvasm b/tests/out/pointer-access.spvasm new file mode 100644 index 000000000..ef4b7bbce --- /dev/null +++ b/tests/out/pointer-access.spvasm @@ -0,0 +1,55 @@ +; SPIR-V +; Version: 1.0 +; Generator: rspirv +; Bound: 35 +OpCapability Shader +OpCapability Linkage +OpExtension "SPV_KHR_storage_buffer_storage_class" +%1 = OpExtInstImport "GLSL.std.450" +OpMemoryModel Logical GLSL450 +OpSource GLSL 450 +OpDecorate %12 ArrayStride 4 +OpDecorate %14 ArrayStride 4 +OpDecorate %15 Block +OpMemberDecorate %15 0 Offset 0 +OpDecorate %17 DescriptorSet 0 +OpDecorate %17 Binding 0 +%2 = OpTypeVoid +%4 = OpTypeInt 32 1 +%3 = OpConstant %4 0 +%5 = OpConstant %4 1 +%6 = OpConstant %4 2 +%7 = OpConstant %4 3 +%9 = OpTypeInt 32 0 +%8 = OpConstant %9 0 +%10 = OpConstant %9 7 +%11 = OpTypePointer StorageBuffer %9 +%12 = OpTypeArray %9 %10 +%13 = OpTypePointer StorageBuffer %12 +%14 = OpTypeRuntimeArray %9 +%15 = OpTypeStruct %14 +%16 = OpTypePointer StorageBuffer %15 +%17 = OpVariable %16 StorageBuffer +%22 = OpTypeFunction %9 %13 %9 +%30 = OpTypeFunction %9 %16 %9 +%32 = OpTypePointer StorageBuffer %14 +%21 = OpFunction %9 None %22 +%19 = OpFunctionParameter %13 +%20 = OpFunctionParameter %9 +%18 = OpLabel +OpBranch %23 +%23 = OpLabel +%24 = OpAccessChain %11 %19 %20 +%25 = OpLoad %9 %24 +OpReturnValue %25 +OpFunctionEnd +%29 = OpFunction %9 None %30 +%27 = OpFunctionParameter %16 +%28 = OpFunctionParameter %9 +%26 = OpLabel +OpBranch %31 +%31 = OpLabel +%33 = OpAccessChain %11 %27 %8 %28 +%34 = OpLoad %9 %33 +OpReturnValue %34 +OpFunctionEnd \ No newline at end of file diff --git a/tests/snapshots.rs b/tests/snapshots.rs index c6ebaab81..e009f5696 100644 --- a/tests/snapshots.rs +++ b/tests/snapshots.rs @@ -23,6 +23,15 @@ bitflags::bitflags! { struct Parameters { #[serde(default)] god_mode: bool, + + // We can only deserialize `IndexBoundsCheckPolicy` values if `deserialize` + // feature was enabled, but features should not affect snapshot contents, so + // just take the policy as booleans instead. + #[serde(default)] + bounds_check_read_zero_skip_write: bool, + #[serde(default)] + bounds_check_restrict: bool, + #[cfg_attr(not(feature = "spv-out"), allow(dead_code))] spv_version: (u8, u8), #[cfg_attr(not(feature = "spv-out"), allow(dead_code))] @@ -52,6 +61,9 @@ fn check_targets(module: &naga::Module, name: &str, targets: Targets) { Ok(string) => ron::de::from_str(&string).expect("Couldn't find param file"), Err(_) => Parameters::default(), }; + if params.bounds_check_restrict && params.bounds_check_read_zero_skip_write { + panic!("select only one bounds check policy"); + } let capabilities = if params.god_mode { naga::valid::Capabilities::all() } else { @@ -143,6 +155,14 @@ fn check_output_spv( } else { Some(params.spv_capabilities.clone()) }, + index_bounds_check_policy: if params.bounds_check_restrict { + naga::back::IndexBoundsCheckPolicy::Restrict + } else if params.bounds_check_read_zero_skip_write { + naga::back::IndexBoundsCheckPolicy::ReadZeroSkipWrite + } else { + naga::back::IndexBoundsCheckPolicy::UndefinedBehavior + }, + ..spv::Options::default() }; let spv = spv::write_vec(module, info, &options).unwrap(); @@ -318,6 +338,7 @@ fn convert_wgsl() { "globals", Targets::SPIRV | Targets::METAL | Targets::GLSL | Targets::WGSL, ), + ("bounds-check-zero", Targets::SPIRV), ]; for &(name, targets) in inputs.iter() { @@ -368,6 +389,12 @@ fn convert_spv_shadow() { convert_spv("shadow", true, Targets::IR | Targets::ANALYSIS); } +#[cfg(all(feature = "spv-in", feature = "spv-out"))] +#[test] +fn convert_spv_pointer_access() { + convert_spv("pointer-access", true, Targets::SPIRV); +} + #[cfg(feature = "glsl-in")] fn convert_glsl( name: &str, diff --git a/tests/wgsl-errors.rs b/tests/wgsl-errors.rs index c1acad6ab..c33000403 100644 --- a/tests/wgsl-errors.rs +++ b/tests/wgsl-errors.rs @@ -109,6 +109,25 @@ fn unknown_identifier() { ); } +#[test] +fn negative_index() { + check( + r#" + fn main() -> f32 { + let a = array(0., 1., 2.); + return a[-1]; + } + "#, + r#"error: expected non-negative integer constant expression, found `-1` + ┌─ wgsl:4:26 + │ +4 │ return a[-1]; + │ ^^ expected non-negative integer + +"#, + ); +} + macro_rules! check_validation_error { // We want to support an optional guard expression after the pattern, so // that we can check values we can't match against, like strings. @@ -146,7 +165,13 @@ macro_rules! check_validation_error { } fn validation_error(source: &str) -> Result { - let module = naga::front::wgsl::parse_str(source).expect("expected WGSL parse to succeed"); + let module = match naga::front::wgsl::parse_str(source) { + Ok(module) => module, + Err(err) => { + eprintln!("WGSL parse failed:"); + panic!("{}", err.emit_to_string(source)); + } + }; naga::valid::Validator::new( naga::valid::ValidationFlags::all(), naga::valid::Capabilities::empty(), @@ -354,6 +379,22 @@ fn invalid_access() { .. }) } + + check_validation_error! { + r#" + fn main() -> f32 { + let a = array(0., 1., 2.); + return a[3]; + } + "#: + Err(naga::valid::ValidationError::Function { + error: naga::valid::FunctionError::Expression { + error: naga::valid::ExpressionError::IndexOutOfBounds(_, _), + .. + }, + .. + }) + } } #[test]