Implement ByteAddressableBuffer prototype detached from bindless (#735)

This commit is contained in:
Ashley Hauck 2021-08-31 10:26:30 +02:00 committed by GitHub
parent f0560824fe
commit 66d6c554d8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 657 additions and 2 deletions

View File

@ -90,6 +90,8 @@ pub enum SpirvAttribute {
// `fn`/closure attributes:
UnrollLoops,
BufferLoadIntrinsic,
BufferStoreIntrinsic,
}
// HACK(eddyb) this is similar to `rustc_span::Spanned` but with `value` as the
@ -123,6 +125,8 @@ pub struct AggregatedSpirvAttributes {
// `fn`/closure attributes:
pub unroll_loops: Option<Spanned<()>>,
pub buffer_load_intrinsic: Option<Spanned<()>>,
pub buffer_store_intrinsic: Option<Spanned<()>>,
}
struct MultipleAttrs {
@ -210,6 +214,18 @@ impl AggregatedSpirvAttributes {
"#[spirv(attachment_index)]",
),
UnrollLoops => try_insert(&mut self.unroll_loops, (), span, "#[spirv(unroll_loops)]"),
BufferLoadIntrinsic => try_insert(
&mut self.buffer_load_intrinsic,
(),
span,
"#[spirv(buffer_load_intrinsic)]",
),
BufferStoreIntrinsic => try_insert(
&mut self.buffer_store_intrinsic,
(),
span,
"#[spirv(buffer_store_intrinsic)]",
),
}
}
}
@ -343,6 +359,12 @@ impl CheckSpirvAttrVisitor<'_> {
_ => Err(Expected("function or closure")),
},
SpirvAttribute::BufferLoadIntrinsic | SpirvAttribute::BufferStoreIntrinsic => {
match target {
Target::Fn => Ok(()),
_ => Err(Expected("function")),
}
}
};
match valid_target {
Err(Expected(expected_target)) => self.tcx.sess.span_err(

View File

@ -1850,8 +1850,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
| SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. } => element,
other => self.fatal(&format!(
"extract_value not implemented on type {:?}",
other
"extract_value not implemented on type {}",
other.debug(agg_val.ty, self)
)),
};
self.emit()
@ -2201,6 +2201,24 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// needing to materialize `&core::panic::Location` or `format_args!`.
self.abort();
self.undef(result_type)
} else if self
.buffer_load_intrinsic_fn_id
.borrow()
.contains(&callee_val)
{
self.codegen_buffer_load_intrinsic(result_type, args)
} else if self
.buffer_store_intrinsic_fn_id
.borrow()
.contains(&callee_val)
{
self.codegen_buffer_store_intrinsic(args);
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(void_ty),
ty: void_ty,
}
} else {
let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
self.emit()

View File

@ -0,0 +1,347 @@
use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt};
use crate::spirv_type::SpirvType;
use core::array::IntoIter;
use rspirv::spirv::Word;
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
use rustc_span::DUMMY_SP;
use rustc_target::abi::Align;
impl<'a, 'tcx> Builder<'a, 'tcx> {
fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
let mut err = self.struct_err(&format!(
"Cannot load type {} in an untyped buffer load",
self.debug_type(original_type)
));
if original_type != invalid_type {
err.note(&format!(
"due to containing type {}",
self.debug_type(invalid_type)
));
}
err.emit();
self.undef(invalid_type)
}
fn load_u32(
&mut self,
array: SpirvValue,
dynamic_index: SpirvValue,
constant_offset: u32,
) -> SpirvValue {
let actual_index = if constant_offset != 0 {
let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
self.add(dynamic_index, const_offset_val)
} else {
dynamic_index
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let ptr = self
.emit()
.in_bounds_access_chain(
u32_ptr,
None,
array.def(self),
IntoIter::new([actual_index.def(self)]),
)
.unwrap()
.with_type(u32_ptr);
self.load(u32_ty, ptr, Align::ONE)
}
#[allow(clippy::too_many_arguments)]
fn load_vec_or_arr(
&mut self,
original_type: Word,
result_type: Word,
array: SpirvValue,
dynamic_word_index: SpirvValue,
constant_word_offset: u32,
element: Word,
count: u32,
) -> SpirvValue {
let element_size_bytes = match self.lookup_type(element).sizeof(self) {
Some(size) => size,
None => return self.load_err(original_type, result_type),
};
if element_size_bytes.bytes() % 4 != 0 {
return self.load_err(original_type, result_type);
}
let element_size_words = (element_size_bytes.bytes() / 4) as u32;
let args = (0..count)
.map(|index| {
self.recurse_load_type(
original_type,
element,
array,
dynamic_word_index,
constant_word_offset + element_size_words * index,
)
.def(self)
})
.collect::<Vec<_>>();
self.emit()
.composite_construct(result_type, None, args)
.unwrap()
.with_type(result_type)
}
fn recurse_load_type(
&mut self,
original_type: Word,
result_type: Word,
array: SpirvValue,
dynamic_word_index: SpirvValue,
constant_word_offset: u32,
) -> SpirvValue {
match self.lookup_type(result_type) {
SpirvType::Integer(32, signed) => {
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
self.intcast(val, result_type, signed)
}
SpirvType::Float(32) => {
let val = self.load_u32(array, dynamic_word_index, constant_word_offset);
self.bitcast(val, result_type)
}
SpirvType::Vector { element, count } => self.load_vec_or_arr(
original_type,
result_type,
array,
dynamic_word_index,
constant_word_offset,
element,
count,
),
SpirvType::Array { element, count } => {
let count = match self.builder.lookup_const_u64(count) {
Some(count) => count as u32,
None => return self.load_err(original_type, result_type),
};
self.load_vec_or_arr(
original_type,
result_type,
array,
dynamic_word_index,
constant_word_offset,
element,
count,
)
}
SpirvType::Adt {
size: Some(_),
field_types,
field_offsets,
..
} => {
let args = field_types
.iter()
.zip(field_offsets)
.map(|(&field_type, byte_offset)| {
if byte_offset.bytes() % 4 != 0 {
return None;
}
let word_offset = (byte_offset.bytes() / 4) as u32;
Some(
self.recurse_load_type(
original_type,
field_type,
array,
dynamic_word_index,
constant_word_offset + word_offset,
)
.def(self),
)
})
.collect::<Option<Vec<_>>>();
match args {
None => self.load_err(original_type, result_type),
Some(args) => self
.emit()
.composite_construct(result_type, None, args)
.unwrap()
.with_type(result_type),
}
}
_ => self.load_err(original_type, result_type),
}
}
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_load_intrinsic(
&mut self,
result_type: Word,
args: &[SpirvValue],
) -> SpirvValue {
// Signature: fn load<T>(array: &[u32], index: u32) -> T;
if args.len() != 3 {
self.fatal(&format!(
"buffer_load_intrinsic should have 3 args, it has {}",
args.len()
));
}
// Note that the &[u32] gets split into two arguments - pointer, length
let array = args[0];
let byte_index = args[2];
let two = self.constant_u32(DUMMY_SP, 2);
let word_index = self.lshr(byte_index, two);
self.recurse_load_type(result_type, result_type, array, word_index, 0)
}
fn store_err(&mut self, original_type: Word, value: SpirvValue) {
let mut err = self.struct_err(&format!(
"Cannot load type {} in an untyped buffer store",
self.debug_type(original_type)
));
if original_type != value.ty {
err.note(&format!("due to containing type {}", value.ty));
}
err.emit();
}
fn store_u32(
&mut self,
array: SpirvValue,
dynamic_index: SpirvValue,
constant_offset: u32,
value: SpirvValue,
) {
let actual_index = if constant_offset != 0 {
let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
self.add(dynamic_index, const_offset_val)
} else {
dynamic_index
};
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let u32_ptr = self.type_ptr_to(u32_ty);
let ptr = self
.emit()
.in_bounds_access_chain(
u32_ptr,
None,
array.def(self),
IntoIter::new([actual_index.def(self)]),
)
.unwrap()
.with_type(u32_ptr);
self.store(value, ptr, Align::ONE);
}
#[allow(clippy::too_many_arguments)]
fn store_vec_or_arr(
&mut self,
original_type: Word,
value: SpirvValue,
array: SpirvValue,
dynamic_word_index: SpirvValue,
constant_word_offset: u32,
element: Word,
count: u32,
) {
let element_size_bytes = match self.lookup_type(element).sizeof(self) {
Some(size) => size,
None => return self.store_err(original_type, value),
};
if element_size_bytes.bytes() % 4 != 0 {
return self.store_err(original_type, value);
}
let element_size_words = (element_size_bytes.bytes() / 4) as u32;
for index in 0..count {
let element = self.extract_value(value, index as u64);
self.recurse_store_type(
original_type,
element,
array,
dynamic_word_index,
constant_word_offset + element_size_words * index,
);
}
}
fn recurse_store_type(
&mut self,
original_type: Word,
value: SpirvValue,
array: SpirvValue,
dynamic_word_index: SpirvValue,
constant_word_offset: u32,
) {
match self.lookup_type(value.ty) {
SpirvType::Integer(32, signed) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.intcast(value, u32_ty, signed);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
}
SpirvType::Float(32) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.bitcast(value, u32_ty);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32);
}
SpirvType::Vector { element, count } => self.store_vec_or_arr(
original_type,
value,
array,
dynamic_word_index,
constant_word_offset,
element,
count,
),
SpirvType::Array { element, count } => {
let count = match self.builder.lookup_const_u64(count) {
Some(count) => count as u32,
None => return self.store_err(original_type, value),
};
self.store_vec_or_arr(
original_type,
value,
array,
dynamic_word_index,
constant_word_offset,
element,
count,
);
}
SpirvType::Adt {
size: Some(_),
field_offsets,
..
} => {
for (index, byte_offset) in field_offsets.iter().enumerate() {
if byte_offset.bytes() % 4 != 0 {
return self.store_err(original_type, value);
}
let word_offset = (byte_offset.bytes() / 4) as u32;
let field = self.extract_value(value, index as u64);
self.recurse_store_type(
original_type,
field,
array,
dynamic_word_index,
constant_word_offset + word_offset,
);
}
}
_ => self.store_err(original_type, value),
}
}
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue]) {
// Signature: fn store<T>(array: &[u32], index: u32, value: T);
if args.len() != 4 {
self.fatal(&format!(
"buffer_store_intrinsic should have 4 args, it has {}",
args.len()
));
}
// Note that the &[u32] gets split into two arguments - pointer, length
let array = args[0];
let byte_index = args[2];
let two = self.constant_u32(DUMMY_SP, 2);
let word_index = self.lshr(byte_index, two);
let value = args[3];
self.recurse_store_type(value.ty, value, array, word_index, 0);
}
}

View File

@ -1,4 +1,5 @@
mod builder_methods;
mod byte_addressable_buffer;
mod ext_inst;
mod intrinsics;
pub mod libm_intrinsics;

View File

@ -25,6 +25,13 @@ pub enum SpirvValueKind {
/// of such constants, instead of where they're generated (and cached).
IllegalConst(Word),
/// This can only happen in one specific case - which is as a result of
/// `codegen_buffer_store_intrinsic`, that function is supposed to return
/// OpTypeVoid, however because it gets inline by the compiler it can't.
/// Instead we return this, and trigger an error if we ever end up using the
/// result of this function call (which we can't).
IllegalTypeUsed(Word),
// FIXME(eddyb) this shouldn't be needed, but `rustc_codegen_ssa` still relies
// on converting `Function`s to `Value`s even for direct calls, the `Builder`
// should just have direct and indirect `call` variants (or a `Callee` enum).
@ -132,6 +139,16 @@ impl SpirvValue {
id
}
SpirvValueKind::IllegalTypeUsed(id) => {
cx.tcx
.sess
.struct_span_err(span, "Can't use type as a value")
.note(&format!("Type: *{}", cx.debug_type(id)))
.emit();
id
}
SpirvValueKind::FnAddr { .. } => {
if cx.is_system_crate() {
cx.builder

View File

@ -121,6 +121,12 @@ impl<'tcx> CodegenCx<'tcx> {
if attrs.unroll_loops.is_some() {
self.unroll_loops_decorations.borrow_mut().insert(fn_id);
}
if attrs.buffer_load_intrinsic.is_some() {
self.buffer_load_intrinsic_fn_id.borrow_mut().insert(fn_id);
}
if attrs.buffer_store_intrinsic.is_some() {
self.buffer_store_intrinsic_fn_id.borrow_mut().insert(fn_id);
}
let instance_def_id = instance.def_id();

View File

@ -66,6 +66,10 @@ pub struct CodegenCx<'tcx> {
/// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`.
pub panic_fn_id: Cell<Option<Word>>,
/// Intrinsic for loading a <T> from a &[u32]
pub buffer_load_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
/// Intrinsic for storing a <T> into a &[u32]
pub buffer_store_intrinsic_fn_id: RefCell<FxHashSet<Word>>,
/// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`.
pub panic_bounds_check_fn_id: Cell<Option<Word>>,
@ -123,6 +127,8 @@ impl<'tcx> CodegenCx<'tcx> {
instruction_table: InstructionTable::new(),
libm_intrinsics: Default::default(),
panic_fn_id: Default::default(),
buffer_load_intrinsic_fn_id: Default::default(),
buffer_store_intrinsic_fn_id: Default::default(),
panic_bounds_check_fn_id: Default::default(),
i8_i16_atomics_allowed: false,
codegen_args,

View File

@ -339,6 +339,11 @@ impl Symbols {
SpirvAttribute::IntrinsicType(IntrinsicType::Matrix),
),
("unroll_loops", SpirvAttribute::UnrollLoops),
("buffer_load_intrinsic", SpirvAttribute::BufferLoadIntrinsic),
(
"buffer_store_intrinsic",
SpirvAttribute::BufferStoreIntrinsic,
),
]
.iter()
.cloned();

View File

@ -0,0 +1,66 @@
use core::mem;
#[spirv(buffer_load_intrinsic)]
#[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
unimplemented!()
} // actually implemented in the compiler
#[spirv(buffer_store_intrinsic)]
#[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_store_intrinsic<T>(
_buffer: &mut [u32],
_offset: u32,
_value: T,
) {
unimplemented!()
} // actually implemented in the compiler
#[repr(transparent)]
pub struct ByteAddressableBuffer<'a> {
pub data: &'a mut [u32],
}
/// `ByteAddressableBuffer` is an untyped blob of data, allowing loads and stores of arbitrary
/// basic data types at arbitrary indicies. However, all data must be aligned to size 4, each
/// element within the data (e.g. struct fields) must have a size and alignment of a multiple of 4,
/// and the `byte_index` passed to load and store must be a multiple of 4 (`byte_index` will be
/// rounded down to the nearest multiple of 4). So, it's not technically a *byte* addressable
/// buffer, but rather a *word* buffer, but this naming and behavior was inhereted from HLSL (where
/// it's UB to pass in an index not a multiple of 4).
impl<'a> ByteAddressableBuffer<'a> {
#[inline]
pub fn new(data: &'a mut [u32]) -> Self {
Self { data }
}
/// Loads an arbitrary type from the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute)
pub unsafe fn load<T>(self, byte_index: u32) -> T {
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
panic!("Index out of range")
}
buffer_load_intrinsic(self.data, byte_index)
}
/// Stores an arbitrary type int the buffer. `byte_index` must be a multiple of 4, otherwise,
/// it will get silently rounded down to the nearest multiple of 4.
///
/// # Safety
/// This function allows writing a type to an untyped buffer, then reading a different type
/// from the same buffer, allowing all sorts of safety guarantees to be bypassed (effectively a
/// transmute)
pub unsafe fn store<T>(self, byte_index: u32, value: T) {
if byte_index + mem::size_of::<T>() as u32 > self.data.len() as u32 {
panic!("Index out of range")
}
buffer_store_intrinsic(self.data, byte_index, value);
}
}

View File

@ -96,6 +96,7 @@
pub extern crate spirv_std_macros as macros;
pub mod arch;
pub mod byte_addressable_buffer;
pub mod float;
pub mod image;
pub mod integer;
@ -109,6 +110,7 @@ pub mod vector;
pub use self::sampler::Sampler;
pub use crate::macros::Image;
pub use byte_addressable_buffer::ByteAddressableBuffer;
pub use num_traits;
pub use runtime_array::*;

View File

@ -0,0 +1,25 @@
// build-pass
use spirv_std::{glam::Vec4, ByteAddressableBuffer};
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut [i32; 4],
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
val: [i32; 4],
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,34 @@
// build-pass
use spirv_std::ByteAddressableBuffer;
pub struct BigStruct {
a: u32,
b: u32,
c: u32,
d: u32,
e: u32,
f: u32,
}
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
val: BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,40 @@
// build-pass
use spirv_std::{glam::Vec2, ByteAddressableBuffer};
pub struct Complex {
x: u32,
y: f32,
n: Nesty,
v: Vec2,
a: [f32; 7],
m: [Nesty; 2],
}
pub struct Nesty {
x: f32,
y: f32,
z: f32,
}
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut Nesty,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
val: Nesty,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,22 @@
// build-pass
use spirv_std::ByteAddressableBuffer;
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut f32,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: f32) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,22 @@
// build-pass
use spirv_std::ByteAddressableBuffer;
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut u32,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: u32) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,22 @@
// build-pass
use spirv_std::{glam::Vec4, ByteAddressableBuffer};
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut Vec4,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: Vec4) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}