Detect recursion in inline.rs and bail (#770)

* Detect recursion in inline.rs and bail

* Bail on inline error
This commit is contained in:
Ashley Hauck 2021-10-25 09:42:45 +02:00 committed by GitHub
parent e5c2953ea6
commit 8d019e4e37
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 181 additions and 33 deletions

View File

@ -6,14 +6,20 @@
use super::apply_rewrite_rules;
use super::simple_passes::outgoing_edges;
use super::{get_name, get_names};
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
use rustc_session::Session;
use std::mem::take;
type FunctionMap = FxHashMap<Word, Function>;
pub fn inline(module: &mut Module) {
pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
if module_has_recursion(sess, module) {
return Err(rustc_errors::ErrorReported);
}
let functions = module
.functions
.iter()
@ -58,6 +64,89 @@ pub fn inline(module: &mut Module) {
inliner.inline_fn(function);
fuse_trivial_branches(function);
}
Ok(())
}
// https://stackoverflow.com/a/53995651
fn module_has_recursion(sess: &Session, module: &Module) -> bool {
let func_to_index: FxHashMap<Word, usize> = module
.functions
.iter()
.enumerate()
.map(|(index, func)| (func.def_id().unwrap(), index))
.collect();
let mut discovered = vec![false; module.functions.len()];
let mut finished = vec![false; module.functions.len()];
let mut has_recursion = false;
for index in 0..module.functions.len() {
if !discovered[index] && !finished[index] {
visit(
sess,
module,
index,
&mut discovered,
&mut finished,
&mut has_recursion,
&func_to_index,
);
}
}
fn visit(
sess: &Session,
module: &Module,
current: usize,
discovered: &mut Vec<bool>,
finished: &mut Vec<bool>,
has_recursion: &mut bool,
func_to_index: &FxHashMap<Word, usize>,
) {
discovered[current] = true;
for next in calls(&module.functions[current], func_to_index) {
if discovered[next] {
let names = get_names(module);
let current_name = get_name(&names, module.functions[current].def_id().unwrap());
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
sess.err(&format!(
"module has recursion, which is not allowed: `{}` calls `{}`",
current_name, next_name
));
*has_recursion = true;
break;
}
if !finished[next] {
visit(
sess,
module,
next,
discovered,
finished,
has_recursion,
func_to_index,
);
}
}
discovered[current] = false;
finished[current] = true;
}
fn calls<'a>(
func: &'a Function,
func_to_index: &'a FxHashMap<Word, usize>,
) -> impl Iterator<Item = usize> + 'a {
func.all_inst_iter()
.filter(|inst| inst.class.opcode == Op::FunctionCall)
.map(move |inst| {
*func_to_index
.get(&inst.operands[0].id_ref_any().unwrap())
.unwrap()
})
}
has_recursion
}
fn compute_disallowed_argument_and_return_types(

View File

@ -15,6 +15,8 @@ mod specializer;
mod structurizer;
mod zombies;
use std::borrow::Cow;
use crate::codegen_cx::SpirvMetadata;
use crate::decorations::{CustomDecoration, UnrollLoopsDecoration};
use rspirv::binary::{Assemble, Consumer};
@ -77,6 +79,27 @@ fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Bloc
}
}
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
module
.debug_names
.iter()
.filter(|i| i.class.opcode == Op::Name)
.map(|i| {
(
i.operands[0].unwrap_id_ref(),
i.operands[1].unwrap_literal_string(),
)
})
.collect()
}
fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
names
.get(&id)
.map(|&s| Cow::Borrowed(s))
.unwrap_or_else(|| Cow::Owned(format!("Unnamed function ID %{}", id)))
}
pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<LinkResult> {
let mut output = {
let _timer = sess.timer("link_merge");
@ -178,7 +201,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
{
let _timer = sess.timer("link_inline");
inline::inline(&mut output);
inline::inline(sess, &mut output)?;
}
if opts.dce {

View File

@ -1,8 +1,9 @@
//! See documentation on `CodegenCx::zombie` for a description of the zombie system.
use super::{get_name, get_names};
use crate::decorations::{CustomDecoration, ZombieDecoration};
use rspirv::dr::{Instruction, Module};
use rspirv::spirv::{Op, Word};
use rspirv::spirv::Word;
use rustc_data_structures::fx::FxHashMap;
use rustc_session::Session;
use rustc_span::{Span, DUMMY_SP};
@ -103,20 +104,6 @@ fn spread_zombie(module: &mut Module, zombie: &mut FxHashMap<Word, ZombieInfo<'_
any
}
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
module
.debug_names
.iter()
.filter(|i| i.class.opcode == Op::Name)
.map(|i| {
(
i.operands[0].unwrap_id_ref(),
i.operands[1].unwrap_literal_string(),
)
})
.collect()
}
// If an entry point references a zombie'd value, then the entry point would normally get removed.
// That's an absolutely horrible experience to debug, though, so instead, create a nice error
// message containing the stack trace of how the entry point got to the zombie value.
@ -125,12 +112,10 @@ fn report_error_zombies(sess: &Session, module: &Module, zombie: &FxHashMap<Word
for root in super::dce::collect_roots(module) {
if let Some(reason) = zombie.get(&root) {
let names = names.get_or_insert_with(|| get_names(module));
let stack = reason.stack.iter().map(|s| {
names
.get(s)
.map(|&n| n.to_string())
.unwrap_or_else(|| format!("Unnamed function ID %{}", s))
});
let stack = reason
.stack
.iter()
.map(|&s| get_name(names, s).into_owned());
let stack_note = once("Stack:".to_string())
.chain(stack)
.collect::<Vec<_>>()
@ -174,18 +159,10 @@ pub fn remove_zombies(sess: &Session, module: &mut Module) {
}
if env::var("PRINT_ZOMBIE").is_ok() {
let names = get_names(module);
for f in &module.functions {
if let Some(reason) = is_zombie(f.def.as_ref().unwrap(), &zombies) {
let name_id = f.def_id().unwrap();
let name = module.debug_names.iter().find(|inst| {
inst.class.opcode == Op::Name && inst.operands[0].unwrap_id_ref() == name_id
});
let name = match name {
Some(Instruction { ref operands, .. }) => {
operands[1].unwrap_literal_string().to_string()
}
_ => format!("{}", name_id),
};
let name = get_name(&names, f.def_id().unwrap());
println!("Function removed {:?} because {:?}", name, reason.reason);
}
}

View File

@ -0,0 +1,55 @@
// build-pass
use glam::UVec3;
use spirv_std::glam;
use spirv_std::glam::{Mat3, Vec3, Vec4};
fn index_to_transform(index: usize, raw_data: &[u8]) -> Transform2D {
Transform2D {
own_transform: Mat3::IDENTITY,
parent_offset: 0,
}
}
const SIZE_OF_TRANSFORM: usize = core::mem::size_of::<Transform2D>();
#[derive(Clone)]
struct Transform2D {
own_transform: Mat3,
parent_offset: i32,
}
trait GivesFinalTransform {
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3;
}
impl GivesFinalTransform for (i32, Transform2D) {
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3 {
if self.1.parent_offset == 0 {
self.1.own_transform
} else {
let parent_index = self.0 + self.1.parent_offset;
self.1.own_transform.mul_mat3(
&((
parent_index as i32,
index_to_transform(parent_index as usize, raw_data),
)
.get_final_transform(raw_data)),
)
}
}
}
#[spirv(compute(threads(64)))]
pub fn main_cs(
#[spirv(global_invocation_id)] id: UVec3,
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] raw_data: &mut [u8],
#[spirv(position)] output_position: &mut Vec4,
) {
let index = id.x as usize;
let final_transform =
(index as i32, index_to_transform(index, raw_data)).get_final_transform(raw_data);
*output_position = final_transform
.mul_vec3(Vec3::new(0.1, 0.2, 0.3))
.extend(0.0);
}

View File

@ -0,0 +1,4 @@
error: module has recursion, which is not allowed: `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform` calls `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform`
error: aborting due to previous error