mirror of
https://github.com/vulkano-rs/vulkano.git
synced 2024-11-21 14:24:18 +00:00
Refactor vulkano-shaders
This commit is contained in:
parent
075e9812d8
commit
2752e20d79
1
Cargo.lock
generated
1
Cargo.lock
generated
@ -2465,7 +2465,6 @@ name = "vulkano-shaders"
|
||||
version = "0.34.0"
|
||||
dependencies = [
|
||||
"ahash",
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"shaderc",
|
||||
|
@ -18,16 +18,14 @@ proc-macro = true
|
||||
|
||||
[dependencies]
|
||||
ahash = { workspace = true }
|
||||
heck = { workspace = true }
|
||||
proc-macro2 = { workspace = true }
|
||||
quote = { workspace = true }
|
||||
shaderc = { workspace = true }
|
||||
syn = { workspace = true, features = ["full", "extra-traits"] }
|
||||
syn = { workspace = true }
|
||||
vulkano = { workspace = true }
|
||||
|
||||
[features]
|
||||
shaderc-build-from-source = ["shaderc/build-from-source"]
|
||||
shaderc-debug = []
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
@ -1,135 +1,138 @@
|
||||
use crate::{
|
||||
structs::{self, TypeRegistry},
|
||||
MacroInput,
|
||||
};
|
||||
use heck::ToSnakeCase;
|
||||
use crate::MacroOptions;
|
||||
use ahash::{HashMap, HashSet};
|
||||
use proc_macro2::TokenStream;
|
||||
use quote::{format_ident, quote};
|
||||
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
|
||||
use shaderc::{CompileOptions, Compiler, EnvVersion, SourceLanguage, TargetEnv};
|
||||
use shaderc::{
|
||||
CompilationArtifact, CompileOptions, Compiler, IncludeType, ResolvedInclude, ShaderKind,
|
||||
SourceLanguage, TargetEnv,
|
||||
};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
fs,
|
||||
iter::Iterator,
|
||||
path::{Path, PathBuf},
|
||||
};
|
||||
use syn::{Error, LitStr};
|
||||
use vulkano::shader::spirv::Spirv;
|
||||
|
||||
pub struct Shader {
|
||||
pub source: LitStr,
|
||||
pub name: String,
|
||||
pub spirv: Spirv,
|
||||
pub(super) fn compile(
|
||||
compiler: &Compiler,
|
||||
macro_options: &MacroOptions,
|
||||
source_path: Option<&str>,
|
||||
shader_name: &str,
|
||||
source_code: &str,
|
||||
entry_points: HashMap<Option<String>, ShaderKind>,
|
||||
base_path: &Path,
|
||||
include_paths: &[PathBuf],
|
||||
sources_to_include: &RefCell<HashSet<String>>,
|
||||
) -> Result<Vec<(Option<String>, CompilationArtifact)>, String> {
|
||||
let mut compile_options = CompileOptions::new().ok_or("failed to create compile options")?;
|
||||
|
||||
compile_options.set_source_language(macro_options.source_language);
|
||||
|
||||
if let Some(vulkan_version) = macro_options.vulkan_version {
|
||||
compile_options.set_target_env(TargetEnv::Vulkan, vulkan_version as u32);
|
||||
}
|
||||
|
||||
if let Some(spirv_version) = macro_options.spirv_version {
|
||||
compile_options.set_target_spirv(spirv_version);
|
||||
}
|
||||
|
||||
compile_options.set_include_callback(
|
||||
|requested_source, include_type, containing_source, depth| {
|
||||
include_callback(
|
||||
Path::new(requested_source),
|
||||
include_type,
|
||||
Path::new(containing_source),
|
||||
depth,
|
||||
base_path,
|
||||
include_paths,
|
||||
source_path.is_none(),
|
||||
&mut sources_to_include.borrow_mut(),
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
for (name, value) in ¯o_options.macro_defines {
|
||||
compile_options.add_macro_definition(name, value.as_deref());
|
||||
}
|
||||
|
||||
let file_name = match (source_path, macro_options.source_language) {
|
||||
(Some(source_path), _) => source_path,
|
||||
(None, SourceLanguage::GLSL) => &format!("{shader_name}.glsl"),
|
||||
(None, SourceLanguage::HLSL) => &format!("{shader_name}.hlsl"),
|
||||
};
|
||||
|
||||
entry_points
|
||||
.into_iter()
|
||||
.map(|(entry_point, shader_kind)| {
|
||||
compiler
|
||||
.compile_into_spirv(
|
||||
source_code,
|
||||
shader_kind,
|
||||
file_name,
|
||||
entry_point.as_deref().unwrap_or("main"),
|
||||
Some(&compile_options),
|
||||
)
|
||||
.map(|artifact| (entry_point, artifact))
|
||||
.map_err(|err| err.to_string().replace("\n", " "))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn include_callback(
|
||||
requested_source_path_raw: &str,
|
||||
directive_type: IncludeType,
|
||||
contained_within_path_raw: &str,
|
||||
recursion_depth: usize,
|
||||
include_directories: &[PathBuf],
|
||||
root_source_has_path: bool,
|
||||
requested_source: &Path,
|
||||
include_type: IncludeType,
|
||||
containing_source: &Path,
|
||||
depth: usize,
|
||||
base_path: &Path,
|
||||
includes: &mut Vec<String>,
|
||||
include_paths: &[PathBuf],
|
||||
embedded_root_source: bool,
|
||||
sources_to_include: &mut HashSet<String>,
|
||||
) -> Result<ResolvedInclude, String> {
|
||||
let file_to_include = match directive_type {
|
||||
let resolved_path = match include_type {
|
||||
IncludeType::Relative => {
|
||||
let requested_source_path = Path::new(requested_source_path_raw);
|
||||
// If the shader source is embedded within the macro, abort unless we get an absolute
|
||||
// path.
|
||||
if !root_source_has_path && recursion_depth == 1 && !requested_source_path.is_absolute()
|
||||
{
|
||||
let requested_source_name = requested_source_path
|
||||
.file_name()
|
||||
.expect("failed to get the name of the requested source file")
|
||||
.to_string_lossy();
|
||||
let requested_source_directory = requested_source_path
|
||||
.parent()
|
||||
.expect("failed to get the directory of the requested source file")
|
||||
.to_string_lossy();
|
||||
|
||||
return Err(format!(
|
||||
"usage of relative paths in imports in embedded GLSL is not allowed, try \
|
||||
using `#include <{}>` and adding the directory `{}` to the `include` array in \
|
||||
your `shader!` macro call instead",
|
||||
requested_source_name, requested_source_directory,
|
||||
));
|
||||
if depth == 1 && embedded_root_source && !requested_source.is_absolute() {
|
||||
return Err(
|
||||
"you cannot use relative include directives in embedded shader source code; \
|
||||
try using `#include <...>` instead"
|
||||
.to_owned(),
|
||||
);
|
||||
}
|
||||
|
||||
let mut resolved_path = if recursion_depth == 1 {
|
||||
Path::new(contained_within_path_raw)
|
||||
.parent()
|
||||
.map(|parent| base_path.join(parent))
|
||||
let parent = containing_source.parent().unwrap();
|
||||
|
||||
if depth == 1 {
|
||||
[base_path, parent, requested_source].iter().collect()
|
||||
} else {
|
||||
Path::new(contained_within_path_raw)
|
||||
.parent()
|
||||
.map(|parent| parent.to_owned())
|
||||
[parent, requested_source].iter().collect()
|
||||
}
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"the file `{}` does not reside in a directory, this is an implementation \
|
||||
error",
|
||||
contained_within_path_raw,
|
||||
)
|
||||
});
|
||||
resolved_path.push(requested_source_path);
|
||||
|
||||
if !resolved_path.is_file() {
|
||||
return Err(format!(
|
||||
"invalid inclusion path `{}`, the path does not point to a file",
|
||||
requested_source_path_raw,
|
||||
));
|
||||
}
|
||||
|
||||
resolved_path
|
||||
}
|
||||
IncludeType::Standard => {
|
||||
let requested_source_path = Path::new(requested_source_path_raw);
|
||||
|
||||
if requested_source_path.is_absolute() {
|
||||
// This message is printed either when using a missing file with an absolute path
|
||||
// in the relative include directive or when using absolute paths in a standard
|
||||
if requested_source.is_absolute() {
|
||||
// This is returned when attempting to include a missing file by an absolute path
|
||||
// in a relative include directive and when using an absolute path in a standard
|
||||
// include directive.
|
||||
return Err(format!(
|
||||
"no such file found as specified by the absolute path; keep in mind that \
|
||||
absolute paths cannot be used with inclusion from standard directories \
|
||||
(`#include <...>`), try using `#include \"...\"` instead; requested path: {}",
|
||||
requested_source_path_raw,
|
||||
));
|
||||
return Err(
|
||||
"the specified file was not found; if you're using an absolute path in a \
|
||||
standard include directive (`#include <...>`), try using `#include \"...\"` \
|
||||
instead"
|
||||
.to_owned(),
|
||||
);
|
||||
}
|
||||
|
||||
let found_requested_source_path = include_directories
|
||||
include_paths
|
||||
.iter()
|
||||
.map(|include_directory| include_directory.join(requested_source_path))
|
||||
.find(|resolved_requested_source_path| resolved_requested_source_path.is_file());
|
||||
|
||||
if let Some(found_requested_source_path) = found_requested_source_path {
|
||||
found_requested_source_path
|
||||
} else {
|
||||
return Err(format!(
|
||||
"failed to include the file `{}` from any include directories",
|
||||
requested_source_path_raw,
|
||||
));
|
||||
}
|
||||
.map(|include_path| include_path.join(requested_source))
|
||||
.find(|source_path| source_path.is_file())
|
||||
.ok_or("the specified file was not found".to_owned())?
|
||||
}
|
||||
};
|
||||
|
||||
let content = fs::read_to_string(file_to_include.as_path()).map_err(|err| {
|
||||
format!(
|
||||
"failed to read the contents of file `{file_to_include:?}` to be included in the \
|
||||
shader source: {err}",
|
||||
)
|
||||
})?;
|
||||
let resolved_name = file_to_include
|
||||
.into_os_string()
|
||||
.into_string()
|
||||
.map_err(|_| {
|
||||
"failed to stringify the file to be included; make sure the path consists of valid \
|
||||
unicode characters"
|
||||
})?;
|
||||
let resolved_name = resolved_path.into_os_string().into_string().unwrap();
|
||||
|
||||
includes.push(resolved_name.clone());
|
||||
let content = fs::read_to_string(&resolved_name)
|
||||
.map_err(|err| format!("failed to read `{resolved_name}`: {err}"))?;
|
||||
|
||||
sources_to_include.insert(resolved_name.clone());
|
||||
|
||||
Ok(ResolvedInclude {
|
||||
resolved_name,
|
||||
@ -137,618 +140,48 @@ fn include_callback(
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn compile(
|
||||
input: &MacroInput,
|
||||
path: Option<String>,
|
||||
base_path: &Path,
|
||||
code: &str,
|
||||
shader_kind: ShaderKind,
|
||||
) -> Result<(CompilationArtifact, Vec<String>), String> {
|
||||
let includes = RefCell::new(Vec::new());
|
||||
let compiler = Compiler::new().ok_or("failed to create shader compiler")?;
|
||||
let mut compile_options =
|
||||
CompileOptions::new().ok_or("failed to initialize compile options")?;
|
||||
pub(super) fn generate_shader_code(
|
||||
entry_points: &[(Option<&str>, &[u32])],
|
||||
shader_name: &Option<String>,
|
||||
) -> TokenStream {
|
||||
let load_fns = entry_points.iter().map(|(name, words)| {
|
||||
let load_name = match name {
|
||||
Some(name) => format_ident!("load_{name}"),
|
||||
None => format_ident!("load"),
|
||||
};
|
||||
|
||||
let source_language = input.source_language.unwrap_or(SourceLanguage::GLSL);
|
||||
compile_options.set_source_language(source_language);
|
||||
|
||||
compile_options.set_target_env(
|
||||
TargetEnv::Vulkan,
|
||||
input.vulkan_version.unwrap_or(EnvVersion::Vulkan1_0) as u32,
|
||||
);
|
||||
|
||||
if let Some(spirv_version) = input.spirv_version {
|
||||
compile_options.set_target_spirv(spirv_version);
|
||||
}
|
||||
|
||||
let root_source_path = path.as_deref().unwrap_or(
|
||||
// An arbitrary placeholder file name for embedded shaders.
|
||||
match source_language {
|
||||
SourceLanguage::GLSL => "shader.glsl",
|
||||
SourceLanguage::HLSL => "shader.hlsl",
|
||||
},
|
||||
);
|
||||
|
||||
// Specify the file resolution callback for the `#include` directive.
|
||||
compile_options.set_include_callback(
|
||||
|requested_source_path, directive_type, contained_within_path, recursion_depth| {
|
||||
include_callback(
|
||||
requested_source_path,
|
||||
directive_type,
|
||||
contained_within_path,
|
||||
recursion_depth,
|
||||
&input.include_directories,
|
||||
path.is_some(),
|
||||
base_path,
|
||||
&mut includes.borrow_mut(),
|
||||
)
|
||||
},
|
||||
);
|
||||
|
||||
for (macro_name, macro_value) in &input.macro_defines {
|
||||
compile_options.add_macro_definition(macro_name, Some(macro_value));
|
||||
}
|
||||
|
||||
#[cfg(feature = "shaderc-debug")]
|
||||
compile_options.set_generate_debug_info();
|
||||
|
||||
let content = compiler
|
||||
.compile_into_spirv(
|
||||
code,
|
||||
shader_kind,
|
||||
root_source_path,
|
||||
"main",
|
||||
Some(&compile_options),
|
||||
)
|
||||
.map_err(|e| e.to_string().replace("(s): ", "(s):\n"))?;
|
||||
|
||||
drop(compile_options);
|
||||
|
||||
Ok((content, includes.into_inner()))
|
||||
}
|
||||
|
||||
pub(super) fn reflect(
|
||||
input: &MacroInput,
|
||||
source: LitStr,
|
||||
name: String,
|
||||
words: &[u32],
|
||||
input_paths: Vec<String>,
|
||||
type_registry: &mut TypeRegistry,
|
||||
) -> Result<(TokenStream, TokenStream), Error> {
|
||||
let spirv = Spirv::new(words).map_err(|err| {
|
||||
Error::new_spanned(&source, format!("failed to parse SPIR-V words: {err}"))
|
||||
})?;
|
||||
let shader = Shader {
|
||||
source,
|
||||
name,
|
||||
spirv,
|
||||
};
|
||||
|
||||
let include_bytes = input_paths.into_iter().map(|s| {
|
||||
quote! {
|
||||
// Using `include_bytes` here ensures that changing the shader will force recompilation.
|
||||
// The bytes themselves can be optimized out by the compiler as they are unused.
|
||||
::std::include_bytes!( #s )
|
||||
#[allow(unsafe_code)]
|
||||
#[inline]
|
||||
pub fn #load_name(
|
||||
device: ::std::sync::Arc<::vulkano::device::Device>,
|
||||
) -> ::std::result::Result<
|
||||
::std::sync::Arc<::vulkano::shader::ShaderModule>,
|
||||
::vulkano::Validated<::vulkano::VulkanError>,
|
||||
> {
|
||||
static WORDS: &[u32] = &[ #( #words ),* ];
|
||||
|
||||
unsafe {
|
||||
::vulkano::shader::ShaderModule::new(
|
||||
device,
|
||||
::vulkano::shader::ShaderModuleCreateInfo::new(WORDS),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let load_name = if shader.name.is_empty() {
|
||||
format_ident!("load")
|
||||
if let Some(shader_name) = shader_name {
|
||||
let shader_name = format_ident!("{shader_name}");
|
||||
|
||||
quote! {
|
||||
pub mod #shader_name {
|
||||
#( #load_fns )*
|
||||
}
|
||||
}
|
||||
} else {
|
||||
format_ident!("load_{}", shader.name.to_snake_case())
|
||||
};
|
||||
|
||||
let shader_code = quote! {
|
||||
/// Loads the shader as a `ShaderModule`.
|
||||
#[allow(unsafe_code)]
|
||||
#[inline]
|
||||
pub fn #load_name(
|
||||
device: ::std::sync::Arc<::vulkano::device::Device>,
|
||||
) -> ::std::result::Result<
|
||||
::std::sync::Arc<::vulkano::shader::ShaderModule>,
|
||||
::vulkano::Validated<::vulkano::VulkanError>,
|
||||
> {
|
||||
let _bytes = ( #( #include_bytes ),* );
|
||||
|
||||
static WORDS: &[u32] = &[ #( #words ),* ];
|
||||
|
||||
unsafe {
|
||||
::vulkano::shader::ShaderModule::new(
|
||||
device,
|
||||
::vulkano::shader::ShaderModuleCreateInfo::new(WORDS),
|
||||
)
|
||||
}
|
||||
quote! {
|
||||
#( #load_fns )*
|
||||
}
|
||||
};
|
||||
|
||||
let structs = structs::write_structs(input, &shader, type_registry)?;
|
||||
|
||||
Ok((shader_code, structs))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use proc_macro2::Span;
|
||||
use quote::ToTokens;
|
||||
use shaderc::SpirvVersion;
|
||||
use syn::{File, Item};
|
||||
use vulkano::shader::reflect;
|
||||
|
||||
fn spv_to_words(data: &[u8]) -> Vec<u32> {
|
||||
data.chunks(4)
|
||||
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn convert_paths(root_path: &Path, paths: &[PathBuf]) -> Vec<String> {
|
||||
paths
|
||||
.iter()
|
||||
.map(|p| root_path.join(p).into_os_string().into_string().unwrap())
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spirv_parse() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
|
||||
Spirv::new(&insts).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn spirv_reflect() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/frag.spv"));
|
||||
|
||||
let mut type_registry = TypeRegistry::default();
|
||||
let (_shader_code, _structs) = reflect(
|
||||
&MacroInput::empty(),
|
||||
LitStr::new("../tests/frag.spv", Span::call_site()),
|
||||
String::new(),
|
||||
&insts,
|
||||
Vec::new(),
|
||||
&mut type_registry,
|
||||
)
|
||||
.expect("reflecting spv failed");
|
||||
|
||||
assert_eq!(_structs.to_string(), "", "No structs should be generated");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn include_resolution() {
|
||||
let root_path = PathBuf::from(env!("CARGO_MANIFEST_DIR"));
|
||||
|
||||
let (_compile_relative, _) = compile(
|
||||
&MacroInput::empty(),
|
||||
Some(String::from("tests/include_test.glsl")),
|
||||
&root_path,
|
||||
r#"
|
||||
#version 450
|
||||
#include "include_dir_a/target_a.glsl"
|
||||
#include "include_dir_b/target_b.glsl"
|
||||
void main() {}
|
||||
"#,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("cannot resolve include files");
|
||||
|
||||
let (_compile_include_paths, includes) = compile(
|
||||
&MacroInput {
|
||||
include_directories: vec![
|
||||
root_path.join("tests").join("include_dir_a"),
|
||||
root_path.join("tests").join("include_dir_b"),
|
||||
],
|
||||
..MacroInput::empty()
|
||||
},
|
||||
Some(String::from("tests/include_test.glsl")),
|
||||
&root_path,
|
||||
r#"
|
||||
#version 450
|
||||
#include <target_a.glsl>
|
||||
#include <target_b.glsl>
|
||||
void main() {}
|
||||
"#,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("cannot resolve include files");
|
||||
|
||||
assert_eq!(
|
||||
includes,
|
||||
convert_paths(
|
||||
&root_path,
|
||||
&[
|
||||
["tests", "include_dir_a", "target_a.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
["tests", "include_dir_b", "target_b.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
],
|
||||
),
|
||||
);
|
||||
|
||||
let (_compile_include_paths_with_relative, includes_with_relative) = compile(
|
||||
&MacroInput {
|
||||
include_directories: vec![root_path.join("tests").join("include_dir_a")],
|
||||
..MacroInput::empty()
|
||||
},
|
||||
Some(String::from("tests/include_test.glsl")),
|
||||
&root_path,
|
||||
r#"
|
||||
#version 450
|
||||
#include <target_a.glsl>
|
||||
#include <../include_dir_b/target_b.glsl>
|
||||
void main() {}
|
||||
"#,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("cannot resolve include files");
|
||||
|
||||
assert_eq!(
|
||||
includes_with_relative,
|
||||
convert_paths(
|
||||
&root_path,
|
||||
&[
|
||||
["tests", "include_dir_a", "target_a.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
["tests", "include_dir_a", "../include_dir_b/target_b.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
],
|
||||
),
|
||||
);
|
||||
|
||||
let absolute_path = root_path
|
||||
.join("tests")
|
||||
.join("include_dir_a")
|
||||
.join("target_a.glsl");
|
||||
let absolute_path_str = absolute_path
|
||||
.to_str()
|
||||
.expect("cannot run tests in a folder with non unicode characters");
|
||||
let (_compile_absolute_path, includes_absolute_path) = compile(
|
||||
&MacroInput::empty(),
|
||||
Some(String::from("tests/include_test.glsl")),
|
||||
&root_path,
|
||||
&format!(
|
||||
r#"
|
||||
#version 450
|
||||
#include "{absolute_path_str}"
|
||||
void main() {{}}
|
||||
"#,
|
||||
),
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("cannot resolve include files");
|
||||
|
||||
assert_eq!(
|
||||
includes_absolute_path,
|
||||
convert_paths(
|
||||
&root_path,
|
||||
&[["tests", "include_dir_a", "target_a.glsl"]
|
||||
.into_iter()
|
||||
.collect()],
|
||||
),
|
||||
);
|
||||
|
||||
let (_compile_recursive_, includes_recursive) = compile(
|
||||
&MacroInput {
|
||||
include_directories: vec![
|
||||
root_path.join("tests").join("include_dir_b"),
|
||||
root_path.join("tests").join("include_dir_c"),
|
||||
],
|
||||
..MacroInput::empty()
|
||||
},
|
||||
Some(String::from("tests/include_test.glsl")),
|
||||
&root_path,
|
||||
r#"
|
||||
#version 450
|
||||
#include <target_c.glsl>
|
||||
void main() {}
|
||||
"#,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("cannot resolve include files");
|
||||
|
||||
assert_eq!(
|
||||
includes_recursive,
|
||||
convert_paths(
|
||||
&root_path,
|
||||
&[
|
||||
["tests", "include_dir_c", "target_c.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
["tests", "include_dir_c", "../include_dir_a/target_a.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
["tests", "include_dir_b", "target_b.glsl"]
|
||||
.into_iter()
|
||||
.collect(),
|
||||
],
|
||||
),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn macros() {
|
||||
let need_defines = r#"
|
||||
#version 450
|
||||
#if defined(NAME1) && NAME2 > 29
|
||||
void main() {}
|
||||
#endif
|
||||
"#;
|
||||
|
||||
let compile_no_defines = compile(
|
||||
&MacroInput::empty(),
|
||||
None,
|
||||
Path::new(""),
|
||||
need_defines,
|
||||
ShaderKind::Vertex,
|
||||
);
|
||||
assert!(compile_no_defines.is_err());
|
||||
|
||||
compile(
|
||||
&MacroInput {
|
||||
macro_defines: vec![("NAME1".into(), "".into()), ("NAME2".into(), "58".into())],
|
||||
..MacroInput::empty()
|
||||
},
|
||||
None,
|
||||
Path::new(""),
|
||||
need_defines,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.expect("setting shader macros did not work");
|
||||
}
|
||||
|
||||
/// `entrypoint1.frag.glsl`:
|
||||
/// ```glsl
|
||||
/// #version 450
|
||||
///
|
||||
/// layout(set = 0, binding = 0) uniform Uniform {
|
||||
/// uint data;
|
||||
/// } ubo;
|
||||
///
|
||||
/// layout(set = 0, binding = 1) buffer Buffer {
|
||||
/// uint data;
|
||||
/// } bo;
|
||||
///
|
||||
/// layout(set = 0, binding = 2) uniform sampler textureSampler;
|
||||
/// layout(set = 0, binding = 3) uniform texture2D imageTexture;
|
||||
///
|
||||
/// layout(push_constant) uniform PushConstant {
|
||||
/// uint data;
|
||||
/// } push;
|
||||
///
|
||||
/// layout(input_attachment_index = 0, set = 0, binding = 4) uniform subpassInput inputAttachment;
|
||||
///
|
||||
/// layout(location = 0) out vec4 outColor;
|
||||
///
|
||||
/// void entrypoint1() {
|
||||
/// bo.data = 12;
|
||||
/// outColor = vec4(
|
||||
/// float(ubo.data),
|
||||
/// float(push.data),
|
||||
/// texture(sampler2D(imageTexture, textureSampler), vec2(0.0, 0.0)).x,
|
||||
/// subpassLoad(inputAttachment).x
|
||||
/// );
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// `entrypoint2.frag.glsl`:
|
||||
/// ```glsl
|
||||
/// #version 450
|
||||
///
|
||||
/// layout(input_attachment_index = 0, set = 0, binding = 0) uniform subpassInput inputAttachment2;
|
||||
///
|
||||
/// layout(set = 0, binding = 1) buffer Buffer {
|
||||
/// uint data;
|
||||
/// } bo2;
|
||||
///
|
||||
/// layout(set = 0, binding = 2) uniform Uniform {
|
||||
/// uint data;
|
||||
/// } ubo2;
|
||||
///
|
||||
/// layout(push_constant) uniform PushConstant {
|
||||
/// uint data;
|
||||
/// } push2;
|
||||
///
|
||||
/// void entrypoint2() {
|
||||
/// bo2.data = ubo2.data + push2.data + int(subpassLoad(inputAttachment2).y);
|
||||
/// }
|
||||
/// ```
|
||||
///
|
||||
/// Compiled and linked with:
|
||||
/// ```sh
|
||||
/// glslangvalidator -e entrypoint1 --source-entrypoint entrypoint1 -V100 entrypoint1.frag.glsl -o entrypoint1.spv
|
||||
/// glslangvalidator -e entrypoint2 --source-entrypoint entrypoint2 -V100 entrypoint2.frag.glsl -o entrypoint2.spv
|
||||
/// spirv-link entrypoint1.spv entrypoint2.spv -o multiple_entrypoints.spv
|
||||
/// ```
|
||||
#[test]
|
||||
fn descriptor_calculation_with_multiple_entrypoints() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
|
||||
let spirv = Spirv::new(&insts).unwrap();
|
||||
|
||||
let mut descriptors = Vec::new();
|
||||
for (_, info) in reflect::entry_points(&spirv) {
|
||||
descriptors.push(info.descriptor_binding_requirements);
|
||||
}
|
||||
|
||||
// Check first entrypoint
|
||||
let e1_descriptors = &descriptors[0];
|
||||
let mut e1_bindings = Vec::new();
|
||||
for loc in e1_descriptors.keys() {
|
||||
e1_bindings.push(*loc);
|
||||
}
|
||||
|
||||
assert_eq!(e1_bindings.len(), 5);
|
||||
assert!(e1_bindings.contains(&(0, 0)));
|
||||
assert!(e1_bindings.contains(&(0, 1)));
|
||||
assert!(e1_bindings.contains(&(0, 2)));
|
||||
assert!(e1_bindings.contains(&(0, 3)));
|
||||
assert!(e1_bindings.contains(&(0, 4)));
|
||||
|
||||
// Check second entrypoint
|
||||
let e2_descriptors = &descriptors[1];
|
||||
let mut e2_bindings = Vec::new();
|
||||
for loc in e2_descriptors.keys() {
|
||||
e2_bindings.push(*loc);
|
||||
}
|
||||
|
||||
assert_eq!(e2_bindings.len(), 3);
|
||||
assert!(e2_bindings.contains(&(0, 0)));
|
||||
assert!(e2_bindings.contains(&(0, 1)));
|
||||
assert!(e2_bindings.contains(&(0, 2)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reflect_descriptor_calculation_with_multiple_entrypoints() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/multiple_entrypoints.spv"));
|
||||
|
||||
let mut type_registry = TypeRegistry::default();
|
||||
let (_shader_code, _structs) = reflect(
|
||||
&MacroInput::empty(),
|
||||
LitStr::new("../tests/multiple_entrypoints.spv", Span::call_site()),
|
||||
String::new(),
|
||||
&insts,
|
||||
Vec::new(),
|
||||
&mut type_registry,
|
||||
)
|
||||
.expect("reflecting spv failed");
|
||||
|
||||
let structs = _structs.to_string();
|
||||
assert_ne!(structs, "", "Has some structs");
|
||||
|
||||
let file: File = syn::parse2(_structs).unwrap();
|
||||
let structs: Vec<_> = file
|
||||
.items
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
if let Item::Struct(s) = item {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
|
||||
assert_eq!(
|
||||
buffer.fields.to_token_stream().to_string(),
|
||||
quote!({pub data: u32,}).to_string()
|
||||
);
|
||||
|
||||
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
|
||||
assert_eq!(
|
||||
uniform.fields.to_token_stream().to_string(),
|
||||
quote!({pub data: u32,}).to_string()
|
||||
);
|
||||
}
|
||||
|
||||
fn descriptor_calculation_with_multiple_functions_shader() -> (CompilationArtifact, Vec<String>)
|
||||
{
|
||||
compile(
|
||||
&MacroInput {
|
||||
spirv_version: Some(SpirvVersion::V1_6),
|
||||
vulkan_version: Some(EnvVersion::Vulkan1_3),
|
||||
..MacroInput::empty()
|
||||
},
|
||||
None,
|
||||
Path::new(""),
|
||||
r#"
|
||||
#version 460
|
||||
|
||||
layout(set = 1, binding = 0) buffer Buffer {
|
||||
vec3 data;
|
||||
} bo;
|
||||
|
||||
layout(set = 2, binding = 0) uniform Uniform {
|
||||
float data;
|
||||
} ubo;
|
||||
|
||||
layout(set = 3, binding = 1) uniform sampler textureSampler;
|
||||
layout(set = 3, binding = 2) uniform texture2D imageTexture;
|
||||
|
||||
float withMagicSparkles(float data) {
|
||||
return texture(sampler2D(imageTexture, textureSampler), vec2(data, data)).x;
|
||||
}
|
||||
|
||||
vec3 makeSecretSauce() {
|
||||
return vec3(withMagicSparkles(ubo.data));
|
||||
}
|
||||
|
||||
void main() {
|
||||
bo.data = makeSecretSauce();
|
||||
}
|
||||
"#,
|
||||
ShaderKind::Vertex,
|
||||
)
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn descriptor_calculation_with_multiple_functions() {
|
||||
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
|
||||
let spirv = Spirv::new(artifact.as_binary()).unwrap();
|
||||
|
||||
if let Some((_, info)) = reflect::entry_points(&spirv).next() {
|
||||
let mut bindings = Vec::new();
|
||||
for (loc, _reqs) in info.descriptor_binding_requirements {
|
||||
bindings.push(loc);
|
||||
}
|
||||
|
||||
assert_eq!(bindings.len(), 4);
|
||||
assert!(bindings.contains(&(1, 0)));
|
||||
assert!(bindings.contains(&(2, 0)));
|
||||
assert!(bindings.contains(&(3, 1)));
|
||||
assert!(bindings.contains(&(3, 2)));
|
||||
|
||||
return;
|
||||
}
|
||||
panic!("could not find entrypoint");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn reflect_descriptor_calculation_with_multiple_functions() {
|
||||
let (artifact, _) = descriptor_calculation_with_multiple_functions_shader();
|
||||
|
||||
let mut type_registry = TypeRegistry::default();
|
||||
let (_shader_code, _structs) = reflect(
|
||||
&MacroInput::empty(),
|
||||
LitStr::new(
|
||||
"descriptor_calculation_with_multiple_functions_shader",
|
||||
Span::call_site(),
|
||||
),
|
||||
String::new(),
|
||||
artifact.as_binary(),
|
||||
Vec::new(),
|
||||
&mut type_registry,
|
||||
)
|
||||
.expect("reflecting spv failed");
|
||||
|
||||
let structs = _structs.to_string();
|
||||
assert_ne!(structs, "", "Has some structs");
|
||||
|
||||
let file: File = syn::parse2(_structs).unwrap();
|
||||
let structs: Vec<_> = file
|
||||
.items
|
||||
.iter()
|
||||
.filter_map(|item| {
|
||||
if let Item::Struct(s) = item {
|
||||
Some(s)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
let buffer = structs.iter().find(|s| s.ident == "Buffer").unwrap();
|
||||
assert_eq!(
|
||||
buffer.fields.to_token_stream().to_string(),
|
||||
quote!({pub data: [f32; 3usize],}).to_string()
|
||||
);
|
||||
|
||||
let uniform = structs.iter().find(|s| s.ident == "Uniform").unwrap();
|
||||
assert_eq!(
|
||||
uniform.fields.to_token_stream().to_string(),
|
||||
quote!({pub data: f32,}).to_string()
|
||||
);
|
||||
}
|
||||
}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -1,44 +0,0 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::{codegen::reflect, structs::TypeRegistry, MacroInput};
|
||||
use proc_macro2::Span;
|
||||
use syn::LitStr;
|
||||
|
||||
fn spv_to_words(data: &[u8]) -> Vec<u32> {
|
||||
data.chunks(4)
|
||||
.map(|c| u32::from_le_bytes([c[0], c[1], c[2], c[3]]))
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rust_gpu_reflect_vertex() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-vertex.spv"));
|
||||
|
||||
let mut type_registry = TypeRegistry::default();
|
||||
let (_shader_code, _structs) = reflect(
|
||||
&MacroInput::empty(),
|
||||
LitStr::new("rust-gpu vertex shader", Span::call_site()),
|
||||
String::new(),
|
||||
&insts,
|
||||
Vec::new(),
|
||||
&mut type_registry,
|
||||
)
|
||||
.expect("reflecting spv failed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rust_gpu_reflect_fragment() {
|
||||
let insts = spv_to_words(include_bytes!("../tests/rust-gpu/test_shader-fragment.spv"));
|
||||
|
||||
let mut type_registry = TypeRegistry::default();
|
||||
let (_shader_code, _structs) = reflect(
|
||||
&MacroInput::empty(),
|
||||
LitStr::new("rust-gpu vertex shader", Span::call_site()),
|
||||
String::new(),
|
||||
&insts,
|
||||
Vec::new(),
|
||||
&mut type_registry,
|
||||
)
|
||||
.expect("reflecting spv failed");
|
||||
}
|
||||
}
|
@ -1,174 +1,148 @@
|
||||
use crate::{bail, codegen::Shader, LinAlgType, MacroInput};
|
||||
use crate::{bail, LinAlgTypes, MacroOptions};
|
||||
use ahash::HashMap;
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use quote::{format_ident, quote, ToTokens, TokenStreamExt};
|
||||
use std::{cmp::Ordering, num::NonZeroUsize};
|
||||
use syn::{Error, Ident, Result};
|
||||
use vulkano::shader::spirv::{Decoration, Id, Instruction};
|
||||
use syn::{Error, Ident, LitStr, Result};
|
||||
use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv};
|
||||
|
||||
#[derive(Default)]
|
||||
pub struct TypeRegistry {
|
||||
registered_structs: HashMap<Ident, RegisteredType>,
|
||||
struct Shader {
|
||||
spirv: Spirv,
|
||||
name: String,
|
||||
source: LitStr,
|
||||
}
|
||||
|
||||
impl TypeRegistry {
|
||||
fn register_struct(&mut self, shader: &Shader, ty: &TypeStruct) -> Result<bool> {
|
||||
// Checking with registry if this struct is already registered by another shader, and if
|
||||
// their signatures match.
|
||||
if let Some(registered) = self.registered_structs.get(&ty.ident) {
|
||||
registered.validate_signatures(&shader.name, ty)?;
|
||||
|
||||
// If the struct is already registered and matches this one, skip the duplicate.
|
||||
Ok(false)
|
||||
} else {
|
||||
self.registered_structs.insert(
|
||||
ty.ident.clone(),
|
||||
RegisteredType {
|
||||
shader: shader.name.clone(),
|
||||
ty: ty.clone(),
|
||||
},
|
||||
);
|
||||
|
||||
Ok(true)
|
||||
}
|
||||
}
|
||||
pub(super) struct RegisteredStruct {
|
||||
members: Vec<Member>,
|
||||
shader_name: String,
|
||||
}
|
||||
|
||||
struct RegisteredType {
|
||||
shader: String,
|
||||
ty: TypeStruct,
|
||||
}
|
||||
|
||||
impl RegisteredType {
|
||||
fn validate_signatures(&self, other_shader: &str, other_ty: &TypeStruct) -> Result<()> {
|
||||
let (shader, struct_ident) = (&self.shader, &self.ty.ident);
|
||||
|
||||
if self.ty.members.len() > other_ty.members.len() {
|
||||
let member_ident = &self.ty.members[other_ty.members.len()].ident;
|
||||
bail!(
|
||||
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
|
||||
`{struct_ident}`, but the struct from shader `{shader}` contains an extra field \
|
||||
`{member_ident}`",
|
||||
);
|
||||
}
|
||||
|
||||
if self.ty.members.len() < other_ty.members.len() {
|
||||
let member_ident = &other_ty.members[self.ty.members.len()].ident;
|
||||
bail!(
|
||||
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
|
||||
`{struct_ident}`, but the struct from shader `{other_shader}` contains an extra \
|
||||
field `{member_ident}`",
|
||||
);
|
||||
}
|
||||
|
||||
for (index, (member, other_member)) in self
|
||||
.ty
|
||||
.members
|
||||
.iter()
|
||||
.zip(other_ty.members.iter())
|
||||
.enumerate()
|
||||
{
|
||||
if member.ty != other_member.ty {
|
||||
let (member_ty, other_member_ty) = (&member.ty, &other_member.ty);
|
||||
bail!(
|
||||
"shaders `{shader}` and `{other_shader}` declare structs with the same name \
|
||||
`{struct_ident}`, but the struct from shader `{shader}` contains a field of \
|
||||
type `{member_ty:?}` at index `{index}`, whereas the same struct from shader \
|
||||
`{other_shader}` contains a field of type `{other_member_ty:?}` in the same \
|
||||
position",
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Translates all the structs that are contained in the SPIR-V document as Rust structs.
|
||||
pub(super) fn write_structs(
|
||||
input: &MacroInput,
|
||||
shader: &Shader,
|
||||
type_registry: &mut TypeRegistry,
|
||||
/// Generates Rust structs from structs declared in SPIR-V bytecode.
|
||||
pub(super) fn generate_structs(
|
||||
macro_options: &MacroOptions,
|
||||
spirv: Spirv,
|
||||
shader_name: String,
|
||||
shader_source: LitStr,
|
||||
registered_structs: &mut HashMap<Ident, RegisteredStruct>,
|
||||
) -> Result<TokenStream> {
|
||||
if !input.generate_structs {
|
||||
return Ok(TokenStream::new());
|
||||
}
|
||||
let mut structs_code = TokenStream::new();
|
||||
|
||||
let mut structs = TokenStream::new();
|
||||
let shader = Shader {
|
||||
spirv,
|
||||
name: shader_name,
|
||||
source: shader_source,
|
||||
};
|
||||
|
||||
for (struct_id, member_type_ids) in shader
|
||||
let structs = shader
|
||||
.spirv
|
||||
.types()
|
||||
.iter()
|
||||
.filter_map(|instruction| match *instruction {
|
||||
.filter_map(|instruction| match instruction {
|
||||
Instruction::TypeStruct {
|
||||
result_id,
|
||||
ref member_types,
|
||||
} => Some((result_id, member_types)),
|
||||
member_types,
|
||||
} => Some((*result_id, member_types)),
|
||||
_ => None,
|
||||
})
|
||||
.filter(|&(struct_id, _)| has_defined_layout(shader, struct_id))
|
||||
{
|
||||
let struct_ty = TypeStruct::new(shader, struct_id, member_type_ids)?;
|
||||
.filter(|&(id, _)| struct_has_defined_layout(&shader.spirv, id));
|
||||
|
||||
// Register the type if needed.
|
||||
if !type_registry.register_struct(shader, &struct_ty)? {
|
||||
continue;
|
||||
}
|
||||
for (struct_id, member_type_ids) in structs {
|
||||
let ty = TypeStruct::new(&shader, struct_id, member_type_ids)?;
|
||||
|
||||
let custom_derives = if struct_ty.size().is_some() {
|
||||
input.custom_derives.as_slice()
|
||||
if let Some(registered) = registered_structs.get(&ty.ident) {
|
||||
validate_members(
|
||||
&ty.ident,
|
||||
&ty.members,
|
||||
®istered.members,
|
||||
&shader.name,
|
||||
®istered.shader_name,
|
||||
)?;
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
let struct_ser = Serializer(&struct_ty, input);
|
||||
let custom_derives = if ty.size().is_some() {
|
||||
macro_options.custom_derives.as_slice()
|
||||
} else {
|
||||
&[]
|
||||
};
|
||||
|
||||
structs.extend(quote! {
|
||||
#[allow(non_camel_case_types, non_snake_case)]
|
||||
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
|
||||
#[repr(C)]
|
||||
#struct_ser
|
||||
})
|
||||
}
|
||||
let struct_ser = Serializer(&ty, macro_options);
|
||||
|
||||
Ok(structs)
|
||||
}
|
||||
structs_code.extend(quote! {
|
||||
#[allow(non_camel_case_types, non_snake_case)]
|
||||
#[derive(::vulkano::buffer::BufferContents #(, #custom_derives )* )]
|
||||
#[repr(C)]
|
||||
#struct_ser
|
||||
});
|
||||
|
||||
fn has_defined_layout(shader: &Shader, struct_id: Id) -> bool {
|
||||
for member_info in shader.spirv.id(struct_id).members() {
|
||||
let mut offset_found = false;
|
||||
|
||||
for instruction in member_info.decorations() {
|
||||
match instruction {
|
||||
Instruction::MemberDecorate {
|
||||
decoration: Decoration::BuiltIn { .. },
|
||||
..
|
||||
} => {
|
||||
// Ignore the whole struct if a member is built in, which includes
|
||||
// `gl_Position` for example.
|
||||
return false;
|
||||
}
|
||||
Instruction::MemberDecorate {
|
||||
decoration: Decoration::Offset { .. },
|
||||
..
|
||||
} => {
|
||||
offset_found = true;
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
|
||||
// Some structs don't have `Offset` decorations, in that case they are used as local
|
||||
// variables only. Ignoring these.
|
||||
if !offset_found {
|
||||
return false;
|
||||
registered_structs.insert(
|
||||
ty.ident,
|
||||
RegisteredStruct {
|
||||
members: ty.members,
|
||||
shader_name: shader.name.clone(),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
true
|
||||
Ok(structs_code)
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq)]
|
||||
fn struct_has_defined_layout(spirv: &Spirv, id: Id) -> bool {
|
||||
spirv.id(id).members().iter().all(|member_info| {
|
||||
let decorations = member_info
|
||||
.decorations()
|
||||
.iter()
|
||||
.map(|instruction| match instruction {
|
||||
Instruction::MemberDecorate { decoration, .. } => decoration,
|
||||
_ => unreachable!(),
|
||||
});
|
||||
|
||||
let has_offset_decoration = decorations
|
||||
.clone()
|
||||
.any(|decoration| matches!(decoration, Decoration::Offset { .. }));
|
||||
|
||||
let has_builtin_decoration = decorations
|
||||
.clone()
|
||||
.any(|decoration| matches!(decoration, Decoration::BuiltIn { .. }));
|
||||
|
||||
has_offset_decoration && !has_builtin_decoration
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_members(
|
||||
ident: &Ident,
|
||||
first_members: &[Member],
|
||||
second_members: &[Member],
|
||||
first_shader: &str,
|
||||
second_shader: &str,
|
||||
) -> Result<()> {
|
||||
match first_members.len().cmp(&second_members.len()) {
|
||||
Ordering::Greater => bail!(
|
||||
"the declaration of struct `{ident}` in shader `{first_shader}` has more fields than \
|
||||
the declaration in shader `{second_shader}`"
|
||||
),
|
||||
Ordering::Less => bail!(
|
||||
"the declaration of struct `{ident}` in shader `{second_shader}` has more fields than \
|
||||
the declaration in shader `{first_shader}`"
|
||||
),
|
||||
_ => {}
|
||||
}
|
||||
|
||||
for (index, (first_member, second_member)) in
|
||||
first_members.iter().zip(second_members).enumerate()
|
||||
{
|
||||
let (first_type, second_type) = (&first_member.ty, &second_member.ty);
|
||||
if first_type != second_type {
|
||||
bail!(
|
||||
"field {index} of struct `{ident}` is of type `{first_type:?}` in shader \
|
||||
`{first_shader}` but of type `{second_type:?}` in shader `{second_shader}`"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
|
||||
#[repr(u8)]
|
||||
enum Alignment {
|
||||
A1 = 1,
|
||||
@ -188,23 +162,11 @@ impl Alignment {
|
||||
8 => Alignment::A8,
|
||||
16 => Alignment::A16,
|
||||
32 => Alignment::A32,
|
||||
_ => unreachable!(),
|
||||
_ => panic!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Alignment {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
}
|
||||
|
||||
impl Ord for Alignment {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
(*self as usize).cmp(&(*other as usize))
|
||||
}
|
||||
}
|
||||
|
||||
fn align_up(offset: usize, alignment: Alignment) -> usize {
|
||||
(offset + alignment as usize - 1) & !(alignment as usize - 1)
|
||||
}
|
||||
@ -228,7 +190,10 @@ impl Type {
|
||||
let id_info = shader.spirv.id(type_id);
|
||||
|
||||
let ty = match *id_info.instruction() {
|
||||
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"),
|
||||
Instruction::TypeBool { .. } => bail!(
|
||||
shader.source,
|
||||
"SPIR-V Boolean types don't have a defined layout"
|
||||
),
|
||||
Instruction::TypeInt {
|
||||
width, signedness, ..
|
||||
} => Type::Scalar(TypeScalar::Int(TypeInt::new(shader, width, signedness)?)),
|
||||
@ -345,7 +310,7 @@ impl TypeInt {
|
||||
let signed = match signedness {
|
||||
0 => false,
|
||||
1 => true,
|
||||
_ => bail!(shader.source, "signedness must be 0 or 1"),
|
||||
_ => bail!(shader.source, "signedness must be either 0 or 1"),
|
||||
};
|
||||
|
||||
Ok(TypeInt { width, signed })
|
||||
@ -473,14 +438,17 @@ impl TypeVector {
|
||||
let component_count = ComponentCount::new(shader, component_count)?;
|
||||
|
||||
let component_type = match *shader.spirv.id(component_type_id).instruction() {
|
||||
Instruction::TypeBool { .. } => bail!(shader.source, "can't put booleans in structs"),
|
||||
Instruction::TypeBool { .. } => bail!(
|
||||
shader.source,
|
||||
"SPIR-V Boolean types don't have a defined layout"
|
||||
),
|
||||
Instruction::TypeInt {
|
||||
width, signedness, ..
|
||||
} => TypeScalar::Int(TypeInt::new(shader, width, signedness)?),
|
||||
Instruction::TypeFloat { width, .. } => {
|
||||
TypeScalar::Float(TypeFloat::new(shader, width)?)
|
||||
}
|
||||
_ => bail!(shader.source, "vector components must be scalars"),
|
||||
_ => bail!(shader.source, "vector components must be scalar"),
|
||||
};
|
||||
|
||||
Ok(TypeVector {
|
||||
@ -616,49 +584,46 @@ impl TypeArray {
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let stride = {
|
||||
let mut strides =
|
||||
shader
|
||||
.spirv
|
||||
.id(array_id)
|
||||
.decorations()
|
||||
.iter()
|
||||
.filter_map(|instruction| match *instruction {
|
||||
Instruction::Decorate {
|
||||
decoration: Decoration::ArrayStride { array_stride },
|
||||
..
|
||||
} => Some(array_stride as usize),
|
||||
_ => None,
|
||||
});
|
||||
let stride = strides.next().ok_or_else(|| {
|
||||
Error::new_spanned(
|
||||
&shader.source,
|
||||
"arrays inside structs must have an `ArrayStride` decoration",
|
||||
)
|
||||
})?;
|
||||
let mut strides =
|
||||
shader
|
||||
.spirv
|
||||
.id(array_id)
|
||||
.decorations()
|
||||
.iter()
|
||||
.filter_map(|instruction| match *instruction {
|
||||
Instruction::Decorate {
|
||||
decoration: Decoration::ArrayStride { array_stride },
|
||||
..
|
||||
} => Some(array_stride as usize),
|
||||
_ => None,
|
||||
});
|
||||
|
||||
if !strides.all(|s| s == stride) {
|
||||
bail!(shader.source, "found conflicting `ArrayStride` decorations");
|
||||
}
|
||||
let stride = strides.next().ok_or_else(|| {
|
||||
Error::new_spanned(
|
||||
&shader.source,
|
||||
"arrays inside structs must have an `ArrayStride` decoration",
|
||||
)
|
||||
})?;
|
||||
|
||||
if !is_aligned(stride, element_type.scalar_alignment()) {
|
||||
bail!(
|
||||
shader.source,
|
||||
"array strides must be aligned for the element type",
|
||||
);
|
||||
}
|
||||
if !strides.all(|s| s == stride) {
|
||||
bail!(shader.source, "found conflicting `ArrayStride` decorations");
|
||||
}
|
||||
|
||||
let element_size = element_type.size().ok_or_else(|| {
|
||||
Error::new_spanned(&shader.source, "array elements must be sized")
|
||||
})?;
|
||||
if !is_aligned(stride, element_type.scalar_alignment()) {
|
||||
bail!(
|
||||
shader.source,
|
||||
"array strides must be aligned to the element type's alignment",
|
||||
);
|
||||
}
|
||||
|
||||
if stride < element_size {
|
||||
bail!(shader.source, "array elements must not overlap");
|
||||
}
|
||||
|
||||
stride
|
||||
let Some(element_size) = element_type.size() else {
|
||||
bail!(shader.source, "array elements must be sized");
|
||||
};
|
||||
|
||||
if stride < element_size {
|
||||
bail!(shader.source, "array elements must not overlap");
|
||||
}
|
||||
|
||||
Ok(TypeArray {
|
||||
element_type,
|
||||
length,
|
||||
@ -690,16 +655,18 @@ impl TypeStruct {
|
||||
.iter()
|
||||
.find_map(|instruction| match instruction {
|
||||
Instruction::Name { name, .. } => {
|
||||
// Replace chars that could potentially cause the ident to be invalid with "_".
|
||||
// For example, Rust-GPU names structs by their fully qualified rust name (e.g.
|
||||
// "foo::bar::MyStruct") in which the ":" is an invalid character for idents.
|
||||
// Replace non-alphanumeric and non-ascii characters with '_' to ensure the name
|
||||
// is a valid identifier. For example, Rust-GPU names structs by their fully
|
||||
// qualified rust name (e.g. `foo::bar::MyStruct`) in which `:` makes it an
|
||||
// invalid identifier.
|
||||
let mut name =
|
||||
name.replace(|c: char| !(c.is_ascii_alphanumeric() || c == '_'), "_");
|
||||
if name.starts_with(|c: char| !c.is_ascii_alphabetic()) {
|
||||
name.replace(|c: char| !c.is_ascii_alphanumeric() && c != '_', "_");
|
||||
|
||||
if name.starts_with(|c: char| c.is_ascii_digit()) {
|
||||
name.insert(0, '_');
|
||||
}
|
||||
|
||||
// Worst case: invalid idents will get the UnnamedX name below
|
||||
// Fall-back to `Unnamed{Id}` if it's still invalid
|
||||
syn::parse_str(&name).ok()
|
||||
}
|
||||
_ => None,
|
||||
@ -725,9 +692,7 @@ impl TypeStruct {
|
||||
let mut ty = Type::new(shader, member_id)?;
|
||||
|
||||
{
|
||||
// If the member is an array, then matrix-decorations can be applied to it if the
|
||||
// innermost type of the array is a matrix. Else this will stay being the type of
|
||||
// the member.
|
||||
// Matrix decorations can be applied to an array if its innermost type is a matrix.
|
||||
let mut ty = &mut ty;
|
||||
while let Type::Array(TypeArray { element_type, .. }) = ty {
|
||||
ty = element_type;
|
||||
@ -744,6 +709,7 @@ impl TypeStruct {
|
||||
_ => None,
|
||||
},
|
||||
);
|
||||
|
||||
matrix.stride = strides.next().ok_or_else(|| {
|
||||
Error::new_spanned(
|
||||
&shader.source,
|
||||
@ -766,7 +732,7 @@ impl TypeStruct {
|
||||
);
|
||||
}
|
||||
|
||||
let mut majornessess = member_info.decorations().iter().filter_map(
|
||||
let mut majornesses = member_info.decorations().iter().filter_map(
|
||||
|instruction| match *instruction {
|
||||
Instruction::MemberDecorate {
|
||||
decoration: Decoration::ColMajor,
|
||||
@ -779,7 +745,8 @@ impl TypeStruct {
|
||||
_ => None,
|
||||
},
|
||||
);
|
||||
matrix.majorness = majornessess.next().ok_or_else(|| {
|
||||
|
||||
matrix.majorness = majornesses.next().ok_or_else(|| {
|
||||
Error::new_spanned(
|
||||
&shader.source,
|
||||
"matrices inside structs must have a `ColMajor` or `RowMajor` \
|
||||
@ -787,7 +754,7 @@ impl TypeStruct {
|
||||
)
|
||||
})?;
|
||||
|
||||
if !majornessess.all(|m| m == matrix.majorness) {
|
||||
if !majornesses.all(|m| m == matrix.majorness) {
|
||||
bail!(
|
||||
shader.source,
|
||||
"found conflicting matrix majorness decorations",
|
||||
@ -822,26 +789,27 @@ impl TypeStruct {
|
||||
if !is_aligned(offset, ty.scalar_alignment()) {
|
||||
bail!(
|
||||
shader.source,
|
||||
"struct member offsets must be aligned for the member type",
|
||||
"struct member offsets must be aligned to their type's alignment",
|
||||
);
|
||||
}
|
||||
|
||||
if let Some(last) = members.last() {
|
||||
if !is_aligned(offset, last.ty.scalar_alignment()) {
|
||||
if let Some(previous_member) = members.last() {
|
||||
if !is_aligned(offset, previous_member.ty.scalar_alignment()) {
|
||||
bail!(
|
||||
shader.source,
|
||||
"expected struct member offset to be aligned for the preceding member type",
|
||||
"expected struct member offset to be aligned to the preceding member \
|
||||
type's alignment",
|
||||
);
|
||||
}
|
||||
|
||||
let last_size = last.ty.size().ok_or_else(|| {
|
||||
let last_size = previous_member.ty.size().ok_or_else(|| {
|
||||
Error::new_spanned(
|
||||
&shader.source,
|
||||
"all members except the last member of a struct must be sized",
|
||||
)
|
||||
})?;
|
||||
|
||||
if last.offset + last_size > offset {
|
||||
if previous_member.offset + last_size > offset {
|
||||
bail!(shader.source, "struct members must not overlap");
|
||||
}
|
||||
}
|
||||
@ -888,8 +856,8 @@ impl PartialEq for Member {
|
||||
|
||||
impl Eq for Member {}
|
||||
|
||||
/// Helper for serializing a type to tokens with respect to macro input.
|
||||
struct Serializer<'a, T>(&'a T, &'a MacroInput);
|
||||
/// Helper for serializing a type as tokens according to the macro options.
|
||||
struct Serializer<'a, T>(&'a T, &'a MacroOptions);
|
||||
|
||||
impl ToTokens for Serializer<'_, Type> {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
@ -909,18 +877,16 @@ impl ToTokens for Serializer<'_, TypeVector> {
|
||||
let component_type = &self.0.component_type;
|
||||
let component_count = self.0.component_count as usize;
|
||||
|
||||
match self.1.linalg_type {
|
||||
LinAlgType::Std => {
|
||||
match self.1.linalg_types {
|
||||
LinAlgTypes::Std => {
|
||||
tokens.extend(quote! { [#component_type; #component_count] });
|
||||
}
|
||||
LinAlgType::CgMath => {
|
||||
LinAlgTypes::Cgmath => {
|
||||
let vector = format_ident!("Vector{}", component_count);
|
||||
tokens.extend(quote! { ::cgmath::#vector<#component_type> });
|
||||
}
|
||||
LinAlgType::Nalgebra => {
|
||||
tokens.extend(quote! {
|
||||
::nalgebra::base::SVector<#component_type, #component_count>
|
||||
});
|
||||
LinAlgTypes::Nalgebra => {
|
||||
tokens.extend(quote! { ::nalgebra::SVector<#component_type, #component_count> });
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -936,22 +902,22 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
|
||||
// This can't overflow because the stride must be at least the vector size.
|
||||
let padding = self.0.stride - self.0.vector_size();
|
||||
|
||||
match self.1.linalg_type {
|
||||
// cgmath only has column-major matrices. It also only has square matrices, and its 3x3
|
||||
// matrix is not padded right. Fall back to std for anything else.
|
||||
LinAlgType::CgMath
|
||||
match self.1.linalg_types {
|
||||
// cgmath only supports column-major square matrices, and its 3x3 matrix is not padded
|
||||
// correctly.
|
||||
LinAlgTypes::Cgmath
|
||||
if majorness == MatrixMajorness::ColumnMajor
|
||||
&& padding == 0
|
||||
&& vector_count == component_count =>
|
||||
&& vector_count == component_count
|
||||
&& padding == 0 =>
|
||||
{
|
||||
let matrix = format_ident!("Matrix{}", component_count);
|
||||
tokens.extend(quote! { ::cgmath::#matrix<#component_type> });
|
||||
}
|
||||
// nalgebra only has column-major matrices, and its 3xN matrices are not padded right.
|
||||
// Fall back to std for anything else.
|
||||
LinAlgType::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => {
|
||||
// nalgebra only supports column-major matrices, and its 3xN matrices are not padded
|
||||
// correctly.
|
||||
LinAlgTypes::Nalgebra if majorness == MatrixMajorness::ColumnMajor && padding == 0 => {
|
||||
tokens.extend(quote! {
|
||||
::nalgebra::base::SMatrix<#component_type, #component_count, #vector_count>
|
||||
::nalgebra::SMatrix<#component_type, #component_count, #vector_count>
|
||||
});
|
||||
}
|
||||
_ => {
|
||||
@ -964,7 +930,7 @@ impl ToTokens for Serializer<'_, TypeMatrix> {
|
||||
|
||||
impl ToTokens for Serializer<'_, TypeArray> {
|
||||
fn to_tokens(&self, tokens: &mut TokenStream) {
|
||||
let element_type = &*self.0.element_type;
|
||||
let element_type = self.0.element_type.as_ref();
|
||||
// This can't panic because array elements must be sized.
|
||||
let element_size = element_type.size().unwrap();
|
||||
// This can't overflow because the stride must be at least the element size.
|
||||
@ -1047,7 +1013,8 @@ impl ToTokens for Serializer<'_, TypeStruct> {
|
||||
}
|
||||
}
|
||||
|
||||
/// Helper for wrapping tokens in `Padded`. Doesn't wrap if the padding is `0`.
|
||||
/// Helper for wrapping tokens in [Padded][struct@vulkano::padded::Padded].
|
||||
/// Doesn't wrap if the padding is `0`.
|
||||
struct Padded<T>(T, usize);
|
||||
|
||||
impl<T: ToTokens> ToTokens for Padded<T> {
|
||||
|
Loading…
Reference in New Issue
Block a user