mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-25 08:14:12 +00:00
Documentation, code shuffling, and RPO block sorting
This commit is contained in:
parent
1b3e5230da
commit
cc916c784e
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -299,7 +299,6 @@ dependencies = [
|
||||
"rspirv-linker",
|
||||
"tar",
|
||||
"tempfile",
|
||||
"topological-sort",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
@ -16,7 +16,6 @@ crate-type = ["dylib"]
|
||||
rspirv = "0.7.0"
|
||||
rspirv-linker = { path = "../rspirv-linker" }
|
||||
tar = "0.4"
|
||||
topological-sort = "0.1"
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = "0.6"
|
||||
|
@ -5,7 +5,7 @@ use crate::spirv_type::SpirvType;
|
||||
use crate::symbols::{parse_attr, SpirvAttribute};
|
||||
use rspirv::spirv::{StorageClass, Word};
|
||||
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
|
||||
use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind};
|
||||
use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind, TypeAndMut};
|
||||
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind};
|
||||
use rustc_target::abi::{
|
||||
Abi, Align, FieldsShape, LayoutOf, Primitive, Scalar, Size, TagEncoding, Variants,
|
||||
@ -420,7 +420,7 @@ fn dig_scalar_pointee<'tcx>(
|
||||
index: Option<usize>,
|
||||
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
|
||||
match *ty.ty.kind() {
|
||||
TyKind::Ref(_region, elem_ty, _mutability) => {
|
||||
TyKind::Ref(_, elem_ty, _) | TyKind::RawPtr(TypeAndMut { ty: elem_ty, .. }) => {
|
||||
let elem = cx.layout_of(elem_ty);
|
||||
match index {
|
||||
None => (None, PointeeTy::Ty(elem)),
|
||||
@ -437,20 +437,6 @@ fn dig_scalar_pointee<'tcx>(
|
||||
}
|
||||
}
|
||||
}
|
||||
TyKind::RawPtr(type_and_mut) => {
|
||||
let elem = cx.layout_of(type_and_mut.ty);
|
||||
match index {
|
||||
None => (None, PointeeTy::Ty(elem)),
|
||||
Some(index) => {
|
||||
if elem.is_unsized() {
|
||||
dig_scalar_pointee(cx, ty.field(cx, index), None)
|
||||
} else {
|
||||
// Same comment as TyKind::Ref
|
||||
(None, PointeeTy::Ty(elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)),
|
||||
TyKind::Adt(def, _) if def.is_box() => {
|
||||
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
|
||||
|
@ -2,7 +2,7 @@ use super::Builder;
|
||||
use crate::builder_spirv::{BuilderCursor, SpirvConst, SpirvValueExt};
|
||||
use crate::spirv_type::SpirvType;
|
||||
use rspirv::dr::{InsertPoint, Instruction, Operand};
|
||||
use rspirv::spirv::{MemorySemantics, Op, Scope, StorageClass};
|
||||
use rspirv::spirv::{MemorySemantics, Op, Scope, StorageClass, Word};
|
||||
use rustc_codegen_ssa::common::{
|
||||
AtomicOrdering, AtomicRmwBinOp, IntPredicate, RealPredicate, SynchronizationScope,
|
||||
};
|
||||
@ -65,6 +65,147 @@ fn ordering_to_semantics(ordering: AtomicOrdering) -> MemorySemantics {
|
||||
}
|
||||
}
|
||||
|
||||
fn memset_fill_u16(b: u8) -> u16 {
|
||||
b as u16 | ((b as u16) << 8)
|
||||
}
|
||||
|
||||
fn memset_fill_u32(b: u8) -> u32 {
|
||||
b as u32 | ((b as u32) << 8) | ((b as u32) << 16) | ((b as u32) << 24)
|
||||
}
|
||||
|
||||
fn memset_fill_u64(b: u8) -> u64 {
|
||||
b as u64
|
||||
| ((b as u64) << 8)
|
||||
| ((b as u64) << 16)
|
||||
| ((b as u64) << 24)
|
||||
| ((b as u64) << 32)
|
||||
| ((b as u64) << 40)
|
||||
| ((b as u64) << 48)
|
||||
| ((b as u64) << 56)
|
||||
}
|
||||
|
||||
fn memset_dynamic_scalar<'a, 'tcx>(
|
||||
builder: &Builder<'a, 'tcx>,
|
||||
fill_var: Word,
|
||||
byte_width: usize,
|
||||
is_float: bool,
|
||||
) -> Word {
|
||||
let composite_type = SpirvType::Vector {
|
||||
element: SpirvType::Integer(8, false).def(builder),
|
||||
count: byte_width as u32,
|
||||
}
|
||||
.def(builder);
|
||||
let composite = builder
|
||||
.emit()
|
||||
.composite_construct(
|
||||
composite_type,
|
||||
None,
|
||||
std::iter::repeat(fill_var).take(byte_width),
|
||||
)
|
||||
.unwrap();
|
||||
let result_type = if is_float {
|
||||
SpirvType::Float(byte_width as u32 * 8)
|
||||
} else {
|
||||
SpirvType::Integer(byte_width as u32 * 8, false)
|
||||
};
|
||||
builder
|
||||
.emit()
|
||||
.bitcast(result_type.def(builder), None, composite)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl<'a, 'tcx> Builder<'a, 'tcx> {
|
||||
fn memset_const_pattern(&self, ty: &SpirvType, fill_byte: u8) -> Word {
|
||||
match *ty {
|
||||
SpirvType::Void => panic!("memset invalid on void pattern"),
|
||||
SpirvType::Bool => panic!("memset invalid on bool pattern"),
|
||||
SpirvType::Integer(width, _signedness) => match width {
|
||||
8 => self.constant_u8(fill_byte).def,
|
||||
16 => self.constant_u16(memset_fill_u16(fill_byte)).def,
|
||||
32 => self.constant_u32(memset_fill_u32(fill_byte)).def,
|
||||
64 => self.constant_u64(memset_fill_u64(fill_byte)).def,
|
||||
_ => panic!("memset on integer width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Float(width) => match width {
|
||||
32 => {
|
||||
self.constant_f32(f32::from_bits(memset_fill_u32(fill_byte)))
|
||||
.def
|
||||
}
|
||||
64 => {
|
||||
self.constant_f64(f64::from_bits(memset_fill_u64(fill_byte)))
|
||||
.def
|
||||
}
|
||||
_ => panic!("memset on float width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"),
|
||||
SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"),
|
||||
SpirvType::Vector { element, count } => {
|
||||
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
|
||||
self.constant_composite(ty.def(self), vec![elem_pat; count as usize])
|
||||
.def
|
||||
}
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
|
||||
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
self.constant_composite(ty.def(self), vec![elem_pat; count])
|
||||
.def
|
||||
}
|
||||
SpirvType::RuntimeArray { .. } => {
|
||||
panic!("memset on runtime arrays not implemented yet")
|
||||
}
|
||||
SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"),
|
||||
SpirvType::Function { .. } => panic!("memset on functions not implemented yet"),
|
||||
}
|
||||
}
|
||||
|
||||
fn memset_dynamic_pattern(&self, ty: &SpirvType, fill_var: Word) -> Word {
|
||||
match *ty {
|
||||
SpirvType::Void => panic!("memset invalid on void pattern"),
|
||||
SpirvType::Bool => panic!("memset invalid on bool pattern"),
|
||||
SpirvType::Integer(width, _signedness) => match width {
|
||||
8 => fill_var,
|
||||
16 => memset_dynamic_scalar(self, fill_var, 2, false),
|
||||
32 => memset_dynamic_scalar(self, fill_var, 4, false),
|
||||
64 => memset_dynamic_scalar(self, fill_var, 8, false),
|
||||
_ => panic!("memset on integer width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Float(width) => match width {
|
||||
32 => memset_dynamic_scalar(self, fill_var, 4, true),
|
||||
64 => memset_dynamic_scalar(self, fill_var, 8, true),
|
||||
_ => panic!("memset on float width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"),
|
||||
SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"),
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
|
||||
let count = self.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
self.emit()
|
||||
.composite_construct(
|
||||
ty.def(self),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
SpirvType::Vector { element, count } => {
|
||||
let elem_pat = self.memset_dynamic_pattern(&self.lookup_type(element), fill_var);
|
||||
self.emit()
|
||||
.composite_construct(
|
||||
ty.def(self),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count as usize),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
SpirvType::RuntimeArray { .. } => {
|
||||
panic!("memset on runtime arrays not implemented yet")
|
||||
}
|
||||
SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"),
|
||||
SpirvType::Function { .. } => panic!("memset on functions not implemented yet"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
fn with_cx(cx: &'a Self::CodegenCx) -> Self {
|
||||
// Note: all defaults here *must* be filled out by position_at_end
|
||||
@ -1008,8 +1149,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
SpirvConst::U32(_, v) => v as usize,
|
||||
other => panic!("memset size constant value not supported: {:?}", other),
|
||||
};
|
||||
let pat = elem_ty_spv
|
||||
.memset_const_pattern(self, fill_byte)
|
||||
let pat = self
|
||||
.memset_const_pattern(&elem_ty_spv, fill_byte)
|
||||
.with_type(elem_ty);
|
||||
let elem_ty_sizeof = elem_ty_spv
|
||||
.sizeof(self)
|
||||
@ -1033,8 +1174,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
SpirvConst::U32(_, v) => v as usize,
|
||||
other => panic!("memset size constant value not supported: {:?}", other),
|
||||
};
|
||||
let pat = elem_ty_spv
|
||||
.memset_dynamic_pattern(self, fill_byte.def)
|
||||
let pat = self
|
||||
.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def)
|
||||
.with_type(elem_ty);
|
||||
let elem_ty_sizeof = elem_ty_spv
|
||||
.sizeof(self)
|
||||
|
@ -34,6 +34,7 @@ fn label_of(block: &Block) -> Word {
|
||||
block.label.as_ref().unwrap().result_id.unwrap()
|
||||
}
|
||||
|
||||
// TODO: Move this to the linker.
|
||||
pub fn delete_dead_blocks(func: &mut Function) {
|
||||
if func.blocks.len() < 2 {
|
||||
return;
|
||||
@ -54,36 +55,45 @@ pub fn delete_dead_blocks(func: &mut Function) {
|
||||
func.blocks.retain(|b| visited.contains(&label_of(b)))
|
||||
}
|
||||
|
||||
// TODO: Do we move this to the linker?
|
||||
pub fn block_ordering_pass(func: &mut Function) {
|
||||
if func.blocks.len() < 2 {
|
||||
return;
|
||||
}
|
||||
let mut graph = func
|
||||
.blocks
|
||||
.iter()
|
||||
.map(|block| (label_of(block), outgoing_edges(block)))
|
||||
.collect();
|
||||
let entry_label = label_of(&func.blocks[0]);
|
||||
delete_backedges(&mut graph, entry_label);
|
||||
|
||||
let mut sorter = topological_sort::TopologicalSort::<Word>::new();
|
||||
for (key, values) in graph {
|
||||
for value in values {
|
||||
sorter.add_dependency(key, value);
|
||||
fn visit_postorder(
|
||||
func: &Function,
|
||||
visited: &mut HashSet<Word>,
|
||||
postorder: &mut Vec<Word>,
|
||||
current: Word,
|
||||
) {
|
||||
if !visited.insert(current) {
|
||||
return;
|
||||
}
|
||||
let current_block = func.blocks.iter().find(|b| label_of(b) == current).unwrap();
|
||||
// Reverse the order, so reverse-postorder keeps things tidy
|
||||
for &outgoing in outgoing_edges(current_block).iter().rev() {
|
||||
visit_postorder(func, visited, postorder, outgoing);
|
||||
}
|
||||
postorder.push(current);
|
||||
}
|
||||
|
||||
let mut visited = HashSet::new();
|
||||
let mut postorder = Vec::new();
|
||||
|
||||
let entry_label = label_of(&func.blocks[0]);
|
||||
visit_postorder(func, &mut visited, &mut postorder, entry_label);
|
||||
|
||||
let mut old_blocks = replace(&mut func.blocks, Vec::new());
|
||||
while let Some(item) = sorter.pop() {
|
||||
let index = old_blocks.iter().position(|b| label_of(b) == item).unwrap();
|
||||
// Order blocks according to reverse postorder
|
||||
for &block in postorder.iter().rev() {
|
||||
let index = old_blocks
|
||||
.iter()
|
||||
.position(|b| label_of(b) == block)
|
||||
.unwrap();
|
||||
func.blocks.push(old_blocks.remove(index));
|
||||
}
|
||||
assert!(sorter.is_empty());
|
||||
assert!(old_blocks.is_empty());
|
||||
assert_eq!(
|
||||
label_of(&func.blocks[0]),
|
||||
entry_label,
|
||||
"Topo sorter did something weird (unreachable blocks?)"
|
||||
);
|
||||
assert_eq!(label_of(&func.blocks[0]), entry_label,);
|
||||
}
|
||||
|
||||
fn outgoing_edges(block: &Block) -> Vec<Word> {
|
||||
@ -113,28 +123,3 @@ fn outgoing_edges(block: &Block) -> Vec<Word> {
|
||||
_ => panic!("Invalid block terminator: {:?}", terminator),
|
||||
}
|
||||
}
|
||||
|
||||
fn delete_backedges(graph: &mut HashMap<Word, Vec<Word>>, entry: Word) {
|
||||
// TODO: This has extremely bad runtime
|
||||
let mut backedges = HashSet::new();
|
||||
fn re(
|
||||
graph: &HashMap<Word, Vec<Word>>,
|
||||
entry: Word,
|
||||
stack: &mut Vec<Word>,
|
||||
backedges: &mut HashSet<(Word, Word)>,
|
||||
) {
|
||||
stack.push(entry);
|
||||
for &item in &graph[&entry] {
|
||||
if stack.contains(&item) {
|
||||
backedges.insert((entry, item));
|
||||
} else if !backedges.contains(&(entry, item)) {
|
||||
re(graph, item, stack, backedges);
|
||||
}
|
||||
}
|
||||
assert_eq!(stack.pop(), Some(entry));
|
||||
}
|
||||
re(graph, entry, &mut Vec::new(), &mut backedges);
|
||||
for (from, to) in backedges {
|
||||
graph.get_mut(&from).unwrap().retain(|&o| o != to);
|
||||
}
|
||||
}
|
||||
|
@ -256,13 +256,18 @@ pub fn read_metadata(rlib: &Path) -> MetadataRef {
|
||||
panic!("No .metadata file in rlib: {:?}", rlib);
|
||||
}
|
||||
|
||||
/// This is the actual guts of linking: the rest of the link-related functions are just digging through rustc's
|
||||
/// shenanigans to collect all the object files we need to link.
|
||||
fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) {
|
||||
let mut modules = Vec::new();
|
||||
// `objects` are the plain obj files we need to link - usually produced by the final crate.
|
||||
for obj in objects {
|
||||
let mut bytes = Vec::new();
|
||||
File::open(obj).unwrap().read_to_end(&mut bytes).unwrap();
|
||||
modules.push(load(&bytes));
|
||||
}
|
||||
// `rlibs` are archive files we've created in `create_archive`, usually produced by crates that are being
|
||||
// referenced. We need to unpack them and add the modules inside.
|
||||
for rlib in rlibs {
|
||||
for entry in Archive::new(File::open(rlib).unwrap()).entries().unwrap() {
|
||||
let mut entry = entry.unwrap();
|
||||
@ -276,6 +281,21 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) {
|
||||
|
||||
let mut module_refs = modules.iter_mut().collect::<Vec<_>>();
|
||||
|
||||
{
|
||||
let path = Path::new("/home/khyperia/tmp");
|
||||
if path.is_file() {
|
||||
std::fs::remove_file(path).unwrap();
|
||||
}
|
||||
std::fs::create_dir_all(path).unwrap();
|
||||
for (num, module) in module_refs.iter().enumerate() {
|
||||
File::create(path.join(format!("mod_{}.spv", num)))
|
||||
.unwrap()
|
||||
.write_all(crate::slice_u32_to_u8(&module.assemble()))
|
||||
.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
// Do the link...
|
||||
let result = match rspirv_linker::link(
|
||||
&mut module_refs,
|
||||
&rspirv_linker::Options {
|
||||
@ -302,6 +322,7 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) {
|
||||
}
|
||||
};
|
||||
|
||||
// And finally write out the linked binary.
|
||||
use rspirv::binary::Assemble;
|
||||
File::create(out_filename)
|
||||
.unwrap()
|
||||
@ -315,6 +336,8 @@ fn do_link(objects: &[PathBuf], rlibs: &[PathBuf], out_filename: &Path) {
|
||||
}
|
||||
}
|
||||
|
||||
/// As of right now, this is essentially a no-op, just plumbing through all the files.
|
||||
// TODO: WorkProduct impl
|
||||
pub(crate) fn run_thin(
|
||||
cgcx: &CodegenContext<SpirvCodegenBackend>,
|
||||
modules: Vec<(String, SpirvThinBuffer)>,
|
||||
|
@ -1,5 +1,4 @@
|
||||
use crate::abi::RecursivePointeeCache;
|
||||
use crate::builder::Builder;
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use rspirv::dr::Operand;
|
||||
use rspirv::spirv::{Decoration, StorageClass, Word};
|
||||
@ -11,6 +10,11 @@ use std::iter::once;
|
||||
use std::lazy::SyncLazy;
|
||||
use std::sync::Mutex;
|
||||
|
||||
/// Spir-v types are represented as simple Words, which are the result_id of instructions like OpTypeInteger. Sometimes,
|
||||
/// however, we want to inspect one of these Words and ask questions like "Is this an OpTypeInteger? How many bits does
|
||||
/// it have?". This struct holds all of that information. All types that are emitted are registered in CodegenCx, so you
|
||||
/// can always look up the definition of a Word via cx.lookup_type. Note that this type doesn't actually store the
|
||||
/// result_id of the type declaration instruction, merely the contents.
|
||||
#[derive(Clone, Debug, PartialEq, Eq, Ord, PartialOrd, Hash)]
|
||||
pub enum SpirvType {
|
||||
Void,
|
||||
@ -52,55 +56,6 @@ pub enum SpirvType {
|
||||
},
|
||||
}
|
||||
|
||||
fn memset_fill_u16(b: u8) -> u16 {
|
||||
b as u16 | ((b as u16) << 8)
|
||||
}
|
||||
|
||||
fn memset_fill_u32(b: u8) -> u32 {
|
||||
b as u32 | ((b as u32) << 8) | ((b as u32) << 16) | ((b as u32) << 24)
|
||||
}
|
||||
|
||||
fn memset_fill_u64(b: u8) -> u64 {
|
||||
b as u64
|
||||
| ((b as u64) << 8)
|
||||
| ((b as u64) << 16)
|
||||
| ((b as u64) << 24)
|
||||
| ((b as u64) << 32)
|
||||
| ((b as u64) << 40)
|
||||
| ((b as u64) << 48)
|
||||
| ((b as u64) << 56)
|
||||
}
|
||||
|
||||
fn memset_dynamic_scalar<'a, 'tcx>(
|
||||
builder: &Builder<'a, 'tcx>,
|
||||
fill_var: Word,
|
||||
byte_width: usize,
|
||||
is_float: bool,
|
||||
) -> Word {
|
||||
let composite_type = SpirvType::Vector {
|
||||
element: SpirvType::Integer(8, false).def(builder),
|
||||
count: byte_width as u32,
|
||||
}
|
||||
.def(builder);
|
||||
let composite = builder
|
||||
.emit()
|
||||
.composite_construct(
|
||||
composite_type,
|
||||
None,
|
||||
std::iter::repeat(fill_var).take(byte_width),
|
||||
)
|
||||
.unwrap();
|
||||
let result_type = if is_float {
|
||||
SpirvType::Float(byte_width as u32 * 8)
|
||||
} else {
|
||||
SpirvType::Integer(byte_width as u32 * 8, false)
|
||||
};
|
||||
builder
|
||||
.emit()
|
||||
.bitcast(result_type.def(builder), None, composite)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
impl SpirvType {
|
||||
/// Note: Builder::type_* should be called *nowhere else* but here, to ensure CodegenCx::type_defs stays up-to-date
|
||||
pub fn def<'tcx>(&self, cx: &CodegenCx<'tcx>) -> Word {
|
||||
@ -211,6 +166,8 @@ impl SpirvType {
|
||||
result
|
||||
}
|
||||
|
||||
/// def_with_id is used by the RecursivePointeeCache to handle OpTypeForwardPointer: when emitting the subsequent
|
||||
/// OpTypePointer, the ID is already known and must be re-used.
|
||||
pub fn def_with_id<'tcx>(&self, cx: &CodegenCx<'tcx>, id: Word) -> Word {
|
||||
if let Some(cached) = cx.type_cache.get(self) {
|
||||
assert_eq!(cached, id);
|
||||
@ -236,6 +193,7 @@ impl SpirvType {
|
||||
result
|
||||
}
|
||||
|
||||
/// Use this if you want a pretty type printing that recursively prints the types within (e.g. struct fields)
|
||||
pub fn debug<'cx, 'tcx>(
|
||||
self,
|
||||
id: Word,
|
||||
@ -281,106 +239,6 @@ impl SpirvType {
|
||||
SpirvType::Function { .. } => cx.tcx.data_layout.pointer_align.abi,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn memset_const_pattern<'tcx>(&self, cx: &CodegenCx<'tcx>, fill_byte: u8) -> Word {
|
||||
match *self {
|
||||
SpirvType::Void => panic!("memset invalid on void pattern"),
|
||||
SpirvType::Bool => panic!("memset invalid on bool pattern"),
|
||||
SpirvType::Integer(width, _signedness) => match width {
|
||||
8 => cx.constant_u8(fill_byte).def,
|
||||
16 => cx.constant_u16(memset_fill_u16(fill_byte)).def,
|
||||
32 => cx.constant_u32(memset_fill_u32(fill_byte)).def,
|
||||
64 => cx.constant_u64(memset_fill_u64(fill_byte)).def,
|
||||
_ => panic!("memset on integer width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Float(width) => match width {
|
||||
32 => {
|
||||
cx.constant_f32(f32::from_bits(memset_fill_u32(fill_byte)))
|
||||
.def
|
||||
}
|
||||
64 => {
|
||||
cx.constant_f64(f64::from_bits(memset_fill_u64(fill_byte)))
|
||||
.def
|
||||
}
|
||||
_ => panic!("memset on float width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"),
|
||||
SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"),
|
||||
SpirvType::Vector { element, count } => {
|
||||
let elem_pat = cx.lookup_type(element).memset_const_pattern(cx, fill_byte);
|
||||
cx.constant_composite(self.def(cx), vec![elem_pat; count as usize])
|
||||
.def
|
||||
}
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem_pat = cx.lookup_type(element).memset_const_pattern(cx, fill_byte);
|
||||
let count = cx.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
cx.constant_composite(self.def(cx), vec![elem_pat; count])
|
||||
.def
|
||||
}
|
||||
SpirvType::RuntimeArray { .. } => {
|
||||
panic!("memset on runtime arrays not implemented yet")
|
||||
}
|
||||
SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"),
|
||||
SpirvType::Function { .. } => panic!("memset on functions not implemented yet"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn memset_dynamic_pattern<'a, 'tcx>(
|
||||
&self,
|
||||
builder: &Builder<'a, 'tcx>,
|
||||
fill_var: Word,
|
||||
) -> Word {
|
||||
match *self {
|
||||
SpirvType::Void => panic!("memset invalid on void pattern"),
|
||||
SpirvType::Bool => panic!("memset invalid on bool pattern"),
|
||||
SpirvType::Integer(width, _signedness) => match width {
|
||||
8 => fill_var,
|
||||
16 => memset_dynamic_scalar(builder, fill_var, 2, false),
|
||||
32 => memset_dynamic_scalar(builder, fill_var, 4, false),
|
||||
64 => memset_dynamic_scalar(builder, fill_var, 8, false),
|
||||
_ => panic!("memset on integer width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Float(width) => match width {
|
||||
32 => memset_dynamic_scalar(builder, fill_var, 4, true),
|
||||
64 => memset_dynamic_scalar(builder, fill_var, 8, true),
|
||||
_ => panic!("memset on float width {} not implemented yet", width),
|
||||
},
|
||||
SpirvType::Adt { .. } => panic!("memset on structs not implemented yet"),
|
||||
SpirvType::Opaque { .. } => panic!("memset on opaque type is invalid"),
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem_pat = builder
|
||||
.lookup_type(element)
|
||||
.memset_dynamic_pattern(builder, fill_var);
|
||||
let count = builder.builder.lookup_const_u64(count).unwrap() as usize;
|
||||
builder
|
||||
.emit()
|
||||
.composite_construct(
|
||||
self.def(builder),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
SpirvType::Vector { element, count } => {
|
||||
let elem_pat = builder
|
||||
.lookup_type(element)
|
||||
.memset_dynamic_pattern(builder, fill_var);
|
||||
builder
|
||||
.emit()
|
||||
.composite_construct(
|
||||
self.def(builder),
|
||||
None,
|
||||
std::iter::repeat(elem_pat).take(count as usize),
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
SpirvType::RuntimeArray { .. } => {
|
||||
panic!("memset on runtime arrays not implemented yet")
|
||||
}
|
||||
SpirvType::Pointer { .. } => panic!("memset on pointers not implemented yet"),
|
||||
SpirvType::Function { .. } => panic!("memset on functions not implemented yet"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub struct SpirvTypePrinter<'cx, 'tcx> {
|
||||
@ -389,6 +247,9 @@ pub struct SpirvTypePrinter<'cx, 'tcx> {
|
||||
cx: &'cx CodegenCx<'tcx>,
|
||||
}
|
||||
|
||||
/// Types can be recursive, e.g. a struct can contain a pointer to itself. So, we need to keep track of a stack of what
|
||||
/// types are currently being printed, to not infinitely loop. Unfortunately, unlike fmt::Display, we can't easily pass
|
||||
/// down the "stack" of currently-being-printed types, so we use a global static.
|
||||
static DEBUG_STACK: SyncLazy<Mutex<Vec<Word>>> = SyncLazy::new(|| Mutex::new(Vec::new()));
|
||||
|
||||
impl fmt::Debug for SpirvTypePrinter<'_, '_> {
|
||||
@ -497,6 +358,10 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Types can be recursive, e.g. a struct can contain a pointer to itself. So, we need to keep track of a stack of what
|
||||
/// types are currently being printed, to not infinitely loop. So, we only use fmt::Display::fmt as an "entry point", and
|
||||
/// then call through to our own (recursive) custom function that has a parameter for the current stack. Make sure to not
|
||||
/// call Display on a type inside the custom function!
|
||||
impl fmt::Display for SpirvTypePrinter<'_, '_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.display(&mut Vec::new(), f)
|
||||
|
@ -4,6 +4,11 @@ use rustc_ast::ast::{AttrKind, Attribute};
|
||||
use rustc_span::symbol::Symbol;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Various places in the codebase (mostly attribute parsing) need to compare rustc Symbols to particular keywords.
|
||||
/// Symbols are interned, as in, they don't actually store the string itself inside them, but rather an index into a
|
||||
/// global table of strings. Then, whenever a new Symbol is created, the global table is checked to see if the string
|
||||
/// already exists, deduplicating it if so. This makes things like comparison and cloning really cheap. So, this struct
|
||||
/// is to allocate all our keywords up front and intern them all, so we can do comparisons really easily and fast.
|
||||
pub struct Symbols {
|
||||
pub spirv: Symbol,
|
||||
pub storage_class: Symbol,
|
||||
@ -92,10 +97,13 @@ pub enum SpirvAttribute {
|
||||
Entry(ExecutionModel),
|
||||
}
|
||||
|
||||
// Note that we could mark thie attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens
|
||||
// Note that we could mark the attr as used via cx.tcx.sess.mark_attr_used(attr), but unused reporting already happens
|
||||
// even before we get here :(
|
||||
/// Returns None if this attribute is not a spirv attribute, or if it's a malformed (and an error is reported).
|
||||
/// Returns None if this attribute is not a spirv attribute, or if it's malformed (and an error is reported).
|
||||
pub fn parse_attr<'tcx>(cx: &CodegenCx<'tcx>, attr: &Attribute) -> Option<SpirvAttribute> {
|
||||
// Example attributes that we parse here:
|
||||
// #[spirv(storage_class = "uniform")]
|
||||
// #[spirv(entry = "kernel")]
|
||||
let is_spirv = match attr.kind {
|
||||
AttrKind::Normal(ref item) => {
|
||||
// TODO: We ignore the rest of the path. Is this right?
|
||||
|
@ -47,9 +47,13 @@ pointer_addrspace!("incoming_ray_payload_nv", IncomingRayPayloadNV);
|
||||
pointer_addrspace!("shader_record_buffer_nv", ShaderRecordBufferNV);
|
||||
pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer);
|
||||
|
||||
/// # Safety
|
||||
///
|
||||
/// TODO: Copied from compiler_builtins/mem.rs
|
||||
/// libcore requires a few external symbols to be defined:
|
||||
/// https://github.com/rust-lang/rust/blob/c2bc344eb23d8c1d18e803b3f1e631cf99926fbb/library/core/src/lib.rs#L23-L27
|
||||
/// TODO: This is copied from compiler_builtins/mem.rs. Can we use that one instead? The note in the above link says
|
||||
/// "[the symbols] can also be provided by the compiler-builtins crate". The memcpy in compiler_builtins is behind a
|
||||
/// "mem" feature flag - can we enable that somehow?
|
||||
/// https://github.com/rust-lang/compiler-builtins/blob/eff506cd49b637f1ab5931625a33cef7e91fbbf6/src/mem.rs#L12-L13
|
||||
#[allow(clippy::missing_safety_doc)]
|
||||
#[no_mangle]
|
||||
pub unsafe extern "C" fn memcmp(s1: *const u8, s2: *const u8, n: usize) -> i32 {
|
||||
let mut i = 0;
|
||||
|
Loading…
Reference in New Issue
Block a user