Refactor linker and delete some unused code (#171)

Use more efficient solutions than DefAnalyzer, and shuffle linker step
order to remove need for ScalarType/AggregateType.
This commit is contained in:
Ashley Hauck 2020-10-29 13:03:16 +01:00 committed by GitHub
parent 52277d3336
commit 03e69fc1ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 41 additions and 364 deletions

View File

@ -1,46 +0,0 @@
use rspirv::dr::{Instruction, Module, Operand};
use std::collections::HashMap;
/// `DefAnalyzer` is a simple lookup table for instructions: Sometimes, we have a spirv
/// `result_id`, and we want to find the corresponding instruction. This struct loops over all
/// instructions in the module (expensive!) and builds a table from `result_id` to its defining
/// instruction. Note that it holds a reference to the instruction, instead of cloning it. While we
/// really could clone it, it's nice to keep the reference here, since then rustc guarantees we do
/// not mutate the module while a `DefAnalyzer` is alive (which would be really bad).
pub struct DefAnalyzer<'a> {
def_ids: HashMap<u32, &'a Instruction>,
}
impl<'a> DefAnalyzer<'a> {
pub fn new(module: &'a Module) -> Self {
let mut def_ids = HashMap::new();
module.all_inst_iter().for_each(|inst| {
if let Some(def_id) = inst.result_id {
def_ids
.entry(def_id)
.and_modify(|stored_inst| {
*stored_inst = inst;
})
.or_insert_with(|| inst);
}
});
Self { def_ids }
}
pub fn def(&self, id: u32) -> Option<&'a Instruction> {
self.def_ids.get(&id).copied()
}
/// Helper that extracts the operand as an `IdRef` and then looks up that id's instruction.
///
/// # Panics
///
/// Panics when provided an operand that doesn't reference an id, or that id is missing.
pub fn op_def(&self, operand: &Operand) -> Instruction {
self.def(operand.id_ref_any().expect("Expected ID"))
.unwrap()
.clone()
}
}

View File

@ -1,5 +1,4 @@
use super::ty::trans_aggregate_type;
use super::{print_type, DefAnalyzer, LinkerError, Result};
use super::{LinkerError, Result};
use rspirv::dr::{Instruction, Module};
use rspirv::spirv::{Capability, Decoration, LinkageType, Op, Word};
use std::collections::{HashMap, HashSet};
@ -15,7 +14,8 @@ pub fn run(module: &mut Module) -> Result<()> {
fn find_import_export_pairs_and_killed_params(
module: &Module,
) -> Result<(HashMap<u32, u32>, HashSet<u32>)> {
let defs = DefAnalyzer::new(module);
let type_map = get_type_map(module);
let fn_parameters = fn_parameters(module);
// Map from name -> (definition, type)
let mut exports = HashMap::new();
@ -29,7 +29,7 @@ fn find_import_export_pairs_and_killed_params(
Some((id, name, LinkageType::Export)) => (id, name),
_ => continue,
};
let type_id = get_type_for_link(&defs, id);
let type_id = *type_map.get(&id).expect("Unexpected op");
if exports.insert(name, (id, type_id)).is_some() {
return Err(LinkerError::MultipleExports(name.to_string()));
}
@ -46,12 +46,14 @@ fn find_import_export_pairs_and_killed_params(
}
Some(&x) => x,
};
let import_type = get_type_for_link(&defs, import_id);
let import_type = *type_map.get(&import_id).expect("Unexpected op");
// Make sure the import/export pair has the same type.
check_tys_equal(&defs, name, import_type, export_type)?;
check_tys_equal(name, import_type, export_type)?;
rewrite_rules.insert(import_id, export_id);
for param in fn_parameters(module, &defs, import_id) {
killed_parameters.insert(param);
if let Some(params) = fn_parameters.get(&import_id) {
for &param in params {
killed_parameters.insert(param);
}
}
}
@ -71,64 +73,36 @@ fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> {
}
}
fn get_type_for_link(defs: &DefAnalyzer<'_>, id: Word) -> Word {
let def_inst = defs
.def(id)
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
match def_inst.class.opcode {
Op::Variable => def_inst.result_type.unwrap(),
// Note: the result_type of OpFunction is the return type, not the function type. The
// function type is in operands[1].
Op::Function => def_inst.operands[1].unwrap_id_ref(),
_ => panic!("Unexpected op"),
}
fn get_type_map(module: &Module) -> HashMap<Word, Word> {
let vars = module
.types_global_values
.iter()
.filter(|i| i.class.opcode == Op::Variable)
.map(|i| (i.result_id.unwrap(), i.result_type.unwrap()));
let fns = module.functions.iter().map(|i| {
let d = i.def.as_ref().unwrap();
(d.result_id.unwrap(), d.operands[1].unwrap_id_ref())
});
vars.chain(fns).collect()
}
fn fn_parameters<'a>(
module: &'a Module,
defs: &DefAnalyzer<'_>,
id: Word,
) -> impl IntoIterator<Item = Word> + 'a {
let def_inst = defs
.def(id)
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
match def_inst.class.opcode {
Op::Variable => &[],
Op::Function => {
&module
.functions
.iter()
.find(|f| f.def.as_ref().unwrap().result_id == def_inst.result_id)
.unwrap()
.parameters as &[Instruction]
}
_ => panic!("Unexpected op"),
}
.iter()
.map(|p| p.result_id.unwrap())
fn fn_parameters(module: &Module) -> HashMap<Word, Vec<Word>> {
module
.functions
.iter()
.map(|f| {
let params = f.parameters.iter().map(|i| i.result_id.unwrap()).collect();
(f.def_id().unwrap(), params)
})
.collect()
}
fn check_tys_equal(
defs: &DefAnalyzer<'_>,
name: &str,
import_type_id: Word,
export_type_id: Word,
) -> Result<()> {
let import_type = defs.def(import_type_id).unwrap();
let export_type = defs.def(export_type_id).unwrap();
let imp = trans_aggregate_type(defs, import_type);
let exp = trans_aggregate_type(defs, export_type);
if imp == exp {
fn check_tys_equal(name: &str, import_type: Word, export_type: Word) -> Result<()> {
if import_type == export_type {
Ok(())
} else {
Err(LinkerError::TypeMismatch {
name: name.to_string(),
import_type: print_type(defs, import_type),
export_type: print_type(defs, export_type),
})
}
}

View File

@ -3,19 +3,16 @@ mod test;
mod capability_computation;
mod dce;
mod def_analyzer;
mod duplicates;
mod import_export_link;
mod inline;
mod mem2reg;
mod simple_passes;
mod structurizer;
mod ty;
mod zombies;
use def_analyzer::DefAnalyzer;
use rspirv::binary::Consumer;
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand};
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader};
use rspirv::spirv::{Op, Word};
use rustc_session::Session;
use std::collections::HashMap;
@ -27,12 +24,8 @@ pub enum LinkerError {
UnresolvedSymbol(String),
#[error("Multiple exports found for {:?}", .0)]
MultipleExports(String),
#[error("Types mismatch for {:?}, imported with type {:?}, exported with type {:?}", .name, .import_type, .export_type)]
TypeMismatch {
name: String,
import_type: String,
export_type: String,
},
#[error("Types mismatch for {:?}", .name)]
TypeMismatch { name: String },
}
pub type Result<T> = std::result::Result<T, LinkerError>;
@ -51,18 +44,6 @@ fn id(header: &mut ModuleHeader) -> Word {
result
}
fn print_type(defs: &DefAnalyzer<'_>, ty: &Instruction) -> String {
format!("{}", ty::trans_aggregate_type(defs, ty).unwrap())
}
fn extract_literal_int_as_u64(op: &Operand) -> u64 {
match op {
Operand::LiteralInt32(v) => (*v).into(),
Operand::LiteralInt64(v) => *v,
_ => panic!("Unexpected literal int"),
}
}
fn apply_rewrite_rules(rewrite_rules: &HashMap<Word, Word>, blocks: &mut [Block]) {
let apply = |inst: &mut Instruction| {
if let Some(ref mut id) = &mut inst.result_id {
@ -127,12 +108,6 @@ pub fn link(sess: Option<&Session>, inputs: &mut [&mut Module], opts: &Options)
output
};
// find import / export pairs
{
let _timer = timer("link_find_pairs");
import_export_link::run(&mut output)?;
}
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
{
let _timer = timer("link_remove_duplicates");
@ -143,6 +118,12 @@ pub fn link(sess: Option<&Session>, inputs: &mut [&mut Module], opts: &Options)
// jb-todo: strip identical OpDecoration / OpDecorationGroups
}
// find import / export pairs
{
let _timer = timer("link_find_pairs");
import_export_link::run(&mut output)?;
}
{
let _timer = timer("link_remove_zombies");
zombies::remove_zombies(sess, &mut output);

View File

@ -229,8 +229,6 @@ fn type_mismatch() -> Result<()> {
result.err(),
Some(LinkerError::TypeMismatch {
name: "foo".to_string(),
import_type: "f32".to_string(),
export_type: "u32".to_string(),
})
);
Ok(())

View File

@ -1,230 +0,0 @@
use super::{extract_literal_int_as_u64, DefAnalyzer};
use rspirv::dr::Instruction;
use rspirv::spirv::{AccessQualifier, Dim, ImageFormat, Op, StorageClass};
#[derive(PartialEq, Debug)]
pub enum ScalarType {
Void,
Bool,
Int { width: u32, signed: bool },
Float { width: u32 },
Opaque { name: String },
Event,
DeviceEvent,
ReserveId,
Queue,
Pipe,
ForwardPointer { storage_class: StorageClass },
PipeStorage,
NamedBarrier,
Sampler,
}
fn trans_scalar_type(inst: &Instruction) -> Option<ScalarType> {
Some(match inst.class.opcode {
Op::TypeVoid => ScalarType::Void,
Op::TypeBool => ScalarType::Bool,
Op::TypeEvent => ScalarType::Event,
Op::TypeDeviceEvent => ScalarType::DeviceEvent,
Op::TypeReserveId => ScalarType::ReserveId,
Op::TypeQueue => ScalarType::Queue,
Op::TypePipe => ScalarType::Pipe,
Op::TypePipeStorage => ScalarType::PipeStorage,
Op::TypeNamedBarrier => ScalarType::NamedBarrier,
Op::TypeSampler => ScalarType::Sampler,
Op::TypeForwardPointer => ScalarType::ForwardPointer {
storage_class: inst.operands[0].unwrap_storage_class(),
},
Op::TypeInt => ScalarType::Int {
width: inst.operands[0].unwrap_literal_int32(),
signed: inst.operands[1].unwrap_literal_int32() != 0,
},
Op::TypeFloat => ScalarType::Float {
width: inst.operands[0].unwrap_literal_int32(),
},
Op::TypeOpaque => ScalarType::Opaque {
name: inst.operands[0].unwrap_literal_string().to_string(),
},
_ => return None,
})
}
impl std::fmt::Display for ScalarType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match *self {
Self::Void => f.write_str("void"),
Self::Bool => f.write_str("bool"),
Self::Int { width, signed } => {
if signed {
write!(f, "i{}", width)
} else {
write!(f, "u{}", width)
}
}
Self::Float { width } => write!(f, "f{}", width),
Self::Opaque { ref name } => write!(f, "Opaque{{{}}}", name),
Self::Event => f.write_str("Event"),
Self::DeviceEvent => f.write_str("DeviceEvent"),
Self::ReserveId => f.write_str("ReserveId"),
Self::Queue => f.write_str("Queue"),
Self::Pipe => f.write_str("Pipe"),
Self::ForwardPointer { storage_class } => {
write!(f, "ForwardPointer{{{:?}}}", storage_class)
}
Self::PipeStorage => f.write_str("PipeStorage"),
Self::NamedBarrier => f.write_str("NamedBarrier"),
Self::Sampler => f.write_str("Sampler"),
}
}
}
#[derive(PartialEq, Debug)]
#[allow(dead_code)]
pub enum AggregateType {
Scalar(ScalarType),
Array {
ty: Box<AggregateType>,
len: u64,
},
Pointer {
ty: Box<AggregateType>,
storage_class: StorageClass,
},
Image {
ty: Box<AggregateType>,
dim: Dim,
depth: u32,
arrayed: u32,
multi_sampled: u32,
sampled: u32,
format: ImageFormat,
access: Option<AccessQualifier>,
},
SampledImage {
ty: Box<AggregateType>,
},
Aggregate(Vec<AggregateType>),
Function(Vec<AggregateType>, Box<AggregateType>),
}
pub(crate) fn trans_aggregate_type(
def: &DefAnalyzer<'_>,
inst: &Instruction,
) -> Option<AggregateType> {
Some(match inst.class.opcode {
Op::TypeArray => {
let len_def = def.op_def(&inst.operands[1]);
assert!(len_def.class.opcode == Op::Constant); // don't support spec constants yet
let len_value = extract_literal_int_as_u64(&len_def.operands[0]);
AggregateType::Array {
ty: Box::new(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.expect("Expect base type for OpTypeArray"),
),
len: len_value,
}
}
Op::TypePointer => AggregateType::Pointer {
storage_class: inst.operands[0].unwrap_storage_class(),
ty: Box::new(
trans_aggregate_type(def, &def.op_def(&inst.operands[1]))
.expect("Expect base type for OpTypePointer"),
),
},
Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix | Op::TypeSampledImage => {
AggregateType::Aggregate(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.map_or_else(Vec::new, |v| vec![v]),
)
}
Op::TypeStruct => {
let mut types = vec![];
for operand in inst.operands.iter() {
let op_def = def.op_def(operand);
match trans_aggregate_type(def, &op_def) {
Some(ty) => types.push(ty),
None => panic!("Expected type"),
}
}
AggregateType::Aggregate(types)
}
Op::TypeFunction => {
let mut parameters = vec![];
let ret = trans_aggregate_type(def, &def.op_def(&inst.operands[0])).unwrap();
for operand in inst.operands.iter().skip(1) {
let op_def = def.op_def(operand);
match trans_aggregate_type(def, &op_def) {
Some(ty) => parameters.push(ty),
None => panic!("Expected type"),
}
}
AggregateType::Function(parameters, Box::new(ret))
}
Op::TypeImage => AggregateType::Image {
ty: Box::new(
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
.expect("Expect base type for OpTypeImage"),
),
dim: inst.operands[1].unwrap_dim(),
depth: inst.operands[2].unwrap_literal_int32(),
arrayed: inst.operands[3].unwrap_literal_int32(),
multi_sampled: inst.operands[4].unwrap_literal_int32(),
sampled: inst.operands[5].unwrap_literal_int32(),
format: inst.operands[6].unwrap_image_format(),
access: inst.operands.get(7).map(|op| op.unwrap_access_qualifier()),
},
_ => {
if let Some(ty) = trans_scalar_type(inst) {
AggregateType::Scalar(ty)
} else {
return None;
}
}
})
}
impl std::fmt::Display for AggregateType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Scalar(scalar) => write!(f, "{}", scalar),
Self::Array { ty, len } => write!(f, "[{}; {}]", ty, len),
Self::Pointer { ty, storage_class } => write!(f, "*{{{:?}}} {}", storage_class, ty),
Self::Image {
ty,
dim,
depth,
arrayed,
multi_sampled,
sampled,
format,
access,
} => write!(
f,
"Image {{ {}, dim:{:?}, depth:{}, arrayed:{}, \
multi_sampled:{}, sampled:{}, format:{:?}, access:{:?} }}",
ty, dim, depth, arrayed, multi_sampled, sampled, format, access
),
Self::SampledImage { ty } => write!(f, "SampledImage{{{}}}", ty),
Self::Aggregate(agg) => {
f.write_str("struct {")?;
for elem in agg {
write!(f, " {},", elem)?;
}
f.write_str(" }")
}
Self::Function(args, ret) => {
f.write_str("fn(")?;
for elem in args {
write!(f, " {},", elem)?;
}
write!(f, " ) -> {}", ret)
}
}
}
}