mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-21 22:34:34 +00:00
Detect recursion in inline.rs and bail (#770)
* Detect recursion in inline.rs and bail * Bail on inline error
This commit is contained in:
parent
e5c2953ea6
commit
8d019e4e37
@ -6,14 +6,20 @@
|
||||
|
||||
use super::apply_rewrite_rules;
|
||||
use super::simple_passes::outgoing_edges;
|
||||
use super::{get_name, get_names};
|
||||
use rspirv::dr::{Block, Function, Instruction, Module, ModuleHeader, Operand};
|
||||
use rspirv::spirv::{FunctionControl, Op, StorageClass, Word};
|
||||
use rustc_data_structures::fx::{FxHashMap, FxHashSet};
|
||||
use rustc_session::Session;
|
||||
use std::mem::take;
|
||||
|
||||
type FunctionMap = FxHashMap<Word, Function>;
|
||||
|
||||
pub fn inline(module: &mut Module) {
|
||||
pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
|
||||
// This algorithm gets real sad if there's recursion - but, good news, SPIR-V bans recursion
|
||||
if module_has_recursion(sess, module) {
|
||||
return Err(rustc_errors::ErrorReported);
|
||||
}
|
||||
let functions = module
|
||||
.functions
|
||||
.iter()
|
||||
@ -58,6 +64,89 @@ pub fn inline(module: &mut Module) {
|
||||
inliner.inline_fn(function);
|
||||
fuse_trivial_branches(function);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// https://stackoverflow.com/a/53995651
|
||||
fn module_has_recursion(sess: &Session, module: &Module) -> bool {
|
||||
let func_to_index: FxHashMap<Word, usize> = module
|
||||
.functions
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(index, func)| (func.def_id().unwrap(), index))
|
||||
.collect();
|
||||
let mut discovered = vec![false; module.functions.len()];
|
||||
let mut finished = vec![false; module.functions.len()];
|
||||
let mut has_recursion = false;
|
||||
for index in 0..module.functions.len() {
|
||||
if !discovered[index] && !finished[index] {
|
||||
visit(
|
||||
sess,
|
||||
module,
|
||||
index,
|
||||
&mut discovered,
|
||||
&mut finished,
|
||||
&mut has_recursion,
|
||||
&func_to_index,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
fn visit(
|
||||
sess: &Session,
|
||||
module: &Module,
|
||||
current: usize,
|
||||
discovered: &mut Vec<bool>,
|
||||
finished: &mut Vec<bool>,
|
||||
has_recursion: &mut bool,
|
||||
func_to_index: &FxHashMap<Word, usize>,
|
||||
) {
|
||||
discovered[current] = true;
|
||||
|
||||
for next in calls(&module.functions[current], func_to_index) {
|
||||
if discovered[next] {
|
||||
let names = get_names(module);
|
||||
let current_name = get_name(&names, module.functions[current].def_id().unwrap());
|
||||
let next_name = get_name(&names, module.functions[next].def_id().unwrap());
|
||||
sess.err(&format!(
|
||||
"module has recursion, which is not allowed: `{}` calls `{}`",
|
||||
current_name, next_name
|
||||
));
|
||||
*has_recursion = true;
|
||||
break;
|
||||
}
|
||||
|
||||
if !finished[next] {
|
||||
visit(
|
||||
sess,
|
||||
module,
|
||||
next,
|
||||
discovered,
|
||||
finished,
|
||||
has_recursion,
|
||||
func_to_index,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
discovered[current] = false;
|
||||
finished[current] = true;
|
||||
}
|
||||
|
||||
fn calls<'a>(
|
||||
func: &'a Function,
|
||||
func_to_index: &'a FxHashMap<Word, usize>,
|
||||
) -> impl Iterator<Item = usize> + 'a {
|
||||
func.all_inst_iter()
|
||||
.filter(|inst| inst.class.opcode == Op::FunctionCall)
|
||||
.map(move |inst| {
|
||||
*func_to_index
|
||||
.get(&inst.operands[0].id_ref_any().unwrap())
|
||||
.unwrap()
|
||||
})
|
||||
}
|
||||
|
||||
has_recursion
|
||||
}
|
||||
|
||||
fn compute_disallowed_argument_and_return_types(
|
||||
|
@ -15,6 +15,8 @@ mod specializer;
|
||||
mod structurizer;
|
||||
mod zombies;
|
||||
|
||||
use std::borrow::Cow;
|
||||
|
||||
use crate::codegen_cx::SpirvMetadata;
|
||||
use crate::decorations::{CustomDecoration, UnrollLoopsDecoration};
|
||||
use rspirv::binary::{Assemble, Consumer};
|
||||
@ -77,6 +79,27 @@ fn apply_rewrite_rules(rewrite_rules: &FxHashMap<Word, Word>, blocks: &mut [Bloc
|
||||
}
|
||||
}
|
||||
|
||||
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
|
||||
module
|
||||
.debug_names
|
||||
.iter()
|
||||
.filter(|i| i.class.opcode == Op::Name)
|
||||
.map(|i| {
|
||||
(
|
||||
i.operands[0].unwrap_id_ref(),
|
||||
i.operands[1].unwrap_literal_string(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_name<'a>(names: &FxHashMap<Word, &'a str>, id: Word) -> Cow<'a, str> {
|
||||
names
|
||||
.get(&id)
|
||||
.map(|&s| Cow::Borrowed(s))
|
||||
.unwrap_or_else(|| Cow::Owned(format!("Unnamed function ID %{}", id)))
|
||||
}
|
||||
|
||||
pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<LinkResult> {
|
||||
let mut output = {
|
||||
let _timer = sess.timer("link_merge");
|
||||
@ -178,7 +201,7 @@ pub fn link(sess: &Session, mut inputs: Vec<Module>, opts: &Options) -> Result<L
|
||||
|
||||
{
|
||||
let _timer = sess.timer("link_inline");
|
||||
inline::inline(&mut output);
|
||||
inline::inline(sess, &mut output)?;
|
||||
}
|
||||
|
||||
if opts.dce {
|
||||
|
@ -1,8 +1,9 @@
|
||||
//! See documentation on `CodegenCx::zombie` for a description of the zombie system.
|
||||
|
||||
use super::{get_name, get_names};
|
||||
use crate::decorations::{CustomDecoration, ZombieDecoration};
|
||||
use rspirv::dr::{Instruction, Module};
|
||||
use rspirv::spirv::{Op, Word};
|
||||
use rspirv::spirv::Word;
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_session::Session;
|
||||
use rustc_span::{Span, DUMMY_SP};
|
||||
@ -103,20 +104,6 @@ fn spread_zombie(module: &mut Module, zombie: &mut FxHashMap<Word, ZombieInfo<'_
|
||||
any
|
||||
}
|
||||
|
||||
fn get_names(module: &Module) -> FxHashMap<Word, &str> {
|
||||
module
|
||||
.debug_names
|
||||
.iter()
|
||||
.filter(|i| i.class.opcode == Op::Name)
|
||||
.map(|i| {
|
||||
(
|
||||
i.operands[0].unwrap_id_ref(),
|
||||
i.operands[1].unwrap_literal_string(),
|
||||
)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
// If an entry point references a zombie'd value, then the entry point would normally get removed.
|
||||
// That's an absolutely horrible experience to debug, though, so instead, create a nice error
|
||||
// message containing the stack trace of how the entry point got to the zombie value.
|
||||
@ -125,12 +112,10 @@ fn report_error_zombies(sess: &Session, module: &Module, zombie: &FxHashMap<Word
|
||||
for root in super::dce::collect_roots(module) {
|
||||
if let Some(reason) = zombie.get(&root) {
|
||||
let names = names.get_or_insert_with(|| get_names(module));
|
||||
let stack = reason.stack.iter().map(|s| {
|
||||
names
|
||||
.get(s)
|
||||
.map(|&n| n.to_string())
|
||||
.unwrap_or_else(|| format!("Unnamed function ID %{}", s))
|
||||
});
|
||||
let stack = reason
|
||||
.stack
|
||||
.iter()
|
||||
.map(|&s| get_name(names, s).into_owned());
|
||||
let stack_note = once("Stack:".to_string())
|
||||
.chain(stack)
|
||||
.collect::<Vec<_>>()
|
||||
@ -174,18 +159,10 @@ pub fn remove_zombies(sess: &Session, module: &mut Module) {
|
||||
}
|
||||
|
||||
if env::var("PRINT_ZOMBIE").is_ok() {
|
||||
let names = get_names(module);
|
||||
for f in &module.functions {
|
||||
if let Some(reason) = is_zombie(f.def.as_ref().unwrap(), &zombies) {
|
||||
let name_id = f.def_id().unwrap();
|
||||
let name = module.debug_names.iter().find(|inst| {
|
||||
inst.class.opcode == Op::Name && inst.operands[0].unwrap_id_ref() == name_id
|
||||
});
|
||||
let name = match name {
|
||||
Some(Instruction { ref operands, .. }) => {
|
||||
operands[1].unwrap_literal_string().to_string()
|
||||
}
|
||||
_ => format!("{}", name_id),
|
||||
};
|
||||
let name = get_name(&names, f.def_id().unwrap());
|
||||
println!("Function removed {:?} because {:?}", name, reason.reason);
|
||||
}
|
||||
}
|
||||
|
55
tests/ui/lang/control_flow/issue_764.rs
Normal file
55
tests/ui/lang/control_flow/issue_764.rs
Normal file
@ -0,0 +1,55 @@
|
||||
// build-pass
|
||||
|
||||
use glam::UVec3;
|
||||
use spirv_std::glam;
|
||||
use spirv_std::glam::{Mat3, Vec3, Vec4};
|
||||
|
||||
fn index_to_transform(index: usize, raw_data: &[u8]) -> Transform2D {
|
||||
Transform2D {
|
||||
own_transform: Mat3::IDENTITY,
|
||||
parent_offset: 0,
|
||||
}
|
||||
}
|
||||
|
||||
const SIZE_OF_TRANSFORM: usize = core::mem::size_of::<Transform2D>();
|
||||
|
||||
#[derive(Clone)]
|
||||
struct Transform2D {
|
||||
own_transform: Mat3,
|
||||
parent_offset: i32,
|
||||
}
|
||||
|
||||
trait GivesFinalTransform {
|
||||
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3;
|
||||
}
|
||||
|
||||
impl GivesFinalTransform for (i32, Transform2D) {
|
||||
fn get_final_transform(&self, raw_data: &[u8]) -> Mat3 {
|
||||
if self.1.parent_offset == 0 {
|
||||
self.1.own_transform
|
||||
} else {
|
||||
let parent_index = self.0 + self.1.parent_offset;
|
||||
self.1.own_transform.mul_mat3(
|
||||
&((
|
||||
parent_index as i32,
|
||||
index_to_transform(parent_index as usize, raw_data),
|
||||
)
|
||||
.get_final_transform(raw_data)),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[spirv(compute(threads(64)))]
|
||||
pub fn main_cs(
|
||||
#[spirv(global_invocation_id)] id: UVec3,
|
||||
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] raw_data: &mut [u8],
|
||||
#[spirv(position)] output_position: &mut Vec4,
|
||||
) {
|
||||
let index = id.x as usize;
|
||||
let final_transform =
|
||||
(index as i32, index_to_transform(index, raw_data)).get_final_transform(raw_data);
|
||||
*output_position = final_transform
|
||||
.mul_vec3(Vec3::new(0.1, 0.2, 0.3))
|
||||
.extend(0.0);
|
||||
}
|
4
tests/ui/lang/control_flow/issue_764.stderr
Normal file
4
tests/ui/lang/control_flow/issue_764.stderr
Normal file
@ -0,0 +1,4 @@
|
||||
error: module has recursion, which is not allowed: `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform` calls `<(i32, issue_764::Transform2D) as issue_764::GivesFinalTransform>::get_final_transform`
|
||||
|
||||
error: aborting due to previous error
|
||||
|
Loading…
Reference in New Issue
Block a user