mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Add storage class attribute
This commit is contained in:
parent
ec6608f1b5
commit
b345d519bf
@ -1,8 +1,16 @@
|
||||
#![no_std]
|
||||
#![feature(register_attr)]
|
||||
#![register_attr(spirv)]
|
||||
|
||||
use core::panic::PanicInfo;
|
||||
|
||||
pub fn screaming_bananans() {}
|
||||
#[spirv(storage_class = "private")]
|
||||
pub struct Private<'a, T> {
|
||||
x: &'a T,
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn screaming_bananans(x: Private<u32>) {}
|
||||
|
||||
#[panic_handler]
|
||||
fn panic(_: &PanicInfo) -> ! {
|
||||
|
@ -1,6 +1,7 @@
|
||||
use crate::codegen_cx::CodegenCx;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use rspirv::spirv::{StorageClass, Word};
|
||||
use rustc_ast::ast::AttrKind;
|
||||
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
|
||||
use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind};
|
||||
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind};
|
||||
@ -19,7 +20,13 @@ pub struct RecursivePointeeCache<'tcx> {
|
||||
}
|
||||
|
||||
impl<'tcx> RecursivePointeeCache<'tcx> {
|
||||
fn begin(&self, cx: &CodegenCx<'tcx>, pointee: PointeeTy<'tcx>) -> Option<Word> {
|
||||
fn begin(
|
||||
&self,
|
||||
cx: &CodegenCx<'tcx>,
|
||||
pointee: PointeeTy<'tcx>,
|
||||
storage_class: StorageClass,
|
||||
) -> Option<Word> {
|
||||
// Warning: storage_class must match the one called with end()
|
||||
match self.map.borrow_mut().entry(pointee) {
|
||||
// State: This is the first time we've seen this type. Record that we're beginning to translate this type,
|
||||
// and start doing the translation.
|
||||
@ -33,9 +40,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
|
||||
// emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
|
||||
PointeeDefState::Defining => {
|
||||
let new_id = cx.emit_global().id();
|
||||
// StorageClass will be fixed up later
|
||||
cx.emit_global()
|
||||
.type_forward_pointer(new_id, StorageClass::Generic);
|
||||
cx.emit_global().type_forward_pointer(new_id, storage_class);
|
||||
entry.insert(PointeeDefState::DefiningWithForward(new_id));
|
||||
Some(new_id)
|
||||
}
|
||||
@ -54,6 +59,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
|
||||
storage_class: StorageClass,
|
||||
pointee_spv: Word,
|
||||
) -> Word {
|
||||
// Warning: storage_class must match the one called with begin()
|
||||
match self.map.borrow_mut().entry(pointee) {
|
||||
// We should have hit begin() on this type already, which always inserts an entry.
|
||||
Entry::Vacant(_) => panic!("RecursivePointeeCache::end should always have entry"),
|
||||
@ -73,7 +79,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
|
||||
// Make sure to use the same ID.
|
||||
PointeeDefState::DefiningWithForward(id) => {
|
||||
entry.insert(PointeeDefState::Defined(id));
|
||||
cx.builder.fix_up_pointer_forward(id, storage_class);
|
||||
SpirvType::Pointer {
|
||||
storage_class,
|
||||
pointee: pointee_spv,
|
||||
@ -359,21 +364,22 @@ fn trans_scalar<'tcx>(
|
||||
Primitive::F32 => SpirvType::Float(32).def(cx),
|
||||
Primitive::F64 => SpirvType::Float(64).def(cx),
|
||||
Primitive::Pointer => {
|
||||
let pointee_ty = dig_scalar_pointee(cx, ty, index);
|
||||
let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index);
|
||||
// Default to function storage class.
|
||||
let storage_class = storage_class.unwrap_or(StorageClass::Function);
|
||||
// Pointers can be recursive. So, record what we're currently translating, and if we're already translating
|
||||
// the same type, emit an OpTypeForwardPointer and use that ID.
|
||||
if let Some(predefined_result) =
|
||||
cx.type_cache.recursive_pointee_cache.begin(cx, pointee_ty)
|
||||
cx.type_cache
|
||||
.recursive_pointee_cache
|
||||
.begin(cx, pointee_ty, storage_class)
|
||||
{
|
||||
predefined_result
|
||||
} else {
|
||||
let pointee = pointee_ty.spirv_type(cx);
|
||||
cx.type_cache.recursive_pointee_cache.end(
|
||||
cx,
|
||||
pointee_ty,
|
||||
StorageClass::Function,
|
||||
pointee,
|
||||
)
|
||||
cx.type_cache
|
||||
.recursive_pointee_cache
|
||||
.end(cx, pointee_ty, storage_class, pointee)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -393,12 +399,12 @@ fn dig_scalar_pointee<'tcx>(
|
||||
cx: &CodegenCx<'tcx>,
|
||||
ty: TyAndLayout<'tcx>,
|
||||
index: Option<usize>,
|
||||
) -> PointeeTy<'tcx> {
|
||||
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
|
||||
match *ty.ty.kind() {
|
||||
TyKind::Ref(_region, elem_ty, _mutability) => {
|
||||
let elem = cx.layout_of(elem_ty);
|
||||
match index {
|
||||
None => PointeeTy::Ty(elem),
|
||||
None => (None, PointeeTy::Ty(elem)),
|
||||
Some(index) => {
|
||||
if elem.is_unsized() {
|
||||
dig_scalar_pointee(cx, ty.field(cx, index), None)
|
||||
@ -407,7 +413,7 @@ fn dig_scalar_pointee<'tcx>(
|
||||
// of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other
|
||||
// completely unrelated type, not a wide pointer. So, translate this as a single scalar, one
|
||||
// component of that ScalarPair.
|
||||
PointeeTy::Ty(elem)
|
||||
(None, PointeeTy::Ty(elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -415,18 +421,18 @@ fn dig_scalar_pointee<'tcx>(
|
||||
TyKind::RawPtr(type_and_mut) => {
|
||||
let elem = cx.layout_of(type_and_mut.ty);
|
||||
match index {
|
||||
None => PointeeTy::Ty(elem),
|
||||
None => (None, PointeeTy::Ty(elem)),
|
||||
Some(index) => {
|
||||
if elem.is_unsized() {
|
||||
dig_scalar_pointee(cx, ty.field(cx, index), None)
|
||||
} else {
|
||||
// Same comment as TyKind::Ref
|
||||
PointeeTy::Ty(elem)
|
||||
(None, PointeeTy::Ty(elem))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig),
|
||||
TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)),
|
||||
TyKind::Adt(def, _) if def.is_box() => {
|
||||
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
|
||||
dig_scalar_pointee(cx, ptr_ty, index)
|
||||
@ -445,8 +451,11 @@ fn dig_scalar_pointee_adt<'tcx>(
|
||||
cx: &CodegenCx<'tcx>,
|
||||
ty: TyAndLayout<'tcx>,
|
||||
index: Option<usize>,
|
||||
) -> PointeeTy<'tcx> {
|
||||
match &ty.variants {
|
||||
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
|
||||
// Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only
|
||||
// need to handle the attribute here.
|
||||
let storage_class = get_storage_class(cx, ty);
|
||||
let result = match &ty.variants {
|
||||
// If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the
|
||||
// discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying
|
||||
// type, only available in the dataful variant.
|
||||
@ -500,9 +509,76 @@ fn dig_scalar_pointee_adt<'tcx>(
|
||||
},
|
||||
}
|
||||
}
|
||||
};
|
||||
match (storage_class, result) {
|
||||
(storage_class, (None, result)) => (storage_class, result),
|
||||
(None, (storage_class, result)) => (storage_class, result),
|
||||
(Some(one), (Some(two), _)) => panic!(
|
||||
"Double-applied storage class ({:?} and {:?}) on type {}",
|
||||
one, two, ty.ty
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
|
||||
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
|
||||
// TODO: Split out this attribute parsing
|
||||
for attr in cx.tcx.get_attrs(adt.did) {
|
||||
let is_spirv = match attr.kind {
|
||||
AttrKind::Normal(ref item) => {
|
||||
// TODO: We ignore the rest of the path. Is this right?
|
||||
let last = item.path.segments.last();
|
||||
last.map_or(false, |seg| seg.ident.name == cx.sym.spirv)
|
||||
}
|
||||
AttrKind::DocComment(..) => false,
|
||||
};
|
||||
if !is_spirv {
|
||||
continue;
|
||||
}
|
||||
// Mark it used, since all invalid cases below emit errors.
|
||||
cx.tcx.sess.mark_attr_used(attr);
|
||||
let args = if let Some(args) = attr.meta_item_list() {
|
||||
args
|
||||
} else {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
|
||||
continue;
|
||||
};
|
||||
if args.len() != 1 {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
|
||||
continue;
|
||||
}
|
||||
let arg = &args[0];
|
||||
if arg.has_name(cx.sym.storage_class) {
|
||||
if let Some(storage_arg) = arg.value_str() {
|
||||
match cx.sym.symbol_to_storageclass(storage_arg) {
|
||||
Some(storage_class) => return Some(storage_class),
|
||||
None => {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "unknown spir-v storage class");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
cx.tcx.sess.span_err(
|
||||
attr.span,
|
||||
"storage_class must have value: #[spirv(storage_class = \"..\")]",
|
||||
);
|
||||
}
|
||||
} else {
|
||||
cx.tcx
|
||||
.sess
|
||||
.span_err(attr.span, "unknown argument to spirv attribute");
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word {
|
||||
match ty.fields {
|
||||
FieldsShape::Primitive => panic!(
|
||||
|
@ -1,5 +1,5 @@
|
||||
use rspirv::dr::{Block, Builder, Module, Operand};
|
||||
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word};
|
||||
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
|
||||
use rspirv::{binary::Assemble, binary::Disassemble};
|
||||
use std::cell::{RefCell, RefMut};
|
||||
use std::{fs::File, io::Write, path::Path};
|
||||
@ -227,19 +227,6 @@ impl BuilderSpirv {
|
||||
module.types_global_values.push(inst);
|
||||
}
|
||||
|
||||
pub fn fix_up_pointer_forward(&self, id: Word, storage_class: StorageClass) {
|
||||
for inst in &mut self.builder.borrow_mut().module_mut().types_global_values {
|
||||
if inst.class.opcode == Op::TypeForwardPointer && inst.operands[0] == Operand::IdRef(id)
|
||||
{
|
||||
if let Operand::StorageClass(storage) = &mut inst.operands[1] {
|
||||
*storage = storage_class;
|
||||
} else {
|
||||
panic!()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn select_block_by_id(&self, id: Word) -> BuilderCursor {
|
||||
fn block_matches(block: &Block, id: Word) -> bool {
|
||||
block.label.as_ref().and_then(|b| b.result_id) == Some(id)
|
||||
|
@ -3,6 +3,7 @@ use crate::builder::ExtInst;
|
||||
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt};
|
||||
use crate::finalizing_passes::{block_ordering_pass, zombie_pass};
|
||||
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
|
||||
use crate::symbols::Symbols;
|
||||
use rspirv::dr::{Module, Operand};
|
||||
use rspirv::spirv::{Decoration, FunctionControl, LinkageType, StorageClass, Word};
|
||||
use rustc_codegen_ssa::common::TypeKind;
|
||||
@ -69,6 +70,8 @@ pub struct CodegenCx<'tcx> {
|
||||
/// Invalid spir-v IDs that should be stripped from the final binary
|
||||
zombie_values: RefCell<HashMap<Word, &'static str>>,
|
||||
pub kernel_mode: bool,
|
||||
/// Cache of all the builtin symbols we need
|
||||
pub sym: Box<Symbols>,
|
||||
}
|
||||
|
||||
impl<'tcx> CodegenCx<'tcx> {
|
||||
@ -85,6 +88,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
ext_inst: Default::default(),
|
||||
zombie_values: Default::default(),
|
||||
kernel_mode: true,
|
||||
sym: Box::new(Symbols::new()),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -25,6 +25,7 @@ mod codegen_cx;
|
||||
mod finalizing_passes;
|
||||
mod link;
|
||||
mod spirv_type;
|
||||
mod symbols;
|
||||
mod things;
|
||||
|
||||
#[cfg(test)]
|
||||
|
106
rustc_codegen_spirv/src/symbols.rs
Normal file
106
rustc_codegen_spirv/src/symbols.rs
Normal file
@ -0,0 +1,106 @@
|
||||
use rspirv::spirv::StorageClass;
|
||||
use rustc_span::symbol::Symbol;
|
||||
|
||||
pub struct Symbols {
|
||||
pub spirv: Symbol,
|
||||
pub storage_class: Symbol,
|
||||
|
||||
// storage classes
|
||||
pub uniform_constant: Symbol,
|
||||
pub input: Symbol,
|
||||
pub uniform: Symbol,
|
||||
pub output: Symbol,
|
||||
pub workgroup: Symbol,
|
||||
pub cross_workgroup: Symbol,
|
||||
pub private: Symbol,
|
||||
pub function: Symbol,
|
||||
pub generic: Symbol,
|
||||
pub push_constant: Symbol,
|
||||
pub atomic_counter: Symbol,
|
||||
pub image: Symbol,
|
||||
pub storage_buffer: Symbol,
|
||||
pub callable_data_nv: Symbol,
|
||||
pub incoming_callable_data_nv: Symbol,
|
||||
pub ray_payload_nv: Symbol,
|
||||
pub hit_attribute_nv: Symbol,
|
||||
pub incoming_ray_payload_nv: Symbol,
|
||||
pub shader_record_buffer_nv: Symbol,
|
||||
pub physical_storage_buffer: Symbol,
|
||||
}
|
||||
|
||||
impl Symbols {
|
||||
pub fn new() -> Self {
|
||||
Symbols {
|
||||
spirv: Symbol::intern("spirv"),
|
||||
storage_class: Symbol::intern("storage_class"),
|
||||
|
||||
uniform_constant: Symbol::intern("uniform_constant"),
|
||||
input: Symbol::intern("input"),
|
||||
uniform: Symbol::intern("uniform"),
|
||||
output: Symbol::intern("output"),
|
||||
workgroup: Symbol::intern("workgroup"),
|
||||
cross_workgroup: Symbol::intern("cross_workgroup"),
|
||||
private: Symbol::intern("private"),
|
||||
function: Symbol::intern("function"),
|
||||
generic: Symbol::intern("generic"),
|
||||
push_constant: Symbol::intern("push_constant"),
|
||||
atomic_counter: Symbol::intern("atomic_counter"),
|
||||
image: Symbol::intern("image"),
|
||||
storage_buffer: Symbol::intern("storage_buffer"),
|
||||
callable_data_nv: Symbol::intern("callable_data_nv"),
|
||||
incoming_callable_data_nv: Symbol::intern("incoming_callable_data_nv"),
|
||||
ray_payload_nv: Symbol::intern("ray_payload_nv"),
|
||||
hit_attribute_nv: Symbol::intern("hit_attribute_nv"),
|
||||
incoming_ray_payload_nv: Symbol::intern("incoming_ray_payload_nv"),
|
||||
shader_record_buffer_nv: Symbol::intern("shader_record_buffer_nv"),
|
||||
physical_storage_buffer: Symbol::intern("physical_storage_buffer"),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn symbol_to_storageclass(&self, sym: Symbol) -> Option<StorageClass> {
|
||||
let result = if sym == self.uniform_constant {
|
||||
StorageClass::UniformConstant
|
||||
} else if sym == self.input {
|
||||
StorageClass::Input
|
||||
} else if sym == self.uniform {
|
||||
StorageClass::Uniform
|
||||
} else if sym == self.output {
|
||||
StorageClass::Output
|
||||
} else if sym == self.workgroup {
|
||||
StorageClass::Workgroup
|
||||
} else if sym == self.cross_workgroup {
|
||||
StorageClass::CrossWorkgroup
|
||||
} else if sym == self.private {
|
||||
StorageClass::Private
|
||||
} else if sym == self.function {
|
||||
StorageClass::Function
|
||||
} else if sym == self.generic {
|
||||
StorageClass::Generic
|
||||
} else if sym == self.push_constant {
|
||||
StorageClass::PushConstant
|
||||
} else if sym == self.atomic_counter {
|
||||
StorageClass::AtomicCounter
|
||||
} else if sym == self.image {
|
||||
StorageClass::Image
|
||||
} else if sym == self.storage_buffer {
|
||||
StorageClass::StorageBuffer
|
||||
} else if sym == self.callable_data_nv {
|
||||
StorageClass::CallableDataNV
|
||||
} else if sym == self.incoming_callable_data_nv {
|
||||
StorageClass::IncomingCallableDataNV
|
||||
} else if sym == self.ray_payload_nv {
|
||||
StorageClass::RayPayloadNV
|
||||
} else if sym == self.hit_attribute_nv {
|
||||
StorageClass::HitAttributeNV
|
||||
} else if sym == self.incoming_ray_payload_nv {
|
||||
StorageClass::IncomingRayPayloadNV
|
||||
} else if sym == self.shader_record_buffer_nv {
|
||||
StorageClass::ShaderRecordBufferNV
|
||||
} else if sym == self.physical_storage_buffer {
|
||||
StorageClass::PhysicalStorageBuffer
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
Some(result)
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user