From 2752e20d79e14609f912eca5a462b3767b41a02b Mon Sep 17 00:00:00 2001 From: Mai <63251610+interacsion@users.noreply.github.com> Date: Wed, 25 Sep 2024 15:02:24 +0300 Subject: [PATCH] Refactor vulkano-shaders --- Cargo.lock | 1 - vulkano-shaders/Cargo.toml | 4 +- vulkano-shaders/src/codegen.rs | 851 ++++----------------- vulkano-shaders/src/lib.rs | 1245 +++++++++++++++++++------------ vulkano-shaders/src/rust_gpu.rs | 44 -- vulkano-shaders/src/structs.rs | 435 +++++------ 6 files changed, 1102 insertions(+), 1478 deletions(-) delete mode 100644 vulkano-shaders/src/rust_gpu.rs diff --git a/Cargo.lock b/Cargo.lock index b0cc2b6f..96bf996d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2465,7 +2465,6 @@ name = "vulkano-shaders" version = "0.34.0" dependencies = [ "ahash", - "heck", "proc-macro2", "quote", "shaderc", diff --git a/vulkano-shaders/Cargo.toml b/vulkano-shaders/Cargo.toml index e21e1e5a..b4ba481c 100644 --- a/vulkano-shaders/Cargo.toml +++ b/vulkano-shaders/Cargo.toml @@ -18,16 +18,14 @@ proc-macro = true [dependencies] ahash = { workspace = true } -heck = { workspace = true } proc-macro2 = { workspace = true } quote = { workspace = true } shaderc = { workspace = true } -syn = { workspace = true, features = ["full", "extra-traits"] } +syn = { workspace = true } vulkano = { workspace = true } [features] shaderc-build-from-source = ["shaderc/build-from-source"] -shaderc-debug = [] [lints] workspace = true diff --git a/vulkano-shaders/src/codegen.rs b/vulkano-shaders/src/codegen.rs index af184c15..1307d708 100644 --- a/vulkano-shaders/src/codegen.rs +++ b/vulkano-shaders/src/codegen.rs @@ -1,135 +1,138 @@ -use crate::{ - structs::{self, TypeRegistry}, - MacroInput, -}; -use heck::ToSnakeCase; +use crate::MacroOptions; +use ahash::{HashMap, HashSet}; use proc_macro2::TokenStream; use quote::{format_ident, quote}; -pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind}; -use shaderc::{CompileOptions, Compiler, EnvVersion, SourceLanguage, TargetEnv}; +use shaderc::{ + CompilationArtifact, CompileOptions, Compiler, IncludeType, ResolvedInclude, ShaderKind, + SourceLanguage, TargetEnv, +}; use std::{ cell::RefCell, fs, iter::Iterator, path::{Path, PathBuf}, }; -use syn::{Error, LitStr}; -use vulkano::shader::spirv::Spirv; -pub struct Shader { - pub source: LitStr, - pub name: String, - pub spirv: Spirv, +pub(super) fn compile( + compiler: &Compiler, + macro_options: &MacroOptions, + source_path: Option<&str>, + shader_name: &str, + source_code: &str, + entry_points: HashMap, ShaderKind>, + base_path: &Path, + include_paths: &[PathBuf], + sources_to_include: &RefCell>, +) -> Result, CompilationArtifact)>, String> { + let mut compile_options = CompileOptions::new().ok_or("failed to create compile options")?; + + compile_options.set_source_language(macro_options.source_language); + + if let Some(vulkan_version) = macro_options.vulkan_version { + compile_options.set_target_env(TargetEnv::Vulkan, vulkan_version as u32); + } + + if let Some(spirv_version) = macro_options.spirv_version { + compile_options.set_target_spirv(spirv_version); + } + + compile_options.set_include_callback( + |requested_source, include_type, containing_source, depth| { + include_callback( + Path::new(requested_source), + include_type, + Path::new(containing_source), + depth, + base_path, + include_paths, + source_path.is_none(), + &mut sources_to_include.borrow_mut(), + ) + }, + ); + + for (name, value) in ¯o_options.macro_defines { + compile_options.add_macro_definition(name, value.as_deref()); + } + + let file_name = match (source_path, macro_options.source_language) { + (Some(source_path), _) => source_path, + (None, SourceLanguage::GLSL) => &format!("{shader_name}.glsl"), + (None, SourceLanguage::HLSL) => &format!("{shader_name}.hlsl"), + }; + + entry_points + .into_iter() + .map(|(entry_point, shader_kind)| { + compiler + .compile_into_spirv( + source_code, + shader_kind, + file_name, + entry_point.as_deref().unwrap_or("main"), + Some(&compile_options), + ) + .map(|artifact| (entry_point, artifact)) + .map_err(|err| err.to_string().replace("\n", " ")) + }) + .collect() } -#[allow(clippy::too_many_arguments)] fn include_callback( - requested_source_path_raw: &str, - directive_type: IncludeType, - contained_within_path_raw: &str, - recursion_depth: usize, - include_directories: &[PathBuf], - root_source_has_path: bool, + requested_source: &Path, + include_type: IncludeType, + containing_source: &Path, + depth: usize, base_path: &Path, - includes: &mut Vec, + include_paths: &[PathBuf], + embedded_root_source: bool, + sources_to_include: &mut HashSet, ) -> Result { - let file_to_include = match directive_type { + let resolved_path = match include_type { IncludeType::Relative => { - let requested_source_path = Path::new(requested_source_path_raw); - // If the shader source is embedded within the macro, abort unless we get an absolute - // path. - if !root_source_has_path && recursion_depth == 1 && !requested_source_path.is_absolute() - { - let requested_source_name = requested_source_path - .file_name() - .expect("failed to get the name of the requested source file") - .to_string_lossy(); - let requested_source_directory = requested_source_path - .parent() - .expect("failed to get the directory of the requested source file") - .to_string_lossy(); - - return Err(format!( - "usage of relative paths in imports in embedded GLSL is not allowed, try \ - using `#include <{}>` and adding the directory `{}` to the `include` array in \ - your `shader!` macro call instead", - requested_source_name, requested_source_directory, - )); + if depth == 1 && embedded_root_source && !requested_source.is_absolute() { + return Err( + "you cannot use relative include directives in embedded shader source code; \ + try using `#include <...>` instead" + .to_owned(), + ); } - let mut resolved_path = if recursion_depth == 1 { - Path::new(contained_within_path_raw) - .parent() - .map(|parent| base_path.join(parent)) + let parent = containing_source.parent().unwrap(); + + if depth == 1 { + [base_path, parent, requested_source].iter().collect() } else { - Path::new(contained_within_path_raw) - .parent() - .map(|parent| parent.to_owned()) + [parent, requested_source].iter().collect() } - .unwrap_or_else(|| { - panic!( - "the file `{}` does not reside in a directory, this is an implementation \ - error", - contained_within_path_raw, - ) - }); - resolved_path.push(requested_source_path); - - if !resolved_path.is_file() { - return Err(format!( - "invalid inclusion path `{}`, the path does not point to a file", - requested_source_path_raw, - )); - } - - resolved_path } IncludeType::Standard => { - let requested_source_path = Path::new(requested_source_path_raw); - - if requested_source_path.is_absolute() { - // This message is printed either when using a missing file with an absolute path - // in the relative include directive or when using absolute paths in a standard + if requested_source.is_absolute() { + // This is returned when attempting to include a missing file by an absolute path + // in a relative include directive and when using an absolute path in a standard // include directive. - return Err(format!( - "no such file found as specified by the absolute path; keep in mind that \ - absolute paths cannot be used with inclusion from standard directories \ - (`#include <...>`), try using `#include \"...\"` instead; requested path: {}", - requested_source_path_raw, - )); + return Err( + "the specified file was not found; if you're using an absolute path in a \ + standard include directive (`#include <...>`), try using `#include \"...\"` \ + instead" + .to_owned(), + ); } - let found_requested_source_path = include_directories + include_paths .iter() - .map(|include_directory| include_directory.join(requested_source_path)) - .find(|resolved_requested_source_path| resolved_requested_source_path.is_file()); - - if let Some(found_requested_source_path) = found_requested_source_path { - found_requested_source_path - } else { - return Err(format!( - "failed to include the file `{}` from any include directories", - requested_source_path_raw, - )); - } + .map(|include_path| include_path.join(requested_source)) + .find(|source_path| source_path.is_file()) + .ok_or("the specified file was not found".to_owned())? } }; - let content = fs::read_to_string(file_to_include.as_path()).map_err(|err| { - format!( - "failed to read the contents of file `{file_to_include:?}` to be included in the \ - shader source: {err}", - ) - })?; - let resolved_name = file_to_include - .into_os_string() - .into_string() - .map_err(|_| { - "failed to stringify the file to be included; make sure the path consists of valid \ - unicode characters" - })?; + let resolved_name = resolved_path.into_os_string().into_string().unwrap(); - includes.push(resolved_name.clone()); + let content = fs::read_to_string(&resolved_name) + .map_err(|err| format!("failed to read `{resolved_name}`: {err}"))?; + + sources_to_include.insert(resolved_name.clone()); Ok(ResolvedInclude { resolved_name, @@ -137,618 +140,48 @@ fn include_callback( }) } -pub(super) fn compile( - input: &MacroInput, - path: Option, - base_path: &Path, - code: &str, - shader_kind: ShaderKind, -) -> Result<(CompilationArtifact, Vec), String> { - let includes = RefCell::new(Vec::new()); - let compiler = Compiler::new().ok_or("failed to create shader compiler")?; - let mut compile_options = - CompileOptions::new().ok_or("failed to initialize compile options")?; +pub(super) fn generate_shader_code( + entry_points: &[(Option<&str>, &[u32])], + shader_name: &Option, +) -> TokenStream { + let load_fns = entry_points.iter().map(|(name, words)| { + let load_name = match name { + Some(name) => format_ident!("load_{name}"), + None => format_ident!("load"), + }; - let source_language = input.source_language.unwrap_or(SourceLanguage::GLSL); - compile_options.set_source_language(source_language); - - compile_options.set_target_env( - TargetEnv::Vulkan, - input.vulkan_version.unwrap_or(EnvVersion::Vulkan1_0) as u32, - ); - - if let Some(spirv_version) = input.spirv_version { - compile_options.set_target_spirv(spirv_version); - } - - let root_source_path = path.as_deref().unwrap_or( - // An arbitrary placeholder file name for embedded shaders. - match source_language { - SourceLanguage::GLSL => "shader.glsl", - SourceLanguage::HLSL => "shader.hlsl", - }, - ); - - // Specify the file resolution callback for the `#include` directive. - compile_options.set_include_callback( - |requested_source_path, directive_type, contained_within_path, recursion_depth| { - include_callback( - requested_source_path, - directive_type, - contained_within_path, - recursion_depth, - &input.include_directories, - path.is_some(), - base_path, - &mut includes.borrow_mut(), - ) - }, - ); - - for (macro_name, macro_value) in &input.macro_defines { - compile_options.add_macro_definition(macro_name, Some(macro_value)); - } - - #[cfg(feature = "shaderc-debug")] - compile_options.set_generate_debug_info(); - - let content = compiler - .compile_into_spirv( - code, - shader_kind, - root_source_path, - "main", - Some(&compile_options), - ) - .map_err(|e| e.to_string().replace("(s): ", "(s):\n"))?; - - drop(compile_options); - - Ok((content, includes.into_inner())) -} - -pub(super) fn reflect( - input: &MacroInput, - source: LitStr, - name: String, - words: &[u32], - input_paths: Vec, - type_registry: &mut TypeRegistry, -) -> Result<(TokenStream, TokenStream), Error> { - let spirv = Spirv::new(words).map_err(|err| { - Error::new_spanned(&source, format!("failed to parse SPIR-V words: {err}")) - })?; - let shader = Shader { - source, - name, - spirv, - }; - - let include_bytes = input_paths.into_iter().map(|s| { quote! { - // Using `include_bytes` here ensures that changing the shader will force recompilation. - // The bytes themselves can be optimized out by the compiler as they are unused. - ::std::include_bytes!( #s ) + #[allow(unsafe_code)] + #[inline] + pub fn #load_name( + device: ::std::sync::Arc<::vulkano::device::Device>, + ) -> ::std::result::Result< + ::std::sync::Arc<::vulkano::shader::ShaderModule>, + ::vulkano::Validated<::vulkano::VulkanError>, + > { + static WORDS: &[u32] = &[ #( #words ),* ]; + + unsafe { + ::vulkano::shader::ShaderModule::new( + device, + ::vulkano::shader::ShaderModuleCreateInfo::new(WORDS), + ) + } + } } }); - let load_name = if shader.name.is_empty() { - format_ident!("load") + if let Some(shader_name) = shader_name { + let shader_name = format_ident!("{shader_name}"); + + quote! { + pub mod #shader_name { + #( #load_fns )* + } + } } else { - format_ident!("load_{}", shader.name.to_snake_case()) - }; - - let shader_code = quote! { - /// Loads the shader as a `ShaderModule`. - #[allow(unsafe_code)] - #[inline] - pub fn #load_name( - device: ::std::sync::Arc<::vulkano::device::Device>, - ) -> ::std::result::Result< - ::std::sync::Arc<::vulkano::shader::ShaderModule>, - ::vulkano::Validated<::vulkano::VulkanError>, - > { - let _bytes = ( #( #include_bytes ),* ); - - static WORDS: &[u32] = &[ #( #words ),* ]; - - unsafe { - ::vulkano::shader::ShaderModule::new( - device, - ::vulkano::shader::ShaderModuleCreateInfo::new(WORDS), - ) - } + quote! { + #( #load_fns )* } - }; - - let structs = structs::write_structs(input, &shader, type_registry)?; - - Ok((shader_code, structs)) -} - -#[cfg(test)] -mod tests { - use super::*; - use proc_macro2::Span; - use quote::ToTokens; - use shaderc::SpirvVersion; - use syn::{File, Item}; - use vulkano::shader::reflect; - - fn spv_to_words(data: &[u8]) -> Vec { - data.chunks(4) - .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]])) - .collect() - } - - fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec { - paths - .iter() - .map(|p| root_path.join(p).into_os_string().into_string().unwrap()) - .collect() - } - - #[test] - fn spirv_parse() { - let insts = spv_to_words(include_bytes!("../tests/frag.spv")); - Spirv::new(&insts).unwrap(); - } - - #[test] - fn spirv_reflect() { - let insts = spv_to_words(include_bytes!("../tests/frag.spv")); - - let mut type_registry = TypeRegistry::default(); - let (_shader_code, _structs) = reflect( - &MacroInput::empty(), - LitStr::new("../tests/frag.spv", Span::call_site()), - String::new(), - &insts, - Vec::new(), - &mut type_registry, - ) - .expect("reflecting spv failed"); - - assert_eq!(_structs.to_string(), "", "No structs should be generated"); - } - - #[test] - fn include_resolution() { - let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")); - - let (_compile_relative, _) = compile( - &MacroInput::empty(), - Some(String::from("tests/include_test.glsl")), - &root_path, - r#" - #version 450 - #include "include_dir_a/target_a.glsl" - #include "include_dir_b/target_b.glsl" - void main() {} - "#, - ShaderKind::Vertex, - ) - .expect("cannot resolve include files"); - - let (_compile_include_paths, includes) = compile( - &MacroInput { - include_directories: vec![ - root_path.join("tests").join("include_dir_a"), - root_path.join("tests").join("include_dir_b"), - ], - ..MacroInput::empty() - }, - Some(String::from("tests/include_test.glsl")), - &root_path, - r#" - #version 450 - #include - #include - void main() {} - "#, - ShaderKind::Vertex, - ) - .expect("cannot resolve include files"); - - assert_eq!( - includes, - convert_paths( - &root_path, - &[ - ["tests", "include_dir_a", "target_a.glsl"] - .into_iter() - .collect(), - ["tests", "include_dir_b", "target_b.glsl"] - .into_iter() - .collect(), - ], - ), - ); - - let (_compile_include_paths_with_relative, includes_with_relative) = compile( - &MacroInput { - include_directories: vec![root_path.join("tests").join("include_dir_a")], - ..MacroInput::empty() - }, - Some(String::from("tests/include_test.glsl")), - &root_path, - r#" - #version 450 - #include - #include <../include_dir_b/target_b.glsl> - void main() {} - "#, - ShaderKind::Vertex, - ) - .expect("cannot resolve include files"); - - assert_eq!( - includes_with_relative, - convert_paths( - &root_path, - &[ - ["tests", "include_dir_a", "target_a.glsl"] - .into_iter() - .collect(), - ["tests", "include_dir_a", "../include_dir_b/target_b.glsl"] - .into_iter() - .collect(), - ], - ), - ); - - let absolute_path = root_path - .join("tests") - .join("include_dir_a") - .join("target_a.glsl"); - let absolute_path_str = absolute_path - .to_str() - .expect("cannot run tests in a folder with non unicode characters"); - let (_compile_absolute_path, includes_absolute_path) = compile( - &MacroInput::empty(), - Some(String::from("tests/include_test.glsl")), - &root_path, - &format!( - r#" - #version 450 - #include "{absolute_path_str}" - void main() {{}} - "#, - ), - ShaderKind::Vertex, - ) - .expect("cannot resolve include files"); - - assert_eq!( - includes_absolute_path, - convert_paths( - &root_path, - &[["tests", "include_dir_a", "target_a.glsl"] - .into_iter() - .collect()], - ), - ); - - let (_compile_recursive_, includes_recursive) = compile( - &MacroInput { - include_directories: vec![ - root_path.join("tests").join("include_dir_b"), - root_path.join("tests").join("include_dir_c"), - ], - ..MacroInput::empty() - }, - Some(String::from("tests/include_test.glsl")), - &root_path, - r#" - #version 450 - #include - void main() {} - "#, - ShaderKind::Vertex, - ) - .expect("cannot resolve include files"); - - assert_eq!( - includes_recursive, - convert_paths( - &root_path, - &[ - ["tests", "include_dir_c", "target_c.glsl"] - .into_iter() - .collect(), - ["tests", "include_dir_c", "../include_dir_a/target_a.glsl"] - .into_iter() - .collect(), - ["tests", "include_dir_b", "target_b.glsl"] - .into_iter() - .collect(), - ], - ), - ); - } - - #[test] - fn macros() { - let need_defines = r#" - #version 450 - #if defined(NAME1) && NAME2 > 29 - void main() {} - #endif - "#; - - let compile_no_defines = compile( - &MacroInput::empty(), - None, - Path::new(""), - need_defines, - ShaderKind::Vertex, - ); - assert!(compile_no_defines.is_err()); - - compile( - &MacroInput { - macro_defines: vec![("NAME1".into(), "".into()), ("NAME2".into(), "58".into())], - ..MacroInput::empty() - }, - None, - Path::new(""), - need_defines, - ShaderKind::Vertex, - ) - .expect("setting shader macros did not work"); - } - - /// `entrypoint1.frag.glsl`: - /// ```glsl - /// #version 450 - /// - /// layout(set = 0, binding = 0) uniform Uniform { - /// uint data; - /// } ubo; - /// - /// layout(set = 0, binding = 1) buffer Buffer { - /// uint data; - /// } bo; - /// - /// layout(set = 0, binding = 2) uniform sampler textureSampler; - /// layout(set = 0, binding = 3) uniform texture2D imageTexture; - /// - /// layout(push_constant) uniform PushConstant { - /// uint data; - /// } push; - /// - /// layout(input_attachment_index = 0, set = 0, binding = 4) uniform subpassInput inputAttachment; - /// - /// layout(location = 0) out vec4 outColor; - /// - /// void entrypoint1() { - /// bo.data = 12; - /// outColor = vec4( - /// float(ubo.data), - /// float(push.data), - /// texture(sampler2D(imageTexture, textureSampler), vec2(0.0, 0.0)).x, - /// subpassLoad(inputAttachment).x - /// ); - /// } - /// ``` - /// - /// `entrypoint2.frag.glsl`: - /// ```glsl - /// #version 450 - /// - /// layout(input_attachment_index = 0, set = 0, binding = 0) uniform subpassInput inputAttachment2; - /// - /// layout(set = 0, binding = 1) buffer Buffer { - /// uint data; - /// } bo2; - /// - /// layout(set = 0, binding = 2) uniform Uniform { - /// uint data; - /// } ubo2; - /// - /// layout(push_constant) uniform PushConstant { - /// uint data; - /// } push2; - /// - /// void entrypoint2() { - /// bo2.data = ubo2.data + push2.data + int(subpassLoad(inputAttachment2).y); - /// } - /// ``` - /// - /// Compiled and linked with: - /// ```sh - /// glslangvalidator -e entrypoint1 --source-entrypoint entrypoint1 -V100 entrypoint1.frag.glsl -o entrypoint1.spv - /// glslangvalidator -e entrypoint2 --source-entrypoint entrypoint2 -V100 entrypoint2.frag.glsl -o entrypoint2.spv - /// spirv-link entrypoint1.spv entrypoint2.spv -o multiple_entrypoints.spv - /// ``` - #[test] - fn descriptor_calculation_with_multiple_entrypoints() { - let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv")); - let spirv = Spirv::new(&insts).unwrap(); - - let mut descriptors = Vec::new(); - for (_, info) in reflect::entry_points(&spirv) { - descriptors.push(info.descriptor_binding_requirements); - } - - // Check first entrypoint - let e1_descriptors = &descriptors[0]; - let mut e1_bindings = Vec::new(); - for loc in e1_descriptors.keys() { - e1_bindings.push(*loc); - } - - assert_eq!(e1_bindings.len(), 5); - assert!(e1_bindings.contains(&(0, 0))); - assert!(e1_bindings.contains(&(0, 1))); - assert!(e1_bindings.contains(&(0, 2))); - assert!(e1_bindings.contains(&(0, 3))); - assert!(e1_bindings.contains(&(0, 4))); - - // Check second entrypoint - let e2_descriptors = &descriptors[1]; - let mut e2_bindings = Vec::new(); - for loc in e2_descriptors.keys() { - e2_bindings.push(*loc); - } - - assert_eq!(e2_bindings.len(), 3); - assert!(e2_bindings.contains(&(0, 0))); - assert!(e2_bindings.contains(&(0, 1))); - assert!(e2_bindings.contains(&(0, 2))); - } - - #[test] - fn reflect_descriptor_calculation_with_multiple_entrypoints() { - let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv")); - - let mut type_registry = TypeRegistry::default(); - let (_shader_code, _structs) = reflect( - &MacroInput::empty(), - LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()), - String::new(), - &insts, - Vec::new(), - &mut type_registry, - ) - .expect("reflecting spv failed"); - - let structs = _structs.to_string(); - assert_ne!(structs, "", "Has some structs"); - - let file: File = syn::parse2(_structs).unwrap(); - let structs: Vec<_> = file - .items - .iter() - .filter_map(|item| { - if let Item::Struct(s) = item { - Some(s) - } else { - None - } - }) - .collect(); - - let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap(); - assert_eq!( - buffer.fields.to_token_stream().to_string(), - quote!({pub data: u32,}).to_string() - ); - - let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap(); - assert_eq!( - uniform.fields.to_token_stream().to_string(), - quote!({pub data: u32,}).to_string() - ); - } - - fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec) - { - compile( - &MacroInput { - spirv_version: Some(SpirvVersion::V1_6), - vulkan_version: Some(EnvVersion::Vulkan1_3), - ..MacroInput::empty() - }, - None, - Path::new(""), - r#" - #version 460 - - layout(set = 1, binding = 0) buffer Buffer { - vec3 data; - } bo; - - layout(set = 2, binding = 0) uniform Uniform { - float data; - } ubo; - - layout(set = 3, binding = 1) uniform sampler textureSampler; - layout(set = 3, binding = 2) uniform texture2D imageTexture; - - float withMagicSparkles(float data) { - return texture(sampler2D(imageTexture, textureSampler), vec2(data, data)).x; - } - - vec3 makeSecretSauce() { - return vec3(withMagicSparkles(ubo.data)); - } - - void main() { - bo.data = makeSecretSauce(); - } - "#, - ShaderKind::Vertex, - ) - .unwrap() - } - - #[test] - fn descriptor_calculation_with_multiple_functions() { - let (artifact, _) = descriptor_calculation_with_multiple_functions_shader(); - let spirv = Spirv::new(artifact.as_binary()).unwrap(); - - if let Some((_, info)) = reflect::entry_points(&spirv).next() { - let mut bindings = Vec::new(); - for (loc, _reqs) in info.descriptor_binding_requirements { - bindings.push(loc); - } - - assert_eq!(bindings.len(), 4); - assert!(bindings.contains(&(1, 0))); - assert!(bindings.contains(&(2, 0))); - assert!(bindings.contains(&(3, 1))); - assert!(bindings.contains(&(3, 2))); - - return; - } - panic!("could not find entrypoint"); - } - - #[test] - fn reflect_descriptor_calculation_with_multiple_functions() { - let (artifact, _) = descriptor_calculation_with_multiple_functions_shader(); - - let mut type_registry = TypeRegistry::default(); - let (_shader_code, _structs) = reflect( - &MacroInput::empty(), - LitStr::new( - "descriptor_calculation_with_multiple_functions_shader", - Span::call_site(), - ), - String::new(), - artifact.as_binary(), - Vec::new(), - &mut type_registry, - ) - .expect("reflecting spv failed"); - - let structs = _structs.to_string(); - assert_ne!(structs, "", "Has some structs"); - - let file: File = syn::parse2(_structs).unwrap(); - let structs: Vec<_> = file - .items - .iter() - .filter_map(|item| { - if let Item::Struct(s) = item { - Some(s) - } else { - None - } - }) - .collect(); - - let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap(); - assert_eq!( - buffer.fields.to_token_stream().to_string(), - quote!({pub data: [f32; 3usize],}).to_string() - ); - - let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap(); - assert_eq!( - uniform.fields.to_token_stream().to_string(), - quote!({pub data: f32,}).to_string() - ); } } diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index 961b7f9c..34a3d26b 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -231,24 +231,26 @@ #![doc(html_logo_url = "https://raw.githubusercontent.com/vulkano-rs/vulkano/master/logo.png")] #![recursion_limit = "1024"] -use crate::codegen::ShaderKind; -use ahash::HashMap; +use ahash::{HashMap, HashSet}; use proc_macro2::{Span, TokenStream}; -use quote::quote; -use shaderc::{EnvVersion, SourceLanguage, SpirvVersion}; +use quote::{format_ident, quote}; +use shaderc::{Compiler, EnvVersion, ShaderKind, SourceLanguage, SpirvVersion}; use std::{ - env, fs, mem, + cell::RefCell, + env::{self, VarError}, + fs, iter, path::{Path, PathBuf}, }; -use structs::TypeRegistry; use syn::{ braced, bracketed, parenthesized, parse::{Parse, ParseStream, Result}, - parse_macro_input, parse_quote, Error, Ident, LitBool, LitStr, Path as SynPath, Token, + parse_macro_input, parse_quote, + token::Paren, + Error, Ident, LitBool, LitStr, Token, }; +use vulkano::shader::spirv::{self, Spirv}; mod codegen; -mod rust_gpu; mod structs; #[proc_macro] @@ -260,483 +262,281 @@ pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { .into() } -fn shader_inner(mut input: MacroInput) -> Result { - let (root, relative_path_error_msg) = match input.root_path_env.as_ref() { - None => ( - env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into()), - "to your Cargo.toml".to_owned(), - ), - Some(root_path_env) => { - let root = match env::var(root_path_env.value()) { - Ok(e) => e, - Err(e) => { - bail!( - root_path_env, - "failed to fetch environment variable: {e}; typical parameters are \ - `OUT_DIR` to gather results from your build script, or left default to \ - search relative to your Cargo.toml", - ) - } - }; - let env = root_path_env.value(); - let error = format!("to the path `{root}` specified by the env variable `{env:?}`"); - (root, error) +fn shader_inner(input: MacroInput) -> Result { + let base_path = match &input.macro_options.base_path { + Some(BasePath::Path(path)) => { + if path.is_absolute() { + path.clone() + } else { + Path::new(env::var("CARGO_MANIFEST_DIR").as_deref().unwrap_or(".")).join(path) + } } + Some(BasePath::Env(var)) => match env::var(var.value()) { + Ok(path) => path.into(), + Err(VarError::NotPresent) => bail!(var, "environment variable not found"), + Err(VarError::NotUnicode(_)) => bail!(var, "environment variable is not valid unicode"), + }, + None => env::var("CARGO_MANIFEST_DIR") + .as_deref() + .unwrap_or(".") + .into(), }; - let root_path = Path::new(&root); - let shaders = mem::take(&mut input.shaders); // yoink + let include_paths = input + .macro_options + .include_paths + .iter() + .map(|path| { + if path.is_absolute() { + path.clone() + } else { + Path::new(env::var("CARGO_MANIFEST_DIR").as_deref().unwrap_or(".")).join(path) + } + }) + .collect::>(); - let mut shaders_code = Vec::with_capacity(shaders.len()); - let mut types_code = Vec::with_capacity(shaders.len()); - let mut type_registry = TypeRegistry::default(); + let compiler = Compiler::new() + .ok_or_else(|| Error::new(Span::call_site(), "failed to create shader compiler"))?; - for (name, (shader_kind, source_kind)) in shaders { - let (code, types) = match source_kind { - SourceKind::Src(source) => { - let (artifact, includes) = codegen::compile( - &input, + let sources_to_include = RefCell::new(HashSet::default()); + let mut registered_structs = HashMap::default(); + + let mut shaders_code = TokenStream::new(); + let mut structs_code = TokenStream::new(); + + for (shader_name, shader_source) in input.shaders { + let some_shader_name = shader_name.as_deref().unwrap_or("shader"); + match shader_source { + ShaderSource::Src { src, entry_points } => { + let entry_points = codegen::compile( + &compiler, + &input.macro_options, None, - root_path, - &source.value(), - shader_kind.unwrap(), + some_shader_name, + &src.value(), + entry_points, + &base_path, + &include_paths, + &sources_to_include, ) - .map_err(|err| Error::new_spanned(&source, err))?; + .or_else(|err| bail!(src, "{err}"))?; - let words = artifact.as_binary(); + let entry_points = entry_points + .iter() + .map(|(name, artifact)| (name.as_deref(), artifact.as_binary())) + .collect::>(); - codegen::reflect(&input, source, name, words, includes, &mut type_registry)? + shaders_code.extend(codegen::generate_shader_code(&entry_points, &shader_name)); + + if input.macro_options.generate_structs { + for (_, words) in entry_points { + let spirv = Spirv::new(words) + .or_else(|err| bail!(src, "failed to parse SPIR-V: {err}"))?; + + structs_code.extend(structs::generate_structs( + &input.macro_options, + spirv, + some_shader_name.to_owned(), + src.clone(), + &mut registered_structs, + )?); + } + } } - SourceKind::Path(path) => { - let full_path = root_path.join(path.value()); + ShaderSource::Path { + path: path_lit, + entry_points, + } => { + let path = base_path + .join(path_lit.value()) + .into_os_string() + .into_string() + .unwrap(); - if !full_path.is_file() { - bail!( - path, - "file `{full_path:?}` was not found, note that the path must be relative \ - {relative_path_error_msg}", - ); + let src = fs::read_to_string(&path) + .or_else(|err| bail!(path_lit, "failed to read {path}: {err}"))?; + + let entry_points = codegen::compile( + &compiler, + &input.macro_options, + Some(&path), + some_shader_name, + &src, + entry_points, + &base_path, + &include_paths, + &sources_to_include, + ) + .or_else(|err| bail!(path_lit, "{err}"))?; + + let entry_points = entry_points + .iter() + .map(|(name, artifact)| (name.as_deref(), artifact.as_binary())) + .collect::>(); + + shaders_code.extend(codegen::generate_shader_code(&entry_points, &shader_name)); + + if input.macro_options.generate_structs { + for (_, words) in entry_points { + let spirv = Spirv::new(words) + .or_else(|err| bail!(path_lit, "failed to parse SPIR-V: {err}"))?; + + structs_code.extend(structs::generate_structs( + &input.macro_options, + spirv, + some_shader_name.to_owned(), + path_lit.clone(), + &mut registered_structs, + )?); + } } - let source_code = fs::read_to_string(&full_path) - .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; - - let (artifact, mut includes) = codegen::compile( - &input, - Some(path.value()), - root_path, - &source_code, - shader_kind.unwrap(), - ) - .map_err(|err| Error::new_spanned(&path, err))?; - - let words = artifact.as_binary(); - - includes.push(full_path.into_os_string().into_string().unwrap()); - - codegen::reflect(&input, path, name, words, includes, &mut type_registry)? + sources_to_include.borrow_mut().insert(path); } - SourceKind::Bytes(path) => { - let full_path = root_path.join(path.value()); + ShaderSource::Bytes { path: path_lit } => { + let path = base_path + .join(path_lit.value()) + .into_os_string() + .into_string() + .unwrap(); - if !full_path.is_file() { + let bytes = fs::read(&path) + .or_else(|err| bail!(path_lit, "failed to read `{path}`: {err}"))?; + + let words = spirv::bytes_to_words(&bytes).or_else(|_| { bail!( - path, - "file `{full_path:?}` was not found, note that the path must be relative \ - {relative_path_error_msg}", - ); + path_lit, + "the byte length of `{path}` is not a multiple of 4" + ) + })?; + + shaders_code.extend(codegen::generate_shader_code( + &[(None, &words)], + &shader_name, + )); + + if input.macro_options.generate_structs { + let spirv = Spirv::new(&words) + .or_else(|err| bail!(path_lit, "failed to parse SPIR-V: {err}"))?; + + structs_code.extend(structs::generate_structs( + &input.macro_options, + spirv, + some_shader_name.to_owned(), + path_lit, + &mut registered_structs, + )?); } - let bytes = fs::read(&full_path) - .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; - - let words = vulkano::shader::spirv::bytes_to_words(&bytes) - .or_else(|err| bail!(path, "failed to read source `{full_path:?}`: {err}"))?; - - codegen::reflect(&input, path, name, &words, Vec::new(), &mut type_registry)? + sources_to_include.borrow_mut().insert(path); } - }; - - shaders_code.push(code); - types_code.push(types); + } } - let result = quote! { - #( #shaders_code )* - #( #types_code )* - }; + let includes = sources_to_include + .take() + .into_iter() + .enumerate() + .map(|(index, source)| { + let ident = format_ident!("_INCLUDE{index}"); + quote! { + const #ident: &[u8] = include_bytes!(#source); + } + }); - if input.dump.value { - println!("{}", result); - bail!(input.dump, "`shader!` Rust codegen dumped"); - } - - Ok(result) -} - -enum SourceKind { - Src(LitStr), - Path(LitStr), - Bytes(LitStr), + Ok(quote! { + #structs_code + #shaders_code + #( #includes )* + }) } struct MacroInput { - root_path_env: Option, - include_directories: Vec, - macro_defines: Vec<(String, String)>, - shaders: HashMap, SourceKind)>, - source_language: Option, - spirv_version: Option, - vulkan_version: Option, - generate_structs: bool, - custom_derives: Vec, - linalg_type: LinAlgType, - dump: LitBool, + shaders: HashMap, ShaderSource>, + macro_options: MacroOptions, } -impl MacroInput { - #[cfg(test)] - fn empty() -> Self { - MacroInput { - root_path_env: None, - source_language: None, - include_directories: Vec::new(), - macro_defines: Vec::new(), - shaders: HashMap::default(), - vulkan_version: None, - spirv_version: None, - generate_structs: true, - custom_derives: Vec::new(), - linalg_type: LinAlgType::default(), - dump: LitBool::new(false, Span::call_site()), - } - } +enum ShaderSource { + Src { + src: LitStr, + entry_points: HashMap, ShaderKind>, + }, + Path { + path: LitStr, + entry_points: HashMap, ShaderKind>, + }, + Bytes { + path: LitStr, + }, +} + +struct MacroOptions { + source_language: SourceLanguage, + vulkan_version: Option, + spirv_version: Option, + base_path: Option, + include_paths: Vec, + macro_defines: Vec<(String, Option)>, + generate_structs: bool, + custom_derives: Vec, + linalg_types: LinAlgTypes, +} + +enum BasePath { + Path(PathBuf), + Env(LitStr), +} + +#[derive(Default)] +enum LinAlgTypes { + #[default] + Std, + Cgmath, + Nalgebra, } impl Parse for MacroInput { fn parse(input: ParseStream<'_>) -> Result { - let root = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into()); + let mut shaders = None; + let mut src = None; + let mut path = None; + let mut bytes = None; + let mut entry_points = None; + let mut stage = None; - let mut root_path_env = None; let mut source_language = None; - let mut include_directories = Vec::new(); - let mut macro_defines = Vec::new(); - let mut shaders = HashMap::default(); let mut vulkan_version = None; let mut spirv_version = None; + let mut base_path = None; + let mut base_path_env = None; + let mut include_paths = None; + let mut macro_defines = None; let mut generate_structs = None; let mut custom_derives = None; - let mut linalg_type = None; - let mut dump = None; - - fn parse_shader_fields( - output: &mut (Option, Option), - name: &str, - input: ParseStream<'_>, - ) -> Result<()> { - match name { - "ty" => { - let lit = input.parse::()?; - if output.0.is_some() { - bail!(lit, "field `ty` is already defined"); - } - - output.0 = Some(match lit.value().as_str() { - "vertex" => ShaderKind::Vertex, - "tess_ctrl" => ShaderKind::TessControl, - "tess_eval" => ShaderKind::TessEvaluation, - "geometry" => ShaderKind::Geometry, - "task" => ShaderKind::Task, - "mesh" => ShaderKind::Mesh, - "fragment" => ShaderKind::Fragment, - "compute" => ShaderKind::Compute, - "raygen" => ShaderKind::RayGeneration, - "anyhit" => ShaderKind::AnyHit, - "closesthit" => ShaderKind::ClosestHit, - "miss" => ShaderKind::Miss, - "intersection" => ShaderKind::Intersection, - "callable" => ShaderKind::Callable, - ty => bail!( - lit, - "expected `vertex`, `tess_ctrl`, `tess_eval`, `geometry`, `task`, \ - `mesh`, `fragment` `compute`, `raygen`, `anyhit`, `closesthit`, \ - `miss`, `intersection` or `callable`, found `{ty}`", - ), - }); - } - "bytes" => { - let lit = input.parse::()?; - if output.1.is_some() { - bail!( - lit, - "only one of `src`, `path`, or `bytes` can be defined per shader entry", - ); - } - - output.1 = Some(SourceKind::Bytes(lit)); - } - "path" => { - let lit = input.parse::()?; - if output.1.is_some() { - bail!( - lit, - "only one of `src`, `path` or `bytes` can be defined per shader entry", - ); - } - - output.1 = Some(SourceKind::Path(lit)); - } - "src" => { - let lit = input.parse::()?; - if output.1.is_some() { - bail!( - lit, - "only one of `src`, `path` or `bytes` can be defined per shader entry", - ); - } - - output.1 = Some(SourceKind::Src(lit)); - } - _ => unreachable!(), - } - - Ok(()) - } + let mut linalg_types = None; while !input.is_empty() { - let field_ident = input.parse::()?; + let field = input.parse::()?; input.parse::()?; - let field = field_ident.to_string(); - match field.as_str() { - "bytes" | "src" | "path" | "ty" => { - if shaders.len() > 1 || (shaders.len() == 1 && !shaders.contains_key("")) { - bail!( - field_ident, - "only one of `src`, `path`, `bytes` or `shaders` can be defined", - ); - } - - parse_shader_fields(shaders.entry(String::new()).or_default(), &field, input)?; - } - "shaders" => { - if !shaders.is_empty() { - bail!( - field_ident, - "only one of `src`, `path`, `bytes` or `shaders` can be defined", - ); - } - - let in_braces; - braced!(in_braces in input); - - while !in_braces.is_empty() { - let name_ident = in_braces.parse::()?; - let name = name_ident.to_string(); - - if shaders.contains_key(&name) { - bail!(name_ident, "shader entry `{name}` is already defined"); - } - - in_braces.parse::()?; - - let in_shader_definition; - braced!(in_shader_definition in in_braces); - - while !in_shader_definition.is_empty() { - let field_ident = in_shader_definition.parse::()?; - in_shader_definition.parse::()?; - let field = field_ident.to_string(); - - match field.as_str() { - "bytes" | "src" | "path" | "ty" => { - parse_shader_fields( - shaders.entry(name.clone()).or_default(), - &field, - &in_shader_definition, - )?; - } - field => bail!( - field_ident, - "expected `bytes`, `src`, `path` or `ty` as a field, found \ - `{field}`", - ), - } - - if !in_shader_definition.is_empty() { - in_shader_definition.parse::()?; - } - } - - if !in_braces.is_empty() { - in_braces.parse::()?; - } - - match shaders.get(&name).unwrap() { - (None, _) => bail!( - "please specify a type for shader `{name}` e.g. `ty: \"vertex\"`", - ), - (_, None) => bail!( - "please specify a source for shader `{name}` e.g. \ - `path: \"entry_point.glsl\"`", - ), - _ => (), - } - } - - if shaders.is_empty() { - bail!("at least one shader entry must be defined"); - } - } - "define" => { - let array_input; - bracketed!(array_input in input); - - while !array_input.is_empty() { - let tuple_input; - parenthesized!(tuple_input in array_input); - - let name = tuple_input.parse::()?; - tuple_input.parse::()?; - let value = tuple_input.parse::()?; - macro_defines.push((name.value(), value.value())); - - if !array_input.is_empty() { - array_input.parse::()?; - } - } - } - "root_path_env" => { - let lit = input.parse::()?; - if root_path_env.is_some() { - bail!(lit, "field `root_path_env` is already defined"); - } - root_path_env = Some(lit); - } - "include" => { - let in_brackets; - bracketed!(in_brackets in input); - - while !in_brackets.is_empty() { - let path = in_brackets.parse::()?; - - include_directories.push([&root, &path.value()].into_iter().collect()); - - if !in_brackets.is_empty() { - in_brackets.parse::()?; - } - } - } - "lang" => { - let lit = input.parse::()?; - if source_language.is_some() { - bail!(lit, "field `lang` is already defined"); - } - - source_language = Some(match lit.value().as_str() { - "glsl" => SourceLanguage::GLSL, - "hlsl" => SourceLanguage::HLSL, - lang => bail!(lit, "expected `glsl` or `hlsl`, found `{lang}`"), - }) - } - "vulkan_version" => { - let lit = input.parse::()?; - if vulkan_version.is_some() { - bail!(lit, "field `vulkan_version` is already defined"); - } - - vulkan_version = Some(match lit.value().as_str() { - "1.0" => EnvVersion::Vulkan1_0, - "1.1" => EnvVersion::Vulkan1_1, - "1.2" => EnvVersion::Vulkan1_2, - "1.3" => EnvVersion::Vulkan1_3, - ver => bail!(lit, "expected `1.0`, `1.1`, `1.2` or `1.3`, found `{ver}`"), - }); - } - "spirv_version" => { - let lit = input.parse::()?; - if spirv_version.is_some() { - bail!(lit, "field `spirv_version` is already defined"); - } - - spirv_version = Some(match lit.value().as_str() { - "1.0" => SpirvVersion::V1_0, - "1.1" => SpirvVersion::V1_1, - "1.2" => SpirvVersion::V1_2, - "1.3" => SpirvVersion::V1_3, - "1.4" => SpirvVersion::V1_4, - "1.5" => SpirvVersion::V1_5, - "1.6" => SpirvVersion::V1_6, - ver => bail!( - lit, - "expected `1.0`, `1.1`, `1.2`, `1.3`, `1.4`, `1.5` or `1.6`, found \ - `{ver}`", - ), - }); - } - "generate_structs" => { - let lit = input.parse::()?; - if generate_structs.is_some() { - bail!(lit, "field `generate_structs` is already defined"); - } - generate_structs = Some(lit.value); - } - "custom_derives" => { - let in_brackets; - bracketed!(in_brackets in input); - - while !in_brackets.is_empty() { - if custom_derives.is_none() { - custom_derives = Some(Vec::new()); - } - - custom_derives - .as_mut() - .unwrap() - .push(in_brackets.parse::()?); - - if !in_brackets.is_empty() { - in_brackets.parse::()?; - } - } - } - "types_meta" => { - bail!( - field_ident, - "you no longer need to add any derives to use the generated structs in \ - buffers, and you also no longer need bytemuck as a dependency, because \ - `BufferContents` is derived automatically for the generated structs; if \ - you need to add additional derives (e.g. `Debug`, `PartialEq`) then please \ - use the `custom_derives` field of the macro", - ); - } - "linalg_type" => { - let lit = input.parse::()?; - if linalg_type.is_some() { - bail!(lit, "field `linalg_type` is already defined"); - } - - linalg_type = Some(match lit.value().as_str() { - "std" => LinAlgType::Std, - "cgmath" => LinAlgType::CgMath, - "nalgebra" => LinAlgType::Nalgebra, - ty => bail!(lit, "expected `std`, `cgmath` or `nalgebra`, found `{ty}`"), - }); - } - "dump" => { - let lit = input.parse::()?; - if dump.is_some() { - bail!(lit, "field `dump` is already defined"); - } - - dump = Some(lit); - } - field => bail!( - field_ident, - "expected `bytes`, `src`, `path`, `ty`, `shaders`, `define`, `include`, \ - `vulkan_version`, `spirv_version`, `generate_structs`, `custom_derives`, \ - `linalg_type` or `dump` as a field, found `{field}`", - ), + match field.to_string().as_str() { + "shaders" => parse_shaders(input, field, &mut shaders)?, + "src" => parse_src(input, field, &mut src)?, + "path" => parse_path(input, field, &mut path)?, + "bytes" => parse_bytes(input, field, &mut bytes)?, + "entry_points" => parse_entry_points(input, field, &mut entry_points)?, + "stage" => parse_stage(input, field, &mut stage)?, + "lang" => parse_lang(input, field, &mut source_language)?, + "vulkan_version" => parse_vulkan_version(input, field, &mut vulkan_version)?, + "spirv_version" => parse_spirv_version(input, field, &mut spirv_version)?, + "base_path" => parse_base_path(input, field, &mut base_path)?, + "base_path_env" => parse_base_path_env(input, field, &mut base_path_env)?, + "include" => parse_include(input, field, &mut include_paths)?, + "define" => parse_define(input, field, &mut macro_defines)?, + "generate_structs" => parse_generate_structs(input, field, &mut generate_structs)?, + "custom_derives" => parse_custom_derives(input, field, &mut custom_derives)?, + "linalg_types" => parse_linalg_types(input, field, &mut linalg_types)?, + other => bail!(field, "unsupported field `{other}`"), } if !input.is_empty() { @@ -744,59 +544,530 @@ impl Parse for MacroInput { } } - if shaders.is_empty() { - bail!(r#"please specify at least one shader e.g. `ty: "vertex", src: ""`"#); - } + let shaders = match (src, path, bytes, shaders) { + (Some(src), None, None, None) => { + let entry_points = match (entry_points, stage) { + (Some(entry_points), None) => entry_points, + (None, Some(stage)) => iter::once((None, stage)).collect(), + _ => bail!( + "exactly one of the fields `entry_points` and `stage` must be defined" + ), + }; - match shaders.get("") { - // if source is bytes, the shader type should not be declared - Some((None, Some(SourceKind::Bytes(_)))) => {} - Some((_, Some(SourceKind::Bytes(_)))) => { - bail!( - r#"one may not specify a shader type when including precompiled SPIR-V binaries. Please remove the `ty:` declaration"# - ); + iter::once((None, ShaderSource::Src { src, entry_points })).collect() } - Some((None, _)) => { - bail!(r#"please specify the type of the shader e.g. `ty: "vertex"`"#); + (None, Some(path), None, None) => { + let entry_points = match (entry_points, stage) { + (Some(entry_points), None) => entry_points, + (None, Some(stage)) => iter::once((None, stage)).collect(), + _ => bail!( + "exactly one of the fields `entry_points` and `stage` must be defined" + ), + }; + + iter::once((None, ShaderSource::Path { path, entry_points })).collect() } - Some((_, None)) => { - bail!(r#"please specify the source of the shader e.g. `src: ""`"#); + (None, None, Some(bytes), None) => { + if entry_points.is_some() { + bail!("field `entry_points` cannot be defined when `bytes` is used"); + } + + if stage.is_some() { + bail!("field `stage` cannot be defined when `bytes` is used"); + } + + iter::once((None, ShaderSource::Bytes { path: bytes })).collect() } - _ => {} - } + (None, None, None, Some(shaders)) => { + if entry_points.is_some() { + bail!("field `entry_points` cannot be defined when `shaders` is used"); + } + + if stage.is_some() { + bail!("field `stage` cannot be defined when `shaders` is used"); + } + + shaders + } + _ => bail!( + "exactly one of the fields `src`, `path`, `bytes` and `shaders` must be defined" + ), + }; + + let base_path = match (base_path, base_path_env) { + (None, None) => None, + (Some(base_path), None) => Some(BasePath::Path(base_path)), + (None, Some(base_path_env)) => Some(BasePath::Env(base_path_env)), + (Some(_), Some(_)) => { + bail!("only one of the fields `base_path` and `base_path_env` can be defined"); + } + }; Ok(MacroInput { - root_path_env, - include_directories, - macro_defines, - shaders: shaders - .into_iter() - .map(|(key, (shader_kind, shader_source))| { - (key, (shader_kind, shader_source.unwrap())) - }) - .collect(), - source_language, - vulkan_version, - spirv_version, - generate_structs: generate_structs.unwrap_or(true), - custom_derives: custom_derives.unwrap_or_else(|| { - vec![ - parse_quote! { ::std::clone::Clone }, - parse_quote! { ::std::marker::Copy }, - ] - }), - linalg_type: linalg_type.unwrap_or_default(), - dump: dump.unwrap_or_else(|| LitBool::new(false, Span::call_site())), + shaders, + macro_options: MacroOptions { + source_language: source_language.unwrap_or(SourceLanguage::GLSL), + vulkan_version, + spirv_version, + base_path, + include_paths: include_paths.unwrap_or_default(), + macro_defines: macro_defines.unwrap_or_default(), + generate_structs: generate_structs.unwrap_or(true), + custom_derives: custom_derives.unwrap_or_else(|| { + vec![ + parse_quote!(::std::clone::Clone), + parse_quote!(::std::marker::Copy), + ] + }), + linalg_types: linalg_types.unwrap_or_default(), + }, }) } } -#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)] -enum LinAlgType { - #[default] - Std, - CgMath, - Nalgebra, +fn parse_shaders( + input: ParseStream<'_>, + field: Ident, + shaders: &mut Option, ShaderSource>>, +) -> Result<()> { + if shaders.is_some() { + bail!(field, "field `shaders` was already defined"); + } + + let shaders = shaders.insert(HashMap::default()); + + let in_braces; + let braces = braced!(in_braces in input); + + while !in_braces.is_empty() { + let mut src = None; + let mut path = None; + let mut bytes = None; + let mut entry_points = None; + let mut stage = None; + + let ident = in_braces.parse::()?; + in_braces.parse::()?; + let name = Some(ident.to_string()); + + if shaders.contains_key(&name) { + bail!(ident, "shader `{ident}` was already defined"); + } + + let in_shader; + let shader = braced!(in_shader in in_braces); + + while !in_shader.is_empty() { + let field = in_shader.parse::()?; + in_shader.parse::()?; + + match field.to_string().as_str() { + "src" => parse_src(&in_shader, field, &mut src)?, + "path" => parse_path(&in_shader, field, &mut path)?, + "bytes" => parse_bytes(&in_shader, field, &mut bytes)?, + "entry_points" => parse_entry_points(&in_shader, field, &mut entry_points)?, + "stage" => parse_stage(&in_shader, field, &mut stage)?, + other => bail!(field, "unsupported field `{other}`"), + } + + if !in_shader.is_empty() { + in_shader.parse::()?; + } + } + + let source = match (src, path, bytes) { + (Some(src), None, None) => { + let entry_points = match (entry_points, stage) { + (Some(entry_points), None) => entry_points, + (None, Some(stage)) => iter::once((None, stage)).collect(), + _ => bail!( + ident, + "exactly one of the fields `entry_points` and `stage` must be defined" + ), + }; + + ShaderSource::Src { src, entry_points } + } + (None, Some(path), None) => { + let entry_points = match (entry_points, stage) { + (Some(entry_points), None) => entry_points, + (None, Some(stage)) => iter::once((None, stage)).collect(), + _ => bail!( + ident, + "exactly one of the fields `entry_points` and `stage` must be defined" + ), + }; + + ShaderSource::Path { path, entry_points } + } + (None, None, Some(bytes)) => { + if entry_points.is_some() { + bail!( + ident, + "field `entry_points` cannot be defined when `bytes` is used" + ); + } + + if stage.is_some() { + bail!( + ident, + "field `stage` cannot be defined when `bytes` is used" + ); + } + + ShaderSource::Bytes { path: bytes } + } + _ => { + return Err(Error::new( + shader.span.join(), + "exactly one of the fields `src`, `path` and `bytes` must be defined", + )) + } + }; + + shaders.insert(name, source); + + if !in_braces.is_empty() { + in_braces.parse::()?; + } + } + + if shaders.is_empty() { + return Err(Error::new( + braces.span.join(), + "at least one shader must be defined", + )); + } + + Ok(()) +} + +fn parse_src(input: ParseStream<'_>, field: Ident, src: &mut Option) -> Result<()> { + if src.is_some() { + bail!(field, "field `src` was already defined"); + } + + *src = Some(input.parse::()?); + Ok(()) +} + +fn parse_path(input: ParseStream<'_>, field: Ident, path: &mut Option) -> Result<()> { + if path.is_some() { + bail!(field, "field `path` was already defined"); + } + + *path = Some(input.parse::()?); + Ok(()) +} + +fn parse_bytes(input: ParseStream<'_>, field: Ident, bytes: &mut Option) -> Result<()> { + if bytes.is_some() { + bail!(field, "field `bytes` was already defined"); + } + + *bytes = Some(input.parse::()?); + Ok(()) +} + +fn parse_entry_points( + input: ParseStream<'_>, + field: Ident, + entry_points: &mut Option, ShaderKind>>, +) -> Result<()> { + if entry_points.is_some() { + bail!(field, "field `entry_points` was already defined"); + } + + let entry_points = entry_points.insert(HashMap::default()); + + let in_braces; + let braces = braced!(in_braces in input); + + while !in_braces.is_empty() { + let ident = in_braces.parse::()?; + in_braces.parse::()?; + let name = Some(ident.to_string()); + + if entry_points.contains_key(&name) { + bail!(ident, "entry point `{ident}` was already defined"); + } + + entry_points.insert(name, parse_shader_kind(&in_braces)?); + + if !in_braces.is_empty() { + in_braces.parse::()?; + } + } + + if entry_points.is_empty() { + return Err(Error::new( + braces.span.join(), + "at least one entry point must be defined", + )); + } + + Ok(()) +} + +fn parse_stage(input: ParseStream<'_>, field: Ident, stage: &mut Option) -> Result<()> { + if stage.is_some() { + bail!(field, "field `stage` was already defined"); + } + + *stage = Some(parse_shader_kind(input)?); + Ok(()) +} + +fn parse_shader_kind(input: ParseStream<'_>) -> Result { + let lit = input.parse::()?; + + Ok(match lit.value().as_str() { + "vertex" => ShaderKind::Vertex, + "fragment" => ShaderKind::Fragment, + "compute" => ShaderKind::Compute, + "geometry" => ShaderKind::Geometry, + "tess_ctrl" => ShaderKind::TessControl, + "tess_eval" => ShaderKind::TessEvaluation, + "task" => ShaderKind::Task, + "mesh" => ShaderKind::Mesh, + "raygen" => ShaderKind::RayGeneration, + "any_hit" => ShaderKind::AnyHit, + "closest_hit" => ShaderKind::ClosestHit, + "miss" => ShaderKind::Miss, + "intersection" => ShaderKind::Intersection, + "callable" => ShaderKind::Callable, + other => bail!( + lit, + "expected `vertex`, `fragment`, `compute`, `geometry`, `tess_ctrl`, `tess_eval`, \ + `task`, `mesh`, `raygen`, `any_hit`, `closest_hit`, `miss`, `intersection` or \ + `callable, found `{other}`", + ), + }) +} + +fn parse_lang( + input: ParseStream<'_>, + field: Ident, + source_language: &mut Option, +) -> Result<()> { + if source_language.is_some() { + bail!(field, "field `lang` was already defined"); + } + + let lit = input.parse::()?; + + *source_language = Some(match lit.value().as_str() { + "glsl" => SourceLanguage::GLSL, + "hlsl" => SourceLanguage::HLSL, + other => bail!(lit, "expected `glsl` or `hlsl`, found `{other}`"), + }); + + Ok(()) +} + +fn parse_vulkan_version( + input: ParseStream<'_>, + field: Ident, + vulkan_version: &mut Option, +) -> Result<()> { + if vulkan_version.is_some() { + bail!(field, "field `vulkan_version` was already defined"); + } + + let lit = input.parse::()?; + + *vulkan_version = Some(match lit.value().as_str() { + "1.0" => EnvVersion::Vulkan1_0, + "1.1" => EnvVersion::Vulkan1_1, + "1.2" => EnvVersion::Vulkan1_2, + "1.3" => EnvVersion::Vulkan1_3, + other => bail!( + lit, + "expected `1.0`, `1.1`, `1.2` or `1.3`, found `{other}`" + ), + }); + + Ok(()) +} + +fn parse_spirv_version( + input: ParseStream<'_>, + field: Ident, + spirv_version: &mut Option, +) -> Result<()> { + if spirv_version.is_some() { + bail!(field, "field `spirv_version` was already defined"); + } + + let lit = input.parse::()?; + + *spirv_version = Some(match lit.value().as_str() { + "1.0" => SpirvVersion::V1_0, + "1.1" => SpirvVersion::V1_1, + "1.2" => SpirvVersion::V1_2, + "1.3" => SpirvVersion::V1_3, + "1.4" => SpirvVersion::V1_4, + "1.5" => SpirvVersion::V1_5, + "1.6" => SpirvVersion::V1_6, + other => bail!( + lit, + "expected `1.0`, `1.1`, `1.2`, `1.3`, `1.4`, `1.5` or `1.6`, found `{other}`", + ), + }); + + Ok(()) +} + +fn parse_base_path( + input: ParseStream<'_>, + field: Ident, + base_path: &mut Option, +) -> Result<()> { + if base_path.is_some() { + bail!(field, "field `base_path` was already defined"); + } + + *base_path = Some(input.parse::()?.value().into()); + Ok(()) +} + +fn parse_base_path_env( + input: ParseStream<'_>, + field: Ident, + base_path_env: &mut Option, +) -> Result<()> { + if base_path_env.is_some() { + bail!(field, "field `base_path_env` was already defined"); + } + + *base_path_env = Some(input.parse::()?); + Ok(()) +} + +fn parse_include( + input: ParseStream<'_>, + field: Ident, + include_paths: &mut Option>, +) -> Result<()> { + if include_paths.is_some() { + bail!(field, "field `include` was already defined"); + } + + let include_paths = include_paths.insert(Vec::new()); + + let in_brackets; + bracketed!(in_brackets in input); + + while !in_brackets.is_empty() { + include_paths.push(in_brackets.parse::()?.value().into()); + + if !in_brackets.is_empty() { + in_brackets.parse::()?; + } + } + + Ok(()) +} + +fn parse_define( + input: ParseStream<'_>, + field: Ident, + macro_defines: &mut Option)>>, +) -> Result<()> { + if macro_defines.is_some() { + bail!(field, "field `define` was already defined"); + } + + let macro_defines = macro_defines.insert(Vec::new()); + + let in_brackets; + bracketed!(in_brackets in input); + + while !in_brackets.is_empty() { + if in_brackets.peek(LitStr) { + let name = in_brackets.parse::()?; + macro_defines.push((name.value(), None)); + } else if in_brackets.peek(Paren) { + let in_parens; + parenthesized!(in_parens in in_brackets); + + let name = in_parens.parse::()?; + in_parens.parse::()?; + let value = in_parens.parse::()?; + + macro_defines.push((name.value(), Some(value.value()))); + } else { + return Err(in_brackets.error("expected a string literal or `(`")); + } + + if !in_brackets.is_empty() { + in_brackets.parse::()?; + } + } + + Ok(()) +} + +fn parse_generate_structs( + input: ParseStream<'_>, + field: Ident, + generate_structs: &mut Option, +) -> Result<()> { + if generate_structs.is_some() { + bail!(field, "field `generate_structs` was already defined"); + } + + *generate_structs = Some(input.parse::()?.value); + Ok(()) +} + +fn parse_custom_derives( + input: ParseStream<'_>, + field: Ident, + custom_derives: &mut Option>, +) -> Result<()> { + if custom_derives.is_some() { + bail!(field, "field `custom_derives` was already defined"); + } + + let custom_derives = custom_derives.insert(Vec::new()); + + let in_brackets; + bracketed!(in_brackets in input); + + while !in_brackets.is_empty() { + custom_derives.push(in_brackets.parse::()?); + + if !in_brackets.is_empty() { + in_brackets.parse::()?; + } + } + + Ok(()) +} + +fn parse_linalg_types( + input: ParseStream<'_>, + field: Ident, + linalg_types: &mut Option, +) -> Result<()> { + if linalg_types.is_some() { + bail!(field, "field `linalg_types` was already defined"); + } + + let lit = input.parse::()?; + + *linalg_types = Some(match lit.value().as_str() { + "std" => LinAlgTypes::Std, + "cgmath" => LinAlgTypes::Cgmath, + "nalgebra" => LinAlgTypes::Nalgebra, + other => bail!( + lit, + "expected `std`, `cgmath` or `nalgebra`, found `{other}`" + ), + }); + + Ok(()) } macro_rules! bail { diff --git a/vulkano-shaders/src/rust_gpu.rs b/vulkano-shaders/src/rust_gpu.rs deleted file mode 100644 index 1c4a8914..00000000 --- a/vulkano-shaders/src/rust_gpu.rs +++ /dev/null @@ -1,44 +0,0 @@ -#[cfg(test)] -mod tests { - use crate::{codegen::reflect, structs::TypeRegistry, MacroInput}; - use proc_macro2::Span; - use syn::LitStr; - - fn spv_to_words(data: &[u8]) -> Vec { - data.chunks(4) - .map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]])) - .collect() - } - - #[test] - fn rust_gpu_reflect_vertex() { - let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-vertex.spv")); - - let mut type_registry = TypeRegistry::default(); - let (_shader_code, _structs) = reflect( - &MacroInput::empty(), - LitStr::new("rust-gpu vertex shader", Span::call_site()), - String::new(), - &insts, - Vec::new(), - &mut type_registry, - ) - .expect("reflecting spv failed"); - } - - #[test] - fn rust_gpu_reflect_fragment() { - let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-fragment.spv")); - - let mut type_registry = TypeRegistry::default(); - let (_shader_code, _structs) = reflect( - &MacroInput::empty(), - LitStr::new("rust-gpu vertex shader", Span::call_site()), - String::new(), - &insts, - Vec::new(), - &mut type_registry, - ) - .expect("reflecting spv failed"); - } -} diff --git a/vulkano-shaders/src/structs.rs b/vulkano-shaders/src/structs.rs index e79edcb3..f1aa10dc 100644 --- a/vulkano-shaders/src/structs.rs +++ b/vulkano-shaders/src/structs.rs @@ -1,174 +1,148 @@ -use crate::{bail, codegen::Shader, LinAlgType, MacroInput}; +use crate::{bail, LinAlgTypes, MacroOptions}; use ahash::HashMap; use proc_macro2::{Span, TokenStream}; use quote::{format_ident, quote, ToTokens, TokenStreamExt}; use std::{cmp::Ordering, num::NonZeroUsize}; -use syn::{Error, Ident, Result}; -use vulkano::shader::spirv::{Decoration, Id, Instruction}; +use syn::{Error, Ident, LitStr, Result}; +use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv}; -#[derive(Default)] -pub struct TypeRegistry { - registered_structs: HashMap, +struct Shader { + spirv: Spirv, + name: String, + source: LitStr, } -impl TypeRegistry { - fn register_struct(&mut self, shader: &Shader, ty: &TypeStruct) -> Result { - // Checking with registry if this struct is already registered by another shader, and if - // their signatures match. - if let Some(registered) = self.registered_structs.get(&ty.ident) { - registered.validate_signatures(&shader.name, ty)?; - - // If the struct is already registered and matches this one, skip the duplicate. - Ok(false) - } else { - self.registered_structs.insert( - ty.ident.clone(), - RegisteredType { - shader: shader.name.clone(), - ty: ty.clone(), - }, - ); - - Ok(true) - } - } +pub(super) struct RegisteredStruct { + members: Vec, + shader_name: String, } -struct RegisteredType { - shader: String, - ty: TypeStruct, -} - -impl RegisteredType { - fn validate_signatures(&self, other_shader: &str, other_ty: &TypeStruct) -> Result<()> { - let (shader, struct_ident) = (&self.shader, &self.ty.ident); - - if self.ty.members.len() > other_ty.members.len() { - let member_ident = &self.ty.members[other_ty.members.len()].ident; - bail!( - "shaders `{shader}` and `{other_shader}` declare structs with the same name \ - `{struct_ident}`, but the struct from shader `{shader}` contains an extra field \ - `{member_ident}`", - ); - } - - if self.ty.members.len() < other_ty.members.len() { - let member_ident = &other_ty.members[self.ty.members.len()].ident; - bail!( - "shaders `{shader}` and `{other_shader}` declare structs with the same name \ - `{struct_ident}`, but the struct from shader `{other_shader}` contains an extra \ - field `{member_ident}`", - ); - } - - for (index, (member, other_member)) in self - .ty - .members - .iter() - .zip(other_ty.members.iter()) - .enumerate() - { - if member.ty != other_member.ty { - let (member_ty, other_member_ty) = (&member.ty, &other_member.ty); - bail!( - "shaders `{shader}` and `{other_shader}` declare structs with the same name \ - `{struct_ident}`, but the struct from shader `{shader}` contains a field of \ - type `{member_ty:?}` at index `{index}`, whereas the same struct from shader \ - `{other_shader}` contains a field of type `{other_member_ty:?}` in the same \ - position", - ); - } - } - - Ok(()) - } -} - -/// Translates all the structs that are contained in the SPIR-V document as Rust structs. -pub(super) fn write_structs( - input: &MacroInput, - shader: &Shader, - type_registry: &mut TypeRegistry, +/// Generates Rust structs from structs declared in SPIR-V bytecode. +pub(super) fn generate_structs( + macro_options: &MacroOptions, + spirv: Spirv, + shader_name: String, + shader_source: LitStr, + registered_structs: &mut HashMap, ) -> Result { - if !input.generate_structs { - return Ok(TokenStream::new()); - } + let mut structs_code = TokenStream::new(); - let mut structs = TokenStream::new(); + let shader = Shader { + spirv, + name: shader_name, + source: shader_source, + }; - for (struct_id, member_type_ids) in shader + let structs = shader .spirv .types() .iter() - .filter_map(|instruction| match *instruction { + .filter_map(|instruction| match instruction { Instruction::TypeStruct { result_id, - ref member_types, - } => Some((result_id, member_types)), + member_types, + } => Some((*result_id, member_types)), _ => None, }) - .filter(|&(struct_id, _)| has_defined_layout(shader, struct_id)) - { - let struct_ty = TypeStruct::new(shader, struct_id, member_type_ids)?; + .filter(|&(id, _)| struct_has_defined_layout(&shader.spirv, id)); - // Register the type if needed. - if !type_registry.register_struct(shader, &struct_ty)? { - continue; - } + for (struct_id, member_type_ids) in structs { + let ty = TypeStruct::new(&shader, struct_id, member_type_ids)?; - let custom_derives = if struct_ty.size().is_some() { - input.custom_derives.as_slice() + if let Some(registered) = registered_structs.get(&ty.ident) { + validate_members( + &ty.ident, + &ty.members, + ®istered.members, + &shader.name, + ®istered.shader_name, + )?; } else { - &[] - }; - let struct_ser = Serializer(&struct_ty, input); + let custom_derives = if ty.size().is_some() { + macro_options.custom_derives.as_slice() + } else { + &[] + }; - structs.extend(quote! { - #[allow(non_camel_case_types, non_snake_case)] - #[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )] - #[repr(C)] - #struct_ser - }) - } + let struct_ser = Serializer(&ty, macro_options); - Ok(structs) -} + structs_code.extend(quote! { + #[allow(non_camel_case_types, non_snake_case)] + #[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )] + #[repr(C)] + #struct_ser + }); -fn has_defined_layout(shader: &Shader, struct_id: Id) -> bool { - for member_info in shader.spirv.id(struct_id).members() { - let mut offset_found = false; - - for instruction in member_info.decorations() { - match instruction { - Instruction::MemberDecorate { - decoration: Decoration::BuiltIn { .. }, - .. - } => { - // Ignore the whole struct if a member is built in, which includes - // `gl_Position` for example. - return false; - } - Instruction::MemberDecorate { - decoration: Decoration::Offset { .. }, - .. - } => { - offset_found = true; - } - _ => (), - } - } - - // Some structs don't have `Offset` decorations, in that case they are used as local - // variables only. Ignoring these. - if !offset_found { - return false; + registered_structs.insert( + ty.ident, + RegisteredStruct { + members: ty.members, + shader_name: shader.name.clone(), + }, + ); } } - true + Ok(structs_code) } -#[derive(Clone, Copy, PartialEq, Eq)] +fn struct_has_defined_layout(spirv: &Spirv, id: Id) -> bool { + spirv.id(id).members().iter().all(|member_info| { + let decorations = member_info + .decorations() + .iter() + .map(|instruction| match instruction { + Instruction::MemberDecorate { decoration, .. } => decoration, + _ => unreachable!(), + }); + + let has_offset_decoration = decorations + .clone() + .any(|decoration| matches!(decoration, Decoration::Offset { .. })); + + let has_builtin_decoration = decorations + .clone() + .any(|decoration| matches!(decoration, Decoration::BuiltIn { .. })); + + has_offset_decoration && !has_builtin_decoration + }) +} + +fn validate_members( + ident: &Ident, + first_members: &[Member], + second_members: &[Member], + first_shader: &str, + second_shader: &str, +) -> Result<()> { + match first_members.len().cmp(&second_members.len()) { + Ordering::Greater => bail!( + "the declaration of struct `{ident}` in shader `{first_shader}` has more fields than \ + the declaration in shader `{second_shader}`" + ), + Ordering::Less => bail!( + "the declaration of struct `{ident}` in shader `{second_shader}` has more fields than \ + the declaration in shader `{first_shader}`" + ), + _ => {} + } + + for (index, (first_member, second_member)) in + first_members.iter().zip(second_members).enumerate() + { + let (first_type, second_type) = (&first_member.ty, &second_member.ty); + if first_type != second_type { + bail!( + "field {index} of struct `{ident}` is of type `{first_type:?}` in shader \ + `{first_shader}` but of type `{second_type:?}` in shader `{second_shader}`" + ); + } + } + + Ok(()) +} + +#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] #[repr(u8)] enum Alignment { A1 = 1, @@ -188,23 +162,11 @@ impl Alignment { 8 => Alignment::A8, 16 => Alignment::A16, 32 => Alignment::A32, - _ => unreachable!(), + _ => panic!(), } } } -impl PartialOrd for Alignment { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } -} - -impl Ord for Alignment { - fn cmp(&self, other: &Self) -> Ordering { - (*self as usize).cmp(&(*other as usize)) - } -} - fn align_up(offset: usize, alignment: Alignment) -> usize { (offset + alignment as usize - 1) & !(alignment as usize - 1) } @@ -228,7 +190,10 @@ impl Type { let id_info = shader.spirv.id(type_id); let ty = match *id_info.instruction() { - Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"), + Instruction::TypeBool { .. } => bail!( + shader.source, + "SPIR-V Boolean types don't have a defined layout" + ), Instruction::TypeInt { width, signedness, .. } => Type::Scalar(TypeScalar::Int(TypeInt::new(shader, width, signedness)?)), @@ -345,7 +310,7 @@ impl TypeInt { let signed = match signedness { 0 => false, 1 => true, - _ => bail!(shader.source, "signedness must be 0 or 1"), + _ => bail!(shader.source, "signedness must be either 0 or 1"), }; Ok(TypeInt { width, signed }) @@ -473,14 +438,17 @@ impl TypeVector { let component_count = ComponentCount::new(shader, component_count)?; let component_type = match *shader.spirv.id(component_type_id).instruction() { - Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"), + Instruction::TypeBool { .. } => bail!( + shader.source, + "SPIR-V Boolean types don't have a defined layout" + ), Instruction::TypeInt { width, signedness, .. } => TypeScalar::Int(TypeInt::new(shader, width, signedness)?), Instruction::TypeFloat { width, .. } => { TypeScalar::Float(TypeFloat::new(shader, width)?) } - _ => bail!(shader.source, "vector components must be scalars"), + _ => bail!(shader.source, "vector components must be scalar"), }; Ok(TypeVector { @@ -616,49 +584,46 @@ impl TypeArray { }) .transpose()?; - let stride = { - let mut strides = - shader - .spirv - .id(array_id) - .decorations() - .iter() - .filter_map(|instruction| match *instruction { - Instruction::Decorate { - decoration: Decoration::ArrayStride { array_stride }, - .. - } => Some(array_stride as usize), - _ => None, - }); - let stride = strides.next().ok_or_else(|| { - Error::new_spanned( - &shader.source, - "arrays inside structs must have an `ArrayStride` decoration", - ) - })?; + let mut strides = + shader + .spirv + .id(array_id) + .decorations() + .iter() + .filter_map(|instruction| match *instruction { + Instruction::Decorate { + decoration: Decoration::ArrayStride { array_stride }, + .. + } => Some(array_stride as usize), + _ => None, + }); - if !strides.all(|s| s == stride) { - bail!(shader.source, "found conflicting `ArrayStride` decorations"); - } + let stride = strides.next().ok_or_else(|| { + Error::new_spanned( + &shader.source, + "arrays inside structs must have an `ArrayStride` decoration", + ) + })?; - if !is_aligned(stride, element_type.scalar_alignment()) { - bail!( - shader.source, - "array strides must be aligned for the element type", - ); - } + if !strides.all(|s| s == stride) { + bail!(shader.source, "found conflicting `ArrayStride` decorations"); + } - let element_size = element_type.size().ok_or_else(|| { - Error::new_spanned(&shader.source, "array elements must be sized") - })?; + if !is_aligned(stride, element_type.scalar_alignment()) { + bail!( + shader.source, + "array strides must be aligned to the element type's alignment", + ); + } - if stride < element_size { - bail!(shader.source, "array elements must not overlap"); - } - - stride + let Some(element_size) = element_type.size() else { + bail!(shader.source, "array elements must be sized"); }; + if stride < element_size { + bail!(shader.source, "array elements must not overlap"); + } + Ok(TypeArray { element_type, length, @@ -690,16 +655,18 @@ impl TypeStruct { .iter() .find_map(|instruction| match instruction { Instruction::Name { name, .. } => { - // Replace chars that could potentially cause the ident to be invalid with "_". - // For example, Rust-GPU names structs by their fully qualified rust name (e.g. - // "foo::bar::MyStruct") in which the ":" is an invalid character for idents. + // Replace non-alphanumeric and non-ascii characters with '_' to ensure the name + // is a valid identifier. For example, Rust-GPU names structs by their fully + // qualified rust name (e.g. `foo::bar::MyStruct`) in which `:` makes it an + // invalid identifier. let mut name = - name.replace(|c: char| !(c.is_ascii_alphanumeric() || c == '_'), "_"); - if name.starts_with(|c: char| !c.is_ascii_alphabetic()) { + name.replace(|c: char| !c.is_ascii_alphanumeric() && c != '_', "_"); + + if name.starts_with(|c: char| c.is_ascii_digit()) { name.insert(0, '_'); } - // Worst case: invalid idents will get the UnnamedX name below + // Fall-back to `Unnamed{Id}` if it's still invalid syn::parse_str(&name).ok() } _ => None, @@ -725,9 +692,7 @@ impl TypeStruct { let mut ty = Type::new(shader, member_id)?; { - // If the member is an array, then matrix-decorations can be applied to it if the - // innermost type of the array is a matrix. Else this will stay being the type of - // the member. + // Matrix decorations can be applied to an array if its innermost type is a matrix. let mut ty = &mut ty; while let Type::Array(TypeArray { element_type, .. }) = ty { ty = element_type; @@ -744,6 +709,7 @@ impl TypeStruct { _ => None, }, ); + matrix.stride = strides.next().ok_or_else(|| { Error::new_spanned( &shader.source, @@ -766,7 +732,7 @@ impl TypeStruct { ); } - let mut majornessess = member_info.decorations().iter().filter_map( + let mut majornesses = member_info.decorations().iter().filter_map( |instruction| match *instruction { Instruction::MemberDecorate { decoration: Decoration::ColMajor, @@ -779,7 +745,8 @@ impl TypeStruct { _ => None, }, ); - matrix.majorness = majornessess.next().ok_or_else(|| { + + matrix.majorness = majornesses.next().ok_or_else(|| { Error::new_spanned( &shader.source, "matrices inside structs must have a `ColMajor` or `RowMajor` \ @@ -787,7 +754,7 @@ impl TypeStruct { ) })?; - if !majornessess.all(|m| m == matrix.majorness) { + if !majornesses.all(|m| m == matrix.majorness) { bail!( shader.source, "found conflicting matrix majorness decorations", @@ -822,26 +789,27 @@ impl TypeStruct { if !is_aligned(offset, ty.scalar_alignment()) { bail!( shader.source, - "struct member offsets must be aligned for the member type", + "struct member offsets must be aligned to their type's alignment", ); } - if let Some(last) = members.last() { - if !is_aligned(offset, last.ty.scalar_alignment()) { + if let Some(previous_member) = members.last() { + if !is_aligned(offset, previous_member.ty.scalar_alignment()) { bail!( shader.source, - "expected struct member offset to be aligned for the preceding member type", + "expected struct member offset to be aligned to the preceding member \ + type's alignment", ); } - let last_size = last.ty.size().ok_or_else(|| { + let last_size = previous_member.ty.size().ok_or_else(|| { Error::new_spanned( &shader.source, "all members except the last member of a struct must be sized", ) })?; - if last.offset + last_size > offset { + if previous_member.offset + last_size > offset { bail!(shader.source, "struct members must not overlap"); } } @@ -888,8 +856,8 @@ impl PartialEq for Member { impl Eq for Member {} -/// Helper for serializing a type to tokens with respect to macro input. -struct Serializer<'a, T>(&'a T, &'a MacroInput); +/// Helper for serializing a type as tokens according to the macro options. +struct Serializer<'a, T>(&'a T, &'a MacroOptions); impl ToTokens for Serializer<'_, Type> { fn to_tokens(&self, tokens: &mut TokenStream) { @@ -909,18 +877,16 @@ impl ToTokens for Serializer<'_, TypeVector> { let component_type = &self.0.component_type; let component_count = self.0.component_count as usize; - match self.1.linalg_type { - LinAlgType::Std => { + match self.1.linalg_types { + LinAlgTypes::Std => { tokens.extend(quote! { [#component_type; #component_count] }); } - LinAlgType::CgMath => { + LinAlgTypes::Cgmath => { let vector = format_ident!("Vector{}", component_count); tokens.extend(quote! { ::cgmath::#vector<#component_type> }); } - LinAlgType::Nalgebra => { - tokens.extend(quote! { - ::nalgebra::base::SVector<#component_type, #component_count> - }); + LinAlgTypes::Nalgebra => { + tokens.extend(quote! { ::nalgebra::SVector<#component_type, #component_count> }); } } } @@ -936,22 +902,22 @@ impl ToTokens for Serializer<'_, TypeMatrix> { // This can't overflow because the stride must be at least the vector size. let padding = self.0.stride - self.0.vector_size(); - match self.1.linalg_type { - // cgmath only has column-major matrices. It also only has square matrices, and its 3x3 - // matrix is not padded right. Fall back to std for anything else. - LinAlgType::CgMath + match self.1.linalg_types { + // cgmath only supports column-major square matrices, and its 3x3 matrix is not padded + // correctly. + LinAlgTypes::Cgmath if majorness == MatrixMajorness::ColumnMajor - && padding == 0 - && vector_count == component_count => + && vector_count == component_count + && padding == 0 => { let matrix = format_ident!("Matrix{}", component_count); tokens.extend(quote! { ::cgmath::#matrix<#component_type> }); } - // nalgebra only has column-major matrices, and its 3xN matrices are not padded right. - // Fall back to std for anything else. - LinAlgType::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => { + // nalgebra only supports column-major matrices, and its 3xN matrices are not padded + // correctly. + LinAlgTypes::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => { tokens.extend(quote! { - ::nalgebra::base::SMatrix<#component_type, #component_count, #vector_count> + ::nalgebra::SMatrix<#component_type, #component_count, #vector_count> }); } _ => { @@ -964,7 +930,7 @@ impl ToTokens for Serializer<'_, TypeMatrix> { impl ToTokens for Serializer<'_, TypeArray> { fn to_tokens(&self, tokens: &mut TokenStream) { - let element_type = &*self.0.element_type; + let element_type = self.0.element_type.as_ref(); // This can't panic because array elements must be sized. let element_size = element_type.size().unwrap(); // This can't overflow because the stride must be at least the element size. @@ -1047,7 +1013,8 @@ impl ToTokens for Serializer<'_, TypeStruct> { } } -/// Helper for wrapping tokens in `Padded`. Doesn't wrap if the padding is `0`. +/// Helper for wrapping tokens in [Padded][struct@vulkano::padded::Padded]. +/// Doesn't wrap if the padding is `0`. struct Padded(T, usize); impl ToTokens for Padded {