linker/inline: group all 3 "type properties" into a map of "relevant globals".

This commit is contained in:
Eduard-Mihai Burtescu 2023-04-03 12:01:16 +03:00 committed by Eduard-Mihai Burtescu
parent 11a2fe71b5
commit 0ace4c7c95

View File

@ -27,8 +27,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
.iter()
.map(|f| (f.def_id().unwrap(), f.clone()))
.collect();
let (disallowed_argument_types, disallowed_return_types) =
compute_disallowed_argument_and_return_types(module);
let relevant_globals = gather_relevant_globals(module);
let void = module
.types_global_values
.iter()
@ -39,7 +38,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
let mut dropped_ids = FxHashSet::default();
let mut inlined_dont_inlines = Vec::new();
module.functions.retain(|f| {
if should_inline(&disallowed_argument_types, &disallowed_return_types, f) {
if should_inline(&relevant_globals, f) {
if has_dont_inline(f) {
inlined_dont_inlines.push(f.def_id().unwrap());
}
@ -73,8 +72,7 @@ pub fn inline(sess: &Session, module: &mut Module) -> super::Result<()> {
types_global_values: &mut module.types_global_values,
void,
functions: &functions,
disallowed_argument_types: &disallowed_argument_types,
disallowed_return_types: &disallowed_return_types,
relevant_globals: &relevant_globals,
};
for function in &mut module.functions {
inliner.inline_fn(function);
@ -166,9 +164,41 @@ fn deny_recursion_in_module(sess: &Session, module: &Module) -> super::Result<()
}
}
fn compute_disallowed_argument_and_return_types(
module: &Module,
) -> (FxHashSet<Word>, FxHashSet<Word>) {
/// Any global types/variables, relevant to the inliner (mostly pointer-related).
enum Global {
Type {
// FIXME(eddyb) rewrite these to variants of `Global`.
illegal_pointee: bool,
illegal_fn_param: bool,
illegal_fn_ret: bool,
},
}
impl Global {
// FIXME(eddyb) these are negative checks, but should really be positive
// (which would require gathering *all* types as `Global`s).
fn illegal_as_pointee_ty(&self) -> bool {
match *self {
Self::Type {
illegal_pointee, ..
} => illegal_pointee,
}
}
fn illegal_as_fn_param_ty(&self) -> bool {
match *self {
Self::Type {
illegal_fn_param, ..
} => illegal_fn_param,
}
}
fn illegal_as_fn_ret_ty(&self) -> bool {
match *self {
Self::Type { illegal_fn_ret, .. } => illegal_fn_ret,
}
}
}
fn gather_relevant_globals(module: &Module) -> FxHashMap<Word, Global> {
let allowed_argument_storage_classes = &[
StorageClass::UniformConstant,
StorageClass::Function,
@ -176,48 +206,57 @@ fn compute_disallowed_argument_and_return_types(
StorageClass::Workgroup,
StorageClass::AtomicCounter,
];
let mut disallowed_argument_types = FxHashSet::default();
let mut disallowed_pointees = FxHashSet::default();
let mut disallowed_return_types = FxHashSet::default();
let mut relevant_globals = FxHashMap::<_, Global>::default();
for inst in &module.types_global_values {
match inst.class.opcode {
Op::TypePointer => {
let storage_class = inst.operands[0].unwrap_storage_class();
let pointee = inst.operands[1].unwrap_id_ref();
if !allowed_argument_storage_classes.contains(&storage_class)
|| disallowed_pointees.contains(&pointee)
|| disallowed_argument_types.contains(&pointee)
{
disallowed_argument_types.insert(inst.result_id.unwrap());
}
disallowed_pointees.insert(inst.result_id.unwrap());
disallowed_return_types.insert(inst.result_id.unwrap());
let illegal_fn_param = !allowed_argument_storage_classes.contains(&storage_class)
|| relevant_globals.get(&pointee).map_or(false, |pointee| {
pointee.illegal_as_pointee_ty() || pointee.illegal_as_fn_param_ty()
});
relevant_globals.insert(
inst.result_id.unwrap(),
Global::Type {
illegal_pointee: true,
illegal_fn_param,
illegal_fn_ret: true,
},
);
}
Op::TypeStruct => {
let fields = || inst.operands.iter().map(|op| op.id_ref_any().unwrap());
if fields().any(|id| disallowed_argument_types.contains(&id)) {
disallowed_argument_types.insert(inst.result_id.unwrap());
Op::TypeStruct | Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => {
let mut illegal_pointee = false;
let mut illegal_fn_param = false;
let mut illegal_fn_ret = false;
let component_tys = if inst.class.opcode == Op::TypeStruct {
&inst.operands
} else {
&inst.operands[..1]
};
for ty in component_tys {
if let Some(ty) = relevant_globals.get(&ty.id_ref_any().unwrap()) {
illegal_pointee |= ty.illegal_as_pointee_ty();
illegal_fn_param |= ty.illegal_as_fn_param_ty();
illegal_fn_ret |= ty.illegal_as_fn_ret_ty();
}
}
if fields().any(|id| disallowed_pointees.contains(&id)) {
disallowed_pointees.insert(inst.result_id.unwrap());
}
if fields().any(|id| disallowed_return_types.contains(&id)) {
disallowed_return_types.insert(inst.result_id.unwrap());
}
}
Op::TypeArray | Op::TypeRuntimeArray | Op::TypeVector => {
let id = inst.operands[0].id_ref_any().unwrap();
if disallowed_argument_types.contains(&id) {
disallowed_argument_types.insert(inst.result_id.unwrap());
}
if disallowed_pointees.contains(&id) {
disallowed_pointees.insert(inst.result_id.unwrap());
if illegal_pointee || illegal_fn_param || illegal_fn_ret {
relevant_globals.insert(
inst.result_id.unwrap(),
Global::Type {
illegal_pointee,
illegal_fn_param,
illegal_fn_ret,
},
);
}
}
_ => {}
}
}
(disallowed_argument_types, disallowed_return_types)
relevant_globals
}
fn has_dont_inline(function: &Function) -> bool {
@ -226,19 +265,18 @@ fn has_dont_inline(function: &Function) -> bool {
control.contains(FunctionControl::DONT_INLINE)
}
fn should_inline(
disallowed_argument_types: &FxHashSet<Word>,
disallowed_return_types: &FxHashSet<Word>,
function: &Function,
) -> bool {
fn should_inline(relevant_globals: &FxHashMap<Word, Global>, function: &Function) -> bool {
let def = function.def.as_ref().unwrap();
let control = def.operands[0].unwrap_function_control();
control.contains(FunctionControl::INLINE)
|| function
.parameters
.iter()
.any(|inst| disallowed_argument_types.contains(inst.result_type.as_ref().unwrap()))
|| disallowed_return_types.contains(&function.def.as_ref().unwrap().result_type.unwrap())
|| function.parameters.iter().any(|inst| {
relevant_globals
.get(inst.result_type.as_ref().unwrap())
.map_or(false, |param_ty| param_ty.illegal_as_fn_param_ty())
})
|| relevant_globals
.get(&function.def.as_ref().unwrap().result_type.unwrap())
.map_or(false, |ret_ty| ret_ty.illegal_as_fn_ret_ty())
}
// This should be more general, but a very common problem is passing an OpAccessChain to an
@ -271,8 +309,7 @@ struct Inliner<'m, 'map> {
types_global_values: &'m mut Vec<Instruction>,
void: Word,
functions: &'map FunctionMap,
disallowed_argument_types: &'map FxHashSet<Word>,
disallowed_return_types: &'map FxHashSet<Word>,
relevant_globals: &'map FxHashMap<Word, Global>,
// rewrite_rules: FxHashMap<Word, Word>,
}
@ -335,11 +372,7 @@ impl Inliner<'_, '_> {
)
})
.find(|(_, inst, f)| {
should_inline(
self.disallowed_argument_types,
self.disallowed_return_types,
f,
) || args_invalid(caller, inst)
should_inline(self.relevant_globals, f) || args_invalid(caller, inst)
});
let (call_index, call_inst, callee) = match call {
None => return false,