mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
linker/inline: group all 3 "type properties" into a map of "relevant globals".
This commit is contained in:
parent
11a2fe71b5
commit
0ace4c7c95
@ -27,8 +27,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
|
||||
.iter()
|
||||
.map(|f| (f.def_id().unwrap(), f.clone()))
|
||||
.collect();
|
||||
let (disallowed_argument_types, disallowed_return_types) =
|
||||
compute_disallowed_argument_and_return_types(module);
|
||||
let relevant_globals = gather_relevant_globals(module);
|
||||
let void = module
|
||||
.types_global_values
|
||||
.iter()
|
||||
@ -39,7 +38,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
|
||||
let mut dropped_ids = FxHashSet::default();
|
||||
let mut inlined_dont_inlines = Vec::new();
|
||||
module.functions.retain(|f| {
|
||||
if should_inline(&disallowed_argument_types, &disallowed_return_types, f) {
|
||||
if should_inline(&relevant_globals, f) {
|
||||
if has_dont_inline(f) {
|
||||
inlined_dont_inlines.push(f.def_id().unwrap());
|
||||
}
|
||||
@ -73,8 +72,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
|
||||
types_global_values: &mut module.types_global_values,
|
||||
void,
|
||||
functions: &functions,
|
||||
disallowed_argument_types: &disallowed_argument_types,
|
||||
disallowed_return_types: &disallowed_return_types,
|
||||
relevant_globals: &relevant_globals,
|
||||
};
|
||||
for function in &mut module.functions {
|
||||
inliner.inline_fn(function);
|
||||
@ -166,9 +164,41 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_disallowed_argument_and_return_types(
|
||||
module: &Module,
|
||||
) -> (FxHashSet<Word>, FxHashSet<Word>) {
|
||||
/// Any global types/variables, relevant to the inliner (mostly pointer-related).
|
||||
enum Global {
|
||||
Type {
|
||||
// FIXME(eddyb) rewrite these to variants of `Global`.
|
||||
illegal_pointee: bool,
|
||||
illegal_fn_param: bool,
|
||||
illegal_fn_ret: bool,
|
||||
},
|
||||
}
|
||||
|
||||
impl Global {
|
||||
// FIXME(eddyb) these are negative checks, but should really be positive
|
||||
// (which would require gathering *all* types as `Global`s).
|
||||
fn illegal_as_pointee_ty(&self) -> bool {
|
||||
match *self {
|
||||
Self::Type {
|
||||
illegal_pointee, ..
|
||||
} => illegal_pointee,
|
||||
}
|
||||
}
|
||||
fn illegal_as_fn_param_ty(&self) -> bool {
|
||||
match *self {
|
||||
Self::Type {
|
||||
illegal_fn_param, ..
|
||||
} => illegal_fn_param,
|
||||
}
|
||||
}
|
||||
fn illegal_as_fn_ret_ty(&self) -> bool {
|
||||
match *self {
|
||||
Self::Type { illegal_fn_ret, .. } => illegal_fn_ret,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn gather_relevant_globals(module: &Module) -> FxHashMap<Word, Global> {
|
||||
let allowed_argument_storage_classes = &[
|
||||
StorageClass::UniformConstant,
|
||||
StorageClass::Function,
|
||||
@ -176,48 +206,57 @@ fn compute_disallowed_argument_and_return_types(
|
||||
StorageClass::Workgroup,
|
||||
StorageClass::AtomicCounter,
|
||||
];
|
||||
let mut disallowed_argument_types = FxHashSet::default();
|
||||
let mut disallowed_pointees = FxHashSet::default();
|
||||
let mut disallowed_return_types = FxHashSet::default();
|
||||
let mut relevant_globals = FxHashMap::<_, Global>::default();
|
||||
for inst in &module.types_global_values {
|
||||
match inst.class.opcode {
|
||||
Op::TypePointer => {
|
||||
let storage_class = inst.operands[0].unwrap_storage_class();
|
||||
let pointee = inst.operands[1].unwrap_id_ref();
|
||||
if !allowed_argument_storage_classes.contains(&storage_class)
|
||||
|| disallowed_pointees.contains(&pointee)
|
||||
|| disallowed_argument_types.contains(&pointee)
|
||||
{
|
||||
disallowed_argument_types.insert(inst.result_id.unwrap());
|
||||
}
|
||||
disallowed_pointees.insert(inst.result_id.unwrap());
|
||||
disallowed_return_types.insert(inst.result_id.unwrap());
|
||||
let illegal_fn_param = !allowed_argument_storage_classes.contains(&storage_class)
|
||||
|| relevant_globals.get(&pointee).map_or(false, |pointee| {
|
||||
pointee.illegal_as_pointee_ty() || pointee.illegal_as_fn_param_ty()
|
||||
});
|
||||
|
||||
relevant_globals.insert(
|
||||
inst.result_id.unwrap(),
|
||||
Global::Type {
|
||||
illegal_pointee: true,
|
||||
illegal_fn_param,
|
||||
illegal_fn_ret: true,
|
||||
},
|
||||
);
|
||||
}
|
||||
Op::TypeStruct => {
|
||||
let fields = || inst.operands.iter().map(|op| op.id_ref_any().unwrap());
|
||||
if fields().any(|id| disallowed_argument_types.contains(&id)) {
|
||||
disallowed_argument_types.insert(inst.result_id.unwrap());
|
||||
Op::TypeStruct | Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => {
|
||||
let mut illegal_pointee = false;
|
||||
let mut illegal_fn_param = false;
|
||||
let mut illegal_fn_ret = false;
|
||||
let component_tys = if inst.class.opcode == Op::TypeStruct {
|
||||
&inst.operands
|
||||
} else {
|
||||
&inst.operands[..1]
|
||||
};
|
||||
for ty in component_tys {
|
||||
if let Some(ty) = relevant_globals.get(&ty.id_ref_any().unwrap()) {
|
||||
illegal_pointee |= ty.illegal_as_pointee_ty();
|
||||
illegal_fn_param |= ty.illegal_as_fn_param_ty();
|
||||
illegal_fn_ret |= ty.illegal_as_fn_ret_ty();
|
||||
}
|
||||
}
|
||||
if fields().any(|id| disallowed_pointees.contains(&id)) {
|
||||
disallowed_pointees.insert(inst.result_id.unwrap());
|
||||
}
|
||||
if fields().any(|id| disallowed_return_types.contains(&id)) {
|
||||
disallowed_return_types.insert(inst.result_id.unwrap());
|
||||
}
|
||||
}
|
||||
Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => {
|
||||
let id = inst.operands[0].id_ref_any().unwrap();
|
||||
if disallowed_argument_types.contains(&id) {
|
||||
disallowed_argument_types.insert(inst.result_id.unwrap());
|
||||
}
|
||||
if disallowed_pointees.contains(&id) {
|
||||
disallowed_pointees.insert(inst.result_id.unwrap());
|
||||
if illegal_pointee || illegal_fn_param || illegal_fn_ret {
|
||||
relevant_globals.insert(
|
||||
inst.result_id.unwrap(),
|
||||
Global::Type {
|
||||
illegal_pointee,
|
||||
illegal_fn_param,
|
||||
illegal_fn_ret,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
(disallowed_argument_types, disallowed_return_types)
|
||||
relevant_globals
|
||||
}
|
||||
|
||||
fn has_dont_inline(function: &Function) -> bool {
|
||||
@ -226,19 +265,18 @@ fn has_dont_inline(function: &Function) -> bool {
|
||||
control.contains(FunctionControl::DONT_INLINE)
|
||||
}
|
||||
|
||||
fn should_inline(
|
||||
disallowed_argument_types: &FxHashSet<Word>,
|
||||
disallowed_return_types: &FxHashSet<Word>,
|
||||
function: &Function,
|
||||
) -> bool {
|
||||
fn should_inline(relevant_globals: &FxHashMap<Word, Global>, function: &Function) -> bool {
|
||||
let def = function.def.as_ref().unwrap();
|
||||
let control = def.operands[0].unwrap_function_control();
|
||||
control.contains(FunctionControl::INLINE)
|
||||
|| function
|
||||
.parameters
|
||||
.iter()
|
||||
.any(|inst| disallowed_argument_types.contains(inst.result_type.as_ref().unwrap()))
|
||||
|| disallowed_return_types.contains(&function.def.as_ref().unwrap().result_type.unwrap())
|
||||
|| function.parameters.iter().any(|inst| {
|
||||
relevant_globals
|
||||
.get(inst.result_type.as_ref().unwrap())
|
||||
.map_or(false, |param_ty| param_ty.illegal_as_fn_param_ty())
|
||||
})
|
||||
|| relevant_globals
|
||||
.get(&function.def.as_ref().unwrap().result_type.unwrap())
|
||||
.map_or(false, |ret_ty| ret_ty.illegal_as_fn_ret_ty())
|
||||
}
|
||||
|
||||
// This should be more general, but a very common problem is passing an OpAccessChain to an
|
||||
@ -271,8 +309,7 @@ struct Inliner<'m, 'map> {
|
||||
types_global_values: &'m mut Vec<Instruction>,
|
||||
void: Word,
|
||||
functions: &'map FunctionMap,
|
||||
disallowed_argument_types: &'map FxHashSet<Word>,
|
||||
disallowed_return_types: &'map FxHashSet<Word>,
|
||||
relevant_globals: &'map FxHashMap<Word, Global>,
|
||||
// rewrite_rules: FxHashMap<Word, Word>,
|
||||
}
|
||||
|
||||
@ -335,11 +372,7 @@ impl Inliner<'_, '_> {
|
||||
)
|
||||
})
|
||||
.find(|(_, inst, f)| {
|
||||
should_inline(
|
||||
self.disallowed_argument_types,
|
||||
self.disallowed_return_types,
|
||||
f,
|
||||
) || args_invalid(caller, inst)
|
||||
should_inline(self.relevant_globals, f) || args_invalid(caller, inst)
|
||||
});
|
||||
let (call_index, call_inst, callee) = match call {
|
||||
None => return false,
|
||||
|
Loading…
Reference in New Issue
Block a user