Error when int doesn't have spirv(flat) (#815)

This commit is contained in:
Ashley Hauck 2021-12-06 11:31:26 +01:00 committed by GitHub
parent 0652153f1d
commit 340f4bb70a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
47 changed files with 127 additions and 69 deletions

View File

@ -517,11 +517,12 @@ impl<'tcx> CodegenCx<'tcx> {
);
}
self.check_for_bools(
self.check_for_bad_types(
hir_param.ty_span,
var_ptr_spirv_type,
storage_class,
attrs.builtin.is_some(),
attrs.flat.is_some(),
);
// Assign locations from left to right, incrementing each storage class
@ -563,7 +564,15 @@ impl<'tcx> CodegenCx<'tcx> {
}
// Booleans are only allowed in some storage classes. Error if they're in others.
fn check_for_bools(&self, span: Span, ty: Word, storage_class: StorageClass, is_builtin: bool) {
// Integers and f64s must be decorated with `#[spirv(flat)]`.
fn check_for_bad_types(
&self,
span: Span,
ty: Word,
storage_class: StorageClass,
is_builtin: bool,
is_flat: bool,
) {
// private and function are allowed here, but they can't happen.
// SPIR-V technically allows all input/output variables to be booleans, not just builtins,
// but has a note:
@ -578,15 +587,31 @@ impl<'tcx> CodegenCx<'tcx> {
{
return;
}
if recurse(self, ty) {
let mut has_bool = false;
let mut must_be_flat = false;
recurse(self, ty, &mut has_bool, &mut must_be_flat);
if has_bool {
self.tcx
.sess
.span_err(span, "entrypoint parameter cannot contain a boolean");
}
fn recurse(cx: &CodegenCx<'_>, ty: Word) -> bool {
if matches!(storage_class, StorageClass::Input | StorageClass::Output)
&& must_be_flat
&& !is_flat
{
self.tcx
.sess
.span_err(span, "parameter must be decorated with #[spirv(flat)]");
}
fn recurse(cx: &CodegenCx<'_>, ty: Word, has_bool: &mut bool, must_be_flat: &mut bool) {
match cx.lookup_type(ty) {
SpirvType::Bool => true,
SpirvType::Adt { field_types, .. } => field_types.iter().any(|&f| recurse(cx, f)),
SpirvType::Bool => *has_bool = true,
SpirvType::Integer(_, _) | SpirvType::Float(64) => *must_be_flat = true,
SpirvType::Adt { field_types, .. } => {
for f in field_types {
recurse(cx, f, has_bool, must_be_flat);
}
}
SpirvType::Vector { element, .. }
| SpirvType::Matrix { element, .. }
| SpirvType::Array { element, .. }
@ -594,12 +619,17 @@ impl<'tcx> CodegenCx<'tcx> {
| SpirvType::Pointer { pointee: element }
| SpirvType::InterfaceBlock {
inner_type: element,
} => recurse(cx, element),
} => recurse(cx, element, has_bool, must_be_flat),
SpirvType::Function {
return_type,
arguments,
} => recurse(cx, return_type) || arguments.iter().any(|&a| recurse(cx, a)),
_ => false,
} => {
recurse(cx, return_type, has_bool, must_be_flat);
for a in arguments {
recurse(cx, a, has_bool, must_be_flat);
}
}
_ => (),
}
}
}

View File

@ -5,7 +5,7 @@ use spirv_std::{glam::Vec4, ByteAddressableBuffer};
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut [i32; 4],
#[spirv(flat)] out: &mut [i32; 4],
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
@ -16,7 +16,7 @@ pub fn load(
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
val: [i32; 4],
#[spirv(flat)] val: [i32; 4],
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);

View File

@ -14,7 +14,7 @@ pub struct BigStruct {
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut BigStruct,
#[spirv(flat)] out: &mut BigStruct,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
@ -25,7 +25,7 @@ pub fn load(
#[spirv(fragment)]
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
val: BigStruct,
#[spirv(flat)] val: BigStruct,
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);

View File

@ -5,7 +5,7 @@ use spirv_std::ByteAddressableBuffer;
#[spirv(fragment)]
pub fn load(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
out: &mut u32,
#[spirv(flat)] out: &mut u32,
) {
unsafe {
let buf = ByteAddressableBuffer::new(buf);
@ -14,7 +14,10 @@ pub fn load(
}
#[spirv(fragment)]
pub fn store(#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32], val: u32) {
pub fn store(
#[spirv(descriptor_set = 0, binding = 0, storage_buffer)] buf: &mut [u32],
#[spirv(flat)] val: u32,
) {
unsafe {
let mut buf = ByteAddressableBuffer::new(buf);
buf.store(5, val);

View File

@ -24,7 +24,7 @@ impl Foo {
}
#[spirv(fragment)]
pub fn main(in_packed: u64, out_sum: &mut u32) {
pub fn main(#[spirv(flat)] in_packed: u64, #[spirv(flat)] out_sum: &mut u32) {
let foo = Foo::unpack(in_packed);
*out_sum = foo.a + (foo.b + foo.c) as u32;
}

View File

@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled),
output: &mut u32,
#[spirv(flat)] output: &mut u32,
) {
*output = image.query_levels();
}

View File

@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled, multisampled),
output: &mut u32,
#[spirv(flat)] output: &mut u32,
) {
*output = image.query_samples();
}

View File

@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled=false),
output: &mut glam::UVec2,
#[spirv(flat)] output: &mut glam::UVec2,
) {
*output = image.query_size();
}

View File

@ -6,7 +6,7 @@ use spirv_std::{arch, Image};
#[spirv(fragment)]
pub fn main(
#[spirv(descriptor_set = 0, binding = 0)] image: &Image!(2D, type=f32, sampled),
output: &mut glam::UVec2,
#[spirv(flat)] output: &mut glam::UVec2,
) {
*output = image.query_size_lod(0);
}

View File

@ -8,7 +8,7 @@ use spirv_std as _;
use glam::Vec4;
#[spirv(fragment)]
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], i: u32, out: &mut Vec4) {
pub fn main(#[spirv(push_constant)] array_in: &[Vec4; 16], #[spirv(flat)] i: u32, out: &mut Vec4) {
unsafe {
asm!(
"%val_ptr = OpAccessChain _ {array_ptr} {index}",

View File

@ -10,7 +10,7 @@ use glam::Vec4;
#[spirv(fragment)]
pub fn main(
#[spirv(storage_buffer, descriptor_set = 0, binding = 0)] slice_in: &[Vec4],
i: u32,
#[spirv(flat)] i: u32,
out: &mut Vec4,
) {
unsafe {

View File

@ -10,6 +10,6 @@ const OFFSETS: [f32; 18] = [
#[allow(unused_attributes)]
#[spirv(fragment)]
pub fn main(x: &mut u32) {
pub fn main(#[spirv(flat)] x: &mut u32) {
*x = OFFSETS.len() as u32;
}

View File

@ -18,11 +18,11 @@ fn array3_deep_load(r: &'static [&'static u32; 3]) -> [u32; 3] {
}
#[spirv(fragment)]
pub fn main_pair(pair_out: &mut (u32, f32)) {
pub fn main_pair(#[spirv(flat)] pair_out: &mut (u32, f32)) {
*pair_out = pair_deep_load(&(&123, &3.14));
}
#[spirv(fragment)]
pub fn main_array3(array3_out: &mut [u32; 3]) {
pub fn main_array3(#[spirv(flat)] array3_out: &mut [u32; 3]) {
*array3_out = array3_deep_load(&[&0, &1, &2]);
}

View File

@ -22,9 +22,9 @@ fn deep_transpose(r: &'static &'static Mat2) -> Mat2 {
#[spirv(fragment)]
pub fn main(
scalar_out: &mut u32,
#[spirv(flat)] scalar_out: &mut u32,
#[spirv(push_constant)] vec_in: &Vec2,
bool_out: &mut u32,
#[spirv(flat)] bool_out: &mut u32,
vec_out: &mut Vec2,
) {
*scalar_out = deep_load(&&123);

View File

@ -15,7 +15,12 @@ fn scalar_load(r: &'static u32) -> u32 {
const ROT90: Mat2 = const_mat2![[0.0, 1.0], [-1.0, 0.0]];
#[spirv(fragment)]
pub fn main(scalar_out: &mut u32, vec_in: Vec2, bool_out: &mut u32, vec_out: &mut Vec2) {
pub fn main(
#[spirv(flat)] scalar_out: &mut u32,
vec_in: Vec2,
#[spirv(flat)] bool_out: &mut u32,
vec_out: &mut Vec2,
) {
*scalar_out = scalar_load(&123);
*bool_out = (vec_in == Vec2::ZERO) as u32;
*vec_out = ROT90.transpose() * vec_in;

View File

@ -9,7 +9,7 @@ fn closure_user<F: FnMut(&u32, u32)>(ptr: &u32, xmax: u32, mut callback: F) {
}
#[spirv(fragment)]
pub fn main(ptr: &mut u32) {
pub fn main(#[spirv(flat)] ptr: &mut u32) {
closure_user(ptr, 10, |ptr, i| {
if *ptr == i {
spirv_std::arch::kill();

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 32 {
let current_position = 0;
if i < current_position {

View File

@ -3,6 +3,6 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
for _ in 0..i {}
}

View File

@ -25,6 +25,6 @@ impl<T: Num + Ord + Copy> Iterator for RangeIter<T> {
}
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
for _ in RangeIter(0..i) {}
}

View File

@ -3,6 +3,6 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i > 0 {}
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i > 0 {
} else {
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i > 0 {
} else if i < 0 {
} else {

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i > 0 {
if i < 10 {}
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i < 10 {
return;
} else {

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i < 10 {
return;
} else {

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i == 0 {
while i < 10 {}
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
if i > 0 {}
if i > 1 {}
}

View File

@ -3,6 +3,6 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {}
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
break;
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
continue;
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
if i == 0 {
break;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
if i == 0 {
break;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
if i == 0 {
break;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
if i == 0 {
continue;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
if i == 0 {
continue;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 10 {
return;
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 20 {
while i < 10 {}
}

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 20 {
while i < 10 {
break;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 20 {
while i < 10 {
continue;

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 20 {
while i < 10 {
if i > 10 {

View File

@ -3,7 +3,7 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(i: i32) {
pub fn main(#[spirv(flat)] i: i32) {
while i < 20 {
while i < 10 {
if i > 5 {

View File

@ -11,6 +11,6 @@ fn has_two_decimal_digits(x: u32) -> bool {
}
#[spirv(fragment)]
pub fn main(i: u32, o: &mut u32) {
pub fn main(#[spirv(flat)] i: u32, #[spirv(flat)] o: &mut u32) {
*o = has_two_decimal_digits(i) as u32;
}

View File

@ -7,6 +7,6 @@
use spirv_std as _;
#[spirv(fragment)]
pub fn main(out: &mut u32) {
pub fn main(#[spirv(flat)] out: &mut u32) {
*out = None.unwrap_or(15);
}

View File

@ -5,61 +5,61 @@ use spirv_std::float::*;
use spirv_std::glam::{Vec2, Vec4};
#[spirv(fragment)]
pub fn test_vec2_to_f16x2(i: Vec2, o: &mut u32) {
pub fn test_vec2_to_f16x2(i: Vec2, #[spirv(flat)] o: &mut u32) {
*o = vec2_to_f16x2(i);
}
#[spirv(fragment)]
pub fn test_f16x2_to_vec2(i: u32, o: &mut Vec2) {
pub fn test_f16x2_to_vec2(#[spirv(flat)] i: u32, o: &mut Vec2) {
*o = f16x2_to_vec2(i);
}
#[spirv(fragment)]
pub fn test_f32_to_f16(i: f32, o: &mut u32) {
pub fn test_f32_to_f16(i: f32, #[spirv(flat)] o: &mut u32) {
*o = f32_to_f16(i);
}
#[spirv(fragment)]
pub fn test_f16_to_f32(i: u32, o: &mut f32) {
pub fn test_f16_to_f32(#[spirv(flat)] i: u32, o: &mut f32) {
*o = f16_to_f32(i);
}
#[spirv(fragment)]
pub fn test_vec4_to_u8x4_snorm(i: Vec4, o: &mut u32) {
pub fn test_vec4_to_u8x4_snorm(i: Vec4, #[spirv(flat)] o: &mut u32) {
*o = vec4_to_u8x4_snorm(i);
}
#[spirv(fragment)]
pub fn test_vec4_to_u8x4_unorm(i: Vec4, o: &mut u32) {
pub fn test_vec4_to_u8x4_unorm(i: Vec4, #[spirv(flat)] o: &mut u32) {
*o = vec4_to_u8x4_unorm(i);
}
#[spirv(fragment)]
pub fn test_vec2_to_u16x2_snorm(i: Vec2, o: &mut u32) {
pub fn test_vec2_to_u16x2_snorm(i: Vec2, #[spirv(flat)] o: &mut u32) {
*o = vec2_to_u16x2_snorm(i);
}
#[spirv(fragment)]
pub fn test_vec2_to_u16x2_unorm(i: Vec2, o: &mut u32) {
pub fn test_vec2_to_u16x2_unorm(i: Vec2, #[spirv(flat)] o: &mut u32) {
*o = vec2_to_u16x2_unorm(i);
}
#[spirv(fragment)]
pub fn test_u8x4_to_vec4_snorm(i: u32, o: &mut Vec4) {
pub fn test_u8x4_to_vec4_snorm(#[spirv(flat)] i: u32, o: &mut Vec4) {
*o = u8x4_to_vec4_snorm(i);
}
#[spirv(fragment)]
pub fn test_u8x4_to_vec4_unorm(i: u32, o: &mut Vec4) {
pub fn test_u8x4_to_vec4_unorm(#[spirv(flat)] i: u32, o: &mut Vec4) {
*o = u8x4_to_vec4_unorm(i);
}
#[spirv(fragment)]
pub fn test_u16x2_to_vec2_snorm(i: u32, o: &mut Vec2) {
pub fn test_u16x2_to_vec2_snorm(#[spirv(flat)] i: u32, o: &mut Vec2) {
*o = u16x2_to_vec2_snorm(i);
}
#[spirv(fragment)]
pub fn test_u16x2_to_vec2_unorm(i: u32, o: &mut Vec2) {
pub fn test_u16x2_to_vec2_unorm(#[spirv(flat)] i: u32, o: &mut Vec2) {
*o = u16x2_to_vec2_unorm(i);
}

View File

@ -12,6 +12,6 @@ fn track_caller_maybe_panic(x: u32) {
}
#[spirv(fragment)]
pub fn main(x: u32) {
pub fn main(#[spirv(flat)] x: u32) {
track_caller_maybe_panic(x);
}

View File

@ -0,0 +1,6 @@
// build-fail
use spirv_std as _;
#[spirv(fragment)]
pub fn fragment(int: u32, double: f64) {}

View File

@ -0,0 +1,14 @@
error: parameter must be decorated with #[spirv(flat)]
--> $DIR/int-without-flat.rs:6:22
|
6 | pub fn fragment(int: u32, double: f64) {}
| ^^^
error: parameter must be decorated with #[spirv(flat)]
--> $DIR/int-without-flat.rs:6:35
|
6 | pub fn fragment(int: u32, double: f64) {}
| ^^^
error: aborting due to 2 previous errors