Add storage class attribute

This commit is contained in:
khyperia 2020-09-15 14:42:39 +02:00
parent ec6608f1b5
commit b345d519bf
6 changed files with 218 additions and 36 deletions

View File

@ -1,8 +1,16 @@
#![no_std]
#![feature(register_attr)]
#![register_attr(spirv)]
use core::panic::PanicInfo;
pub fn screaming_bananans() {}
#[spirv(storage_class = "private")]
pub struct Private<'a, T> {
x: &'a T,
}
#[no_mangle]
pub extern "C" fn screaming_bananans(x: Private<u32>) {}
#[panic_handler]
fn panic(_: &PanicInfo) -> ! {

View File

@ -1,6 +1,7 @@
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use rspirv::spirv::{StorageClass, Word};
use rustc_ast::ast::AttrKind;
use rustc_middle::ty::layout::{FnAbiExt, TyAndLayout};
use rustc_middle::ty::{GeneratorSubsts, PolyFnSig, Ty, TyKind};
use rustc_target::abi::call::{CastTarget, FnAbi, PassMode, Reg, RegKind};
@ -19,7 +20,13 @@ pub struct RecursivePointeeCache<'tcx> {
}
impl<'tcx> RecursivePointeeCache<'tcx> {
fn begin(&self, cx: &CodegenCx<'tcx>, pointee: PointeeTy<'tcx>) -> Option<Word> {
fn begin(
&self,
cx: &CodegenCx<'tcx>,
pointee: PointeeTy<'tcx>,
storage_class: StorageClass,
) -> Option<Word> {
// Warning: storage_class must match the one called with end()
match self.map.borrow_mut().entry(pointee) {
// State: This is the first time we've seen this type. Record that we're beginning to translate this type,
// and start doing the translation.
@ -33,9 +40,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
// emit an OpTypeForwardPointer, and use that ID. (This is the juicy part of this algorithm)
PointeeDefState::Defining => {
let new_id = cx.emit_global().id();
// StorageClass will be fixed up later
cx.emit_global()
.type_forward_pointer(new_id, StorageClass::Generic);
cx.emit_global().type_forward_pointer(new_id, storage_class);
entry.insert(PointeeDefState::DefiningWithForward(new_id));
Some(new_id)
}
@ -54,6 +59,7 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
storage_class: StorageClass,
pointee_spv: Word,
) -> Word {
// Warning: storage_class must match the one called with begin()
match self.map.borrow_mut().entry(pointee) {
// We should have hit begin() on this type already, which always inserts an entry.
Entry::Vacant(_) => panic!("RecursivePointeeCache::end should always have entry"),
@ -73,7 +79,6 @@ impl<'tcx> RecursivePointeeCache<'tcx> {
// Make sure to use the same ID.
PointeeDefState::DefiningWithForward(id) => {
entry.insert(PointeeDefState::Defined(id));
cx.builder.fix_up_pointer_forward(id, storage_class);
SpirvType::Pointer {
storage_class,
pointee: pointee_spv,
@ -359,21 +364,22 @@ fn trans_scalar<'tcx>(
Primitive::F32 => SpirvType::Float(32).def(cx),
Primitive::F64 => SpirvType::Float(64).def(cx),
Primitive::Pointer => {
let pointee_ty = dig_scalar_pointee(cx, ty, index);
let (storage_class, pointee_ty) = dig_scalar_pointee(cx, ty, index);
// Default to function storage class.
let storage_class = storage_class.unwrap_or(StorageClass::Function);
// Pointers can be recursive. So, record what we're currently translating, and if we're already translating
// the same type, emit an OpTypeForwardPointer and use that ID.
if let Some(predefined_result) =
cx.type_cache.recursive_pointee_cache.begin(cx, pointee_ty)
cx.type_cache
.recursive_pointee_cache
.begin(cx, pointee_ty, storage_class)
{
predefined_result
} else {
let pointee = pointee_ty.spirv_type(cx);
cx.type_cache.recursive_pointee_cache.end(
cx,
pointee_ty,
StorageClass::Function,
pointee,
)
cx.type_cache
.recursive_pointee_cache
.end(cx, pointee_ty, storage_class, pointee)
}
}
}
@ -393,12 +399,12 @@ fn dig_scalar_pointee<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> PointeeTy<'tcx> {
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
match *ty.ty.kind() {
TyKind::Ref(_region, elem_ty, _mutability) => {
let elem = cx.layout_of(elem_ty);
match index {
None => PointeeTy::Ty(elem),
None => (None, PointeeTy::Ty(elem)),
Some(index) => {
if elem.is_unsized() {
dig_scalar_pointee(cx, ty.field(cx, index), None)
@ -407,7 +413,7 @@ fn dig_scalar_pointee<'tcx>(
// of ScalarPair could be deduced, but it's actually e.g. a sized pointer followed by some other
// completely unrelated type, not a wide pointer. So, translate this as a single scalar, one
// component of that ScalarPair.
PointeeTy::Ty(elem)
(None, PointeeTy::Ty(elem))
}
}
}
@ -415,18 +421,18 @@ fn dig_scalar_pointee<'tcx>(
TyKind::RawPtr(type_and_mut) => {
let elem = cx.layout_of(type_and_mut.ty);
match index {
None => PointeeTy::Ty(elem),
None => (None, PointeeTy::Ty(elem)),
Some(index) => {
if elem.is_unsized() {
dig_scalar_pointee(cx, ty.field(cx, index), None)
} else {
// Same comment as TyKind::Ref
PointeeTy::Ty(elem)
(None, PointeeTy::Ty(elem))
}
}
}
}
TyKind::FnPtr(sig) if index.is_none() => PointeeTy::Fn(sig),
TyKind::FnPtr(sig) if index.is_none() => (None, PointeeTy::Fn(sig)),
TyKind::Adt(def, _) if def.is_box() => {
let ptr_ty = cx.layout_of(cx.tcx.mk_mut_ptr(ty.ty.boxed_ty()));
dig_scalar_pointee(cx, ptr_ty, index)
@ -445,8 +451,11 @@ fn dig_scalar_pointee_adt<'tcx>(
cx: &CodegenCx<'tcx>,
ty: TyAndLayout<'tcx>,
index: Option<usize>,
) -> PointeeTy<'tcx> {
match &ty.variants {
) -> (Option<StorageClass>, PointeeTy<'tcx>) {
// Storage classes can only be applied on structs containing a single pointer field (because we said so), so we only
// need to handle the attribute here.
let storage_class = get_storage_class(cx, ty);
let result = match &ty.variants {
// If it's a Variants::Multiple, then we want to emit the type of the dataful variant, not the type of the
// discriminant. This is because the discriminant can e.g. have type *mut(), whereas we want the full underlying
// type, only available in the dataful variant.
@ -500,9 +509,76 @@ fn dig_scalar_pointee_adt<'tcx>(
},
}
}
};
match (storage_class, result) {
(storage_class, (None, result)) => (storage_class, result),
(None, (storage_class, result)) => (storage_class, result),
(Some(one), (Some(two), _)) => panic!(
"Double-applied storage class ({:?} and {:?}) on type {}",
one, two, ty.ty
),
}
}
fn get_storage_class<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Option<StorageClass> {
if let TyKind::Adt(adt, _substs) = ty.ty.kind() {
// TODO: Split out this attribute parsing
for attr in cx.tcx.get_attrs(adt.did) {
let is_spirv = match attr.kind {
AttrKind::Normal(ref item) => {
// TODO: We ignore the rest of the path. Is this right?
let last = item.path.segments.last();
last.map_or(false, |seg| seg.ident.name == cx.sym.spirv)
}
AttrKind::DocComment(..) => false,
};
if !is_spirv {
continue;
}
// Mark it used, since all invalid cases below emit errors.
cx.tcx.sess.mark_attr_used(attr);
let args = if let Some(args) = attr.meta_item_list() {
args
} else {
cx.tcx
.sess
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
continue;
};
if args.len() != 1 {
cx.tcx
.sess
.span_err(attr.span, "#[spirv(..)] attribute must have one argument");
continue;
}
let arg = &args[0];
if arg.has_name(cx.sym.storage_class) {
if let Some(storage_arg) = arg.value_str() {
match cx.sym.symbol_to_storageclass(storage_arg) {
Some(storage_class) => return Some(storage_class),
None => {
cx.tcx
.sess
.span_err(attr.span, "unknown spir-v storage class");
continue;
}
}
} else {
cx.tcx.sess.span_err(
attr.span,
"storage_class must have value: #[spirv(storage_class = \"..\")]",
);
}
} else {
cx.tcx
.sess
.span_err(attr.span, "unknown argument to spirv attribute");
}
}
}
None
}
fn trans_aggregate<'tcx>(cx: &CodegenCx<'tcx>, ty: TyAndLayout<'tcx>) -> Word {
match ty.fields {
FieldsShape::Primitive => panic!(

View File

@ -1,5 +1,5 @@
use rspirv::dr::{Block, Builder, Module, Operand};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, StorageClass, Word};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
use rspirv::{binary::Assemble, binary::Disassemble};
use std::cell::{RefCell, RefMut};
use std::{fs::File, io::Write, path::Path};
@ -227,19 +227,6 @@ impl BuilderSpirv {
module.types_global_values.push(inst);
}
pub fn fix_up_pointer_forward(&self, id: Word, storage_class: StorageClass) {
for inst in &mut self.builder.borrow_mut().module_mut().types_global_values {
if inst.class.opcode == Op::TypeForwardPointer && inst.operands[0] == Operand::IdRef(id)
{
if let Operand::StorageClass(storage) = &mut inst.operands[1] {
*storage = storage_class;
} else {
panic!()
}
}
}
}
pub fn select_block_by_id(&self, id: Word) -> BuilderCursor {
fn block_matches(block: &Block, id: Word) -> bool {
block.label.as_ref().and_then(|b| b.result_id) == Some(id)

View File

@ -3,6 +3,7 @@ use crate::builder::ExtInst;
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt};
use crate::finalizing_passes::{block_ordering_pass, zombie_pass};
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
use crate::symbols::Symbols;
use rspirv::dr::{Module, Operand};
use rspirv::spirv::{Decoration, FunctionControl, LinkageType, StorageClass, Word};
use rustc_codegen_ssa::common::TypeKind;
@ -69,6 +70,8 @@ pub struct CodegenCx<'tcx> {
/// Invalid spir-v IDs that should be stripped from the final binary
zombie_values: RefCell<HashMap<Word, &'static str>>,
pub kernel_mode: bool,
/// Cache of all the builtin symbols we need
pub sym: Box<Symbols>,
}
impl<'tcx> CodegenCx<'tcx> {
@ -85,6 +88,7 @@ impl<'tcx> CodegenCx<'tcx> {
ext_inst: Default::default(),
zombie_values: Default::default(),
kernel_mode: true,
sym: Box::new(Symbols::new()),
}
}

View File

@ -25,6 +25,7 @@ mod codegen_cx;
mod finalizing_passes;
mod link;
mod spirv_type;
mod symbols;
mod things;
#[cfg(test)]

View File

@ -0,0 +1,106 @@
use rspirv::spirv::StorageClass;
use rustc_span::symbol::Symbol;
pub struct Symbols {
pub spirv: Symbol,
pub storage_class: Symbol,
// storage classes
pub uniform_constant: Symbol,
pub input: Symbol,
pub uniform: Symbol,
pub output: Symbol,
pub workgroup: Symbol,
pub cross_workgroup: Symbol,
pub private: Symbol,
pub function: Symbol,
pub generic: Symbol,
pub push_constant: Symbol,
pub atomic_counter: Symbol,
pub image: Symbol,
pub storage_buffer: Symbol,
pub callable_data_nv: Symbol,
pub incoming_callable_data_nv: Symbol,
pub ray_payload_nv: Symbol,
pub hit_attribute_nv: Symbol,
pub incoming_ray_payload_nv: Symbol,
pub shader_record_buffer_nv: Symbol,
pub physical_storage_buffer: Symbol,
}
impl Symbols {
pub fn new() -> Self {
Symbols {
spirv: Symbol::intern("spirv"),
storage_class: Symbol::intern("storage_class"),
uniform_constant: Symbol::intern("uniform_constant"),
input: Symbol::intern("input"),
uniform: Symbol::intern("uniform"),
output: Symbol::intern("output"),
workgroup: Symbol::intern("workgroup"),
cross_workgroup: Symbol::intern("cross_workgroup"),
private: Symbol::intern("private"),
function: Symbol::intern("function"),
generic: Symbol::intern("generic"),
push_constant: Symbol::intern("push_constant"),
atomic_counter: Symbol::intern("atomic_counter"),
image: Symbol::intern("image"),
storage_buffer: Symbol::intern("storage_buffer"),
callable_data_nv: Symbol::intern("callable_data_nv"),
incoming_callable_data_nv: Symbol::intern("incoming_callable_data_nv"),
ray_payload_nv: Symbol::intern("ray_payload_nv"),
hit_attribute_nv: Symbol::intern("hit_attribute_nv"),
incoming_ray_payload_nv: Symbol::intern("incoming_ray_payload_nv"),
shader_record_buffer_nv: Symbol::intern("shader_record_buffer_nv"),
physical_storage_buffer: Symbol::intern("physical_storage_buffer"),
}
}
pub fn symbol_to_storageclass(&self, sym: Symbol) -> Option<StorageClass> {
let result = if sym == self.uniform_constant {
StorageClass::UniformConstant
} else if sym == self.input {
StorageClass::Input
} else if sym == self.uniform {
StorageClass::Uniform
} else if sym == self.output {
StorageClass::Output
} else if sym == self.workgroup {
StorageClass::Workgroup
} else if sym == self.cross_workgroup {
StorageClass::CrossWorkgroup
} else if sym == self.private {
StorageClass::Private
} else if sym == self.function {
StorageClass::Function
} else if sym == self.generic {
StorageClass::Generic
} else if sym == self.push_constant {
StorageClass::PushConstant
} else if sym == self.atomic_counter {
StorageClass::AtomicCounter
} else if sym == self.image {
StorageClass::Image
} else if sym == self.storage_buffer {
StorageClass::StorageBuffer
} else if sym == self.callable_data_nv {
StorageClass::CallableDataNV
} else if sym == self.incoming_callable_data_nv {
StorageClass::IncomingCallableDataNV
} else if sym == self.ray_payload_nv {
StorageClass::RayPayloadNV
} else if sym == self.hit_attribute_nv {
StorageClass::HitAttributeNV
} else if sym == self.incoming_ray_payload_nv {
StorageClass::IncomingRayPayloadNV
} else if sym == self.shader_record_buffer_nv {
StorageClass::ShaderRecordBufferNV
} else if sym == self.physical_storage_buffer {
StorageClass::PhysicalStorageBuffer
} else {
return None;
};
Some(result)
}
}