Fix ByteAddressableBuffer PassMode::Pair (#837)

This commit is contained in:
Ashley Hauck 2022-01-10 10:35:03 +01:00 committed by GitHub
parent b99fc516e6
commit fe5c7716ed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 181 additions and 47 deletions

View File

@ -2188,7 +2188,17 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
for (argument, argument_type) in args.iter().zip(argument_types) { for (argument, argument_type) in args.iter().zip(argument_types) {
assert_ty_eq!(self, argument.ty, argument_type); assert_ty_eq!(self, argument.ty, argument_type);
} }
let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).cloned(); let libm_intrinsic = self.libm_intrinsics.borrow().get(&callee_val).copied();
let buffer_load_intrinsic = self
.buffer_load_intrinsic_fn_id
.borrow()
.get(&callee_val)
.copied();
let buffer_store_intrinsic = self
.buffer_store_intrinsic_fn_id
.borrow()
.get(&callee_val)
.copied();
if let Some(libm_intrinsic) = libm_intrinsic { if let Some(libm_intrinsic) = libm_intrinsic {
let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args); let result = self.call_libm_intrinsic(libm_intrinsic, result_type, args);
if result_type != result.ty { if result_type != result.ty {
@ -2207,18 +2217,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// needing to materialize `&core::panic::Location` or `format_args!`. // needing to materialize `&core::panic::Location` or `format_args!`.
self.abort(); self.abort();
self.undef(result_type) self.undef(result_type)
} else if self } else if let Some(mode) = buffer_load_intrinsic {
.buffer_load_intrinsic_fn_id self.codegen_buffer_load_intrinsic(result_type, args, mode)
.borrow() } else if let Some(mode) = buffer_store_intrinsic {
.contains(&callee_val) self.codegen_buffer_store_intrinsic(args, mode);
{
self.codegen_buffer_load_intrinsic(result_type, args)
} else if self
.buffer_store_intrinsic_fn_id
.borrow()
.contains(&callee_val)
{
self.codegen_buffer_store_intrinsic(args);
let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self); let void_ty = SpirvType::Void.def(rustc_span::DUMMY_SP, self);
SpirvValue { SpirvValue {

View File

@ -1,10 +1,12 @@
use super::Builder; use super::Builder;
use crate::builder_spirv::{SpirvValue, SpirvValueExt}; use crate::builder_spirv::{SpirvValue, SpirvValueExt, SpirvValueKind};
use crate::spirv_type::SpirvType; use crate::spirv_type::SpirvType;
use rspirv::spirv::Word; use rspirv::spirv::Word;
use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods}; use rustc_codegen_ssa::traits::{BaseTypeMethods, BuilderMethods};
use rustc_errors::ErrorReported;
use rustc_span::DUMMY_SP; use rustc_span::DUMMY_SP;
use rustc_target::abi::Align; use rustc_target::abi::call::PassMode;
use rustc_target::abi::{Align, Size};
impl<'a, 'tcx> Builder<'a, 'tcx> { impl<'a, 'tcx> Builder<'a, 'tcx> {
fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue { fn load_err(&mut self, original_type: Word, invalid_type: Word) -> SpirvValue {
@ -168,7 +170,25 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
&mut self, &mut self,
result_type: Word, result_type: Word,
args: &[SpirvValue], args: &[SpirvValue],
pass_mode: PassMode,
) -> SpirvValue { ) -> SpirvValue {
match pass_mode {
PassMode::Ignore => {
return SpirvValue {
kind: SpirvValueKind::IllegalTypeUsed(result_type),
ty: result_type,
}
}
// PassMode::Pair is identical to PassMode::Direct - it's returned as a struct
PassMode::Direct(_) | PassMode::Pair(_, _) => (),
PassMode::Cast(_) => {
self.fatal("PassMode::Cast not supported in codegen_buffer_load_intrinsic")
}
PassMode::Indirect { .. } => {
self.fatal("PassMode::Indirect not supported in codegen_buffer_load_intrinsic")
}
}
// Signature: fn load<T>(array: &[u32], index: u32) -> T; // Signature: fn load<T>(array: &[u32], index: u32) -> T;
if args.len() != 3 { if args.len() != 3 {
self.fatal(&format!( self.fatal(&format!(
@ -184,15 +204,16 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.recurse_load_type(result_type, result_type, array, word_index, 0) self.recurse_load_type(result_type, result_type, array, word_index, 0)
} }
fn store_err(&mut self, original_type: Word, value: SpirvValue) { fn store_err(&mut self, original_type: Word, value: SpirvValue) -> Result<(), ErrorReported> {
let mut err = self.struct_err(&format!( let mut err = self.struct_err(&format!(
"Cannot load type {} in an untyped buffer store", "Cannot store type {} in an untyped buffer store",
self.debug_type(original_type) self.debug_type(original_type)
)); ));
if original_type != value.ty { if original_type != value.ty {
err.note(&format!("due to containing type {}", value.ty)); err.note(&format!("due to containing type {}", value.ty));
} }
err.emit(); err.emit();
Err(ErrorReported)
} }
fn store_u32( fn store_u32(
@ -201,7 +222,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
dynamic_index: SpirvValue, dynamic_index: SpirvValue,
constant_offset: u32, constant_offset: u32,
value: SpirvValue, value: SpirvValue,
) { ) -> Result<(), ErrorReported> {
let actual_index = if constant_offset != 0 { let actual_index = if constant_offset != 0 {
let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset); let const_offset_val = self.constant_u32(DUMMY_SP, constant_offset);
self.add(dynamic_index, const_offset_val) self.add(dynamic_index, const_offset_val)
@ -216,6 +237,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.unwrap() .unwrap()
.with_type(u32_ptr); .with_type(u32_ptr);
self.store(value, ptr, Align::ONE); self.store(value, ptr, Align::ONE);
Ok(())
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -228,7 +250,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
constant_word_offset: u32, constant_word_offset: u32,
element: Word, element: Word,
count: u32, count: u32,
) { ) -> Result<(), ErrorReported> {
let element_size_bytes = match self.lookup_type(element).sizeof(self) { let element_size_bytes = match self.lookup_type(element).sizeof(self) {
Some(size) => size, Some(size) => size,
None => return self.store_err(original_type, value), None => return self.store_err(original_type, value),
@ -245,8 +267,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array, array,
dynamic_word_index, dynamic_word_index,
constant_word_offset + element_size_words * index, constant_word_offset + element_size_words * index,
); )?;
} }
Ok(())
} }
fn recurse_store_type( fn recurse_store_type(
@ -256,17 +279,17 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array: SpirvValue, array: SpirvValue,
dynamic_word_index: SpirvValue, dynamic_word_index: SpirvValue,
constant_word_offset: u32, constant_word_offset: u32,
) { ) -> Result<(), ErrorReported> {
match self.lookup_type(value.ty) { match self.lookup_type(value.ty) {
SpirvType::Integer(32, signed) => { SpirvType::Integer(32, signed) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.intcast(value, u32_ty, signed); let value_u32 = self.intcast(value, u32_ty, signed);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32); self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
} }
SpirvType::Float(32) => { SpirvType::Float(32) => {
let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self); let u32_ty = SpirvType::Integer(32, false).def(DUMMY_SP, self);
let value_u32 = self.bitcast(value, u32_ty); let value_u32 = self.bitcast(value, u32_ty);
self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32); self.store_u32(array, dynamic_word_index, constant_word_offset, value_u32)
} }
SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self SpirvType::Vector { element, count } | SpirvType::Matrix { element, count } => self
.store_vec_mat_arr( .store_vec_mat_arr(
@ -291,7 +314,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
constant_word_offset, constant_word_offset,
element, element,
count, count,
); )
} }
SpirvType::Adt { SpirvType::Adt {
size: Some(_), size: Some(_),
@ -310,8 +333,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
array, array,
dynamic_word_index, dynamic_word_index,
constant_word_offset + word_offset, constant_word_offset + word_offset,
); )?;
} }
Ok(())
} }
_ => self.store_err(original_type, value), _ => self.store_err(original_type, value),
@ -319,11 +343,25 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
} }
/// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller. /// Note: DOES NOT do bounds checking! Bounds checking is expected to be done in the caller.
pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue]) { pub fn codegen_buffer_store_intrinsic(&mut self, args: &[SpirvValue], pass_mode: PassMode) {
// Signature: fn store<T>(array: &[u32], index: u32, value: T); // Signature: fn store<T>(array: &[u32], index: u32, value: T);
if args.len() != 4 { let is_pair = match pass_mode {
// haha shrug
PassMode::Ignore => return,
PassMode::Direct(_) => false,
PassMode::Pair(_, _) => true,
PassMode::Cast(_) => {
self.fatal("PassMode::Cast not supported in codegen_buffer_store_intrinsic")
}
PassMode::Indirect { .. } => {
self.fatal("PassMode::Indirect not supported in codegen_buffer_store_intrinsic")
}
};
let expected_args = if is_pair { 5 } else { 4 };
if args.len() != expected_args {
self.fatal(&format!( self.fatal(&format!(
"buffer_store_intrinsic should have 4 args, it has {}", "buffer_store_intrinsic should have {} args, it has {}",
expected_args,
args.len() args.len()
)); ));
} }
@ -332,7 +370,20 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let byte_index = args[2]; let byte_index = args[2];
let two = self.constant_u32(DUMMY_SP, 2); let two = self.constant_u32(DUMMY_SP, 2);
let word_index = self.lshr(byte_index, two); let word_index = self.lshr(byte_index, two);
if is_pair {
let value_one = args[3];
let value_two = args[4];
let one_result = self.recurse_store_type(value_one.ty, value_one, array, word_index, 0);
let size_of_one = self.lookup_type(value_one.ty).sizeof(self);
if one_result.is_ok() && size_of_one != Some(Size::from_bytes(4)) {
self.fatal("Expected PassMode::Pair first element to have size 4");
}
let _ = self.recurse_store_type(value_two.ty, value_two, array, word_index, 1);
} else {
let value = args[3]; let value = args[3];
self.recurse_store_type(value.ty, value, array, word_index, 0); let _ = self.recurse_store_type(value.ty, value, array, word_index, 0);
}
} }
} }

View File

@ -120,10 +120,16 @@ impl<'tcx> CodegenCx<'tcx> {
self.unroll_loops_decorations.borrow_mut().insert(fn_id); self.unroll_loops_decorations.borrow_mut().insert(fn_id);
} }
if attrs.buffer_load_intrinsic.is_some() { if attrs.buffer_load_intrinsic.is_some() {
self.buffer_load_intrinsic_fn_id.borrow_mut().insert(fn_id); let mode = fn_abi.ret.mode;
self.buffer_load_intrinsic_fn_id
.borrow_mut()
.insert(fn_id, mode);
} }
if attrs.buffer_store_intrinsic.is_some() { if attrs.buffer_store_intrinsic.is_some() {
self.buffer_store_intrinsic_fn_id.borrow_mut().insert(fn_id); let mode = fn_abi.args.last().unwrap().mode;
self.buffer_store_intrinsic_fn_id
.borrow_mut()
.insert(fn_id, mode);
} }
let instance_def_id = instance.def_id(); let instance_def_id = instance.def_id();

View File

@ -66,7 +66,7 @@ impl<'tcx> CodegenCx<'tcx> {
} }
// FIXME(eddyb) support these (by just ignoring them) - if there // FIXME(eddyb) support these (by just ignoring them) - if there
// is any validation concern, it should be done on the types. // is any validation concern, it should be done on the types.
PassMode::Ignore => self.tcx.sess.span_err( PassMode::Ignore => self.tcx.sess.span_fatal(
hir_param.ty_span, hir_param.ty_span,
&format!( &format!(
"entry point parameter type not yet supported \ "entry point parameter type not yet supported \

View File

@ -29,7 +29,7 @@ use rustc_session::Session;
use rustc_span::def_id::{DefId, LOCAL_CRATE}; use rustc_span::def_id::{DefId, LOCAL_CRATE};
use rustc_span::symbol::{sym, Symbol}; use rustc_span::symbol::{sym, Symbol};
use rustc_span::{SourceFile, Span, DUMMY_SP}; use rustc_span::{SourceFile, Span, DUMMY_SP};
use rustc_target::abi::call::FnAbi; use rustc_target::abi::call::{FnAbi, PassMode};
use rustc_target::abi::{HasDataLayout, TargetDataLayout}; use rustc_target::abi::{HasDataLayout, TargetDataLayout};
use rustc_target::spec::{HasTargetSpec, Target}; use rustc_target::spec::{HasTargetSpec, Target};
use std::cell::{Cell, RefCell}; use std::cell::{Cell, RefCell};
@ -66,10 +66,10 @@ pub struct CodegenCx<'tcx> {
/// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`. /// Simple `panic!("...")` and builtin panics (from MIR `Assert`s) call `#[lang = "panic"]`.
pub panic_fn_id: Cell<Option<Word>>, pub panic_fn_id: Cell<Option<Word>>,
/// Intrinsic for loading a <T> from a &[u32] /// Intrinsic for loading a <T> from a &[u32]. The PassMode is the mode of the <T>.
pub buffer_load_intrinsic_fn_id: RefCell<FxHashSet<Word>>, pub buffer_load_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
/// Intrinsic for storing a <T> into a &[u32] /// Intrinsic for storing a <T> into a &[u32]. The PassMode is the mode of the <T>.
pub buffer_store_intrinsic_fn_id: RefCell<FxHashSet<Word>>, pub buffer_store_intrinsic_fn_id: RefCell<FxHashMap<Word, PassMode>>,
/// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`. /// Builtin bounds-checking panics (from MIR `Assert`s) call `#[lang = "panic_bounds_check"]`.
pub panic_bounds_check_fn_id: Cell<Option<Word>>, pub panic_bounds_check_fn_id: Cell<Option<Word>>,

View File

@ -5,18 +5,14 @@ use core::mem;
#[spirv(buffer_load_intrinsic)] #[spirv(buffer_load_intrinsic)]
#[spirv_std_macros::gpu_only] #[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T { unsafe fn buffer_load_intrinsic<T>(_buffer: &[u32], _offset: u32) -> T {
unimplemented!() unimplemented!()
} // actually implemented in the compiler } // actually implemented in the compiler
#[spirv(buffer_store_intrinsic)] #[spirv(buffer_store_intrinsic)]
#[spirv_std_macros::gpu_only] #[spirv_std_macros::gpu_only]
#[allow(improper_ctypes_definitions)] #[allow(improper_ctypes_definitions)]
unsafe extern "unadjusted" fn buffer_store_intrinsic<T>( unsafe fn buffer_store_intrinsic<T>(_buffer: &mut [u32], _offset: u32, _value: T) {
_buffer: &mut [u32],
_offset: u32,
_value: T,
) {
unimplemented!() unimplemented!()
} // actually implemented in the compiler } // actually implemented in the compiler

View File

@ -2,7 +2,6 @@
#![cfg_attr( #![cfg_attr(
target_arch = "spirv", target_arch = "spirv",
feature( feature(
abi_unadjusted,
asm, asm,
asm_const, asm_const,
asm_experimental_arch, asm_experimental_arch,

23
tests/README.md Normal file
View File

@ -0,0 +1,23 @@
# Compiletests
This folder contains tests known as "compiletests". Each file in the `ui` folder corresponds to a
single compiletest. The way they work is a tool iterates over every file, and tries to compile it.
At the start of the file, there's some meta-comments about the expected result of the compile:
whether it should succeed compilation, or fail. If it is expected to fail, there's a corresponding
.stderr file next to the file that contains the expected compiler error message.
The `src` folder here is the tool that iterates over every file in the `ui` folder. It uses the
`compiletests` library, taken from rustc's own compiletest framework.
You can run compiletests via `cargo compiletests`. This is an alias set up in `.cargo/config` for
`cargo run --release -p compiletests --`. You can filter to run specific tests by passing the
(partial) filenames to `cargo compiletests some_file_name`, and update the `.stderr` files to
contain new output via the `--bless` flag (with `--bless`, make sure you're actually supposed to be
changing the .stderr files due to an intentional change, and hand-validate the output is correct
afterwards).
Keep in mind that tests here here are not executed, merely checked for errors (including validating
the resulting binary with spirv-val). Because of this, there might be some strange code in here -
the point isn't to make a fully functional shader every time (that would take an annoying amount of
effort), but rather validate that specific parts of the compiler are doing their job correctly
(either succeeding as they should, or erroring as they should).

View File

@ -96,9 +96,9 @@ error[E0277]: the trait bound `{float}: Vector<f32, 2_usize>` is not satisfied
<DVec2 as Vector<f64, 2_usize>> <DVec2 as Vector<f64, 2_usize>>
and 13 others and 13 others
note: required by a bound in `debug_printf_assert_is_vector` note: required by a bound in `debug_printf_assert_is_vector`
--> $SPIRV_STD_SRC/lib.rs:146:8 --> $SPIRV_STD_SRC/lib.rs:145:8
| |
146 | V: crate::vector::Vector<TY, SIZE>, 145 | V: crate::vector::Vector<TY, SIZE>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector` | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `debug_printf_assert_is_vector`
error[E0308]: mismatched types error[E0308]: mismatched types

View File

@ -0,0 +1,25 @@
// build-pass
use spirv_std::ByteAddressableBuffer;
pub struct EmptyStruct {}
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] out: &mut EmptyStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32]) {
let val = EmptyStruct {};
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}

View File

@ -0,0 +1,32 @@
// build-pass
use spirv_std::ByteAddressableBuffer;
pub struct SmallStruct {
a: u32,
b: u32,
}
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] out: &mut SmallStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
*out = buf.load(5);
}
}
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] a: u32,
#[spirv(flat)] b: u32,
) {
let val = SmallStruct { a, b };
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);
}
}