From cc916c784efe84e5c7247e7286e0fb360c1ea2cc Mon Sep 17 00:00:00 2001 From: khyperia Date: Tue, 22 Sep 2020 15:51:30 +0200 Subject: [PATCH] Documentation, code shuffling, and RPO block sorting --- Cargo.lock | 1 - rustc_codegen_spirv/Cargo.toml | 1 - rustc_codegen_spirv/src/abi.rs | 18 +- .../src/builder/builder_methods.rs | 151 +++++++++++++++- rustc_codegen_spirv/src/finalizing_passes.rs | 75 ++++---- rustc_codegen_spirv/src/link.rs | 23 +++ rustc_codegen_spirv/src/spirv_type.rs | 165 ++---------------- rustc_codegen_spirv/src/symbols.rs | 12 +- spirv-std/src/lib.rs | 10 +- 9 files changed, 233 insertions(+), 223 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 536c8537cb..e972c9c3ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -299,7 +299,6 @@ dependencies = [ "rspirv-linker", "tar", "tempfile", - "topological-sort", ] [[package]] diff --git a/rustc_codegen_spirv/Cargo.toml b/rustc_codegen_spirv/Cargo.toml index 239a970af3..6eee56a14d 100644 --- a/rustc_codegen_spirv/Cargo.toml +++ b/rustc_codegen_spirv/Cargo.toml @@ -16,7 +16,6 @@ crate-type = ["dylib"] rspirv = "0.7.0" rspirv-linker = { path = "../rspirv-linker" } tar = "0.4" -topological-sort = "0.1" [dev-dependencies] pretty_assertions = "0.6" diff --git a/rustc_codegen_spirv/src/abi.rs b/rustc_codegen_spirv/src/abi.rs index 33206d4c09..6c90ffbeaf 100644 --- a/rustc_codegen_spirv/src/abi.rs +++ b/rustc_codegen_spirv/src/abi.rs @@ -5,7 +5,7 @@ use crate::spirv_type::SpirvType; use crate::symbols::{parse_attr, SpirvAttribute}; use rspirv::spirv::{StorageClass, Word}; use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout}; -use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind}; +use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind, TypeAndMut}; use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind}; use rustc_target::abi::{ Abi, Align, FieldsShape, LayoutOf, Primitive, Scalar, Size, TagEncoding, Variants, @@ -420,7 +420,7 @@ fn dig_scalar_pointee<'tcx>( index: Option, ) -> (Option, PointeeTy<'tcx>) { match *ty.ty.kind() { - TyKind::Ref(_region, elem_ty, _mutability) => { + TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => { let elem = cx.layout_of(elem_ty); match index { None => (None, PointeeTy::Ty(elem)), @@ -437,20 +437,6 @@ fn dig_scalar_pointee<'tcx>( } } } - TyKind::RawPtr(type_and_mut) => { - let elem = cx.layout_of(type_and_mut.ty); - match index { - None => (None, PointeeTy::Ty(elem)), - Some(index) => { - if elem.is_unsized() { - dig_scalar_pointee(cx, ty.field(cx, index), None) - } else { - // Same comment as TyKind::Ref - (None, PointeeTy::Ty(elem)) - } - } - } - } TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)), TyKind::Adt(def, _) if def.is_box() => { let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty())); diff --git a/rustc_codegen_spirv/src/builder/builder_methods.rs b/rustc_codegen_spirv/src/builder/builder_methods.rs index 0369ae4c7f..d5f68aef71 100644 --- a/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -2,7 +2,7 @@ use super::Builder; use crate::builder_spirv::{BuilderCursor, SpirvConst, SpirvValueExt}; use crate::spirv_type::SpirvType; use rspirv::dr::{InsertPoint, Instruction, Operand}; -use rspirv::spirv::{MemorySemantics, Op, Scope, StorageClass}; +use rspirv::spirv::{MemorySemantics, Op, Scope, StorageClass, Word}; use rustc_codegen_ssa::common::{ AtomicOrdering, AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope, }; @@ -65,6 +65,147 @@ fn ordering_to_semantics(ordering: AtomicOrdering) -> MemorySemantics { } } +fn memset_fill_u16(b: u8) -> u16 { + b as u16 | ((b as u16) << 8) +} + +fn memset_fill_u32(b: u8) -> u32 { + b as u32 | ((b as u32) << 8) | ((b as u32) << 16) | ((b as u32) << 24) +} + +fn memset_fill_u64(b: u8) -> u64 { + b as u64 + | ((b as u64) << 8) + | ((b as u64) << 16) + | ((b as u64) << 24) + | ((b as u64) << 32) + | ((b as u64) << 40) + | ((b as u64) << 48) + | ((b as u64) << 56) +} + +fn memset_dynamic_scalar<'a, 'tcx>( + builder: &Builder<'a, 'tcx>, + fill_var: Word, + byte_width: usize, + is_float: bool, +) -> Word { + let composite_type = SpirvType::Vector { + element: SpirvType::Integer(8, false).def(builder), + count: byte_width as u32, + } + .def(builder); + let composite = builder + .emit() + .composite_construct( + composite_type, + None, + std::iter::repeat(fill_var).take(byte_width), + ) + .unwrap(); + let result_type = if is_float { + SpirvType::Float(byte_width as u32 * 8) + } else { + SpirvType::Integer(byte_width as u32 * 8, false) + }; + builder + .emit() + .bitcast(result_type.def(builder), None, composite) + .unwrap() +} + +impl<'a, 'tcx> Builder<'a, 'tcx> { + fn memset_const_pattern(&self, ty: &SpirvType, fill_byte: u8) -> Word { + match *ty { + SpirvType::Void => panic!("memset invalid on void pattern"), + SpirvType::Bool => panic!("memset invalid on bool pattern"), + SpirvType::Integer(width, _signedness) => match width { + 8 => self.constant_u8(fill_byte).def, + 16 => self.constant_u16(memset_fill_u16(fill_byte)).def, + 32 => self.constant_u32(memset_fill_u32(fill_byte)).def, + 64 => self.constant_u64(memset_fill_u64(fill_byte)).def, + _ => panic!("memset on integer width {} not implemented yet", width), + }, + SpirvType::Float(width) => match width { + 32 => { + self.constant_f32(f32::from_bits(memset_fill_u32(fill_byte))) + .def + } + 64 => { + self.constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) + .def + } + _ => panic!("memset on float width {} not implemented yet", width), + }, + SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), + SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), + SpirvType::Vector { element, count } => { + let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); + self.constant_composite(ty.def(self), vec![elem_pat; count as usize]) + .def + } + SpirvType::Array { element, count } => { + let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); + let count = self.builder.lookup_const_u64(count).unwrap() as usize; + self.constant_composite(ty.def(self), vec![elem_pat; count]) + .def + } + SpirvType::RuntimeArray { .. } => { + panic!("memset on runtime arrays not implemented yet") + } + SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), + SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), + } + } + + fn memset_dynamic_pattern(&self, ty: &SpirvType, fill_var: Word) -> Word { + match *ty { + SpirvType::Void => panic!("memset invalid on void pattern"), + SpirvType::Bool => panic!("memset invalid on bool pattern"), + SpirvType::Integer(width, _signedness) => match width { + 8 => fill_var, + 16 => memset_dynamic_scalar(self, fill_var, 2, false), + 32 => memset_dynamic_scalar(self, fill_var, 4, false), + 64 => memset_dynamic_scalar(self, fill_var, 8, false), + _ => panic!("memset on integer width {} not implemented yet", width), + }, + SpirvType::Float(width) => match width { + 32 => memset_dynamic_scalar(self, fill_var, 4, true), + 64 => memset_dynamic_scalar(self, fill_var, 8, true), + _ => panic!("memset on float width {} not implemented yet", width), + }, + SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), + SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), + SpirvType::Array { element, count } => { + let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var); + let count = self.builder.lookup_const_u64(count).unwrap() as usize; + self.emit() + .composite_construct( + ty.def(self), + None, + std::iter::repeat(elem_pat).take(count), + ) + .unwrap() + } + SpirvType::Vector { element, count } => { + let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var); + self.emit() + .composite_construct( + ty.def(self), + None, + std::iter::repeat(elem_pat).take(count as usize), + ) + .unwrap() + } + SpirvType::RuntimeArray { .. } => { + panic!("memset on runtime arrays not implemented yet") + } + SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), + SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), + } + } +} + impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn with_cx(cx: &'a Self::CodegenCx) -> Self { // Note: all defaults here *must* be filled out by position_at_end @@ -1008,8 +1149,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvConst::U32(_, v) => v as usize, other => panic!("memset size constant value not supported: {:?}", other), }; - let pat = elem_ty_spv - .memset_const_pattern(self, fill_byte) + let pat = self + .memset_const_pattern(&elem_ty_spv, fill_byte) .with_type(elem_ty); let elem_ty_sizeof = elem_ty_spv .sizeof(self) @@ -1033,8 +1174,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvConst::U32(_, v) => v as usize, other => panic!("memset size constant value not supported: {:?}", other), }; - let pat = elem_ty_spv - .memset_dynamic_pattern(self, fill_byte.def) + let pat = self + .memset_dynamic_pattern(&elem_ty_spv, fill_byte.def) .with_type(elem_ty); let elem_ty_sizeof = elem_ty_spv .sizeof(self) diff --git a/rustc_codegen_spirv/src/finalizing_passes.rs b/rustc_codegen_spirv/src/finalizing_passes.rs index 2617ed56bc..85dffd1f08 100644 --- a/rustc_codegen_spirv/src/finalizing_passes.rs +++ b/rustc_codegen_spirv/src/finalizing_passes.rs @@ -34,6 +34,7 @@ fn label_of(block: &Block) -> Word { block.label.as_ref().unwrap().result_id.unwrap() } +// TODO: Move this to the linker. pub fn delete_dead_blocks(func: &mut Function) { if func.blocks.len() < 2 { return; @@ -54,36 +55,45 @@ pub fn delete_dead_blocks(func: &mut Function) { func.blocks.retain(|b| visited.contains(&label_of(b))) } +// TODO: Do we move this to the linker? pub fn block_ordering_pass(func: &mut Function) { if func.blocks.len() < 2 { return; } - let mut graph = func - .blocks - .iter() - .map(|block| (label_of(block), outgoing_edges(block))) - .collect(); - let entry_label = label_of(&func.blocks[0]); - delete_backedges(&mut graph, entry_label); - - let mut sorter = topological_sort::TopologicalSort::::new(); - for (key, values) in graph { - for value in values { - sorter.add_dependency(key, value); + fn visit_postorder( + func: &Function, + visited: &mut HashSet, + postorder: &mut Vec, + current: Word, + ) { + if !visited.insert(current) { + return; } + let current_block = func.blocks.iter().find(|b| label_of(b) == current).unwrap(); + // Reverse the order, so reverse-postorder keeps things tidy + for &outgoing in outgoing_edges(current_block).iter().rev() { + visit_postorder(func, visited, postorder, outgoing); + } + postorder.push(current); } + + let mut visited = HashSet::new(); + let mut postorder = Vec::new(); + + let entry_label = label_of(&func.blocks[0]); + visit_postorder(func, &mut visited, &mut postorder, entry_label); + let mut old_blocks = replace(&mut func.blocks, Vec::new()); - while let Some(item) = sorter.pop() { - let index = old_blocks.iter().position(|b| label_of(b) == item).unwrap(); + // Order blocks according to reverse postorder + for &block in postorder.iter().rev() { + let index = old_blocks + .iter() + .position(|b| label_of(b) == block) + .unwrap(); func.blocks.push(old_blocks.remove(index)); } - assert!(sorter.is_empty()); assert!(old_blocks.is_empty()); - assert_eq!( - label_of(&func.blocks[0]), - entry_label, - "Topo sorter did something weird (unreachable blocks?)" - ); + assert_eq!(label_of(&func.blocks[0]), entry_label,); } fn outgoing_edges(block: &Block) -> Vec { @@ -113,28 +123,3 @@ fn outgoing_edges(block: &Block) -> Vec { _ => panic!("Invalid block terminator: {:?}", terminator), } } - -fn delete_backedges(graph: &mut HashMap>, entry: Word) { - // TODO: This has extremely bad runtime - let mut backedges = HashSet::new(); - fn re( - graph: &HashMap>, - entry: Word, - stack: &mut Vec, - backedges: &mut HashSet<(Word, Word)>, - ) { - stack.push(entry); - for &item in &graph[&entry] { - if stack.contains(&item) { - backedges.insert((entry, item)); - } else if !backedges.contains(&(entry, item)) { - re(graph, item, stack, backedges); - } - } - assert_eq!(stack.pop(), Some(entry)); - } - re(graph, entry, &mut Vec::new(), &mut backedges); - for (from, to) in backedges { - graph.get_mut(&from).unwrap().retain(|&o| o != to); - } -} diff --git a/rustc_codegen_spirv/src/link.rs b/rustc_codegen_spirv/src/link.rs index bf81c42578..ac6292fa41 100644 --- a/rustc_codegen_spirv/src/link.rs +++ b/rustc_codegen_spirv/src/link.rs @@ -256,13 +256,18 @@ pub fn read_metadata(rlib: &Path) -> MetadataRef { panic!("No .metadata file in rlib: {:?}", rlib); } +/// This is the actual guts of linking: the rest of the link-related functions are just digging through rustc's +/// shenanigans to collect all the object files we need to link. fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) { let mut modules = Vec::new(); + // `objects` are the plain obj files we need to link - usually produced by the final crate. for obj in objects { let mut bytes = Vec::new(); File::open(obj).unwrap().read_to_end(&mut bytes).unwrap(); modules.push(load(&bytes)); } + // `rlibs` are archive files we've created in `create_archive`, usually produced by crates that are being + // referenced. We need to unpack them and add the modules inside. for rlib in rlibs { for entry in Archive::new(File::open(rlib).unwrap()).entries().unwrap() { let mut entry = entry.unwrap(); @@ -276,6 +281,21 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) { let mut module_refs = modules.iter_mut().collect::>(); + { + let path = Path::new("/home/khyperia/tmp"); + if path.is_file() { + std::fs::remove_file(path).unwrap(); + } + std::fs::create_dir_all(path).unwrap(); + for (num, module) in module_refs.iter().enumerate() { + File::create(path.join(format!("mod_{}.spv", num))) + .unwrap() + .write_all(crate::slice_u32_to_u8(&module.assemble())) + .unwrap(); + } + } + + // Do the link... let result = match rspirv_linker::link( &mut module_refs, &rspirv_linker::Options { @@ -302,6 +322,7 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) { } }; + // And finally write out the linked binary. use rspirv::binary::Assemble; File::create(out_filename) .unwrap() @@ -315,6 +336,8 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) { } } +/// As of right now, this is essentially a no-op, just plumbing through all the files. +// TODO: WorkProduct impl pub(crate) fn run_thin( cgcx: &CodegenContext, modules: Vec<(String, SpirvThinBuffer)>, diff --git a/rustc_codegen_spirv/src/spirv_type.rs b/rustc_codegen_spirv/src/spirv_type.rs index cd37774002..e6f89dc4ec 100644 --- a/rustc_codegen_spirv/src/spirv_type.rs +++ b/rustc_codegen_spirv/src/spirv_type.rs @@ -1,5 +1,4 @@ use crate::abi::RecursivePointeeCache; -use crate::builder::Builder; use crate::codegen_cx::CodegenCx; use rspirv::dr::Operand; use rspirv::spirv::{Decoration, StorageClass, Word}; @@ -11,6 +10,11 @@ use std::iter::once; use std::lazy::SyncLazy; use std::sync::Mutex; +/// Spir-v types are represented as simple Words, which are the result_id of instructions like OpTypeInteger. Sometimes, +/// however, we want to inspect one of these Words and ask questions like "Is this an OpTypeInteger? How many bits does +/// it have?". This struct holds all of that information. All types that are emitted are registered in CodegenCx, so you +/// can always look up the definition of a Word via cx.lookup_type. Note that this type doesn't actually store the +/// result_id of the type declaration instruction, merely the contents. #[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)] pub enum SpirvType { Void, @@ -52,55 +56,6 @@ pub enum SpirvType { }, } -fn memset_fill_u16(b: u8) -> u16 { - b as u16 | ((b as u16) << 8) -} - -fn memset_fill_u32(b: u8) -> u32 { - b as u32 | ((b as u32) << 8) | ((b as u32) << 16) | ((b as u32) << 24) -} - -fn memset_fill_u64(b: u8) -> u64 { - b as u64 - | ((b as u64) << 8) - | ((b as u64) << 16) - | ((b as u64) << 24) - | ((b as u64) << 32) - | ((b as u64) << 40) - | ((b as u64) << 48) - | ((b as u64) << 56) -} - -fn memset_dynamic_scalar<'a, 'tcx>( - builder: &Builder<'a, 'tcx>, - fill_var: Word, - byte_width: usize, - is_float: bool, -) -> Word { - let composite_type = SpirvType::Vector { - element: SpirvType::Integer(8, false).def(builder), - count: byte_width as u32, - } - .def(builder); - let composite = builder - .emit() - .composite_construct( - composite_type, - None, - std::iter::repeat(fill_var).take(byte_width), - ) - .unwrap(); - let result_type = if is_float { - SpirvType::Float(byte_width as u32 * 8) - } else { - SpirvType::Integer(byte_width as u32 * 8, false) - }; - builder - .emit() - .bitcast(result_type.def(builder), None, composite) - .unwrap() -} - impl SpirvType { /// Note: Builder::type_* should be called *nowhere else* but here, to ensure CodegenCx::type_defs stays up-to-date pub fn def<'tcx>(&self, cx: &CodegenCx<'tcx>) -> Word { @@ -211,6 +166,8 @@ impl SpirvType { result } + /// def_with_id is used by the RecursivePointeeCache to handle OpTypeForwardPointer: when emitting the subsequent + /// OpTypePointer, the ID is already known and must be re-used. pub fn def_with_id<'tcx>(&self, cx: &CodegenCx<'tcx>, id: Word) -> Word { if let Some(cached) = cx.type_cache.get(self) { assert_eq!(cached, id); @@ -236,6 +193,7 @@ impl SpirvType { result } + /// Use this if you want a pretty type printing that recursively prints the types within (e.g. struct fields) pub fn debug<'cx, 'tcx>( self, id: Word, @@ -281,106 +239,6 @@ impl SpirvType { SpirvType::Function { .. } => cx.tcx.data_layout.pointer_align.abi, } } - - pub fn memset_const_pattern<'tcx>(&self, cx: &CodegenCx<'tcx>, fill_byte: u8) -> Word { - match *self { - SpirvType::Void => panic!("memset invalid on void pattern"), - SpirvType::Bool => panic!("memset invalid on bool pattern"), - SpirvType::Integer(width, _signedness) => match width { - 8 => cx.constant_u8(fill_byte).def, - 16 => cx.constant_u16(memset_fill_u16(fill_byte)).def, - 32 => cx.constant_u32(memset_fill_u32(fill_byte)).def, - 64 => cx.constant_u64(memset_fill_u64(fill_byte)).def, - _ => panic!("memset on integer width {} not implemented yet", width), - }, - SpirvType::Float(width) => match width { - 32 => { - cx.constant_f32(f32::from_bits(memset_fill_u32(fill_byte))) - .def - } - 64 => { - cx.constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) - .def - } - _ => panic!("memset on float width {} not implemented yet", width), - }, - SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), - SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), - SpirvType::Vector { element, count } => { - let elem_pat = cx.lookup_type(element).memset_const_pattern(cx, fill_byte); - cx.constant_composite(self.def(cx), vec![elem_pat; count as usize]) - .def - } - SpirvType::Array { element, count } => { - let elem_pat = cx.lookup_type(element).memset_const_pattern(cx, fill_byte); - let count = cx.builder.lookup_const_u64(count).unwrap() as usize; - cx.constant_composite(self.def(cx), vec![elem_pat; count]) - .def - } - SpirvType::RuntimeArray { .. } => { - panic!("memset on runtime arrays not implemented yet") - } - SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), - SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), - } - } - - pub fn memset_dynamic_pattern<'a, 'tcx>( - &self, - builder: &Builder<'a, 'tcx>, - fill_var: Word, - ) -> Word { - match *self { - SpirvType::Void => panic!("memset invalid on void pattern"), - SpirvType::Bool => panic!("memset invalid on bool pattern"), - SpirvType::Integer(width, _signedness) => match width { - 8 => fill_var, - 16 => memset_dynamic_scalar(builder, fill_var, 2, false), - 32 => memset_dynamic_scalar(builder, fill_var, 4, false), - 64 => memset_dynamic_scalar(builder, fill_var, 8, false), - _ => panic!("memset on integer width {} not implemented yet", width), - }, - SpirvType::Float(width) => match width { - 32 => memset_dynamic_scalar(builder, fill_var, 4, true), - 64 => memset_dynamic_scalar(builder, fill_var, 8, true), - _ => panic!("memset on float width {} not implemented yet", width), - }, - SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"), - SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"), - SpirvType::Array { element, count } => { - let elem_pat = builder - .lookup_type(element) - .memset_dynamic_pattern(builder, fill_var); - let count = builder.builder.lookup_const_u64(count).unwrap() as usize; - builder - .emit() - .composite_construct( - self.def(builder), - None, - std::iter::repeat(elem_pat).take(count), - ) - .unwrap() - } - SpirvType::Vector { element, count } => { - let elem_pat = builder - .lookup_type(element) - .memset_dynamic_pattern(builder, fill_var); - builder - .emit() - .composite_construct( - self.def(builder), - None, - std::iter::repeat(elem_pat).take(count as usize), - ) - .unwrap() - } - SpirvType::RuntimeArray { .. } => { - panic!("memset on runtime arrays not implemented yet") - } - SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"), - SpirvType::Function { .. } => panic!("memset on functions not implemented yet"), - } - } } pub struct SpirvTypePrinter<'cx, 'tcx> { @@ -389,6 +247,9 @@ pub struct SpirvTypePrinter<'cx, 'tcx> { cx: &'cx CodegenCx<'tcx>, } +/// Types can be recursive, e.g. a struct can contain a pointer to itself. So, we need to keep track of a stack of what +/// types are currently being printed, to not infinitely loop. Unfortunately, unlike fmt::Display, we can't easily pass +/// down the "stack" of currently-being-printed types, so we use a global static. static DEBUG_STACK: SyncLazy>> = SyncLazy::new(|| Mutex::new(Vec::new())); impl fmt::Debug for SpirvTypePrinter<'_, '_> { @@ -497,6 +358,10 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> { } } +/// Types can be recursive, e.g. a struct can contain a pointer to itself. So, we need to keep track of a stack of what +/// types are currently being printed, to not infinitely loop. So, we only use fmt::Display::fmt as an "entry point", and +/// then call through to our own (recursive) custom function that has a parameter for the current stack. Make sure to not +/// call Display on a type inside the custom function! impl fmt::Display for SpirvTypePrinter<'_, '_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.display(&mut Vec::new(), f) diff --git a/rustc_codegen_spirv/src/symbols.rs b/rustc_codegen_spirv/src/symbols.rs index b075f52059..56db2b9a54 100644 --- a/rustc_codegen_spirv/src/symbols.rs +++ b/rustc_codegen_spirv/src/symbols.rs @@ -4,6 +4,11 @@ use rustc_ast::ast::{AttrKind, Attribute}; use rustc_span::symbol::Symbol; use std::collections::HashMap; +/// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords. +/// Symbols are interned, as in, they don't actually store the string itself inside them, but rather an index into a +/// global table of strings. Then, whenever a new Symbol is created, the global table is checked to see if the string +/// already exists, deduplicating it if so. This makes things like comparison and cloning really cheap. So, this struct +/// is to allocate all our keywords up front and intern them all, so we can do comparisons really easily and fast. pub struct Symbols { pub spirv: Symbol, pub storage_class: Symbol, @@ -92,10 +97,13 @@ pub enum SpirvAttribute { Entry(ExecutionModel), } -// Note that we could mark thie attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens +// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens // even before we get here :( -/// Returns None if this attribute is not a spirv attribute, or if it's a malformed (and an error is reported). +/// Returns None if this attribute is not a spirv attribute, or if it's malformed (and an error is reported). pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option { + // Example attributes that we parse here: + // #[spirv(storage_class = "uniform")] + // #[spirv(entry = "kernel")] let is_spirv = match attr.kind { AttrKind::Normal(ref item) => { // TODO: We ignore the rest of the path. Is this right? diff --git a/spirv-std/src/lib.rs b/spirv-std/src/lib.rs index 2e7e18751a..9e8ad50a5c 100644 --- a/spirv-std/src/lib.rs +++ b/spirv-std/src/lib.rs @@ -47,9 +47,13 @@ pointer_addrspace!("incoming_ray_payload_nv", IncomingRayPayloadNV); pointer_addrspace!("shader_record_buffer_nv", ShaderRecordBufferNV); pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer); -/// # Safety -/// -/// TODO: Copied from compiler_builtins/mem.rs +/// libcore requires a few external symbols to be defined: +/// https://github.com/rust-lang/rust/blob/c2bc344eb23d8c1d18e803b3f1e631cf99926fbb/library/core/src/lib.rs#L23-L27 +/// TODO: This is copied from compiler_builtins/mem.rs. Can we use that one instead? The note in the above link says +/// "[the symbols] can also be provided by the compiler-builtins crate". The memcpy in compiler_builtins is behind a +/// "mem" feature flag - can we enable that somehow? +/// https://github.com/rust-lang/compiler-builtins/blob/eff506cd49b637f1ab5931625a33cef7e91fbbf6/src/mem.rs#L12-L13 +#[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32 { let mut i = 0;