Vulkano-shaders struct generation refactor (#1945)

This commit is contained in:
Rua 2022-08-07 12:54:38 +02:00 committed by GitHub
parent b0668f5a31
commit 69b72362ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 353 additions and 340 deletions

View File

@ -246,13 +246,8 @@ where
let entry_points = reflect::entry_points(&spirv)
.map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info));
let specialization_constants = structs::write_specialization_constants(
prefix,
&spirv,
types_meta,
shared_constants,
types_registry,
);
let specialization_constants =
structs::write_specialization_constants(prefix, &spirv, shared_constants, types_registry);
let load_name = if prefix.is_empty() {
format_ident!("load")

View File

@ -10,126 +10,122 @@
use crate::{RegisteredType, TypesMeta};
use heck::ToUpperCamelCase;
use proc_macro2::{Span, TokenStream};
use std::borrow::Cow;
use std::collections::HashMap;
use std::mem;
use syn::Ident;
use syn::LitStr;
use std::{borrow::Cow, collections::HashMap, mem};
use syn::{Ident, LitStr};
use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv};
/// Translates all the structs that are contained in the SPIR-V document as Rust structs.
pub(super) fn write_structs<'a>(
shader: &'a str,
spirv: &Spirv,
types_meta: &TypesMeta,
spirv: &'a Spirv,
types_meta: &'a TypesMeta,
types_registry: &'a mut HashMap<String, RegisteredType>,
) -> TokenStream {
let structs = spirv
spirv
.iter_global()
.filter_map(|instruction| match instruction {
Instruction::TypeStruct {
&Instruction::TypeStruct {
result_id,
member_types,
} => Some(
write_struct(
shader,
spirv,
*result_id,
member_types,
types_meta,
Some(types_registry),
)
.0,
),
ref member_types,
} => Some((result_id, member_types)),
_ => None,
});
})
.filter(|&(struct_id, _member_types)| has_defined_layout(spirv, struct_id))
.filter_map(|(struct_id, member_types)| {
let (rust_members, is_sized) =
write_struct_members(shader, spirv, struct_id, member_types);
quote! {
#( #structs )*
}
let struct_name = spirv
.id(struct_id)
.iter_name()
.find_map(|instruction| match instruction {
Instruction::Name { name, .. } => Some(name.as_str()),
_ => None,
})
.unwrap_or("__unnamed");
// Register the type if needed
if !register_struct(types_registry, shader, &rust_members, struct_name) {
return None;
}
let struct_ident = format_ident!("{}", struct_name);
let members = rust_members
.iter()
.map(|Member { name, ty, .. }| quote!(pub #name: #ty,));
let struct_body = quote! {
#[repr(C)]
#[allow(non_snake_case)]
pub struct #struct_ident {
#( #members )*
}
};
Some(if is_sized {
let derives = write_derives(types_meta);
let impls = write_impls(types_meta, &struct_name, &rust_members);
quote! {
#derives
#struct_body
#(#impls)*
}
} else {
struct_body
})
})
.collect()
}
/// Analyzes a single struct, returns a string containing its Rust definition, plus its size.
fn write_struct<'a>(
// The members of this struct.
struct Member {
name: Ident,
is_dummy: bool,
ty: TokenStream,
signature: Cow<'static, str>,
}
fn write_struct_members<'a>(
shader: &'a str,
spirv: &Spirv,
struct_id: Id,
members: &[Id],
types_meta: &TypesMeta,
types_registry: Option<&'a mut HashMap<String, RegisteredType>>,
) -> (TokenStream, Option<usize>) {
let id_info = spirv.id(struct_id);
let name = Ident::new(
id_info
.iter_name()
.find_map(|instruction| match instruction {
Instruction::Name { name, .. } => Some(name.as_str()),
_ => None,
})
.unwrap_or("__unnamed"),
Span::call_site(),
);
// The members of this struct.
struct Member {
name: Ident,
dummy: bool,
ty: TokenStream,
signature: Cow<'static, str>,
}
) -> (Vec<Member>, bool) {
let mut rust_members = Vec::with_capacity(members.len());
// Padding structs will be named `_paddingN` where `N` is determined by this variable.
let mut next_padding_num = 0;
// Dummy members will be named `_dummyN` where `N` is determined by this variable.
let mut next_dummy_num = 0;
// Contains the offset of the next field.
// Equals to `None` if there's a runtime-sized field in there.
let mut current_rust_offset = Some(0);
for (&member, member_info) in members.iter().zip(id_info.iter_members()) {
for (member_index, (&member, member_info)) in members
.iter()
.zip(spirv.id(struct_id).iter_members())
.enumerate()
{
// Compute infos about the member.
let (ty, signature, rust_size, rust_align) =
type_from_id(shader, spirv, member, types_meta);
let (ty, signature, rust_size, rust_align) = type_from_id(shader, spirv, member);
let member_name = member_info
.iter_name()
.find_map(|instruction| match instruction {
Instruction::MemberName { name, .. } => Some(name.as_str()),
Instruction::MemberName { name, .. } => Some(Cow::from(name.as_str())),
_ => None,
})
.unwrap_or("__unnamed");
// Ignore the whole struct is a member is built in, which includes
// `gl_Position` for example.
if member_info.iter_decoration().any(|instruction| {
matches!(
instruction,
Instruction::MemberDecorate {
decoration: Decoration::BuiltIn { .. },
..
}
)
}) {
return (quote! {}, None); // TODO: is this correct? shouldn't it return a correct struct but with a flag or something?
}
.unwrap_or_else(|| Cow::from(format!("__unnamed{}", member_index)));
// Finding offset of the current member, as requested by the SPIR-V code.
let spirv_offset =
member_info
.iter_decoration()
.find_map(|instruction| match instruction {
Instruction::MemberDecorate {
decoration: Decoration::Offset { byte_offset },
..
} => Some(*byte_offset),
_ => None,
});
// Some structs don't have `Offset` decorations, in the case they are used as local
// variables only. Ignoring these.
let spirv_offset = match spirv_offset {
Some(o) => o as usize,
None => return (quote! {}, None), // TODO: shouldn't we return and let the caller ignore it instead?
};
let spirv_offset = member_info
.iter_decoration()
.find_map(|instruction| match instruction {
Instruction::MemberDecorate {
decoration: Decoration::Offset { byte_offset },
..
} => Some(*byte_offset as usize),
_ => None,
})
.unwrap();
// We need to add a dummy field if necessary.
{
@ -146,14 +142,13 @@ fn write_struct<'a>(
if spirv_offset != *current_rust_offset {
let diff = spirv_offset.checked_sub(*current_rust_offset).unwrap();
let padding_num = next_padding_num;
next_padding_num += 1;
rust_members.push(Member {
name: Ident::new(&format!("_dummy{}", padding_num), Span::call_site()),
dummy: true,
name: format_ident!("_dummy{}", next_dummy_num.to_string()),
is_dummy: true,
ty: quote! { [u8; #diff] },
signature: Cow::from(format!("[u8; {}]", diff)),
});
next_dummy_num += 1;
*current_rust_offset += diff;
}
}
@ -167,194 +162,159 @@ fn write_struct<'a>(
rust_members.push(Member {
name: Ident::new(&member_name, Span::call_site()),
dummy: false,
is_dummy: false,
ty,
signature,
});
}
// Try determine the total size of the struct in order to add padding at the end of the struct.
let mut spirv_req_total_size = None;
for inst in spirv.iter_global() {
match inst {
Instruction::TypeArray {
result_id,
element_type,
..
}
| Instruction::TypeRuntimeArray {
result_id,
element_type,
} if *element_type == struct_id => {
spirv_req_total_size =
spirv
.id(*result_id)
.iter_decoration()
.find_map(|instruction| match instruction {
Instruction::Decorate {
decoration: Decoration::ArrayStride { array_stride },
..
} => Some(*array_stride),
_ => None,
})
}
_ => (),
}
}
// Adding the final padding members, if the struct is sized.
if let Some(cur_size) = current_rust_offset {
// Try to determine the total size of the struct.
if let Some(req_size) = struct_size_from_array_stride(spirv, struct_id) {
let diff = req_size.checked_sub(cur_size as u32).unwrap();
// Adding the final padding members.
if let (Some(cur_size), Some(req_size)) = (current_rust_offset, spirv_req_total_size) {
let diff = req_size.checked_sub(cur_size as u32).unwrap();
if diff >= 1 {
rust_members.push(Member {
name: Ident::new(&format!("_dummy{}", next_padding_num), Span::call_site()),
dummy: true,
ty: quote! { [u8; #diff as usize] },
signature: Cow::from(format!("[u8; {}]", diff)),
});
}
}
let total_size = spirv_req_total_size
.map(|sz| sz as usize)
.or(current_rust_offset);
// For single shader-mode registration mechanism skipped.
if let Some(types_registry) = types_registry {
let target_type = RegisteredType {
shader: shader.to_string(),
signature: rust_members
.iter()
.map(|member| (member.name.to_string(), member.signature.clone()))
.collect(),
};
let name = name.to_string();
// Checking with Registry if this struct already registered by another shader, and if their
// signatures match.
if let Some(registered) = types_registry.get(name.as_str()) {
registered.assert_signatures(name.as_str(), &target_type);
// If the struct already registered and matches this one, skip duplicate.
return (quote! {}, total_size);
}
assert!(types_registry.insert(name, target_type).is_none());
}
// We can only implement Clone if there's no unsized member in the struct.
let (clone_impl, copy_derive) =
if current_rust_offset.is_some() && (types_meta.clone || types_meta.copy) {
(
if types_meta.clone {
let mut copies = vec![];
for member in &rust_members {
let name = &member.name;
copies.push(quote! { #name: self.#name, });
}
// Clone is implemented manually because members can be large arrays
// that do not implement Clone, but do implement Copy
quote! {
impl Clone for #name {
fn clone(&self) -> Self {
#name {
#( #copies )*
}
}
}
}
} else {
quote! {}
},
if types_meta.copy {
quote! { #[derive(Copy)] }
} else {
quote! {}
},
)
} else {
(quote! {}, quote! {})
};
let partial_eq_impl = if current_rust_offset.is_some() && types_meta.partial_eq {
let mut fields = vec![];
for member in &rust_members {
if !member.dummy {
let name = &member.name;
fields.push(quote! {
if self.#name != other.#name {
return false
}
if diff >= 1 {
rust_members.push(Member {
name: Ident::new(&format!("_dummy{}", next_dummy_num), Span::call_site()),
is_dummy: true,
ty: quote! { [u8; #diff as usize] },
signature: Cow::from(format!("[u8; {}]", diff)),
});
}
}
}
(rust_members, current_rust_offset.is_some())
}
fn register_struct(
types_registry: &mut HashMap<String, RegisteredType>,
shader: &str,
rust_members: &[Member],
struct_name: &str,
) -> bool {
let target_type = RegisteredType {
shader: shader.to_string(),
signature: rust_members
.iter()
.map(|member| (member.name.to_string(), member.signature.clone()))
.collect(),
};
// Checking with Registry if this struct already registered by another shader, and if their
// signatures match.
if let Some(registered) = types_registry.get(struct_name) {
registered.assert_signatures(struct_name, &target_type);
// If the struct already registered and matches this one, skip duplicate.
false
} else {
assert!(types_registry
.insert(struct_name.to_owned(), target_type)
.is_none());
true
}
}
fn write_derives(types_meta: &TypesMeta) -> TokenStream {
let mut derives = vec![];
if types_meta.clone {
derives.push(quote! { Clone });
}
if types_meta.copy {
derives.push(quote! { Copy });
}
derives.extend(
types_meta
.custom_derives
.iter()
.map(|derive| quote! { #derive }),
);
if !derives.is_empty() {
quote! {
#[derive(#(#derives),*)]
}
} else {
quote! {}
}
}
fn write_impls<'a>(
types_meta: &'a TypesMeta,
struct_name: &'a str,
rust_members: &'a [Member],
) -> impl Iterator<Item = TokenStream> + 'a {
let struct_ident = format_ident!("{}", struct_name);
(types_meta.partial_eq.then(|| {
let fields = rust_members
.iter()
.filter(|Member { is_dummy, .. }| !is_dummy)
.map(|Member { name, .. }| {
quote! {
if self.#name != other.#name {
return false
}
}
});
quote! {
impl PartialEq for #name {
impl PartialEq for #struct_ident {
fn eq(&self, other: &Self) -> bool {
#( #fields )*
true
}
}
}
} else {
quote! {}
};
let (debug_impl, display_impl) = if current_rust_offset.is_some()
&& (types_meta.debug || types_meta.display)
{
let mut fields = vec![];
for member in &rust_members {
if !member.dummy {
let name = &member.name;
}).into_iter())
.chain(types_meta.debug.then(|| {
let fields = rust_members
.iter()
.filter(|Member { is_dummy, .. }| !is_dummy)
.map(|Member { name, .. }| {
let name_string = LitStr::new(name.to_string().as_ref(), name.span());
quote! { .field(#name_string, &self.#name) }
});
fields.push(quote! {.field(#name_string, &self.#name)});
quote! {
impl std::fmt::Debug for #struct_ident {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
formatter
.debug_struct(#struct_name)
#( #fields )*
.finish()
}
}
}
}))
.chain(types_meta.display.then(|| {
let fields = rust_members
.iter()
.filter(|Member { is_dummy, .. }| !is_dummy)
.map(|Member { name, .. }| {
let name_string = LitStr::new(name.to_string().as_ref(), name.span());
quote! { .field(#name_string, &self.#name) }
});
let name_string = LitStr::new(name.to_string().as_ref(), name.span());
(
if types_meta.debug {
quote! {
impl std::fmt::Debug for #name {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
formatter
.debug_struct(#name_string)
#( #fields )*
.finish()
}
}
}
} else {
quote! {}
},
if types_meta.display {
quote! {
impl std::fmt::Display for #name {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
formatter
.debug_struct(#name_string)
#( #fields )*
.finish()
}
}
}
} else {
quote! {}
},
)
} else {
(quote! {}, quote! {})
};
let default_impl = if current_rust_offset.is_some() && types_meta.default {
quote! {
impl Default for #name {
impl std::fmt::Display for #struct_ident {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> Result<(), std::fmt::Error> {
formatter
.debug_struct(#struct_name)
#( #fields )*
.finish()
}
}
}
}))
.chain(types_meta.default.then(|| {
quote! {
impl Default for #struct_ident {
fn default() -> Self {
unsafe {
std::mem::MaybeUninit::<Self>::zeroed().assume_init()
@ -362,53 +322,73 @@ fn write_struct<'a>(
}
}
}
} else {
quote! {}
};
}))
.chain(types_meta.impls.iter().map(move |i| quote!{ #i for #struct_ident {} }))
}
// If the struct has unsized members none of custom impls applied.
let custom_impls = if current_rust_offset.is_some() {
let impls = &types_meta.impls;
fn has_defined_layout(spirv: &Spirv, struct_id: Id) -> bool {
for member_info in spirv.id(struct_id).iter_members() {
let mut offset_found = false;
quote! {
#( #impls for #name {} )*
for instruction in member_info.iter_decoration() {
match instruction {
Instruction::MemberDecorate {
decoration: Decoration::BuiltIn { .. },
..
} => {
// Ignore the whole struct if a member is built in, which includes
// `gl_Position` for example.
return false;
}
Instruction::MemberDecorate {
decoration: Decoration::Offset { .. },
..
} => {
offset_found = true;
}
_ => (),
}
}
} else {
quote! {}
};
// If the struct has unsized members none of custom derives applied.
let custom_derives = if current_rust_offset.is_some() && !types_meta.custom_derives.is_empty() {
let derive_list = &types_meta.custom_derives;
quote! { #[derive(#( #derive_list ),*)] }
} else {
quote! {}
};
let mut members = vec![];
for member in &rust_members {
let name = &member.name;
let ty = &member.ty;
members.push(quote!(pub #name: #ty,));
// Some structs don't have `Offset` decorations, in the case they are used as local
// variables only. Ignoring these.
if !offset_found {
return false;
}
}
let ast = quote! {
#[repr(C)]
#copy_derive
#custom_derives
#[allow(non_snake_case)]
pub struct #name {
#( #members )*
}
#clone_impl
#partial_eq_impl
#debug_impl
#display_impl
#default_impl
#custom_impls
};
true
}
(ast, total_size)
fn struct_size_from_array_stride(spirv: &Spirv, type_id: Id) -> Option<u32> {
let mut iter = spirv.iter_global().filter_map(|inst| match inst {
Instruction::TypeArray {
result_id,
element_type,
..
}
| Instruction::TypeRuntimeArray {
result_id,
element_type,
} if *element_type == type_id => {
spirv
.id(*result_id)
.iter_decoration()
.find_map(|instruction| match instruction {
Instruction::Decorate {
decoration: Decoration::ArrayStride { array_stride },
..
} => Some(*array_stride),
_ => None,
})
}
_ => None,
});
iter.next().map(|first_stride| {
// Ensure that all strides we find match the first one.
debug_assert!(iter.all(|array_stride| array_stride == first_stride));
first_stride
})
}
/// Returns the type name to put in the Rust struct, and its size and alignment.
@ -417,10 +397,9 @@ fn write_struct<'a>(
pub(super) fn type_from_id(
shader: &str,
spirv: &Spirv,
searched: Id,
types_meta: &TypesMeta,
type_id: Id,
) -> (TokenStream, Cow<'static, str>, Option<usize>, usize) {
let id_info = spirv.id(searched);
let id_info = spirv.id(type_id);
match id_info.instruction() {
Instruction::TypeBool { .. } => {
@ -570,8 +549,7 @@ pub(super) fn type_from_id(
..
} => {
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
let (ty, item, t_size, t_align) =
type_from_id(shader, spirv, component_type, types_meta);
let (ty, item, t_size, t_align) = type_from_id(shader, spirv, component_type);
let array_length = component_count as usize;
let size = t_size.map(|s| s * component_count as usize);
return (
@ -588,7 +566,7 @@ pub(super) fn type_from_id(
} => {
// FIXME: row-major or column-major
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
let (ty, item, t_size, t_align) = type_from_id(shader, spirv, column_type, types_meta);
let (ty, item, t_size, t_align) = type_from_id(shader, spirv, column_type);
let array_length = column_count as usize;
let size = t_size.map(|s| s * column_count as usize);
return (
@ -604,13 +582,18 @@ pub(super) fn type_from_id(
..
} => {
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
let (ty, item, t_size, t_align) = type_from_id(shader, spirv, element_type, types_meta);
let t_size = t_size.expect("array components must be sized");
let len = match spirv.id(length).instruction() {
&Instruction::Constant { ref value, .. } => value,
let (element_type, element_type_string, element_size, element_align) =
type_from_id(shader, spirv, element_type);
let element_size = element_size.expect("array components must be sized");
let array_length = match spirv.id(length).instruction() {
&Instruction::Constant { ref value, .. } => {
value.iter().rev().fold(0u64, |a, &b| (a << 32) | b as u64)
}
_ => panic!("failed to find array length"),
};
let len = len.iter().rev().fold(0u64, |a, &b| (a << 32) | b as u64);
} as usize;
let stride = id_info
.iter_decoration()
.find_map(|instruction| match instruction {
@ -621,51 +604,87 @@ pub(super) fn type_from_id(
_ => None,
})
.unwrap();
if stride as usize > t_size {
if stride as usize > element_size {
panic!("Not possible to generate a rust array with the correct alignment since the SPIR-V \
ArrayStride is larger than the size of the array element in rust. Try wrapping \
the array element in a struct or rounding up the size of a vector or matrix \
(e.g. increase a vec3 to a vec4)")
}
let array_length = len as usize;
let size = Some(t_size * len as usize);
return (
quote! { [#ty; #array_length] },
Cow::from(format!("[{}; {}]", item, array_length)),
size,
t_align,
quote! { [#element_type; #array_length] },
Cow::from(format!("[{}; {}]", element_type_string, array_length)),
Some(element_size * array_length as usize),
element_align,
);
}
&Instruction::TypeRuntimeArray { element_type, .. } => {
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
let (ty, name, _, t_align) = type_from_id(shader, spirv, element_type, types_meta);
let (element_type, element_type_string, _, element_align) =
type_from_id(shader, spirv, element_type);
return (
quote! { [#ty] },
Cow::from(format!("[{}]", name)),
quote! { [#element_type] },
Cow::from(format!("[{}]", element_type_string)),
None,
t_align,
element_align,
);
}
Instruction::TypeStruct { member_types, .. } => {
// TODO: take the Offset member decorate into account?
let size = if !has_defined_layout(spirv, type_id) {
None
} else {
// If the struct appears in an array, then first try to get the size from the
// array stride.
struct_size_from_array_stride(spirv, type_id)
.map(|size| size as usize)
.or_else(|| {
// We haven't found any strides, so we have to calculate the size based
// on the offset and size of the last member.
member_types
.iter()
.zip(spirv.id(type_id).iter_members())
.last()
.map_or(Some(0), |(&member, member_info)| {
let spirv_offset = member_info
.iter_decoration()
.find_map(|instruction| match instruction {
Instruction::MemberDecorate {
decoration: Decoration::Offset { byte_offset },
..
} => Some(*byte_offset as usize),
_ => None,
})
.unwrap();
let (_, _, rust_size, _) = type_from_id(shader, spirv, member);
rust_size.map(|rust_size| spirv_offset + rust_size)
})
})
};
let align = member_types
.iter()
.map(|&t| type_from_id(shader, spirv, t).3)
.max()
.unwrap_or(1);
let name_string = id_info
.iter_name()
.find_map(|instruction| match instruction {
Instruction::Name { name, .. } => Some(name.as_str()),
Instruction::Name { name, .. } => Some(Cow::from(name.clone())),
_ => None,
})
.unwrap_or("__unnamed");
let name = Ident::new(&name_string, Span::call_site());
let ty = quote! { #name };
let (_, size) = write_struct(shader, spirv, searched, member_types, types_meta, None);
let align = member_types
.iter()
.map(|&t| type_from_id(shader, spirv, t, types_meta).3)
.max()
.unwrap_or(1);
return (ty, Cow::from(name_string.to_owned()), size, align);
.unwrap_or(Cow::from("__unnamed"));
let name = {
let name = format_ident!("{}", name_string);
quote! { #name }
};
return (name, name_string, size, align);
}
_ => panic!("Type #{} not found", searched),
_ => panic!("Type #{} not found", type_id),
}
}
@ -674,7 +693,6 @@ pub(super) fn type_from_id(
pub(super) fn write_specialization_constants<'a>(
shader: &'a str,
spirv: &Spirv,
types_meta: &TypesMeta,
shared_constants: bool,
types_registry: &'a mut HashMap<String, RegisteredType>,
) -> TokenStream {
@ -735,7 +753,7 @@ pub(super) fn write_specialization_constants<'a>(
Some(mem::size_of::<u32>()),
mem::align_of::<u32>(),
),
_ => type_from_id(shader, spirv, result_type_id, types_meta),
_ => type_from_id(shader, spirv, result_type_id),
};
let rust_size = rust_size.expect("Found runtime-sized specialization constant");