Refactor vulkano-shaders

This commit is contained in:
Mai 2024-09-25 15:02:24 +03:00
parent 075e9812d8
commit 2752e20d79
6 changed files with 1102 additions and 1478 deletions

1
Cargo.lock generated
View File

@ -2465,7 +2465,6 @@ name = "vulkano-shaders"
version = "0.34.0" version = "0.34.0"
dependencies = [ dependencies = [
"ahash", "ahash",
"heck",
"proc-macro2", "proc-macro2",
"quote", "quote",
"shaderc", "shaderc",

View File

@ -18,16 +18,14 @@ proc-macro = true
[dependencies] [dependencies]
ahash = { workspace = true } ahash = { workspace = true }
heck = { workspace = true }
proc-macro2 = { workspace = true } proc-macro2 = { workspace = true }
quote = { workspace = true } quote = { workspace = true }
shaderc = { workspace = true } shaderc = { workspace = true }
syn = { workspace = true, features = ["full", "extra-traits"] } syn = { workspace = true }
vulkano = { workspace = true } vulkano = { workspace = true }
[features] [features]
shaderc-build-from-source = ["shaderc/build-from-source"] shaderc-build-from-source = ["shaderc/build-from-source"]
shaderc-debug = []
[lints] [lints]
workspace = true workspace = true

View File

@ -1,135 +1,138 @@
use crate::{ use crate::MacroOptions;
structs::{self, TypeRegistry}, use ahash::{HashMap, HashSet};
MacroInput,
};
use heck::ToSnakeCase;
use proc_macro2::TokenStream; use proc_macro2::TokenStream;
use quote::{format_ident, quote}; use quote::{format_ident, quote};
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind}; use shaderc::{
use shaderc::{CompileOptions, Compiler, EnvVersion, SourceLanguage, TargetEnv}; CompilationArtifact, CompileOptions, Compiler, IncludeType, ResolvedInclude, ShaderKind,
SourceLanguage, TargetEnv,
};
use std::{ use std::{
cell::RefCell, cell::RefCell,
fs, fs,
iter::Iterator, iter::Iterator,
path::{Path, PathBuf}, path::{Path, PathBuf},
}; };
use syn::{Error, LitStr};
use vulkano::shader::spirv::Spirv;
pub struct Shader { pub(super) fn compile(
pub source: LitStr, compiler: &Compiler,
pub name: String, macro_options: &MacroOptions,
pub spirv: Spirv, source_path: Option<&str>,
shader_name: &str,
source_code: &str,
entry_points: HashMap<Option<String>, ShaderKind>,
base_path: &Path,
include_paths: &[PathBuf],
sources_to_include: &RefCell<HashSet<String>>,
) -> Result<Vec<(Option<String>, CompilationArtifact)>, String> {
let mut compile_options = CompileOptions::new().ok_or("failed to create compile options")?;
compile_options.set_source_language(macro_options.source_language);
if let Some(vulkan_version) = macro_options.vulkan_version {
compile_options.set_target_env(TargetEnv::Vulkan, vulkan_version as u32);
}
if let Some(spirv_version) = macro_options.spirv_version {
compile_options.set_target_spirv(spirv_version);
}
compile_options.set_include_callback(
|requested_source, include_type, containing_source, depth| {
include_callback(
Path::new(requested_source),
include_type,
Path::new(containing_source),
depth,
base_path,
include_paths,
source_path.is_none(),
&mut sources_to_include.borrow_mut(),
)
},
);
for (name, value) in &macro_options.macro_defines {
compile_options.add_macro_definition(name, value.as_deref());
}
let file_name = match (source_path, macro_options.source_language) {
(Some(source_path), _) => source_path,
(None, SourceLanguage::GLSL) => &format!("{shader_name}.glsl"),
(None, SourceLanguage::HLSL) => &format!("{shader_name}.hlsl"),
};
entry_points
.into_iter()
.map(|(entry_point, shader_kind)| {
compiler
.compile_into_spirv(
source_code,
shader_kind,
file_name,
entry_point.as_deref().unwrap_or("main"),
Some(&compile_options),
)
.map(|artifact| (entry_point, artifact))
.map_err(|err| err.to_string().replace("\n", " "))
})
.collect()
} }
#[allow(clippy::too_many_arguments)]
fn include_callback( fn include_callback(
requested_source_path_raw: &str, requested_source: &Path,
directive_type: IncludeType, include_type: IncludeType,
contained_within_path_raw: &str, containing_source: &Path,
recursion_depth: usize, depth: usize,
include_directories: &[PathBuf],
root_source_has_path: bool,
base_path: &Path, base_path: &Path,
includes: &mut Vec<String>, include_paths: &[PathBuf],
embedded_root_source: bool,
sources_to_include: &mut HashSet<String>,
) -> Result<ResolvedInclude, String> { ) -> Result<ResolvedInclude, String> {
let file_to_include = match directive_type { let resolved_path = match include_type {
IncludeType::Relative => { IncludeType::Relative => {
let requested_source_path = Path::new(requested_source_path_raw); if depth == 1 && embedded_root_source && !requested_source.is_absolute() {
// If the shader source is embedded within the macro, abort unless we get an absolute return Err(
// path. "you cannot use relative include directives in embedded shader source code; \
if !root_source_has_path && recursion_depth == 1 && !requested_source_path.is_absolute() try using `#include <...>` instead"
{ .to_owned(),
let requested_source_name = requested_source_path );
.file_name()
.expect("failed to get the name of the requested source file")
.to_string_lossy();
let requested_source_directory = requested_source_path
.parent()
.expect("failed to get the directory of the requested source file")
.to_string_lossy();
return Err(format!(
"usage of relative paths in imports in embedded GLSL is not allowed, try \
using `#include <{}>` and adding the directory `{}` to the `include` array in \
your `shader!` macro call instead",
requested_source_name, requested_source_directory,
));
} }
let mut resolved_path = if recursion_depth == 1 { let parent = containing_source.parent().unwrap();
Path::new(contained_within_path_raw)
.parent() if depth == 1 {
.map(|parent| base_path.join(parent)) [base_path, parent, requested_source].iter().collect()
} else { } else {
Path::new(contained_within_path_raw) [parent, requested_source].iter().collect()
.parent()
.map(|parent| parent.to_owned())
} }
.unwrap_or_else(|| {
panic!(
"the file `{}` does not reside in a directory, this is an implementation \
error",
contained_within_path_raw,
)
});
resolved_path.push(requested_source_path);
if !resolved_path.is_file() {
return Err(format!(
"invalid inclusion path `{}`, the path does not point to a file",
requested_source_path_raw,
));
}
resolved_path
} }
IncludeType::Standard => { IncludeType::Standard => {
let requested_source_path = Path::new(requested_source_path_raw); if requested_source.is_absolute() {
// This is returned when attempting to include a missing file by an absolute path
if requested_source_path.is_absolute() { // in a relative include directive and when using an absolute path in a standard
// This message is printed either when using a missing file with an absolute path
// in the relative include directive or when using absolute paths in a standard
// include directive. // include directive.
return Err(format!( return Err(
"no such file found as specified by the absolute path; keep in mind that \ "the specified file was not found; if you're using an absolute path in a \
absolute paths cannot be used with inclusion from standard directories \ standard include directive (`#include <...>`), try using `#include \"...\"` \
(`#include <...>`), try using `#include \"...\"` instead; requested path: {}", instead"
requested_source_path_raw, .to_owned(),
)); );
} }
let found_requested_source_path = include_directories include_paths
.iter() .iter()
.map(|include_directory| include_directory.join(requested_source_path)) .map(|include_path| include_path.join(requested_source))
.find(|resolved_requested_source_path| resolved_requested_source_path.is_file()); .find(|source_path| source_path.is_file())
.ok_or("the specified file was not found".to_owned())?
if let Some(found_requested_source_path) = found_requested_source_path {
found_requested_source_path
} else {
return Err(format!(
"failed to include the file `{}` from any include directories",
requested_source_path_raw,
));
}
} }
}; };
let content = fs::read_to_string(file_to_include.as_path()).map_err(|err| { let resolved_name = resolved_path.into_os_string().into_string().unwrap();
format!(
"failed to read the contents of file `{file_to_include:?}` to be included in the \
shader source: {err}",
)
})?;
let resolved_name = file_to_include
.into_os_string()
.into_string()
.map_err(|_| {
"failed to stringify the file to be included; make sure the path consists of valid \
unicode characters"
})?;
includes.push(resolved_name.clone()); let content = fs::read_to_string(&resolved_name)
.map_err(|err| format!("failed to read `{resolved_name}`: {err}"))?;
sources_to_include.insert(resolved_name.clone());
Ok(ResolvedInclude { Ok(ResolvedInclude {
resolved_name, resolved_name,
@ -137,109 +140,17 @@ fn include_callback(
}) })
} }
pub(super) fn compile( pub(super) fn generate_shader_code(
input: &MacroInput, entry_points: &[(Option<&str>, &[u32])],
path: Option<String>, shader_name: &Option<String>,
base_path: &Path, ) -> TokenStream {
code: &str, let load_fns = entry_points.iter().map(|(name, words)| {
shader_kind: ShaderKind, let load_name = match name {
) -> Result<(CompilationArtifact, Vec<String>), String> { Some(name) => format_ident!("load_{name}"),
let includes = RefCell::new(Vec::new()); None => format_ident!("load"),
let compiler = Compiler::new().ok_or("failed to create shader compiler")?;
let mut compile_options =
CompileOptions::new().ok_or("failed to initialize compile options")?;
let source_language = input.source_language.unwrap_or(SourceLanguage::GLSL);
compile_options.set_source_language(source_language);
compile_options.set_target_env(
TargetEnv::Vulkan,
input.vulkan_version.unwrap_or(EnvVersion::Vulkan1_0) as u32,
);
if let Some(spirv_version) = input.spirv_version {
compile_options.set_target_spirv(spirv_version);
}
let root_source_path = path.as_deref().unwrap_or(
// An arbitrary placeholder file name for embedded shaders.
match source_language {
SourceLanguage::GLSL => "shader.glsl",
SourceLanguage::HLSL => "shader.hlsl",
},
);
// Specify the file resolution callback for the `#include` directive.
compile_options.set_include_callback(
|requested_source_path, directive_type, contained_within_path, recursion_depth| {
include_callback(
requested_source_path,
directive_type,
contained_within_path,
recursion_depth,
&input.include_directories,
path.is_some(),
base_path,
&mut includes.borrow_mut(),
)
},
);
for (macro_name, macro_value) in &input.macro_defines {
compile_options.add_macro_definition(macro_name, Some(macro_value));
}
#[cfg(feature = "shaderc-debug")]
compile_options.set_generate_debug_info();
let content = compiler
.compile_into_spirv(
code,
shader_kind,
root_source_path,
"main",
Some(&compile_options),
)
.map_err(|e| e.to_string().replace("(s): ", "(s):\n"))?;
drop(compile_options);
Ok((content, includes.into_inner()))
}
pub(super) fn reflect(
input: &MacroInput,
source: LitStr,
name: String,
words: &[u32],
input_paths: Vec<String>,
type_registry: &mut TypeRegistry,
) -> Result<(TokenStream, TokenStream), Error> {
let spirv = Spirv::new(words).map_err(|err| {
Error::new_spanned(&source, format!("failed to parse SPIR-V words: {err}"))
})?;
let shader = Shader {
source,
name,
spirv,
}; };
let include_bytes = input_paths.into_iter().map(|s| {
quote! { quote! {
// Using `include_bytes` here ensures that changing the shader will force recompilation.
// The bytes themselves can be optimized out by the compiler as they are unused.
::std::include_bytes!( #s )
}
});
let load_name = if shader.name.is_empty() {
format_ident!("load")
} else {
format_ident!("load_{}", shader.name.to_snake_case())
};
let shader_code = quote! {
/// Loads the shader as a `ShaderModule`.
#[allow(unsafe_code)] #[allow(unsafe_code)]
#[inline] #[inline]
pub fn #load_name( pub fn #load_name(
@ -248,8 +159,6 @@ pub(super) fn reflect(
::std::sync::Arc<::vulkano::shader::ShaderModule>, ::std::sync::Arc<::vulkano::shader::ShaderModule>,
::vulkano::Validated<::vulkano::VulkanError>, ::vulkano::Validated<::vulkano::VulkanError>,
> { > {
let _bytes = ( #( #include_bytes ),* );
static WORDS: &[u32] = &[ #( #words ),* ]; static WORDS: &[u32] = &[ #( #words ),* ];
unsafe { unsafe {
@ -259,496 +168,20 @@ pub(super) fn reflect(
) )
} }
} }
};
let structs = structs::write_structs(input, &shader, type_registry)?;
Ok((shader_code, structs))
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use quote::ToTokens;
use shaderc::SpirvVersion;
use syn::{File, Item};
use vulkano::shader::reflect;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
} }
});
fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec<String> { if let Some(shader_name) = shader_name {
paths let shader_name = format_ident!("{shader_name}");
.iter()
.map(|p| root_path.join(p).into_os_string().into_string().unwrap()) quote! {
.collect() pub mod #shader_name {
#( #load_fns )*
} }
#[test]
fn spirv_parse() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
Spirv::new(&insts).unwrap();
} }
#[test]
fn spirv_reflect() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/frag.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
assert_eq!(_structs.to_string(), "", "No structs should be generated");
}
#[test]
fn include_resolution() {
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let (_compile_relative, _) = compile(
&MacroInput::empty(),
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include "include_dir_a/target_a.glsl"
#include "include_dir_b/target_b.glsl"
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
let (_compile_include_paths, includes) = compile(
&MacroInput {
include_directories: vec![
root_path.join("tests").join("include_dir_a"),
root_path.join("tests").join("include_dir_b"),
],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_a.glsl>
#include <target_b.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes,
convert_paths(
&root_path,
&[
["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_b", "target_b.glsl"]
.into_iter()
.collect(),
],
),
);
let (_compile_include_paths_with_relative, includes_with_relative) = compile(
&MacroInput {
include_directories: vec![root_path.join("tests").join("include_dir_a")],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_a.glsl>
#include <../include_dir_b/target_b.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_with_relative,
convert_paths(
&root_path,
&[
["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_a", "../include_dir_b/target_b.glsl"]
.into_iter()
.collect(),
],
),
);
let absolute_path = root_path
.join("tests")
.join("include_dir_a")
.join("target_a.glsl");
let absolute_path_str = absolute_path
.to_str()
.expect("cannot run tests in a folder with non unicode characters");
let (_compile_absolute_path, includes_absolute_path) = compile(
&MacroInput::empty(),
Some(String::from("tests/include_test.glsl")),
&root_path,
&format!(
r#"
#version 450
#include "{absolute_path_str}"
void main() {{}}
"#,
),
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_absolute_path,
convert_paths(
&root_path,
&[["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect()],
),
);
let (_compile_recursive_, includes_recursive) = compile(
&MacroInput {
include_directories: vec![
root_path.join("tests").join("include_dir_b"),
root_path.join("tests").join("include_dir_c"),
],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_c.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_recursive,
convert_paths(
&root_path,
&[
["tests", "include_dir_c", "target_c.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_c", "../include_dir_a/target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_b", "target_b.glsl"]
.into_iter()
.collect(),
],
),
);
}
#[test]
fn macros() {
let need_defines = r#"
#version 450
#if defined(NAME1) && NAME2 > 29
void main() {}
#endif
"#;
let compile_no_defines = compile(
&MacroInput::empty(),
None,
Path::new(""),
need_defines,
ShaderKind::Vertex,
);
assert!(compile_no_defines.is_err());
compile(
&MacroInput {
macro_defines: vec![("NAME1".into(), "".into()), ("NAME2".into(), "58".into())],
..MacroInput::empty()
},
None,
Path::new(""),
need_defines,
ShaderKind::Vertex,
)
.expect("setting shader macros did not work");
}
/// `entrypoint1.frag.glsl`:
/// ```glsl
/// #version 450
///
/// layout(set = 0, binding = 0) uniform Uniform {
/// uint data;
/// } ubo;
///
/// layout(set = 0, binding = 1) buffer Buffer {
/// uint data;
/// } bo;
///
/// layout(set = 0, binding = 2) uniform sampler textureSampler;
/// layout(set = 0, binding = 3) uniform texture2D imageTexture;
///
/// layout(push_constant) uniform PushConstant {
/// uint data;
/// } push;
///
/// layout(input_attachment_index = 0, set = 0, binding = 4) uniform subpassInput inputAttachment;
///
/// layout(location = 0) out vec4 outColor;
///
/// void entrypoint1() {
/// bo.data = 12;
/// outColor = vec4(
/// float(ubo.data),
/// float(push.data),
/// texture(sampler2D(imageTexture, textureSampler), vec2(0.0, 0.0)).x,
/// subpassLoad(inputAttachment).x
/// );
/// }
/// ```
///
/// `entrypoint2.frag.glsl`:
/// ```glsl
/// #version 450
///
/// layout(input_attachment_index = 0, set = 0, binding = 0) uniform subpassInput inputAttachment2;
///
/// layout(set = 0, binding = 1) buffer Buffer {
/// uint data;
/// } bo2;
///
/// layout(set = 0, binding = 2) uniform Uniform {
/// uint data;
/// } ubo2;
///
/// layout(push_constant) uniform PushConstant {
/// uint data;
/// } push2;
///
/// void entrypoint2() {
/// bo2.data = ubo2.data + push2.data + int(subpassLoad(inputAttachment2).y);
/// }
/// ```
///
/// Compiled and linked with:
/// ```sh
/// glslangvalidator -e entrypoint1 --source-entrypoint entrypoint1 -V100 entrypoint1.frag.glsl -o entrypoint1.spv
/// glslangvalidator -e entrypoint2 --source-entrypoint entrypoint2 -V100 entrypoint2.frag.glsl -o entrypoint2.spv
/// spirv-link entrypoint1.spv entrypoint2.spv -o multiple_entrypoints.spv
/// ```
#[test]
fn descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let spirv = Spirv::new(&insts).unwrap();
let mut descriptors = Vec::new();
for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements);
}
// Check first entrypoint
let e1_descriptors = &descriptors[0];
let mut e1_bindings = Vec::new();
for loc in e1_descriptors.keys() {
e1_bindings.push(*loc);
}
assert_eq!(e1_bindings.len(), 5);
assert!(e1_bindings.contains(&(0, 0)));
assert!(e1_bindings.contains(&(0, 1)));
assert!(e1_bindings.contains(&(0, 2)));
assert!(e1_bindings.contains(&(0, 3)));
assert!(e1_bindings.contains(&(0, 4)));
// Check second entrypoint
let e2_descriptors = &descriptors[1];
let mut e2_bindings = Vec::new();
for loc in e2_descriptors.keys() {
e2_bindings.push(*loc);
}
assert_eq!(e2_bindings.len(), 3);
assert!(e2_bindings.contains(&(0, 0)));
assert!(e2_bindings.contains(&(0, 1)));
assert!(e2_bindings.contains(&(0, 2)));
}
#[test]
fn reflect_descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else { } else {
None quote! {
#( #load_fns )*
} }
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
}
fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec<String>)
{
compile(
&MacroInput {
spirv_version: Some(SpirvVersion::V1_6),
vulkan_version: Some(EnvVersion::Vulkan1_3),
..MacroInput::empty()
},
None,
Path::new(""),
r#"
#version 460
layout(set = 1, binding = 0) buffer Buffer {
vec3 data;
} bo;
layout(set = 2, binding = 0) uniform Uniform {
float data;
} ubo;
layout(set = 3, binding = 1) uniform sampler textureSampler;
layout(set = 3, binding = 2) uniform texture2D imageTexture;
float withMagicSparkles(float data) {
return texture(sampler2D(imageTexture, textureSampler), vec2(data, data)).x;
}
vec3 makeSecretSauce() {
return vec3(withMagicSparkles(ubo.data));
}
void main() {
bo.data = makeSecretSauce();
}
"#,
ShaderKind::Vertex,
)
.unwrap()
}
#[test]
fn descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let spirv = Spirv::new(artifact.as_binary()).unwrap();
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc);
}
assert_eq!(bindings.len(), 4);
assert!(bindings.contains(&(1, 0)));
assert!(bindings.contains(&(2, 0)));
assert!(bindings.contains(&(3, 1)));
assert!(bindings.contains(&(3, 2)));
return;
}
panic!("could not find entrypoint");
}
#[test]
fn reflect_descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new(
"descriptor_calculation_with_multiple_functions_shader",
Span::call_site(),
),
String::new(),
artifact.as_binary(),
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: [f32; 3usize],}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: f32,}).to_string()
);
} }
} }

File diff suppressed because it is too large Load Diff

View File

@ -1,44 +0,0 @@
#[cfg(test)]
mod tests {
use crate::{codegen::reflect, structs::TypeRegistry, MacroInput};
use proc_macro2::Span;
use syn::LitStr;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
#[test]
fn rust_gpu_reflect_vertex() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-vertex.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
#[test]
fn rust_gpu_reflect_fragment() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-fragment.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
}

View File

@ -1,174 +1,148 @@
use crate::{bail, codegen::Shader, LinAlgType, MacroInput}; use crate::{bail, LinAlgTypes, MacroOptions};
use ahash::HashMap; use ahash::HashMap;
use proc_macro2::{Span, TokenStream}; use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens, TokenStreamExt}; use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use std::{cmp::Ordering, num::NonZeroUsize}; use std::{cmp::Ordering, num::NonZeroUsize};
use syn::{Error, Ident, Result}; use syn::{Error, Ident, LitStr, Result};
use vulkano::shader::spirv::{Decoration, Id, Instruction}; use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv};
#[derive(Default)] struct Shader {
pub struct TypeRegistry { spirv: Spirv,
registered_structs: HashMap<Ident, RegisteredType>, name: String,
source: LitStr,
} }
impl TypeRegistry { pub(super) struct RegisteredStruct {
fn register_struct(&mut self, shader: &Shader, ty: &TypeStruct) -> Result<bool> { members: Vec<Member>,
// Checking with registry if this struct is already registered by another shader, and if shader_name: String,
// their signatures match. }
if let Some(registered) = self.registered_structs.get(&ty.ident) {
registered.validate_signatures(&shader.name, ty)?;
// If the struct is already registered and matches this one, skip the duplicate. /// Generates Rust structs from structs declared in SPIR-V bytecode.
Ok(false) pub(super) fn generate_structs(
macro_options: &MacroOptions,
spirv: Spirv,
shader_name: String,
shader_source: LitStr,
registered_structs: &mut HashMap<Ident, RegisteredStruct>,
) -> Result<TokenStream> {
let mut structs_code = TokenStream::new();
let shader = Shader {
spirv,
name: shader_name,
source: shader_source,
};
let structs = shader
.spirv
.types()
.iter()
.filter_map(|instruction| match instruction {
Instruction::TypeStruct {
result_id,
member_types,
} => Some((*result_id, member_types)),
_ => None,
})
.filter(|&(id, _)| struct_has_defined_layout(&shader.spirv, id));
for (struct_id, member_type_ids) in structs {
let ty = TypeStruct::new(&shader, struct_id, member_type_ids)?;
if let Some(registered) = registered_structs.get(&ty.ident) {
validate_members(
&ty.ident,
&ty.members,
&registered.members,
&shader.name,
&registered.shader_name,
)?;
} else { } else {
self.registered_structs.insert( let custom_derives = if ty.size().is_some() {
ty.ident.clone(), macro_options.custom_derives.as_slice()
RegisteredType { } else {
shader: shader.name.clone(), &[]
ty: ty.clone(), };
let struct_ser = Serializer(&ty, macro_options);
structs_code.extend(quote! {
#[allow(non_camel_case_types, non_snake_case)]
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
#[repr(C)]
#struct_ser
});
registered_structs.insert(
ty.ident,
RegisteredStruct {
members: ty.members,
shader_name: shader.name.clone(),
}, },
); );
}
}
Ok(true) Ok(structs_code)
}
}
} }
struct RegisteredType { fn struct_has_defined_layout(spirv: &Spirv, id: Id) -> bool {
shader: String, spirv.id(id).members().iter().all(|member_info| {
ty: TypeStruct, let decorations = member_info
} .decorations()
impl RegisteredType {
fn validate_signatures(&self, other_shader: &str, other_ty: &TypeStruct) -> Result<()> {
let (shader, struct_ident) = (&self.shader, &self.ty.ident);
if self.ty.members.len() > other_ty.members.len() {
let member_ident = &self.ty.members[other_ty.members.len()].ident;
bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
`{struct_ident}`, but the struct from shader `{shader}` contains an extra field \
`{member_ident}`",
);
}
if self.ty.members.len() < other_ty.members.len() {
let member_ident = &other_ty.members[self.ty.members.len()].ident;
bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
`{struct_ident}`, but the struct from shader `{other_shader}` contains an extra \
field `{member_ident}`",
);
}
for (index, (member, other_member)) in self
.ty
.members
.iter() .iter()
.zip(other_ty.members.iter()) .map(|instruction| match instruction {
.enumerate() Instruction::MemberDecorate { decoration, .. } => decoration,
_ => unreachable!(),
});
let has_offset_decoration = decorations
.clone()
.any(|decoration| matches!(decoration, Decoration::Offset { .. }));
let has_builtin_decoration = decorations
.clone()
.any(|decoration| matches!(decoration, Decoration::BuiltIn { .. }));
has_offset_decoration && !has_builtin_decoration
})
}
fn validate_members(
ident: &Ident,
first_members: &[Member],
second_members: &[Member],
first_shader: &str,
second_shader: &str,
) -> Result<()> {
match first_members.len().cmp(&second_members.len()) {
Ordering::Greater => bail!(
"the declaration of struct `{ident}` in shader `{first_shader}` has more fields than \
the declaration in shader `{second_shader}`"
),
Ordering::Less => bail!(
"the declaration of struct `{ident}` in shader `{second_shader}` has more fields than \
the declaration in shader `{first_shader}`"
),
_ => {}
}
for (index, (first_member, second_member)) in
first_members.iter().zip(second_members).enumerate()
{ {
if member.ty != other_member.ty { let (first_type, second_type) = (&first_member.ty, &second_member.ty);
let (member_ty, other_member_ty) = (&member.ty, &other_member.ty); if first_type != second_type {
bail!( bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \ "field {index} of struct `{ident}` is of type `{first_type:?}` in shader \
`{struct_ident}`, but the struct from shader `{shader}` contains a field of \ `{first_shader}` but of type `{second_type:?}` in shader `{second_shader}`"
type `{member_ty:?}` at index `{index}`, whereas the same struct from shader \
`{other_shader}` contains a field of type `{other_member_ty:?}` in the same \
position",
); );
} }
} }
Ok(()) Ok(())
}
} }
/// Translates all the structs that are contained in the SPIR-V document as Rust structs. #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub(super) fn write_structs(
input: &MacroInput,
shader: &Shader,
type_registry: &mut TypeRegistry,
) -> Result<TokenStream> {
if !input.generate_structs {
return Ok(TokenStream::new());
}
let mut structs = TokenStream::new();
for (struct_id, member_type_ids) in shader
.spirv
.types()
.iter()
.filter_map(|instruction| match *instruction {
Instruction::TypeStruct {
result_id,
ref member_types,
} => Some((result_id, member_types)),
_ => None,
})
.filter(|&(struct_id, _)| has_defined_layout(shader, struct_id))
{
let struct_ty = TypeStruct::new(shader, struct_id, member_type_ids)?;
// Register the type if needed.
if !type_registry.register_struct(shader, &struct_ty)? {
continue;
}
let custom_derives = if struct_ty.size().is_some() {
input.custom_derives.as_slice()
} else {
&[]
};
let struct_ser = Serializer(&struct_ty, input);
structs.extend(quote! {
#[allow(non_camel_case_types, non_snake_case)]
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
#[repr(C)]
#struct_ser
})
}
Ok(structs)
}
fn has_defined_layout(shader: &Shader, struct_id: Id) -> bool {
for member_info in shader.spirv.id(struct_id).members() {
let mut offset_found = false;
for instruction in member_info.decorations() {
match instruction {
Instruction::MemberDecorate {
decoration: Decoration::BuiltIn { .. },
..
} => {
// Ignore the whole struct if a member is built in, which includes
// `gl_Position` for example.
return false;
}
Instruction::MemberDecorate {
decoration: Decoration::Offset { .. },
..
} => {
offset_found = true;
}
_ => (),
}
}
// Some structs don't have `Offset` decorations, in that case they are used as local
// variables only. Ignoring these.
if !offset_found {
return false;
}
}
true
}
#[derive(Clone, Copy, PartialEq, Eq)]
#[repr(u8)] #[repr(u8)]
enum Alignment { enum Alignment {
A1 = 1, A1 = 1,
@ -188,23 +162,11 @@ impl Alignment {
8 => Alignment::A8, 8 => Alignment::A8,
16 => Alignment::A16, 16 => Alignment::A16,
32 => Alignment::A32, 32 => Alignment::A32,
_ => unreachable!(), _ => panic!(),
} }
} }
} }
impl PartialOrd for Alignment {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Alignment {
fn cmp(&self, other: &Self) -> Ordering {
(*self as usize).cmp(&(*other as usize))
}
}
fn align_up(offset: usize, alignment: Alignment) -> usize { fn align_up(offset: usize, alignment: Alignment) -> usize {
(offset + alignment as usize - 1) & !(alignment as usize - 1) (offset + alignment as usize - 1) & !(alignment as usize - 1)
} }
@ -228,7 +190,10 @@ impl Type {
let id_info = shader.spirv.id(type_id); let id_info = shader.spirv.id(type_id);
let ty = match *id_info.instruction() { let ty = match *id_info.instruction() {
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"), Instruction::TypeBool { .. } => bail!(
shader.source,
"SPIR-V Boolean types don't have a defined layout"
),
Instruction::TypeInt { Instruction::TypeInt {
width, signedness, .. width, signedness, ..
} => Type::Scalar(TypeScalar::Int(TypeInt::new(shader, width, signedness)?)), } => Type::Scalar(TypeScalar::Int(TypeInt::new(shader, width, signedness)?)),
@ -345,7 +310,7 @@ impl TypeInt {
let signed = match signedness { let signed = match signedness {
0 => false, 0 => false,
1 => true, 1 => true,
_ => bail!(shader.source, "signedness must be 0 or 1"), _ => bail!(shader.source, "signedness must be either 0 or 1"),
}; };
Ok(TypeInt { width, signed }) Ok(TypeInt { width, signed })
@ -473,14 +438,17 @@ impl TypeVector {
let component_count = ComponentCount::new(shader, component_count)?; let component_count = ComponentCount::new(shader, component_count)?;
let component_type = match *shader.spirv.id(component_type_id).instruction() { let component_type = match *shader.spirv.id(component_type_id).instruction() {
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"), Instruction::TypeBool { .. } => bail!(
shader.source,
"SPIR-V Boolean types don't have a defined layout"
),
Instruction::TypeInt { Instruction::TypeInt {
width, signedness, .. width, signedness, ..
} => TypeScalar::Int(TypeInt::new(shader, width, signedness)?), } => TypeScalar::Int(TypeInt::new(shader, width, signedness)?),
Instruction::TypeFloat { width, .. } => { Instruction::TypeFloat { width, .. } => {
TypeScalar::Float(TypeFloat::new(shader, width)?) TypeScalar::Float(TypeFloat::new(shader, width)?)
} }
_ => bail!(shader.source, "vector components must be scalars"), _ => bail!(shader.source, "vector components must be scalar"),
}; };
Ok(TypeVector { Ok(TypeVector {
@ -616,7 +584,6 @@ impl TypeArray {
}) })
.transpose()?; .transpose()?;
let stride = {
let mut strides = let mut strides =
shader shader
.spirv .spirv
@ -630,6 +597,7 @@ impl TypeArray {
} => Some(array_stride as usize), } => Some(array_stride as usize),
_ => None, _ => None,
}); });
let stride = strides.next().ok_or_else(|| { let stride = strides.next().ok_or_else(|| {
Error::new_spanned( Error::new_spanned(
&shader.source, &shader.source,
@ -644,21 +612,18 @@ impl TypeArray {
if !is_aligned(stride, element_type.scalar_alignment()) { if !is_aligned(stride, element_type.scalar_alignment()) {
bail!( bail!(
shader.source, shader.source,
"array strides must be aligned for the element type", "array strides must be aligned to the element type's alignment",
); );
} }
let element_size = element_type.size().ok_or_else(|| { let Some(element_size) = element_type.size() else {
Error::new_spanned(&shader.source, "array elements must be sized") bail!(shader.source, "array elements must be sized");
})?; };
if stride < element_size { if stride < element_size {
bail!(shader.source, "array elements must not overlap"); bail!(shader.source, "array elements must not overlap");
} }
stride
};
Ok(TypeArray { Ok(TypeArray {
element_type, element_type,
length, length,
@ -690,16 +655,18 @@ impl TypeStruct {
.iter() .iter()
.find_map(|instruction| match instruction { .find_map(|instruction| match instruction {
Instruction::Name { name, .. } => { Instruction::Name { name, .. } => {
// Replace chars that could potentially cause the ident to be invalid with "_". // Replace non-alphanumeric and non-ascii characters with '_' to ensure the name
// For example, Rust-GPU names structs by their fully qualified rust name (e.g. // is a valid identifier. For example, Rust-GPU names structs by their fully
// "foo::bar::MyStruct") in which the ":" is an invalid character for idents. // qualified rust name (e.g. `foo::bar::MyStruct`) in which `:` makes it an
// invalid identifier.
let mut name = let mut name =
name.replace(|c: char| !(c.is_ascii_alphanumeric() || c == '_'), "_"); name.replace(|c: char| !c.is_ascii_alphanumeric() && c != '_', "_");
if name.starts_with(|c: char| !c.is_ascii_alphabetic()) {
if name.starts_with(|c: char| c.is_ascii_digit()) {
name.insert(0, '_'); name.insert(0, '_');
} }
// Worst case: invalid idents will get the UnnamedX name below // Fall-back to `Unnamed{Id}` if it's still invalid
syn::parse_str(&name).ok() syn::parse_str(&name).ok()
} }
_ => None, _ => None,
@ -725,9 +692,7 @@ impl TypeStruct {
let mut ty = Type::new(shader, member_id)?; let mut ty = Type::new(shader, member_id)?;
{ {
// If the member is an array, then matrix-decorations can be applied to it if the // Matrix decorations can be applied to an array if its innermost type is a matrix.
// innermost type of the array is a matrix. Else this will stay being the type of
// the member.
let mut ty = &mut ty; let mut ty = &mut ty;
while let Type::Array(TypeArray { element_type, .. }) = ty { while let Type::Array(TypeArray { element_type, .. }) = ty {
ty = element_type; ty = element_type;
@ -744,6 +709,7 @@ impl TypeStruct {
_ => None, _ => None,
}, },
); );
matrix.stride = strides.next().ok_or_else(|| { matrix.stride = strides.next().ok_or_else(|| {
Error::new_spanned( Error::new_spanned(
&shader.source, &shader.source,
@ -766,7 +732,7 @@ impl TypeStruct {
); );
} }
let mut majornessess = member_info.decorations().iter().filter_map( let mut majornesses = member_info.decorations().iter().filter_map(
|instruction| match *instruction { |instruction| match *instruction {
Instruction::MemberDecorate { Instruction::MemberDecorate {
decoration: Decoration::ColMajor, decoration: Decoration::ColMajor,
@ -779,7 +745,8 @@ impl TypeStruct {
_ => None, _ => None,
}, },
); );
matrix.majorness = majornessess.next().ok_or_else(|| {
matrix.majorness = majornesses.next().ok_or_else(|| {
Error::new_spanned( Error::new_spanned(
&shader.source, &shader.source,
"matrices inside structs must have a `ColMajor` or `RowMajor` \ "matrices inside structs must have a `ColMajor` or `RowMajor` \
@ -787,7 +754,7 @@ impl TypeStruct {
) )
})?; })?;
if !majornessess.all(|m| m == matrix.majorness) { if !majornesses.all(|m| m == matrix.majorness) {
bail!( bail!(
shader.source, shader.source,
"found conflicting matrix majorness decorations", "found conflicting matrix majorness decorations",
@ -822,26 +789,27 @@ impl TypeStruct {
if !is_aligned(offset, ty.scalar_alignment()) { if !is_aligned(offset, ty.scalar_alignment()) {
bail!( bail!(
shader.source, shader.source,
"struct member offsets must be aligned for the member type", "struct member offsets must be aligned to their type's alignment",
); );
} }
if let Some(last) = members.last() { if let Some(previous_member) = members.last() {
if !is_aligned(offset, last.ty.scalar_alignment()) { if !is_aligned(offset, previous_member.ty.scalar_alignment()) {
bail!( bail!(
shader.source, shader.source,
"expected struct member offset to be aligned for the preceding member type", "expected struct member offset to be aligned to the preceding member \
type's alignment",
); );
} }
let last_size = last.ty.size().ok_or_else(|| { let last_size = previous_member.ty.size().ok_or_else(|| {
Error::new_spanned( Error::new_spanned(
&shader.source, &shader.source,
"all members except the last member of a struct must be sized", "all members except the last member of a struct must be sized",
) )
})?; })?;
if last.offset + last_size > offset { if previous_member.offset + last_size > offset {
bail!(shader.source, "struct members must not overlap"); bail!(shader.source, "struct members must not overlap");
} }
} }
@ -888,8 +856,8 @@ impl PartialEq for Member {
impl Eq for Member {} impl Eq for Member {}
/// Helper for serializing a type to tokens with respect to macro input. /// Helper for serializing a type as tokens according to the macro options.
struct Serializer<'a, T>(&'a T, &'a MacroInput); struct Serializer<'a, T>(&'a T, &'a MacroOptions);
impl ToTokens for Serializer<'_, Type> { impl ToTokens for Serializer<'_, Type> {
fn to_tokens(&self, tokens: &mut TokenStream) { fn to_tokens(&self, tokens: &mut TokenStream) {
@ -909,18 +877,16 @@ impl ToTokens for Serializer<'_, TypeVector> {
let component_type = &self.0.component_type; let component_type = &self.0.component_type;
let component_count = self.0.component_count as usize; let component_count = self.0.component_count as usize;
match self.1.linalg_type { match self.1.linalg_types {
LinAlgType::Std => { LinAlgTypes::Std => {
tokens.extend(quote! { [#component_type; #component_count] }); tokens.extend(quote! { [#component_type; #component_count] });
} }
LinAlgType::CgMath => { LinAlgTypes::Cgmath => {
let vector = format_ident!("Vector{}", component_count); let vector = format_ident!("Vector{}", component_count);
tokens.extend(quote! { ::cgmath::#vector<#component_type> }); tokens.extend(quote! { ::cgmath::#vector<#component_type> });
} }
LinAlgType::Nalgebra => { LinAlgTypes::Nalgebra => {
tokens.extend(quote! { tokens.extend(quote! { ::nalgebra::SVector<#component_type, #component_count> });
::nalgebra::base::SVector<#component_type, #component_count>
});
} }
} }
} }
@ -936,22 +902,22 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
// This can't overflow because the stride must be at least the vector size. // This can't overflow because the stride must be at least the vector size.
let padding = self.0.stride - self.0.vector_size(); let padding = self.0.stride - self.0.vector_size();
match self.1.linalg_type { match self.1.linalg_types {
// cgmath only has column-major matrices. It also only has square matrices, and its 3x3 // cgmath only supports column-major square matrices, and its 3x3 matrix is not padded
// matrix is not padded right. Fall back to std for anything else. // correctly.
LinAlgType::CgMath LinAlgTypes::Cgmath
if majorness == MatrixMajorness::ColumnMajor if majorness == MatrixMajorness::ColumnMajor
&& padding == 0 && vector_count == component_count
&& vector_count == component_count => && padding == 0 =>
{ {
let matrix = format_ident!("Matrix{}", component_count); let matrix = format_ident!("Matrix{}", component_count);
tokens.extend(quote! { ::cgmath::#matrix<#component_type> }); tokens.extend(quote! { ::cgmath::#matrix<#component_type> });
} }
// nalgebra only has column-major matrices, and its 3xN matrices are not padded right. // nalgebra only supports column-major matrices, and its 3xN matrices are not padded
// Fall back to std for anything else. // correctly.
LinAlgType::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => { LinAlgTypes::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => {
tokens.extend(quote! { tokens.extend(quote! {
::nalgebra::base::SMatrix<#component_type, #component_count, #vector_count> ::nalgebra::SMatrix<#component_type, #component_count, #vector_count>
}); });
} }
_ => { _ => {
@ -964,7 +930,7 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
impl ToTokens for Serializer<'_, TypeArray> { impl ToTokens for Serializer<'_, TypeArray> {
fn to_tokens(&self, tokens: &mut TokenStream) { fn to_tokens(&self, tokens: &mut TokenStream) {
let element_type = &*self.0.element_type; let element_type = self.0.element_type.as_ref();
// This can't panic because array elements must be sized. // This can't panic because array elements must be sized.
let element_size = element_type.size().unwrap(); let element_size = element_type.size().unwrap();
// This can't overflow because the stride must be at least the element size. // This can't overflow because the stride must be at least the element size.
@ -1047,7 +1013,8 @@ impl ToTokens for Serializer<'_, TypeStruct> {
} }
} }
/// Helper for wrapping tokens in `Padded`. Doesn't wrap if the padding is `0`. /// Helper for wrapping tokens in [Padded][struct@vulkano::padded::Padded].
/// Doesn't wrap if the padding is `0`.
struct Padded<T>(T, usize); struct Padded<T>(T, usize);
impl<T: ToTokens> ToTokens for Padded<T> { impl<T: ToTokens> ToTokens for Padded<T> {