Refactor vulkano-shaders

This commit is contained in:
Mai 2024-09-25 15:02:24 +03:00
parent 075e9812d8
commit 2752e20d79
6 changed files with 1102 additions and 1478 deletions

1
Cargo.lock generated
View File

@ -2465,7 +2465,6 @@ name = "vulkano-shaders"
version = "0.34.0"
dependencies = [
"ahash",
"heck",
"proc-macro2",
"quote",
"shaderc",

View File

@ -18,16 +18,14 @@ proc-macro = true
[dependencies]
ahash = { workspace = true }
heck = { workspace = true }
proc-macro2 = { workspace = true }
quote = { workspace = true }
shaderc = { workspace = true }
syn = { workspace = true, features = ["full", "extra-traits"] }
syn = { workspace = true }
vulkano = { workspace = true }
[features]
shaderc-build-from-source = ["shaderc/build-from-source"]
shaderc-debug = []
[lints]
workspace = true

View File

@ -1,135 +1,138 @@
use crate::{
structs::{self, TypeRegistry},
MacroInput,
};
use heck::ToSnakeCase;
use crate::MacroOptions;
use ahash::{HashMap, HashSet};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
use shaderc::{CompileOptions, Compiler, EnvVersion, SourceLanguage, TargetEnv};
use shaderc::{
CompilationArtifact, CompileOptions, Compiler, IncludeType, ResolvedInclude, ShaderKind,
SourceLanguage, TargetEnv,
};
use std::{
cell::RefCell,
fs,
iter::Iterator,
path::{Path, PathBuf},
};
use syn::{Error, LitStr};
use vulkano::shader::spirv::Spirv;
pub struct Shader {
pub source: LitStr,
pub name: String,
pub spirv: Spirv,
pub(super) fn compile(
compiler: &Compiler,
macro_options: &MacroOptions,
source_path: Option<&str>,
shader_name: &str,
source_code: &str,
entry_points: HashMap<Option<String>, ShaderKind>,
base_path: &Path,
include_paths: &[PathBuf],
sources_to_include: &RefCell<HashSet<String>>,
) -> Result<Vec<(Option<String>, CompilationArtifact)>, String> {
let mut compile_options = CompileOptions::new().ok_or("failed to create compile options")?;
compile_options.set_source_language(macro_options.source_language);
if let Some(vulkan_version) = macro_options.vulkan_version {
compile_options.set_target_env(TargetEnv::Vulkan, vulkan_version as u32);
}
if let Some(spirv_version) = macro_options.spirv_version {
compile_options.set_target_spirv(spirv_version);
}
compile_options.set_include_callback(
|requested_source, include_type, containing_source, depth| {
include_callback(
Path::new(requested_source),
include_type,
Path::new(containing_source),
depth,
base_path,
include_paths,
source_path.is_none(),
&mut sources_to_include.borrow_mut(),
)
},
);
for (name, value) in &macro_options.macro_defines {
compile_options.add_macro_definition(name, value.as_deref());
}
let file_name = match (source_path, macro_options.source_language) {
(Some(source_path), _) => source_path,
(None, SourceLanguage::GLSL) => &format!("{shader_name}.glsl"),
(None, SourceLanguage::HLSL) => &format!("{shader_name}.hlsl"),
};
entry_points
.into_iter()
.map(|(entry_point, shader_kind)| {
compiler
.compile_into_spirv(
source_code,
shader_kind,
file_name,
entry_point.as_deref().unwrap_or("main"),
Some(&compile_options),
)
.map(|artifact| (entry_point, artifact))
.map_err(|err| err.to_string().replace("\n", " "))
})
.collect()
}
#[allow(clippy::too_many_arguments)]
fn include_callback(
requested_source_path_raw: &str,
directive_type: IncludeType,
contained_within_path_raw: &str,
recursion_depth: usize,
include_directories: &[PathBuf],
root_source_has_path: bool,
requested_source: &Path,
include_type: IncludeType,
containing_source: &Path,
depth: usize,
base_path: &Path,
includes: &mut Vec<String>,
include_paths: &[PathBuf],
embedded_root_source: bool,
sources_to_include: &mut HashSet<String>,
) -> Result<ResolvedInclude, String> {
let file_to_include = match directive_type {
let resolved_path = match include_type {
IncludeType::Relative => {
let requested_source_path = Path::new(requested_source_path_raw);
// If the shader source is embedded within the macro, abort unless we get an absolute
// path.
if !root_source_has_path && recursion_depth == 1 && !requested_source_path.is_absolute()
{
let requested_source_name = requested_source_path
.file_name()
.expect("failed to get the name of the requested source file")
.to_string_lossy();
let requested_source_directory = requested_source_path
.parent()
.expect("failed to get the directory of the requested source file")
.to_string_lossy();
return Err(format!(
"usage of relative paths in imports in embedded GLSL is not allowed, try \
using `#include <{}>` and adding the directory `{}` to the `include` array in \
your `shader!` macro call instead",
requested_source_name, requested_source_directory,
));
if depth == 1 && embedded_root_source && !requested_source.is_absolute() {
return Err(
"you cannot use relative include directives in embedded shader source code; \
try using `#include <...>` instead"
.to_owned(),
);
}
let mut resolved_path = if recursion_depth == 1 {
Path::new(contained_within_path_raw)
.parent()
.map(|parent| base_path.join(parent))
let parent = containing_source.parent().unwrap();
if depth == 1 {
[base_path, parent, requested_source].iter().collect()
} else {
Path::new(contained_within_path_raw)
.parent()
.map(|parent| parent.to_owned())
[parent, requested_source].iter().collect()
}
.unwrap_or_else(|| {
panic!(
"the file `{}` does not reside in a directory, this is an implementation \
error",
contained_within_path_raw,
)
});
resolved_path.push(requested_source_path);
if !resolved_path.is_file() {
return Err(format!(
"invalid inclusion path `{}`, the path does not point to a file",
requested_source_path_raw,
));
}
resolved_path
}
IncludeType::Standard => {
let requested_source_path = Path::new(requested_source_path_raw);
if requested_source_path.is_absolute() {
// This message is printed either when using a missing file with an absolute path
// in the relative include directive or when using absolute paths in a standard
if requested_source.is_absolute() {
// This is returned when attempting to include a missing file by an absolute path
// in a relative include directive and when using an absolute path in a standard
// include directive.
return Err(format!(
"no such file found as specified by the absolute path; keep in mind that \
absolute paths cannot be used with inclusion from standard directories \
(`#include <...>`), try using `#include \"...\"` instead; requested path: {}",
requested_source_path_raw,
));
return Err(
"the specified file was not found; if you're using an absolute path in a \
standard include directive (`#include <...>`), try using `#include \"...\"` \
instead"
.to_owned(),
);
}
let found_requested_source_path = include_directories
include_paths
.iter()
.map(|include_directory| include_directory.join(requested_source_path))
.find(|resolved_requested_source_path| resolved_requested_source_path.is_file());
if let Some(found_requested_source_path) = found_requested_source_path {
found_requested_source_path
} else {
return Err(format!(
"failed to include the file `{}` from any include directories",
requested_source_path_raw,
));
}
.map(|include_path| include_path.join(requested_source))
.find(|source_path| source_path.is_file())
.ok_or("the specified file was not found".to_owned())?
}
};
let content = fs::read_to_string(file_to_include.as_path()).map_err(|err| {
format!(
"failed to read the contents of file `{file_to_include:?}` to be included in the \
shader source: {err}",
)
})?;
let resolved_name = file_to_include
.into_os_string()
.into_string()
.map_err(|_| {
"failed to stringify the file to be included; make sure the path consists of valid \
unicode characters"
})?;
let resolved_name = resolved_path.into_os_string().into_string().unwrap();
includes.push(resolved_name.clone());
let content = fs::read_to_string(&resolved_name)
.map_err(|err| format!("failed to read `{resolved_name}`: {err}"))?;
sources_to_include.insert(resolved_name.clone());
Ok(ResolvedInclude {
resolved_name,
@ -137,618 +140,48 @@ fn include_callback(
})
}
pub(super) fn compile(
input: &MacroInput,
path: Option<String>,
base_path: &Path,
code: &str,
shader_kind: ShaderKind,
) -> Result<(CompilationArtifact, Vec<String>), String> {
let includes = RefCell::new(Vec::new());
let compiler = Compiler::new().ok_or("failed to create shader compiler")?;
let mut compile_options =
CompileOptions::new().ok_or("failed to initialize compile options")?;
pub(super) fn generate_shader_code(
entry_points: &[(Option<&str>, &[u32])],
shader_name: &Option<String>,
) -> TokenStream {
let load_fns = entry_points.iter().map(|(name, words)| {
let load_name = match name {
Some(name) => format_ident!("load_{name}"),
None => format_ident!("load"),
};
let source_language = input.source_language.unwrap_or(SourceLanguage::GLSL);
compile_options.set_source_language(source_language);
compile_options.set_target_env(
TargetEnv::Vulkan,
input.vulkan_version.unwrap_or(EnvVersion::Vulkan1_0) as u32,
);
if let Some(spirv_version) = input.spirv_version {
compile_options.set_target_spirv(spirv_version);
}
let root_source_path = path.as_deref().unwrap_or(
// An arbitrary placeholder file name for embedded shaders.
match source_language {
SourceLanguage::GLSL => "shader.glsl",
SourceLanguage::HLSL => "shader.hlsl",
},
);
// Specify the file resolution callback for the `#include` directive.
compile_options.set_include_callback(
|requested_source_path, directive_type, contained_within_path, recursion_depth| {
include_callback(
requested_source_path,
directive_type,
contained_within_path,
recursion_depth,
&input.include_directories,
path.is_some(),
base_path,
&mut includes.borrow_mut(),
)
},
);
for (macro_name, macro_value) in &input.macro_defines {
compile_options.add_macro_definition(macro_name, Some(macro_value));
}
#[cfg(feature = "shaderc-debug")]
compile_options.set_generate_debug_info();
let content = compiler
.compile_into_spirv(
code,
shader_kind,
root_source_path,
"main",
Some(&compile_options),
)
.map_err(|e| e.to_string().replace("(s): ", "(s):\n"))?;
drop(compile_options);
Ok((content, includes.into_inner()))
}
pub(super) fn reflect(
input: &MacroInput,
source: LitStr,
name: String,
words: &[u32],
input_paths: Vec<String>,
type_registry: &mut TypeRegistry,
) -> Result<(TokenStream, TokenStream), Error> {
let spirv = Spirv::new(words).map_err(|err| {
Error::new_spanned(&source, format!("failed to parse SPIR-V words: {err}"))
})?;
let shader = Shader {
source,
name,
spirv,
};
let include_bytes = input_paths.into_iter().map(|s| {
quote! {
// Using `include_bytes` here ensures that changing the shader will force recompilation.
// The bytes themselves can be optimized out by the compiler as they are unused.
::std::include_bytes!( #s )
#[allow(unsafe_code)]
#[inline]
pub fn #load_name(
device: ::std::sync::Arc<::vulkano::device::Device>,
) -> ::std::result::Result<
::std::sync::Arc<::vulkano::shader::ShaderModule>,
::vulkano::Validated<::vulkano::VulkanError>,
> {
static WORDS: &[u32] = &[ #( #words ),* ];
unsafe {
::vulkano::shader::ShaderModule::new(
device,
::vulkano::shader::ShaderModuleCreateInfo::new(WORDS),
)
}
}
}
});
let load_name = if shader.name.is_empty() {
format_ident!("load")
if let Some(shader_name) = shader_name {
let shader_name = format_ident!("{shader_name}");
quote! {
pub mod #shader_name {
#( #load_fns )*
}
}
} else {
format_ident!("load_{}", shader.name.to_snake_case())
};
let shader_code = quote! {
/// Loads the shader as a `ShaderModule`.
#[allow(unsafe_code)]
#[inline]
pub fn #load_name(
device: ::std::sync::Arc<::vulkano::device::Device>,
) -> ::std::result::Result<
::std::sync::Arc<::vulkano::shader::ShaderModule>,
::vulkano::Validated<::vulkano::VulkanError>,
> {
let _bytes = ( #( #include_bytes ),* );
static WORDS: &[u32] = &[ #( #words ),* ];
unsafe {
::vulkano::shader::ShaderModule::new(
device,
::vulkano::shader::ShaderModuleCreateInfo::new(WORDS),
)
}
quote! {
#( #load_fns )*
}
};
let structs = structs::write_structs(input, &shader, type_registry)?;
Ok((shader_code, structs))
}
#[cfg(test)]
mod tests {
use super::*;
use proc_macro2::Span;
use quote::ToTokens;
use shaderc::SpirvVersion;
use syn::{File, Item};
use vulkano::shader::reflect;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec<String> {
paths
.iter()
.map(|p| root_path.join(p).into_os_string().into_string().unwrap())
.collect()
}
#[test]
fn spirv_parse() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
Spirv::new(&insts).unwrap();
}
#[test]
fn spirv_reflect() {
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/frag.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
assert_eq!(_structs.to_string(), "", "No structs should be generated");
}
#[test]
fn include_resolution() {
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
let (_compile_relative, _) = compile(
&MacroInput::empty(),
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include "include_dir_a/target_a.glsl"
#include "include_dir_b/target_b.glsl"
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
let (_compile_include_paths, includes) = compile(
&MacroInput {
include_directories: vec![
root_path.join("tests").join("include_dir_a"),
root_path.join("tests").join("include_dir_b"),
],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_a.glsl>
#include <target_b.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes,
convert_paths(
&root_path,
&[
["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_b", "target_b.glsl"]
.into_iter()
.collect(),
],
),
);
let (_compile_include_paths_with_relative, includes_with_relative) = compile(
&MacroInput {
include_directories: vec![root_path.join("tests").join("include_dir_a")],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_a.glsl>
#include <../include_dir_b/target_b.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_with_relative,
convert_paths(
&root_path,
&[
["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_a", "../include_dir_b/target_b.glsl"]
.into_iter()
.collect(),
],
),
);
let absolute_path = root_path
.join("tests")
.join("include_dir_a")
.join("target_a.glsl");
let absolute_path_str = absolute_path
.to_str()
.expect("cannot run tests in a folder with non unicode characters");
let (_compile_absolute_path, includes_absolute_path) = compile(
&MacroInput::empty(),
Some(String::from("tests/include_test.glsl")),
&root_path,
&format!(
r#"
#version 450
#include "{absolute_path_str}"
void main() {{}}
"#,
),
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_absolute_path,
convert_paths(
&root_path,
&[["tests", "include_dir_a", "target_a.glsl"]
.into_iter()
.collect()],
),
);
let (_compile_recursive_, includes_recursive) = compile(
&MacroInput {
include_directories: vec![
root_path.join("tests").join("include_dir_b"),
root_path.join("tests").join("include_dir_c"),
],
..MacroInput::empty()
},
Some(String::from("tests/include_test.glsl")),
&root_path,
r#"
#version 450
#include <target_c.glsl>
void main() {}
"#,
ShaderKind::Vertex,
)
.expect("cannot resolve include files");
assert_eq!(
includes_recursive,
convert_paths(
&root_path,
&[
["tests", "include_dir_c", "target_c.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_c", "../include_dir_a/target_a.glsl"]
.into_iter()
.collect(),
["tests", "include_dir_b", "target_b.glsl"]
.into_iter()
.collect(),
],
),
);
}
#[test]
fn macros() {
let need_defines = r#"
#version 450
#if defined(NAME1) && NAME2 > 29
void main() {}
#endif
"#;
let compile_no_defines = compile(
&MacroInput::empty(),
None,
Path::new(""),
need_defines,
ShaderKind::Vertex,
);
assert!(compile_no_defines.is_err());
compile(
&MacroInput {
macro_defines: vec![("NAME1".into(), "".into()), ("NAME2".into(), "58".into())],
..MacroInput::empty()
},
None,
Path::new(""),
need_defines,
ShaderKind::Vertex,
)
.expect("setting shader macros did not work");
}
/// `entrypoint1.frag.glsl`:
/// ```glsl
/// #version 450
///
/// layout(set = 0, binding = 0) uniform Uniform {
/// uint data;
/// } ubo;
///
/// layout(set = 0, binding = 1) buffer Buffer {
/// uint data;
/// } bo;
///
/// layout(set = 0, binding = 2) uniform sampler textureSampler;
/// layout(set = 0, binding = 3) uniform texture2D imageTexture;
///
/// layout(push_constant) uniform PushConstant {
/// uint data;
/// } push;
///
/// layout(input_attachment_index = 0, set = 0, binding = 4) uniform subpassInput inputAttachment;
///
/// layout(location = 0) out vec4 outColor;
///
/// void entrypoint1() {
/// bo.data = 12;
/// outColor = vec4(
/// float(ubo.data),
/// float(push.data),
/// texture(sampler2D(imageTexture, textureSampler), vec2(0.0, 0.0)).x,
/// subpassLoad(inputAttachment).x
/// );
/// }
/// ```
///
/// `entrypoint2.frag.glsl`:
/// ```glsl
/// #version 450
///
/// layout(input_attachment_index = 0, set = 0, binding = 0) uniform subpassInput inputAttachment2;
///
/// layout(set = 0, binding = 1) buffer Buffer {
/// uint data;
/// } bo2;
///
/// layout(set = 0, binding = 2) uniform Uniform {
/// uint data;
/// } ubo2;
///
/// layout(push_constant) uniform PushConstant {
/// uint data;
/// } push2;
///
/// void entrypoint2() {
/// bo2.data = ubo2.data + push2.data + int(subpassLoad(inputAttachment2).y);
/// }
/// ```
///
/// Compiled and linked with:
/// ```sh
/// glslangvalidator -e entrypoint1 --source-entrypoint entrypoint1 -V100 entrypoint1.frag.glsl -o entrypoint1.spv
/// glslangvalidator -e entrypoint2 --source-entrypoint entrypoint2 -V100 entrypoint2.frag.glsl -o entrypoint2.spv
/// spirv-link entrypoint1.spv entrypoint2.spv -o multiple_entrypoints.spv
/// ```
#[test]
fn descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let spirv = Spirv::new(&insts).unwrap();
let mut descriptors = Vec::new();
for (_, info) in reflect::entry_points(&spirv) {
descriptors.push(info.descriptor_binding_requirements);
}
// Check first entrypoint
let e1_descriptors = &descriptors[0];
let mut e1_bindings = Vec::new();
for loc in e1_descriptors.keys() {
e1_bindings.push(*loc);
}
assert_eq!(e1_bindings.len(), 5);
assert!(e1_bindings.contains(&(0, 0)));
assert!(e1_bindings.contains(&(0, 1)));
assert!(e1_bindings.contains(&(0, 2)));
assert!(e1_bindings.contains(&(0, 3)));
assert!(e1_bindings.contains(&(0, 4)));
// Check second entrypoint
let e2_descriptors = &descriptors[1];
let mut e2_bindings = Vec::new();
for loc in e2_descriptors.keys() {
e2_bindings.push(*loc);
}
assert_eq!(e2_bindings.len(), 3);
assert!(e2_bindings.contains(&(0, 0)));
assert!(e2_bindings.contains(&(0, 1)));
assert!(e2_bindings.contains(&(0, 2)));
}
#[test]
fn reflect_descriptor_calculation_with_multiple_entrypoints() {
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: u32,}).to_string()
);
}
fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec<String>)
{
compile(
&MacroInput {
spirv_version: Some(SpirvVersion::V1_6),
vulkan_version: Some(EnvVersion::Vulkan1_3),
..MacroInput::empty()
},
None,
Path::new(""),
r#"
#version 460
layout(set = 1, binding = 0) buffer Buffer {
vec3 data;
} bo;
layout(set = 2, binding = 0) uniform Uniform {
float data;
} ubo;
layout(set = 3, binding = 1) uniform sampler textureSampler;
layout(set = 3, binding = 2) uniform texture2D imageTexture;
float withMagicSparkles(float data) {
return texture(sampler2D(imageTexture, textureSampler), vec2(data, data)).x;
}
vec3 makeSecretSauce() {
return vec3(withMagicSparkles(ubo.data));
}
void main() {
bo.data = makeSecretSauce();
}
"#,
ShaderKind::Vertex,
)
.unwrap()
}
#[test]
fn descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let spirv = Spirv::new(artifact.as_binary()).unwrap();
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_binding_requirements {
bindings.push(loc);
}
assert_eq!(bindings.len(), 4);
assert!(bindings.contains(&(1, 0)));
assert!(bindings.contains(&(2, 0)));
assert!(bindings.contains(&(3, 1)));
assert!(bindings.contains(&(3, 2)));
return;
}
panic!("could not find entrypoint");
}
#[test]
fn reflect_descriptor_calculation_with_multiple_functions() {
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new(
"descriptor_calculation_with_multiple_functions_shader",
Span::call_site(),
),
String::new(),
artifact.as_binary(),
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
let structs = _structs.to_string();
assert_ne!(structs, "", "Has some structs");
let file: File = syn::parse2(_structs).unwrap();
let structs: Vec<_> = file
.items
.iter()
.filter_map(|item| {
if let Item::Struct(s) = item {
Some(s)
} else {
None
}
})
.collect();
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
assert_eq!(
buffer.fields.to_token_stream().to_string(),
quote!({pub data: [f32; 3usize],}).to_string()
);
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
assert_eq!(
uniform.fields.to_token_stream().to_string(),
quote!({pub data: f32,}).to_string()
);
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,44 +0,0 @@
#[cfg(test)]
mod tests {
use crate::{codegen::reflect, structs::TypeRegistry, MacroInput};
use proc_macro2::Span;
use syn::LitStr;
fn spv_to_words(data: &[u8]) -> Vec<u32> {
data.chunks(4)
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
#[test]
fn rust_gpu_reflect_vertex() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-vertex.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
#[test]
fn rust_gpu_reflect_fragment() {
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-fragment.spv"));
let mut type_registry = TypeRegistry::default();
let (_shader_code, _structs) = reflect(
&MacroInput::empty(),
LitStr::new("rust-gpu vertex shader", Span::call_site()),
String::new(),
&insts,
Vec::new(),
&mut type_registry,
)
.expect("reflecting spv failed");
}
}

View File

@ -1,174 +1,148 @@
use crate::{bail, codegen::Shader, LinAlgType, MacroInput};
use crate::{bail, LinAlgTypes, MacroOptions};
use ahash::HashMap;
use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
use std::{cmp::Ordering, num::NonZeroUsize};
use syn::{Error, Ident, Result};
use vulkano::shader::spirv::{Decoration, Id, Instruction};
use syn::{Error, Ident, LitStr, Result};
use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv};
#[derive(Default)]
pub struct TypeRegistry {
registered_structs: HashMap<Ident, RegisteredType>,
struct Shader {
spirv: Spirv,
name: String,
source: LitStr,
}
impl TypeRegistry {
fn register_struct(&mut self, shader: &Shader, ty: &TypeStruct) -> Result<bool> {
// Checking with registry if this struct is already registered by another shader, and if
// their signatures match.
if let Some(registered) = self.registered_structs.get(&ty.ident) {
registered.validate_signatures(&shader.name, ty)?;
// If the struct is already registered and matches this one, skip the duplicate.
Ok(false)
} else {
self.registered_structs.insert(
ty.ident.clone(),
RegisteredType {
shader: shader.name.clone(),
ty: ty.clone(),
},
);
Ok(true)
}
}
pub(super) struct RegisteredStruct {
members: Vec<Member>,
shader_name: String,
}
struct RegisteredType {
shader: String,
ty: TypeStruct,
}
impl RegisteredType {
fn validate_signatures(&self, other_shader: &str, other_ty: &TypeStruct) -> Result<()> {
let (shader, struct_ident) = (&self.shader, &self.ty.ident);
if self.ty.members.len() > other_ty.members.len() {
let member_ident = &self.ty.members[other_ty.members.len()].ident;
bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
`{struct_ident}`, but the struct from shader `{shader}` contains an extra field \
`{member_ident}`",
);
}
if self.ty.members.len() < other_ty.members.len() {
let member_ident = &other_ty.members[self.ty.members.len()].ident;
bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
`{struct_ident}`, but the struct from shader `{other_shader}` contains an extra \
field `{member_ident}`",
);
}
for (index, (member, other_member)) in self
.ty
.members
.iter()
.zip(other_ty.members.iter())
.enumerate()
{
if member.ty != other_member.ty {
let (member_ty, other_member_ty) = (&member.ty, &other_member.ty);
bail!(
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
`{struct_ident}`, but the struct from shader `{shader}` contains a field of \
type `{member_ty:?}` at index `{index}`, whereas the same struct from shader \
`{other_shader}` contains a field of type `{other_member_ty:?}` in the same \
position",
);
}
}
Ok(())
}
}
/// Translates all the structs that are contained in the SPIR-V document as Rust structs.
pub(super) fn write_structs(
input: &MacroInput,
shader: &Shader,
type_registry: &mut TypeRegistry,
/// Generates Rust structs from structs declared in SPIR-V bytecode.
pub(super) fn generate_structs(
macro_options: &MacroOptions,
spirv: Spirv,
shader_name: String,
shader_source: LitStr,
registered_structs: &mut HashMap<Ident, RegisteredStruct>,
) -> Result<TokenStream> {
if !input.generate_structs {
return Ok(TokenStream::new());
}
let mut structs_code = TokenStream::new();
let mut structs = TokenStream::new();
let shader = Shader {
spirv,
name: shader_name,
source: shader_source,
};
for (struct_id, member_type_ids) in shader
let structs = shader
.spirv
.types()
.iter()
.filter_map(|instruction| match *instruction {
.filter_map(|instruction| match instruction {
Instruction::TypeStruct {
result_id,
ref member_types,
} => Some((result_id, member_types)),
member_types,
} => Some((*result_id, member_types)),
_ => None,
})
.filter(|&(struct_id, _)| has_defined_layout(shader, struct_id))
{
let struct_ty = TypeStruct::new(shader, struct_id, member_type_ids)?;
.filter(|&(id, _)| struct_has_defined_layout(&shader.spirv, id));
// Register the type if needed.
if !type_registry.register_struct(shader, &struct_ty)? {
continue;
}
for (struct_id, member_type_ids) in structs {
let ty = TypeStruct::new(&shader, struct_id, member_type_ids)?;
let custom_derives = if struct_ty.size().is_some() {
input.custom_derives.as_slice()
if let Some(registered) = registered_structs.get(&ty.ident) {
validate_members(
&ty.ident,
&ty.members,
&registered.members,
&shader.name,
&registered.shader_name,
)?;
} else {
&[]
};
let struct_ser = Serializer(&struct_ty, input);
let custom_derives = if ty.size().is_some() {
macro_options.custom_derives.as_slice()
} else {
&[]
};
structs.extend(quote! {
#[allow(non_camel_case_types, non_snake_case)]
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
#[repr(C)]
#struct_ser
})
}
let struct_ser = Serializer(&ty, macro_options);
Ok(structs)
}
structs_code.extend(quote! {
#[allow(non_camel_case_types, non_snake_case)]
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
#[repr(C)]
#struct_ser
});
fn has_defined_layout(shader: &Shader, struct_id: Id) -> bool {
for member_info in shader.spirv.id(struct_id).members() {
let mut offset_found = false;
for instruction in member_info.decorations() {
match instruction {
Instruction::MemberDecorate {
decoration: Decoration::BuiltIn { .. },
..
} => {
// Ignore the whole struct if a member is built in, which includes
// `gl_Position` for example.
return false;
}
Instruction::MemberDecorate {
decoration: Decoration::Offset { .. },
..
} => {
offset_found = true;
}
_ => (),
}
}
// Some structs don't have `Offset` decorations, in that case they are used as local
// variables only. Ignoring these.
if !offset_found {
return false;
registered_structs.insert(
ty.ident,
RegisteredStruct {
members: ty.members,
shader_name: shader.name.clone(),
},
);
}
}
true
Ok(structs_code)
}
#[derive(Clone, Copy, PartialEq, Eq)]
fn struct_has_defined_layout(spirv: &Spirv, id: Id) -> bool {
spirv.id(id).members().iter().all(|member_info| {
let decorations = member_info
.decorations()
.iter()
.map(|instruction| match instruction {
Instruction::MemberDecorate { decoration, .. } => decoration,
_ => unreachable!(),
});
let has_offset_decoration = decorations
.clone()
.any(|decoration| matches!(decoration, Decoration::Offset { .. }));
let has_builtin_decoration = decorations
.clone()
.any(|decoration| matches!(decoration, Decoration::BuiltIn { .. }));
has_offset_decoration && !has_builtin_decoration
})
}
fn validate_members(
ident: &Ident,
first_members: &[Member],
second_members: &[Member],
first_shader: &str,
second_shader: &str,
) -> Result<()> {
match first_members.len().cmp(&second_members.len()) {
Ordering::Greater => bail!(
"the declaration of struct `{ident}` in shader `{first_shader}` has more fields than \
the declaration in shader `{second_shader}`"
),
Ordering::Less => bail!(
"the declaration of struct `{ident}` in shader `{second_shader}` has more fields than \
the declaration in shader `{first_shader}`"
),
_ => {}
}
for (index, (first_member, second_member)) in
first_members.iter().zip(second_members).enumerate()
{
let (first_type, second_type) = (&first_member.ty, &second_member.ty);
if first_type != second_type {
bail!(
"field {index} of struct `{ident}` is of type `{first_type:?}` in shader \
`{first_shader}` but of type `{second_type:?}` in shader `{second_shader}`"
);
}
}
Ok(())
}
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u8)]
enum Alignment {
A1 = 1,
@ -188,23 +162,11 @@ impl Alignment {
8 => Alignment::A8,
16 => Alignment::A16,
32 => Alignment::A32,
_ => unreachable!(),
_ => panic!(),
}
}
}
impl PartialOrd for Alignment {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Alignment {
fn cmp(&self, other: &Self) -> Ordering {
(*self as usize).cmp(&(*other as usize))
}
}
fn align_up(offset: usize, alignment: Alignment) -> usize {
(offset + alignment as usize - 1) & !(alignment as usize - 1)
}
@ -228,7 +190,10 @@ impl Type {
let id_info = shader.spirv.id(type_id);
let ty = match *id_info.instruction() {
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"),
Instruction::TypeBool { .. } => bail!(
shader.source,
"SPIR-V Boolean types don't have a defined layout"
),
Instruction::TypeInt {
width, signedness, ..
} => Type::Scalar(TypeScalar::Int(TypeInt::new(shader, width, signedness)?)),
@ -345,7 +310,7 @@ impl TypeInt {
let signed = match signedness {
0 => false,
1 => true,
_ => bail!(shader.source, "signedness must be 0 or 1"),
_ => bail!(shader.source, "signedness must be either 0 or 1"),
};
Ok(TypeInt { width, signed })
@ -473,14 +438,17 @@ impl TypeVector {
let component_count = ComponentCount::new(shader, component_count)?;
let component_type = match *shader.spirv.id(component_type_id).instruction() {
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"),
Instruction::TypeBool { .. } => bail!(
shader.source,
"SPIR-V Boolean types don't have a defined layout"
),
Instruction::TypeInt {
width, signedness, ..
} => TypeScalar::Int(TypeInt::new(shader, width, signedness)?),
Instruction::TypeFloat { width, .. } => {
TypeScalar::Float(TypeFloat::new(shader, width)?)
}
_ => bail!(shader.source, "vector components must be scalars"),
_ => bail!(shader.source, "vector components must be scalar"),
};
Ok(TypeVector {
@ -616,49 +584,46 @@ impl TypeArray {
})
.transpose()?;
let stride = {
let mut strides =
shader
.spirv
.id(array_id)
.decorations()
.iter()
.filter_map(|instruction| match *instruction {
Instruction::Decorate {
decoration: Decoration::ArrayStride { array_stride },
..
} => Some(array_stride as usize),
_ => None,
});
let stride = strides.next().ok_or_else(|| {
Error::new_spanned(
&shader.source,
"arrays inside structs must have an `ArrayStride` decoration",
)
})?;
let mut strides =
shader
.spirv
.id(array_id)
.decorations()
.iter()
.filter_map(|instruction| match *instruction {
Instruction::Decorate {
decoration: Decoration::ArrayStride { array_stride },
..
} => Some(array_stride as usize),
_ => None,
});
if !strides.all(|s| s == stride) {
bail!(shader.source, "found conflicting `ArrayStride` decorations");
}
let stride = strides.next().ok_or_else(|| {
Error::new_spanned(
&shader.source,
"arrays inside structs must have an `ArrayStride` decoration",
)
})?;
if !is_aligned(stride, element_type.scalar_alignment()) {
bail!(
shader.source,
"array strides must be aligned for the element type",
);
}
if !strides.all(|s| s == stride) {
bail!(shader.source, "found conflicting `ArrayStride` decorations");
}
let element_size = element_type.size().ok_or_else(|| {
Error::new_spanned(&shader.source, "array elements must be sized")
})?;
if !is_aligned(stride, element_type.scalar_alignment()) {
bail!(
shader.source,
"array strides must be aligned to the element type's alignment",
);
}
if stride < element_size {
bail!(shader.source, "array elements must not overlap");
}
stride
let Some(element_size) = element_type.size() else {
bail!(shader.source, "array elements must be sized");
};
if stride < element_size {
bail!(shader.source, "array elements must not overlap");
}
Ok(TypeArray {
element_type,
length,
@ -690,16 +655,18 @@ impl TypeStruct {
.iter()
.find_map(|instruction| match instruction {
Instruction::Name { name, .. } => {
// Replace chars that could potentially cause the ident to be invalid with "_".
// For example, Rust-GPU names structs by their fully qualified rust name (e.g.
// "foo::bar::MyStruct") in which the ":" is an invalid character for idents.
// Replace non-alphanumeric and non-ascii characters with '_' to ensure the name
// is a valid identifier. For example, Rust-GPU names structs by their fully
// qualified rust name (e.g. `foo::bar::MyStruct`) in which `:` makes it an
// invalid identifier.
let mut name =
name.replace(|c: char| !(c.is_ascii_alphanumeric() || c == '_'), "_");
if name.starts_with(|c: char| !c.is_ascii_alphabetic()) {
name.replace(|c: char| !c.is_ascii_alphanumeric() && c != '_', "_");
if name.starts_with(|c: char| c.is_ascii_digit()) {
name.insert(0, '_');
}
// Worst case: invalid idents will get the UnnamedX name below
// Fall-back to `Unnamed{Id}` if it's still invalid
syn::parse_str(&name).ok()
}
_ => None,
@ -725,9 +692,7 @@ impl TypeStruct {
let mut ty = Type::new(shader, member_id)?;
{
// If the member is an array, then matrix-decorations can be applied to it if the
// innermost type of the array is a matrix. Else this will stay being the type of
// the member.
// Matrix decorations can be applied to an array if its innermost type is a matrix.
let mut ty = &mut ty;
while let Type::Array(TypeArray { element_type, .. }) = ty {
ty = element_type;
@ -744,6 +709,7 @@ impl TypeStruct {
_ => None,
},
);
matrix.stride = strides.next().ok_or_else(|| {
Error::new_spanned(
&shader.source,
@ -766,7 +732,7 @@ impl TypeStruct {
);
}
let mut majornessess = member_info.decorations().iter().filter_map(
let mut majornesses = member_info.decorations().iter().filter_map(
|instruction| match *instruction {
Instruction::MemberDecorate {
decoration: Decoration::ColMajor,
@ -779,7 +745,8 @@ impl TypeStruct {
_ => None,
},
);
matrix.majorness = majornessess.next().ok_or_else(|| {
matrix.majorness = majornesses.next().ok_or_else(|| {
Error::new_spanned(
&shader.source,
"matrices inside structs must have a `ColMajor` or `RowMajor` \
@ -787,7 +754,7 @@ impl TypeStruct {
)
})?;
if !majornessess.all(|m| m == matrix.majorness) {
if !majornesses.all(|m| m == matrix.majorness) {
bail!(
shader.source,
"found conflicting matrix majorness decorations",
@ -822,26 +789,27 @@ impl TypeStruct {
if !is_aligned(offset, ty.scalar_alignment()) {
bail!(
shader.source,
"struct member offsets must be aligned for the member type",
"struct member offsets must be aligned to their type's alignment",
);
}
if let Some(last) = members.last() {
if !is_aligned(offset, last.ty.scalar_alignment()) {
if let Some(previous_member) = members.last() {
if !is_aligned(offset, previous_member.ty.scalar_alignment()) {
bail!(
shader.source,
"expected struct member offset to be aligned for the preceding member type",
"expected struct member offset to be aligned to the preceding member \
type's alignment",
);
}
let last_size = last.ty.size().ok_or_else(|| {
let last_size = previous_member.ty.size().ok_or_else(|| {
Error::new_spanned(
&shader.source,
"all members except the last member of a struct must be sized",
)
})?;
if last.offset + last_size > offset {
if previous_member.offset + last_size > offset {
bail!(shader.source, "struct members must not overlap");
}
}
@ -888,8 +856,8 @@ impl PartialEq for Member {
impl Eq for Member {}
/// Helper for serializing a type to tokens with respect to macro input.
struct Serializer<'a, T>(&'a T, &'a MacroInput);
/// Helper for serializing a type as tokens according to the macro options.
struct Serializer<'a, T>(&'a T, &'a MacroOptions);
impl ToTokens for Serializer<'_, Type> {
fn to_tokens(&self, tokens: &mut TokenStream) {
@ -909,18 +877,16 @@ impl ToTokens for Serializer<'_, TypeVector> {
let component_type = &self.0.component_type;
let component_count = self.0.component_count as usize;
match self.1.linalg_type {
LinAlgType::Std => {
match self.1.linalg_types {
LinAlgTypes::Std => {
tokens.extend(quote! { [#component_type; #component_count] });
}
LinAlgType::CgMath => {
LinAlgTypes::Cgmath => {
let vector = format_ident!("Vector{}", component_count);
tokens.extend(quote! { ::cgmath::#vector<#component_type> });
}
LinAlgType::Nalgebra => {
tokens.extend(quote! {
::nalgebra::base::SVector<#component_type, #component_count>
});
LinAlgTypes::Nalgebra => {
tokens.extend(quote! { ::nalgebra::SVector<#component_type, #component_count> });
}
}
}
@ -936,22 +902,22 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
// This can't overflow because the stride must be at least the vector size.
let padding = self.0.stride - self.0.vector_size();
match self.1.linalg_type {
// cgmath only has column-major matrices. It also only has square matrices, and its 3x3
// matrix is not padded right. Fall back to std for anything else.
LinAlgType::CgMath
match self.1.linalg_types {
// cgmath only supports column-major square matrices, and its 3x3 matrix is not padded
// correctly.
LinAlgTypes::Cgmath
if majorness == MatrixMajorness::ColumnMajor
&& padding == 0
&& vector_count == component_count =>
&& vector_count == component_count
&& padding == 0 =>
{
let matrix = format_ident!("Matrix{}", component_count);
tokens.extend(quote! { ::cgmath::#matrix<#component_type> });
}
// nalgebra only has column-major matrices, and its 3xN matrices are not padded right.
// Fall back to std for anything else.
LinAlgType::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => {
// nalgebra only supports column-major matrices, and its 3xN matrices are not padded
// correctly.
LinAlgTypes::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => {
tokens.extend(quote! {
::nalgebra::base::SMatrix<#component_type, #component_count, #vector_count>
::nalgebra::SMatrix<#component_type, #component_count, #vector_count>
});
}
_ => {
@ -964,7 +930,7 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
impl ToTokens for Serializer<'_, TypeArray> {
fn to_tokens(&self, tokens: &mut TokenStream) {
let element_type = &*self.0.element_type;
let element_type = self.0.element_type.as_ref();
// This can't panic because array elements must be sized.
let element_size = element_type.size().unwrap();
// This can't overflow because the stride must be at least the element size.
@ -1047,7 +1013,8 @@ impl ToTokens for Serializer<'_, TypeStruct> {
}
}
/// Helper for wrapping tokens in `Padded`. Doesn't wrap if the padding is `0`.
/// Helper for wrapping tokens in [Padded][struct@vulkano::padded::Padded].
/// Doesn't wrap if the padding is `0`.
struct Padded<T>(T, usize);
impl<T: ToTokens> ToTokens for Padded<T> {