mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2025-02-19 18:35:22 +00:00
Refactor linker and delete some unused code (#171)
Use more efficient solutions than DefAnalyzer, and shuffle linker step order to remove need for ScalarType/AggregateType.
This commit is contained in:
parent
52277d3336
commit
03e69fc1ea
@ -1,46 +0,0 @@
|
||||
use rspirv::dr::{Instruction, Module, Operand};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// `DefAnalyzer` is a simple lookup table for instructions: Sometimes, we have a spirv
|
||||
/// `result_id`, and we want to find the corresponding instruction. This struct loops over all
|
||||
/// instructions in the module (expensive!) and builds a table from `result_id` to its defining
|
||||
/// instruction. Note that it holds a reference to the instruction, instead of cloning it. While we
|
||||
/// really could clone it, it's nice to keep the reference here, since then rustc guarantees we do
|
||||
/// not mutate the module while a `DefAnalyzer` is alive (which would be really bad).
|
||||
pub struct DefAnalyzer<'a> {
|
||||
def_ids: HashMap<u32, &'a Instruction>,
|
||||
}
|
||||
|
||||
impl<'a> DefAnalyzer<'a> {
|
||||
pub fn new(module: &'a Module) -> Self {
|
||||
let mut def_ids = HashMap::new();
|
||||
|
||||
module.all_inst_iter().for_each(|inst| {
|
||||
if let Some(def_id) = inst.result_id {
|
||||
def_ids
|
||||
.entry(def_id)
|
||||
.and_modify(|stored_inst| {
|
||||
*stored_inst = inst;
|
||||
})
|
||||
.or_insert_with(|| inst);
|
||||
}
|
||||
});
|
||||
|
||||
Self { def_ids }
|
||||
}
|
||||
|
||||
pub fn def(&self, id: u32) -> Option<&'a Instruction> {
|
||||
self.def_ids.get(&id).copied()
|
||||
}
|
||||
|
||||
/// Helper that extracts the operand as an `IdRef` and then looks up that id's instruction.
|
||||
///
|
||||
/// # Panics
|
||||
///
|
||||
/// Panics when provided an operand that doesn't reference an id, or that id is missing.
|
||||
pub fn op_def(&self, operand: &Operand) -> Instruction {
|
||||
self.def(operand.id_ref_any().expect("Expected ID"))
|
||||
.unwrap()
|
||||
.clone()
|
||||
}
|
||||
}
|
@ -1,5 +1,4 @@
|
||||
use super::ty::trans_aggregate_type;
|
||||
use super::{print_type, DefAnalyzer, LinkerError, Result};
|
||||
use super::{LinkerError, Result};
|
||||
use rspirv::dr::{Instruction, Module};
|
||||
use rspirv::spirv::{Capability, Decoration, LinkageType, Op, Word};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
@ -15,7 +14,8 @@ pub fn run(module: &mut Module) -> Result<()> {
|
||||
fn find_import_export_pairs_and_killed_params(
|
||||
module: &Module,
|
||||
) -> Result<(HashMap<u32, u32>, HashSet<u32>)> {
|
||||
let defs = DefAnalyzer::new(module);
|
||||
let type_map = get_type_map(module);
|
||||
let fn_parameters = fn_parameters(module);
|
||||
|
||||
// Map from name -> (definition, type)
|
||||
let mut exports = HashMap::new();
|
||||
@ -29,7 +29,7 @@ fn find_import_export_pairs_and_killed_params(
|
||||
Some((id, name, LinkageType::Export)) => (id, name),
|
||||
_ => continue,
|
||||
};
|
||||
let type_id = get_type_for_link(&defs, id);
|
||||
let type_id = *type_map.get(&id).expect("Unexpected op");
|
||||
if exports.insert(name, (id, type_id)).is_some() {
|
||||
return Err(LinkerError::MultipleExports(name.to_string()));
|
||||
}
|
||||
@ -46,12 +46,14 @@ fn find_import_export_pairs_and_killed_params(
|
||||
}
|
||||
Some(&x) => x,
|
||||
};
|
||||
let import_type = get_type_for_link(&defs, import_id);
|
||||
let import_type = *type_map.get(&import_id).expect("Unexpected op");
|
||||
// Make sure the import/export pair has the same type.
|
||||
check_tys_equal(&defs, name, import_type, export_type)?;
|
||||
check_tys_equal(name, import_type, export_type)?;
|
||||
rewrite_rules.insert(import_id, export_id);
|
||||
for param in fn_parameters(module, &defs, import_id) {
|
||||
killed_parameters.insert(param);
|
||||
if let Some(params) = fn_parameters.get(&import_id) {
|
||||
for ¶m in params {
|
||||
killed_parameters.insert(param);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -71,64 +73,36 @@ fn get_linkage_inst(inst: &Instruction) -> Option<(Word, &str, LinkageType)> {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_type_for_link(defs: &DefAnalyzer<'_>, id: Word) -> Word {
|
||||
let def_inst = defs
|
||||
.def(id)
|
||||
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
|
||||
|
||||
match def_inst.class.opcode {
|
||||
Op::Variable => def_inst.result_type.unwrap(),
|
||||
// Note: the result_type of OpFunction is the return type, not the function type. The
|
||||
// function type is in operands[1].
|
||||
Op::Function => def_inst.operands[1].unwrap_id_ref(),
|
||||
_ => panic!("Unexpected op"),
|
||||
}
|
||||
fn get_type_map(module: &Module) -> HashMap<Word, Word> {
|
||||
let vars = module
|
||||
.types_global_values
|
||||
.iter()
|
||||
.filter(|i| i.class.opcode == Op::Variable)
|
||||
.map(|i| (i.result_id.unwrap(), i.result_type.unwrap()));
|
||||
let fns = module.functions.iter().map(|i| {
|
||||
let d = i.def.as_ref().unwrap();
|
||||
(d.result_id.unwrap(), d.operands[1].unwrap_id_ref())
|
||||
});
|
||||
vars.chain(fns).collect()
|
||||
}
|
||||
|
||||
fn fn_parameters<'a>(
|
||||
module: &'a Module,
|
||||
defs: &DefAnalyzer<'_>,
|
||||
id: Word,
|
||||
) -> impl IntoIterator<Item = Word> + 'a {
|
||||
let def_inst = defs
|
||||
.def(id)
|
||||
.unwrap_or_else(|| panic!("Need a matching op for ID {}", id));
|
||||
|
||||
match def_inst.class.opcode {
|
||||
Op::Variable => &[],
|
||||
Op::Function => {
|
||||
&module
|
||||
.functions
|
||||
.iter()
|
||||
.find(|f| f.def.as_ref().unwrap().result_id == def_inst.result_id)
|
||||
.unwrap()
|
||||
.parameters as &[Instruction]
|
||||
}
|
||||
_ => panic!("Unexpected op"),
|
||||
}
|
||||
.iter()
|
||||
.map(|p| p.result_id.unwrap())
|
||||
fn fn_parameters(module: &Module) -> HashMap<Word, Vec<Word>> {
|
||||
module
|
||||
.functions
|
||||
.iter()
|
||||
.map(|f| {
|
||||
let params = f.parameters.iter().map(|i| i.result_id.unwrap()).collect();
|
||||
(f.def_id().unwrap(), params)
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn check_tys_equal(
|
||||
defs: &DefAnalyzer<'_>,
|
||||
name: &str,
|
||||
import_type_id: Word,
|
||||
export_type_id: Word,
|
||||
) -> Result<()> {
|
||||
let import_type = defs.def(import_type_id).unwrap();
|
||||
let export_type = defs.def(export_type_id).unwrap();
|
||||
|
||||
let imp = trans_aggregate_type(defs, import_type);
|
||||
let exp = trans_aggregate_type(defs, export_type);
|
||||
|
||||
if imp == exp {
|
||||
fn check_tys_equal(name: &str, import_type: Word, export_type: Word) -> Result<()> {
|
||||
if import_type == export_type {
|
||||
Ok(())
|
||||
} else {
|
||||
Err(LinkerError::TypeMismatch {
|
||||
name: name.to_string(),
|
||||
import_type: print_type(defs, import_type),
|
||||
export_type: print_type(defs, export_type),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -3,19 +3,16 @@ mod test;
|
||||
|
||||
mod capability_computation;
|
||||
mod dce;
|
||||
mod def_analyzer;
|
||||
mod duplicates;
|
||||
mod import_export_link;
|
||||
mod inline;
|
||||
mod mem2reg;
|
||||
mod simple_passes;
|
||||
mod structurizer;
|
||||
mod ty;
|
||||
mod zombies;
|
||||
|
||||
use def_analyzer::DefAnalyzer;
|
||||
use rspirv::binary::Consumer;
|
||||
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader, Operand};
|
||||
use rspirv::dr::{Block, Instruction, Loader, Module, ModuleHeader};
|
||||
use rspirv::spirv::{Op, Word};
|
||||
use rustc_session::Session;
|
||||
use std::collections::HashMap;
|
||||
@ -27,12 +24,8 @@ pub enum LinkerError {
|
||||
UnresolvedSymbol(String),
|
||||
#[error("Multiple exports found for {:?}", .0)]
|
||||
MultipleExports(String),
|
||||
#[error("Types mismatch for {:?}, imported with type {:?}, exported with type {:?}", .name, .import_type, .export_type)]
|
||||
TypeMismatch {
|
||||
name: String,
|
||||
import_type: String,
|
||||
export_type: String,
|
||||
},
|
||||
#[error("Types mismatch for {:?}", .name)]
|
||||
TypeMismatch { name: String },
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, LinkerError>;
|
||||
@ -51,18 +44,6 @@ fn id(header: &mut ModuleHeader) -> Word {
|
||||
result
|
||||
}
|
||||
|
||||
fn print_type(defs: &DefAnalyzer<'_>, ty: &Instruction) -> String {
|
||||
format!("{}", ty::trans_aggregate_type(defs, ty).unwrap())
|
||||
}
|
||||
|
||||
fn extract_literal_int_as_u64(op: &Operand) -> u64 {
|
||||
match op {
|
||||
Operand::LiteralInt32(v) => (*v).into(),
|
||||
Operand::LiteralInt64(v) => *v,
|
||||
_ => panic!("Unexpected literal int"),
|
||||
}
|
||||
}
|
||||
|
||||
fn apply_rewrite_rules(rewrite_rules: &HashMap<Word, Word>, blocks: &mut [Block]) {
|
||||
let apply = |inst: &mut Instruction| {
|
||||
if let Some(ref mut id) = &mut inst.result_id {
|
||||
@ -127,12 +108,6 @@ pub fn link(sess: Option<&Session>, inputs: &mut [&mut Module], opts: &Options)
|
||||
output
|
||||
};
|
||||
|
||||
// find import / export pairs
|
||||
{
|
||||
let _timer = timer("link_find_pairs");
|
||||
import_export_link::run(&mut output)?;
|
||||
}
|
||||
|
||||
// remove duplicates (https://github.com/KhronosGroup/SPIRV-Tools/blob/e7866de4b1dc2a7e8672867caeb0bdca49f458d3/source/opt/remove_duplicates_pass.cpp)
|
||||
{
|
||||
let _timer = timer("link_remove_duplicates");
|
||||
@ -143,6 +118,12 @@ pub fn link(sess: Option<&Session>, inputs: &mut [&mut Module], opts: &Options)
|
||||
// jb-todo: strip identical OpDecoration / OpDecorationGroups
|
||||
}
|
||||
|
||||
// find import / export pairs
|
||||
{
|
||||
let _timer = timer("link_find_pairs");
|
||||
import_export_link::run(&mut output)?;
|
||||
}
|
||||
|
||||
{
|
||||
let _timer = timer("link_remove_zombies");
|
||||
zombies::remove_zombies(sess, &mut output);
|
||||
|
@ -229,8 +229,6 @@ fn type_mismatch() -> Result<()> {
|
||||
result.err(),
|
||||
Some(LinkerError::TypeMismatch {
|
||||
name: "foo".to_string(),
|
||||
import_type: "f32".to_string(),
|
||||
export_type: "u32".to_string(),
|
||||
})
|
||||
);
|
||||
Ok(())
|
||||
|
@ -1,230 +0,0 @@
|
||||
use super::{extract_literal_int_as_u64, DefAnalyzer};
|
||||
use rspirv::dr::Instruction;
|
||||
use rspirv::spirv::{AccessQualifier, Dim, ImageFormat, Op, StorageClass};
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
pub enum ScalarType {
|
||||
Void,
|
||||
Bool,
|
||||
Int { width: u32, signed: bool },
|
||||
Float { width: u32 },
|
||||
Opaque { name: String },
|
||||
Event,
|
||||
DeviceEvent,
|
||||
ReserveId,
|
||||
Queue,
|
||||
Pipe,
|
||||
ForwardPointer { storage_class: StorageClass },
|
||||
PipeStorage,
|
||||
NamedBarrier,
|
||||
Sampler,
|
||||
}
|
||||
|
||||
fn trans_scalar_type(inst: &Instruction) -> Option<ScalarType> {
|
||||
Some(match inst.class.opcode {
|
||||
Op::TypeVoid => ScalarType::Void,
|
||||
Op::TypeBool => ScalarType::Bool,
|
||||
Op::TypeEvent => ScalarType::Event,
|
||||
Op::TypeDeviceEvent => ScalarType::DeviceEvent,
|
||||
Op::TypeReserveId => ScalarType::ReserveId,
|
||||
Op::TypeQueue => ScalarType::Queue,
|
||||
Op::TypePipe => ScalarType::Pipe,
|
||||
Op::TypePipeStorage => ScalarType::PipeStorage,
|
||||
Op::TypeNamedBarrier => ScalarType::NamedBarrier,
|
||||
Op::TypeSampler => ScalarType::Sampler,
|
||||
Op::TypeForwardPointer => ScalarType::ForwardPointer {
|
||||
storage_class: inst.operands[0].unwrap_storage_class(),
|
||||
},
|
||||
Op::TypeInt => ScalarType::Int {
|
||||
width: inst.operands[0].unwrap_literal_int32(),
|
||||
signed: inst.operands[1].unwrap_literal_int32() != 0,
|
||||
},
|
||||
Op::TypeFloat => ScalarType::Float {
|
||||
width: inst.operands[0].unwrap_literal_int32(),
|
||||
},
|
||||
Op::TypeOpaque => ScalarType::Opaque {
|
||||
name: inst.operands[0].unwrap_literal_string().to_string(),
|
||||
},
|
||||
_ => return None,
|
||||
})
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ScalarType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match *self {
|
||||
Self::Void => f.write_str("void"),
|
||||
Self::Bool => f.write_str("bool"),
|
||||
Self::Int { width, signed } => {
|
||||
if signed {
|
||||
write!(f, "i{}", width)
|
||||
} else {
|
||||
write!(f, "u{}", width)
|
||||
}
|
||||
}
|
||||
Self::Float { width } => write!(f, "f{}", width),
|
||||
Self::Opaque { ref name } => write!(f, "Opaque{{{}}}", name),
|
||||
Self::Event => f.write_str("Event"),
|
||||
Self::DeviceEvent => f.write_str("DeviceEvent"),
|
||||
Self::ReserveId => f.write_str("ReserveId"),
|
||||
Self::Queue => f.write_str("Queue"),
|
||||
Self::Pipe => f.write_str("Pipe"),
|
||||
Self::ForwardPointer { storage_class } => {
|
||||
write!(f, "ForwardPointer{{{:?}}}", storage_class)
|
||||
}
|
||||
Self::PipeStorage => f.write_str("PipeStorage"),
|
||||
Self::NamedBarrier => f.write_str("NamedBarrier"),
|
||||
Self::Sampler => f.write_str("Sampler"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(PartialEq, Debug)]
|
||||
#[allow(dead_code)]
|
||||
pub enum AggregateType {
|
||||
Scalar(ScalarType),
|
||||
Array {
|
||||
ty: Box<AggregateType>,
|
||||
len: u64,
|
||||
},
|
||||
Pointer {
|
||||
ty: Box<AggregateType>,
|
||||
storage_class: StorageClass,
|
||||
},
|
||||
Image {
|
||||
ty: Box<AggregateType>,
|
||||
dim: Dim,
|
||||
depth: u32,
|
||||
arrayed: u32,
|
||||
multi_sampled: u32,
|
||||
sampled: u32,
|
||||
format: ImageFormat,
|
||||
access: Option<AccessQualifier>,
|
||||
},
|
||||
SampledImage {
|
||||
ty: Box<AggregateType>,
|
||||
},
|
||||
Aggregate(Vec<AggregateType>),
|
||||
Function(Vec<AggregateType>, Box<AggregateType>),
|
||||
}
|
||||
|
||||
pub(crate) fn trans_aggregate_type(
|
||||
def: &DefAnalyzer<'_>,
|
||||
inst: &Instruction,
|
||||
) -> Option<AggregateType> {
|
||||
Some(match inst.class.opcode {
|
||||
Op::TypeArray => {
|
||||
let len_def = def.op_def(&inst.operands[1]);
|
||||
assert!(len_def.class.opcode == Op::Constant); // don't support spec constants yet
|
||||
|
||||
let len_value = extract_literal_int_as_u64(&len_def.operands[0]);
|
||||
|
||||
AggregateType::Array {
|
||||
ty: Box::new(
|
||||
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
|
||||
.expect("Expect base type for OpTypeArray"),
|
||||
),
|
||||
len: len_value,
|
||||
}
|
||||
}
|
||||
Op::TypePointer => AggregateType::Pointer {
|
||||
storage_class: inst.operands[0].unwrap_storage_class(),
|
||||
ty: Box::new(
|
||||
trans_aggregate_type(def, &def.op_def(&inst.operands[1]))
|
||||
.expect("Expect base type for OpTypePointer"),
|
||||
),
|
||||
},
|
||||
Op::TypeRuntimeArray | Op::TypeVector | Op::TypeMatrix | Op::TypeSampledImage => {
|
||||
AggregateType::Aggregate(
|
||||
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
|
||||
.map_or_else(Vec::new, |v| vec![v]),
|
||||
)
|
||||
}
|
||||
Op::TypeStruct => {
|
||||
let mut types = vec![];
|
||||
for operand in inst.operands.iter() {
|
||||
let op_def = def.op_def(operand);
|
||||
|
||||
match trans_aggregate_type(def, &op_def) {
|
||||
Some(ty) => types.push(ty),
|
||||
None => panic!("Expected type"),
|
||||
}
|
||||
}
|
||||
|
||||
AggregateType::Aggregate(types)
|
||||
}
|
||||
Op::TypeFunction => {
|
||||
let mut parameters = vec![];
|
||||
let ret = trans_aggregate_type(def, &def.op_def(&inst.operands[0])).unwrap();
|
||||
for operand in inst.operands.iter().skip(1) {
|
||||
let op_def = def.op_def(operand);
|
||||
|
||||
match trans_aggregate_type(def, &op_def) {
|
||||
Some(ty) => parameters.push(ty),
|
||||
None => panic!("Expected type"),
|
||||
}
|
||||
}
|
||||
|
||||
AggregateType::Function(parameters, Box::new(ret))
|
||||
}
|
||||
Op::TypeImage => AggregateType::Image {
|
||||
ty: Box::new(
|
||||
trans_aggregate_type(def, &def.op_def(&inst.operands[0]))
|
||||
.expect("Expect base type for OpTypeImage"),
|
||||
),
|
||||
dim: inst.operands[1].unwrap_dim(),
|
||||
depth: inst.operands[2].unwrap_literal_int32(),
|
||||
arrayed: inst.operands[3].unwrap_literal_int32(),
|
||||
multi_sampled: inst.operands[4].unwrap_literal_int32(),
|
||||
sampled: inst.operands[5].unwrap_literal_int32(),
|
||||
format: inst.operands[6].unwrap_image_format(),
|
||||
access: inst.operands.get(7).map(|op| op.unwrap_access_qualifier()),
|
||||
},
|
||||
_ => {
|
||||
if let Some(ty) = trans_scalar_type(inst) {
|
||||
AggregateType::Scalar(ty)
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AggregateType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Scalar(scalar) => write!(f, "{}", scalar),
|
||||
Self::Array { ty, len } => write!(f, "[{}; {}]", ty, len),
|
||||
Self::Pointer { ty, storage_class } => write!(f, "*{{{:?}}} {}", storage_class, ty),
|
||||
Self::Image {
|
||||
ty,
|
||||
dim,
|
||||
depth,
|
||||
arrayed,
|
||||
multi_sampled,
|
||||
sampled,
|
||||
format,
|
||||
access,
|
||||
} => write!(
|
||||
f,
|
||||
"Image {{ {}, dim:{:?}, depth:{}, arrayed:{}, \
|
||||
multi_sampled:{}, sampled:{}, format:{:?}, access:{:?} }}",
|
||||
ty, dim, depth, arrayed, multi_sampled, sampled, format, access
|
||||
),
|
||||
Self::SampledImage { ty } => write!(f, "SampledImage{{{}}}", ty),
|
||||
Self::Aggregate(agg) => {
|
||||
f.write_str("struct {")?;
|
||||
for elem in agg {
|
||||
write!(f, " {},", elem)?;
|
||||
}
|
||||
f.write_str(" }")
|
||||
}
|
||||
Self::Function(args, ret) => {
|
||||
f.write_str("fn(")?;
|
||||
for elem in args {
|
||||
write!(f, " {},", elem)?;
|
||||
}
|
||||
write!(f, " ) -> {}", ret)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user