General shader improvements, specifically targeting rust-gpu (#2482)

* added tests for codegen reflect()

* added tests for a typical rust-gpu shader

* make idents that are invalid be named UnnamedX instead of panicing

* add generate_structs option to shader! macro, to disable struct generating

* get rid of macro_use extern crate

* disallow specifying a shader type for binary shaders

* fix mesh shaders output interface not being arrayed

* structs in shader interface panic with explanatory message

* fix clippy lints

Co-authored-by: Rua <ruawhitepaw@gmail.com>

---------

Co-authored-by: Firestar99 <4696087-firestar99@users.noreply.gitlab.com>
Co-authored-by: Rua <ruawhitepaw@gmail.com>
This commit is contained in:
Firestar99 2024-03-03 15:22:39 +01:00 committed by GitHub
parent 50d82cdcde
commit 093b43e98d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 293 additions and 44 deletions

View File

@ -4,6 +4,7 @@ use crate::{
};
use heck::ToSnakeCase;
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
use shaderc::{CompileOptions, Compiler, EnvVersion, TargetEnv};
use std::{
@ -262,9 +263,18 @@ pub(super) fn reflect(
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use quote::ToTokens;
use shaderc::SpirvVersion;
use syn::{File, Item};
use vulkano::shader::reflect;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec<String> {
paths
.iter()
@ -274,17 +284,28 @@ mod tests {
#[test]
fn spirv_parse() {
let data = include_bytes!("../tests/frag.spv");
let insts: Vec<_> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
Spirv::new(&insts).unwrap();
}
#[test]
fn spirv_reflect() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/frag.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
assert_eq!(_structs.to_string(), "", "No structs should be generated");
}
#[test]
fn include_resolution() {
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
@ -536,14 +557,8 @@ mod tests {
/// ```
#[test]
fn descriptor_calculation_with_multiple_entrypoints() {
let data = include_bytes!("../tests/multiple_entrypoints.spv");
let instructions: Vec<u32> = data
.chunks(4)
.map(|c| {
((c[3] as u32) << 24) | ((c[2] as u32) << 16) | ((c[1] as u32) << 8) | c[0] as u32
})
.collect();
let spirv = Spirv::new(&instructions).unwrap();
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let spirv = Spirv::new(&insts).unwrap();
let mut descriptors = Vec::new();
for (_, info) in reflect::entry_points(&spirv) {
@ -578,8 +593,52 @@ mod tests {
}
#[test]
fn descriptor_calculation_with_multiple_functions() {
let (comp, _) = compile(
fn reflect_descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
}
fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec<String>)
{
compile(
&MacroInput {
spirv_version: Some(SpirvVersion::V1_6),
vulkan_version: Some(EnvVersion::Vulkan1_3),
@ -615,8 +674,13 @@ mod tests {
"#,
ShaderKind::Vertex,
)
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
.unwrap()
}
#[test]
fn descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let spirv = Spirv::new(artifact.as_binary()).unwrap();
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
@ -634,4 +698,51 @@ mod tests {
}
panic!("could not find entrypoint");
}
#[test]
fn reflect_descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new(
"descriptor_calculation_with_multiple_functions_shader",
Span::call_site(),
),
String::new(),
artifact.as_binary(),
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: [f32; 3usize],}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: f32,}).to_string()
);
}
}

View File

@ -39,7 +39,7 @@
//! appropriate features enabled.
//! - If the `shaders` option is used, then instead of one `load` constructor, there is one for
//! each shader. They are named based on the provided names, `load_first`, `load_second` etc.
//! - A Rust struct translated from each struct contained in the shader data. By default each
//! - A Rust struct translated from each struct contained in the shader data. By default, each
//! structure has a `Clone` and a `Copy` implementation. This behavior could be customized
//! through the `custom_derives` macro option (see below for details). Each struct also has an
//! implementation of [`BufferContents`], so that it can be read from/written to a buffer.
@ -122,8 +122,8 @@
//! ## `bytes: "..."`
//!
//! Provides the path to precompiled SPIR-V bytecode, relative to your `Cargo.toml`. Cannot be used
//! in conjunction with the `src` or `path` field. This allows using shaders compiled through a
//! separate build system.
//! in conjunction with the `src` or `path` field, and may also not specify a shader `ty` type.
//! This allows using shaders compiled through a separate build system.
//!
//! ## `root_path_env: "..."`
//!
@ -143,7 +143,7 @@
//!
//! With these options the user can compile several shaders in a single macro invocation. Each
//! entry key will be the suffix of the generated `load` function (`load_first` in this case).
//! However all other Rust structs translated from the shader source will be shared between
//! However, all other Rust structs translated from the shader source will be shared between
//! shaders. The macro checks that the source structs with the same names between different shaders
//! have the same declaration signature, and throws a compile-time error if they don't.
//!
@ -172,14 +172,21 @@
//! The generated code must be supported by the device at runtime. If not, then an error will be
//! returned when calling `load`.
//!
//! ## `generate_structs: true`
//!
//! Generate rust structs that represent the structs contained in the shader. They all implement
//! [`BufferContents`], which allows then to be passed to the shader, without having to worry about
//! the layout of the struct manually. However, some use-cases, such as Rust-GPU, may not have any
//! use for such structs, and may choose to disable them.
//!
//! ## `custom_derives: [Clone, Default, PartialEq, ...]`
//!
//! Extends the list of derive macros that are added to the `derive` attribute of Rust structs that
//! represent shader structs.
//!
//! By default each generated struct has a derive for `Clone` and `Copy`. If the struct has unsized
//! members none of the derives are applied on the struct, except [`BufferContents`], which is
//! always derived.
//! By default, each generated struct derives `Clone` and `Copy`. If the struct has unsized members
//! none of the derives are applied on the struct, except [`BufferContents`], which is always
//! derived.
//!
//! ## `linalg_type: "..."`
//!
@ -221,14 +228,10 @@
#![allow(clippy::needless_borrowed_reference)]
#![warn(rust_2018_idioms, rust_2021_compatibility)]
#[macro_use]
extern crate quote;
#[macro_use]
extern crate syn;
use crate::codegen::ShaderKind;
use ahash::HashMap;
use proc_macro2::{Span, TokenStream};
use quote::quote;
use shaderc::{EnvVersion, SpirvVersion};
use std::{
env, fs, mem,
@ -236,11 +239,13 @@ use std::{
};
use structs::TypeRegistry;
use syn::{
braced, bracketed, parenthesized,
parse::{Parse, ParseStream, Result},
Error, Ident, LitBool, LitStr, Path as SynPath,
parse_macro_input, parse_quote, Error, Ident, LitBool, LitStr, Path as SynPath, Token,
};
mod codegen;
mod rust_gpu;
mod structs;
#[proc_macro]
@ -286,9 +291,14 @@ fn shader_inner(mut input: MacroInput) -> Result<TokenStream> {
for (name, (shader_kind, source_kind)) in shaders {
let (code, types) = match source_kind {
SourceKind::Src(source) => {
let (artifact, includes) =
codegen::compile(&input, None, root_path, &source.value(), shader_kind)
.map_err(|err| Error::new_spanned(&source, err))?;
let (artifact, includes) = codegen::compile(
&input,
None,
root_path,
&source.value(),
shader_kind.unwrap(),
)
.map_err(|err| Error::new_spanned(&source, err))?;
let words = artifact.as_binary();
@ -313,7 +323,7 @@ fn shader_inner(mut input: MacroInput) -> Result<TokenStream> {
Some(path.value()),
root_path,
&source_code,
shader_kind,
shader_kind.unwrap(),
)
.map_err(|err| Error::new_spanned(&path, err))?;
@ -371,9 +381,10 @@ struct MacroInput {
root_path_env: Option<LitStr>,
include_directories: Vec<PathBuf>,
macro_defines: Vec<(String, String)>,
shaders: HashMap<String, (ShaderKind, SourceKind)>,
shaders: HashMap<String, (Option<ShaderKind>, SourceKind)>,
spirv_version: Option<SpirvVersion>,
vulkan_version: Option<EnvVersion>,
generate_structs: bool,
custom_derives: Vec<SynPath>,
linalg_type: LinAlgType,
dump: LitBool,
@ -389,6 +400,7 @@ impl MacroInput {
shaders: HashMap::default(),
vulkan_version: None,
spirv_version: None,
generate_structs: true,
custom_derives: Vec::new(),
linalg_type: LinAlgType::default(),
dump: LitBool::new(false, Span::call_site()),
@ -406,6 +418,7 @@ impl Parse for MacroInput {
let mut shaders = HashMap::default();
let mut vulkan_version = None;
let mut spirv_version = None;
let mut generate_structs = true;
let mut custom_derives = None;
let mut linalg_type = None;
let mut dump = None;
@ -643,6 +656,10 @@ impl Parse for MacroInput {
),
});
}
"generate_structs" => {
let lit = input.parse::<LitBool>()?;
generate_structs = lit.value;
}
"custom_derives" => {
let in_brackets;
bracketed!(in_brackets in input);
@ -696,8 +713,8 @@ impl Parse for MacroInput {
field => bail!(
field_ident,
"expected `bytes`, `src`, `path`, `ty`, `shaders`, `define`, `include`, \
`vulkan_version`, `spirv_version`, `custom_derives`, `linalg_type` or `dump` \
as a field, found `{field}`",
`vulkan_version`, `spirv_version`, `generate_structs`, `custom_derives`, \
`linalg_type` or `dump` as a field, found `{field}`",
),
}
@ -711,6 +728,13 @@ impl Parse for MacroInput {
}
match shaders.get("") {
// if source is bytes, the shader type should not be declared
Some((None, Some(SourceKind::Bytes(_)))) => {}
Some((_, Some(SourceKind::Bytes(_)))) => {
bail!(
r#"one may not specify a shader type when including precompiled SPIR-V binaries. Please remove the `ty:` declaration"#
);
}
Some((None, _)) => {
bail!(r#"please specify the type of the shader e.g. `ty: "vertex"`"#);
}
@ -727,11 +751,12 @@ impl Parse for MacroInput {
shaders: shaders
.into_iter()
.map(|(key, (shader_kind, shader_source))| {
(key, (shader_kind.unwrap(), shader_source.unwrap()))
(key, (shader_kind, shader_source.unwrap()))
})
.collect(),
vulkan_version,
spirv_version,
generate_structs,
custom_derives: custom_derives.unwrap_or_else(|| {
vec![
parse_quote! { ::std::clone::Clone },

View File

@ -0,0 +1,44 @@
#[cfg(test)]
mod tests {
use crate::{codegen::reflect, structs::TypeRegistry, MacroInput};
use proc_macro2::Span;
use syn::LitStr;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
#[test]
fn rust_gpu_reflect_vertex() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-vertex.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
#[test]
fn rust_gpu_reflect_fragment() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-fragment.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
}

View File

@ -1,7 +1,7 @@
use crate::{bail, codegen::Shader, LinAlgType, MacroInput};
use ahash::HashMap;
use proc_macro2::{Span, TokenStream};
use quote::{ToTokens, TokenStreamExt};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use std::{cmp::Ordering, num::NonZeroUsize};
use syn::{Error, Ident, Result};
use vulkano::shader::spirv::{Decoration, Id, Instruction};
@ -90,6 +90,10 @@ pub(super) fn write_structs(
shader: &Shader,
type_registry: &mut TypeRegistry,
) -> Result<TokenStream> {
if !input.generate_structs {
return Ok(TokenStream::new());
}
let mut structs = TokenStream::new();
for (struct_id, member_type_ids) in shader
@ -685,7 +689,13 @@ impl TypeStruct {
.names()
.iter()
.find_map(|instruction| match instruction {
Instruction::Name { name, .. } => Some(Ident::new(name, Span::call_site())),
Instruction::Name { name, .. } => {
// rust-gpu uses fully qualified rust paths as names which contain `:`.
// I sady don't know how to check against all kinds of invalid ident chars.
let name = name.replace(':', "_");
// Worst case: invalid idents will get the UnnamedX name below
syn::parse_str(&name).ok()
}
_ => None,
})
.unwrap_or_else(|| format_ident!("Unnamed{}", struct_id.as_raw()));

Binary file not shown.

View File

@ -0,0 +1,53 @@
//! This is the rust-gpu source code used to generate the spv output files found in this directory.
//! A pre-release version of rust-gpu 0.10 was used, specifically git 3689d11a. Spirv-Builder was
//! configured like this:
//!
//! ```norun
//! SpirvBuilder::new(shader_crate, "spirv-unknown-vulkan1.2")
//! .multimodule(true)
//! .spirv_metadata(SpirvMetadata::Full)
//! ```
use crate::test_shader::some_module::Bla;
use glam::{vec4, Vec4};
use spirv_std::spirv;
mod some_module {
use super::*;
pub struct Bla {
pub value: Vec4,
pub value2: Vec4,
pub decider: f32,
}
impl Bla {
pub fn lerp(&self) -> Vec4 {
self.value * (1. - self.decider) + self.value2 * self.decider
}
}
}
#[spirv(vertex)]
pub fn vertex(
#[spirv(vertex_index)] vertex_id: u32,
#[spirv(position)] position: &mut Vec4,
vtx_color: &mut Bla,
) {
let corners = [
vec4(-0.5, -0.5, 0.1, 1.),
vec4(0.5, -0.5, 0.1, 1.),
vec4(0., 0.5, 0.1, 1.),
];
*position = corners[(vertex_id % 3) as usize];
*vtx_color = Bla {
value: vec4(1., 1., 0., 1.),
value2: vec4(0., 1., 1., 1.),
decider: f32::max(vertex_id as f32, 1.),
};
}
#[spirv(fragment)]
pub fn fragment(vtx_color: Bla, f_color: &mut Vec4) {
*f_color = vtx_color.lerp();
}

View File

@ -63,7 +63,10 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)>
spirv,
interface,
StorageClass::Output,
matches!(execution_model, ExecutionModel::TessellationControl),
matches!(
execution_model,
ExecutionModel::TessellationControl | ExecutionModel::MeshEXT
),
);
Some((
@ -1378,6 +1381,9 @@ fn shader_interface_type_of(
Instruction::TypePointer { ty, .. } => {
shader_interface_type_of(spirv, ty, ignore_first_array)
}
Instruction::TypeStruct { .. } => {
panic!("Structs are not yet supported in shader in/out interface!");
}
_ => panic!("Type {} not found or invalid", id),
}
}