diff --git a/vulkano-shaders/Cargo.toml b/vulkano-shaders/Cargo.toml index 4f714c71..30f19541 100644 --- a/vulkano-shaders/Cargo.toml +++ b/vulkano-shaders/Cargo.toml @@ -24,5 +24,7 @@ syn = { version = "1.0", features = ["full", "extra-traits"] } vulkano = { version = "0.32.0", path = "../vulkano" } [features] +cgmath = [] +nalgebra = [] shaderc-build-from-source = ["shaderc/build-from-source"] shaderc-debug = [] diff --git a/vulkano-shaders/src/codegen.rs b/vulkano-shaders/src/codegen.rs index 90b2b6c7..de58654d 100644 --- a/vulkano-shaders/src/codegen.rs +++ b/vulkano-shaders/src/codegen.rs @@ -7,7 +7,7 @@ // notice may not be copied, modified, or distributed except // according to those terms. -use crate::{entry_point, read_file_to_string, structs, RegisteredType, TypesMeta}; +use crate::{entry_point, read_file_to_string, structs, LinAlgType, RegisteredType, TypesMeta}; use ahash::HashMap; use proc_macro2::TokenStream; pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind}; @@ -205,7 +205,7 @@ pub fn compile( Ok((content, includes)) } -pub(super) fn reflect<'a>( +pub(super) fn reflect<'a, L: LinAlgType>( prefix: &'a str, words: &[u32], types_meta: &TypesMeta, @@ -243,8 +243,12 @@ pub(super) fn reflect<'a>( let entry_points = reflect::entry_points(&spirv) .map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info)); - let specialization_constants = - structs::write_specialization_constants(prefix, &spirv, shared_constants, types_registry); + let specialization_constants = structs::write_specialization_constants::( + prefix, + &spirv, + shared_constants, + types_registry, + ); let load_name = if prefix.is_empty() { format_ident!("load") @@ -278,7 +282,7 @@ pub(super) fn reflect<'a>( #specialization_constants }; - let structs = structs::write_structs(prefix, &spirv, types_meta, types_registry); + let structs = structs::write_structs::(prefix, &spirv, types_meta, types_registry); Ok((shader_code, structs)) } @@ -304,7 +308,7 @@ impl From for Error { #[cfg(test)] mod tests { use super::*; - use crate::codegen::compile; + use crate::{codegen::compile, StdArray}; use shaderc::ShaderKind; use std::path::{Path, PathBuf}; use vulkano::shader::{reflect, spirv::Spirv}; @@ -371,7 +375,12 @@ mod tests { .unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap(); let res = std::panic::catch_unwind(|| { - structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default()) + structs::write_structs::( + "", + &spirv, + &TypesMeta::default(), + &mut HashMap::default(), + ) }); assert!(res.is_err()); } @@ -400,7 +409,12 @@ mod tests { ) .unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap(); - structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default()); + structs::write_structs::( + "", + &spirv, + &TypesMeta::default(), + &mut HashMap::default(), + ); } #[test] fn test_wrap_alignment() { @@ -432,7 +446,12 @@ mod tests { ) .unwrap(); let spirv = Spirv::new(comp.as_binary()).unwrap(); - structs::write_structs("", &spirv, &TypesMeta::default(), &mut HashMap::default()); + structs::write_structs::( + "", + &spirv, + &TypesMeta::default(), + &mut HashMap::default(), + ); } #[test] diff --git a/vulkano-shaders/src/lib.rs b/vulkano-shaders/src/lib.rs index 55d03279..4fca5792 100644 --- a/vulkano-shaders/src/lib.rs +++ b/vulkano-shaders/src/lib.rs @@ -226,6 +226,7 @@ extern crate syn; use crate::codegen::ShaderKind; use ahash::HashMap; +use proc_macro2::TokenStream; use shaderc::{EnvVersion, SpirvVersion}; use std::{ borrow::Cow, @@ -245,6 +246,160 @@ mod codegen; mod entry_point; mod structs; +/// Generates vectors and matrices using standard Rust arrays. +#[proc_macro] +pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + shader_inner::(input) +} + +/// Generates vectors and matrices using the [`cgmath`] library where possible, falling back to +/// standard Rust arrays otherwise. +/// +/// [`cgmath`]: https://crates.io/crates/cgmath +#[cfg(feature = "cgmath")] +#[proc_macro] +pub fn shader_cgmath(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + shader_inner::(input) +} + +/// Generates vectors and matrices using the [`nalgebra`] library. +/// +/// [`nalgebra`]: https://crates.io/crates/nalgebra +#[cfg(feature = "nalgebra")] +#[proc_macro] +pub fn shader_nalgebra(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + shader_inner::(input) +} + +fn shader_inner(input: proc_macro::TokenStream) -> proc_macro::TokenStream { + let input = parse_macro_input!(input as MacroInput); + + let is_single = input.shaders.len() == 1; + let root = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into()); + let root_path = Path::new(&root); + + let mut shaders_code = Vec::with_capacity(input.shaders.len()); + let mut types_code = Vec::with_capacity(input.shaders.len()); + let mut types_registry = HashMap::default(); + + for (prefix, (shader_kind, shader_source)) in input.shaders { + let (code, types) = if let SourceKind::Bytes(path) = shader_source { + let full_path = root_path.join(&path); + + let bytes = if full_path.is_file() { + fs::read(full_path) + .unwrap_or_else(|_| panic!("Error reading source from {:?}", path)) + } else { + panic!( + "File {:?} was not found; note that the path must be relative to your Cargo.toml", + path + ); + }; + + // The SPIR-V specification essentially guarantees that + // a shader will always be an integer number of words + assert_eq!(0, bytes.len() % 4); + codegen::reflect::( + prefix.as_str(), + unsafe { from_raw_parts(bytes.as_slice().as_ptr() as *const u32, bytes.len() / 4) }, + &input.types_meta, + empty(), + input.shared_constants, + &mut types_registry, + ) + .unwrap() + } else { + let (path, full_path, source_code) = match shader_source { + SourceKind::Src(source) => (None, None, source), + SourceKind::Path(path) => { + let full_path = root_path.join(&path); + let source_code = read_file_to_string(&full_path) + .unwrap_or_else(|_| panic!("Error reading source from {:?}", path)); + + if full_path.is_file() { + (Some(path.clone()), Some(full_path), source_code) + } else { + panic!("File {:?} was not found; note that the path must be relative to your Cargo.toml", path); + } + } + SourceKind::Bytes(_) => unreachable!(), + }; + + let include_paths = input + .include_directories + .iter() + .map(|include_directory| { + let include_path = Path::new(include_directory); + let mut full_include_path = root_path.to_owned(); + full_include_path.push(include_path); + full_include_path + }) + .collect::>(); + + let (content, includes) = match codegen::compile( + path, + &root_path, + &source_code, + shader_kind, + &include_paths, + &input.macro_defines, + input.vulkan_version, + input.spirv_version, + ) { + Ok(ok) => ok, + Err(e) => { + if is_single { + panic!("{}", e.replace("(s): ", "(s):\n")) + } else { + panic!("Shader {:?} {}", prefix, e.replace("(s): ", "(s):\n")) + } + } + }; + + let input_paths = includes + .iter() + .map(|s| s.as_ref()) + .chain(full_path.as_deref().map(codegen::path_to_str)); + + codegen::reflect::( + prefix.as_str(), + content.as_binary(), + &input.types_meta, + input_paths, + input.shared_constants, + &mut types_registry, + ) + .unwrap() + }; + + shaders_code.push(code); + types_code.push(types); + } + + let uses = &input.types_meta.uses; + + let result = quote! { + #( + #shaders_code + )* + + pub mod ty { + #( #uses )* + + #( + #types_code + )* + } + }; + + if input.dump { + println!("{}", result); + panic!("`shader!` rust codegen dumped") // TODO: use span from dump + } + + proc_macro::TokenStream::from(result) +} + enum SourceKind { Src(String), Path(String), @@ -794,132 +949,57 @@ pub(self) fn read_file_to_string(full_path: &Path) -> IoResult { Ok(buf) } -#[proc_macro] -pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(input as MacroInput); - - let is_single = input.shaders.len() == 1; - let root = env::var("CARGO_MANIFEST_DIR").unwrap_or_else(|_| ".".into()); - let root_path = Path::new(&root); - - let mut shaders_code = Vec::with_capacity(input.shaders.len()); - let mut types_code = Vec::with_capacity(input.shaders.len()); - let mut types_registry = HashMap::default(); - - for (prefix, (shader_kind, shader_source)) in input.shaders { - let (code, types) = if let SourceKind::Bytes(path) = shader_source { - let full_path = root_path.join(&path); - - let bytes = if full_path.is_file() { - fs::read(full_path) - .unwrap_or_else(|_| panic!("Error reading source from {:?}", path)) - } else { - panic!( - "File {:?} was not found; note that the path must be relative to your Cargo.toml", - path - ); - }; - - // The SPIR-V specification essentially guarantees that - // a shader will always be an integer number of words - assert_eq!(0, bytes.len() % 4); - codegen::reflect( - prefix.as_str(), - unsafe { from_raw_parts(bytes.as_slice().as_ptr() as *const u32, bytes.len() / 4) }, - &input.types_meta, - empty(), - input.shared_constants, - &mut types_registry, - ) - .unwrap() - } else { - let (path, full_path, source_code) = match shader_source { - SourceKind::Src(source) => (None, None, source), - SourceKind::Path(path) => { - let full_path = root_path.join(&path); - let source_code = read_file_to_string(&full_path) - .unwrap_or_else(|_| panic!("Error reading source from {:?}", path)); - - if full_path.is_file() { - (Some(path.clone()), Some(full_path), source_code) - } else { - panic!("File {:?} was not found; note that the path must be relative to your Cargo.toml", path); - } - } - SourceKind::Bytes(_) => unreachable!(), - }; - - let include_paths = input - .include_directories - .iter() - .map(|include_directory| { - let include_path = Path::new(include_directory); - let mut full_include_path = root_path.to_owned(); - full_include_path.push(include_path); - full_include_path - }) - .collect::>(); - - let (content, includes) = match codegen::compile( - path, - &root_path, - &source_code, - shader_kind, - &include_paths, - &input.macro_defines, - input.vulkan_version, - input.spirv_version, - ) { - Ok(ok) => ok, - Err(e) => { - if is_single { - panic!("{}", e.replace("(s): ", "(s):\n")) - } else { - panic!("Shader {:?} {}", prefix, e.replace("(s): ", "(s):\n")) - } - } - }; - - let input_paths = includes - .iter() - .map(|s| s.as_ref()) - .chain(full_path.as_deref().map(codegen::path_to_str)); - - codegen::reflect( - prefix.as_str(), - content.as_binary(), - &input.types_meta, - input_paths, - input.shared_constants, - &mut types_registry, - ) - .unwrap() - }; - - shaders_code.push(code); - types_code.push(types); - } - - let uses = &input.types_meta.uses; - - let result = quote! { - #( - #shaders_code - )* - - pub mod ty { - #( #uses )* - - #( - #types_code - )* - } - }; - - if input.dump { - println!("{}", result); - panic!("`shader!` rust codegen dumped") // TODO: use span from dump - } - - proc_macro::TokenStream::from(result) +trait LinAlgType { + fn vector(component_type: &TokenStream, component_count: usize) -> TokenStream; + fn matrix(component_type: &TokenStream, row_count: usize, column_count: usize) -> TokenStream; +} + +struct StdArray; + +impl LinAlgType for StdArray { + fn vector(component_type: &TokenStream, component_count: usize) -> TokenStream { + quote! { [#component_type; #component_count] } + } + + fn matrix(component_type: &TokenStream, row_count: usize, column_count: usize) -> TokenStream { + quote! { [[#component_type; #row_count]; #column_count] } + } +} + +struct CGMath; + +impl LinAlgType for CGMath { + fn vector(component_type: &TokenStream, component_count: usize) -> TokenStream { + // cgmath only has 1, 2, 3 and 4-component vector types. + // Fall back to arrays for anything else. + if matches!(component_count, 1 | 2 | 3 | 4) { + let ty = format_ident!("{}", format!("Vector{}", component_count)); + quote! { cgmath::#ty<#component_type> } + } else { + StdArray::vector(component_type, component_count) + } + } + + fn matrix(component_type: &TokenStream, row_count: usize, column_count: usize) -> TokenStream { + // cgmath only has square 2x2, 3x3 and 4x4 matrix types. + // Fall back to arrays for anything else. + if row_count == column_count && matches!(column_count, 2 | 3 | 4) { + let ty = format_ident!("{}", format!("Matrix{}", column_count)); + quote! { cgmath::#ty<#component_type> } + } else { + StdArray::matrix(component_type, row_count, column_count) + } + } +} + +struct Nalgebra; + +impl LinAlgType for Nalgebra { + fn vector(component_type: &TokenStream, component_count: usize) -> TokenStream { + quote! { nalgebra::base::SVector<#component_type, #component_count> } + } + + fn matrix(component_type: &TokenStream, row_count: usize, column_count: usize) -> TokenStream { + quote! { nalgebra::base::SMatrix<#component_type, #row_count, #column_count> } + } } diff --git a/vulkano-shaders/src/structs.rs b/vulkano-shaders/src/structs.rs index a607c4e5..4410c2e6 100644 --- a/vulkano-shaders/src/structs.rs +++ b/vulkano-shaders/src/structs.rs @@ -7,7 +7,7 @@ // notice may not be copied, modified, or distributed except // according to those terms. -use crate::{RegisteredType, TypesMeta}; +use crate::{LinAlgType, RegisteredType, TypesMeta}; use ahash::HashMap; use heck::ToUpperCamelCase; use proc_macro2::{Span, TokenStream}; @@ -16,7 +16,7 @@ use syn::{Ident, LitStr}; use vulkano::shader::spirv::{Decoration, Id, Instruction, Spirv}; /// Translates all the structs that are contained in the SPIR-V document as Rust structs. -pub(super) fn write_structs<'a>( +pub(super) fn write_structs<'a, L: LinAlgType>( shader: &'a str, spirv: &'a Spirv, types_meta: &'a TypesMeta, @@ -33,7 +33,8 @@ pub(super) fn write_structs<'a>( }) .filter(|&(struct_id, _member_types)| has_defined_layout(spirv, struct_id)) .filter_map(|(struct_id, member_types)| { - let (rust_members, is_sized) = write_struct_members(spirv, struct_id, member_types); + let (rust_members, is_sized) = + write_struct_members::(spirv, struct_id, member_types); let struct_name = spirv .id(struct_id) @@ -85,7 +86,11 @@ struct Member { signature: Cow<'static, str>, } -fn write_struct_members(spirv: &Spirv, struct_id: Id, members: &[Id]) -> (Vec, bool) { +fn write_struct_members( + spirv: &Spirv, + struct_id: Id, + members: &[Id], +) -> (Vec, bool) { let mut rust_members = Vec::with_capacity(members.len()); // Dummy members will be named `_dummyN` where `N` is determined by this variable. @@ -101,7 +106,7 @@ fn write_struct_members(spirv: &Spirv, struct_id: Id, members: &[Id]) -> (Vec(spirv, member); let member_name = member_info .iter_name() .find_map(|instruction| match instruction { @@ -397,7 +402,7 @@ fn struct_size_from_array_stride(spirv: &Spirv, type_id: Id) -> Option { /// Returns the type name to put in the Rust struct, and its size and alignment. /// /// The size can be `None` if it's only known at runtime. -pub(super) fn type_from_id( +pub(super) fn type_from_id( spirv: &Spirv, type_id: Id, ) -> (TokenStream, Cow<'static, str>, Option, usize) { @@ -551,32 +556,53 @@ pub(super) fn type_from_id( .. } => { debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::()); - let (ty, item, t_size, t_align) = type_from_id(spirv, component_type); - let array_length = component_count as usize; - let size = t_size.map(|s| s * component_count as usize); - ( - quote! { [#ty; #array_length] }, - Cow::from(format!("[{}; {}]", item, array_length)), - size, - t_align, - ) + let component_count = component_count as usize; + let (element_ty, element_item, element_size, align) = + type_from_id::(spirv, component_type); + + let ty = L::vector(&element_ty, component_count); + let item = Cow::from(format!("[{}; {}]", element_item, component_count)); + let size = element_size.map(|s| s * component_count); + + (ty, item, size, align) } &Instruction::TypeMatrix { column_type, column_count, .. } => { - // FIXME: row-major or column-major debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::()); - let (ty, item, t_size, t_align) = type_from_id(spirv, column_type); - let array_length = column_count as usize; - let size = t_size.map(|s| s * column_count as usize); - ( - quote! { [#ty; #array_length] }, - Cow::from(format!("[{}; {}]", item, array_length)), - size, - t_align, - ) + let column_count = column_count as usize; + + // FIXME: row-major or column-major + let (row_count, element, element_item, element_size, align) = + match spirv.id(column_type).instruction() { + &Instruction::TypeVector { + component_type, + component_count, + .. + } => { + let (element, element_item, element_size, align) = + type_from_id::(spirv, component_type); + ( + component_count as usize, + element, + element_item, + element_size, + align, + ) + } + _ => unreachable!(), + }; + + let ty = L::matrix(&element, row_count, column_count); + let size = element_size.map(|s| s * row_count * column_count); + let item = Cow::from(format!( + "[[{}; {}]; {}]", + element_item, row_count, column_count + )); + + (ty, item, size, align) } &Instruction::TypeArray { element_type, @@ -586,7 +612,7 @@ pub(super) fn type_from_id( debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::()); let (element_type, element_type_string, element_size, element_align) = - type_from_id(spirv, element_type); + type_from_id::(spirv, element_type); let element_size = element_size.expect("array components must be sized"); let array_length = match spirv.id(length).instruction() { @@ -624,7 +650,7 @@ pub(super) fn type_from_id( debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::()); let (element_type, element_type_string, _, element_align) = - type_from_id(spirv, element_type); + type_from_id::(spirv, element_type); ( quote! { [#element_type] }, @@ -660,7 +686,7 @@ pub(super) fn type_from_id( _ => None, }) .unwrap(); - let (_, _, rust_size, _) = type_from_id(spirv, member); + let (_, _, rust_size, _) = type_from_id::(spirv, member); rust_size.map(|rust_size| spirv_offset + rust_size) }) }) @@ -668,7 +694,7 @@ pub(super) fn type_from_id( let align = member_types .iter() - .map(|&t| type_from_id(spirv, t).3) + .map(|&t| type_from_id::(spirv, t).3) .max() .unwrap_or(1); @@ -692,7 +718,7 @@ pub(super) fn type_from_id( /// Writes the `SpecializationConstants` struct that contains the specialization constants and /// implements the `Default` and the `vulkano::shader::SpecializationConstants` traits. -pub(super) fn write_specialization_constants<'a>( +pub(super) fn write_specialization_constants<'a, L: LinAlgType>( shader: &'a str, spirv: &Spirv, shared_constants: bool, @@ -757,7 +783,7 @@ pub(super) fn write_specialization_constants<'a>( Some(mem::size_of::()), mem::align_of::(), ), - _ => type_from_id(spirv, result_type_id), + _ => type_from_id::(spirv, result_type_id), }; let rust_size = rust_size.expect("Found runtime-sized specialization constant"); diff --git a/vulkano/Cargo.toml b/vulkano/Cargo.toml index df039f3b..27851e40 100644 --- a/vulkano/Cargo.toml +++ b/vulkano/Cargo.toml @@ -19,6 +19,7 @@ ahash = "0.8" # All versions of vk.xml can be found at https://github.com/KhronosGroup/Vulkan-Headers/commits/main/registry/vk.xml. ash = "^0.37.1" bytemuck = { version = "1.7", features = ["derive", "extern_crate_std", "min_const_generics"] } +cgmath = { version = "0.18.0", optional = true } crossbeam-queue = "0.3" half = "2" libloading = "0.7" diff --git a/vulkano/autogen/formats.rs b/vulkano/autogen/formats.rs index b4761418..b5697152 100644 --- a/vulkano/autogen/formats.rs +++ b/vulkano/autogen/formats.rs @@ -51,12 +51,15 @@ struct FormatMember { components: [u8; 4], compression: Option, planes: Vec, - rust_type: Option, texels_per_block: u8, type_color: Option, type_depth: Option, type_stencil: Option, ycbcr_chroma_sampling: Option, + + type_std_array: Option, + type_cgmath: Option, + type_nalgebra: Option, } fn formats_output(members: &[FormatMember]) -> TokenStream { @@ -208,15 +211,43 @@ fn formats_output(members: &[FormatMember]) -> TokenStream { let try_from_items = members.iter().map(|FormatMember { name, ffi_name, .. }| { quote! { ash::vk::Format::#ffi_name => Ok(Self::#name), } }); + let type_for_format_items = members.iter().filter_map( |FormatMember { - name, rust_type, .. + name, + type_std_array, + .. }| { - rust_type.as_ref().map(|rust_type| { - quote! { (#name) => { #rust_type }; } + type_std_array.as_ref().map(|ty| { + quote! { (#name) => { #ty }; } }) }, ); + let type_for_format_cgmath_items = members.iter().filter_map( + |FormatMember { + name, + type_std_array, + type_cgmath, + .. + }| { + (type_cgmath.as_ref().or(type_std_array.as_ref())).map(|ty| { + quote! { (#name) => { #ty }; } + }) + }, + ); + let type_for_format_nalgebra_items = members.iter().filter_map( + |FormatMember { + name, + type_std_array, + type_nalgebra, + .. + }| { + (type_nalgebra.as_ref().or(type_std_array.as_ref())).map(|ty| { + quote! { (#name) => { #ty }; } + }) + }, + ); + let validate_device_items = members.iter().map(|FormatMember { name, requires, .. }| { let requires_items = requires.iter().map( |RequiresOneOf { @@ -502,13 +533,12 @@ fn formats_output(members: &[FormatMember]) -> TokenStream { } } - /// Converts a format enum identifier to a type that is suitable for representing the format - /// in a buffer or image. + /// Converts a format enum identifier to a standard Rust type that is suitable for + /// representing the format in a buffer or image. /// /// This macro returns one possible suitable representation, but there are usually other - /// possibilities for a given format, including those provided by external libraries like - /// `cmath` or `nalgebra`. A compile error occurs for formats that have no well-defined size - /// (the `size` method returns `None`). + /// possibilities for a given format. A compile error occurs for formats that have no + /// well-defined size (the `size` method returns `None`). /// /// - For regular unpacked formats with one component, this returns a single floating point, /// signed or unsigned integer with the appropriate number of bits. For formats with @@ -533,6 +563,77 @@ fn formats_output(members: &[FormatMember]) -> TokenStream { macro_rules! type_for_format { #(#type_for_format_items)* } + + /// Converts a format enum identifier to a [`cgmath`] or standard Rust type that is + /// suitable for representing the format in a buffer or image. + /// + /// This macro returns one possible suitable representation, but there are usually other + /// possibilities for a given format. A compile error occurs for formats that have no + /// well-defined size (the `size` method returns `None`). + /// + /// - For regular unpacked formats with one component, this returns a single floating point, + /// signed or unsigned integer with the appropriate number of bits. For formats with + /// multiple components, a [`cgmath`] `Vector` is returned. + /// - For packed formats, this returns an unsigned integer with the size of the packed + /// element. For multi-packed formats (such as `2PACK16`), an array is returned. + /// - For compressed formats, this returns `[u8; N]` where N is the size of a block. + /// + /// Note: for 16-bit floating point values, you need to import the [`half::f16`] type. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate vulkano; + /// # fn main() { + /// let pixel: type_for_format_cgmath!(R32G32B32A32_SFLOAT); + /// # } + /// ``` + /// + /// The type of `pixel` will be [`Vector4`]. + /// + /// [`cgmath`]: https://crates.io/crates/cgmath + /// [`Vector4`]: https://docs.rs/cgmath/latest/cgmath/struct.Vector4.html + #[cfg(feature = "cgmath")] + #[macro_export] + macro_rules! type_for_format_cgmath { + #(#type_for_format_cgmath_items)* + } + + /// Converts a format enum identifier to a [`nalgebra`] or standard Rust type that is + /// suitable for representing the format in a buffer or image. + /// + /// This macro returns one possible suitable representation, but there are usually other + /// possibilities for a given format. A compile error occurs for formats that have no + /// well-defined size (the `size` method returns `None`). + /// + /// - For regular unpacked formats with one component, this returns a single floating point, + /// signed or unsigned integer with the appropriate number of bits. For formats with + /// multiple components, a [`nalgebra`] [`SVector`] is returned. + /// - For packed formats, this returns an unsigned integer with the size of the packed + /// element. For multi-packed formats (such as `2PACK16`), an array is returned. + /// - For compressed formats, this returns `[u8; N]` where N is the size of a block. + /// + /// Note: for 16-bit floating point values, you need to import the [`half::f16`] type. + /// + /// # Examples + /// + /// ``` + /// # #[macro_use] extern crate vulkano; + /// # fn main() { + /// let pixel: type_for_format_nalgebra!(R32G32B32A32_SFLOAT); + /// # } + /// ``` + /// + /// The type of `pixel` will be [`Vector4`]. + /// + /// [`nalgebra`]: https://crates.io/crates/nalgebra + /// [`SVector`]: https://docs.rs/nalgebra/latest/nalgebra/base/type.SVector.html + /// [`Vector4`]: https://docs.rs/nalgebra/latest/nalgebra/base/type.Vector4.html + #[cfg(feature = "nalgebra")] + #[macro_export] + macro_rules! type_for_format_nalgebra { + #(#type_for_format_nalgebra_items)* + } } } @@ -582,12 +683,15 @@ fn formats_members( .as_ref() .map(|c| format_ident!("{}", c.replace(' ', "_"))), planes: vec![], - rust_type: None, texels_per_block: format.texelsPerBlock, type_color: None, type_depth: None, type_stencil: None, ycbcr_chroma_sampling: None, + + type_std_array: None, + type_cgmath: None, + type_nalgebra: None, }; for child in &format.children { @@ -682,7 +786,7 @@ fn formats_members( member.block_size = Some(format.blockSize as u64); if format.compressed.is_some() { - member.rust_type = Some({ + member.type_std_array = Some({ let block_size = Literal::usize_unsuffixed(format.blockSize as usize); quote! { [u8; #block_size] } }); @@ -690,7 +794,7 @@ fn formats_members( let pack_elements = format.blockSize * 8 / pack_bits; let element_type = format_ident!("u{}", pack_bits); - member.rust_type = Some(if pack_elements > 1 { + member.type_std_array = Some(if pack_elements > 1 { let elements = Literal::usize_unsuffixed(pack_elements as usize); quote! { [#element_type; #elements] } } else { @@ -704,9 +808,9 @@ fn formats_members( _ => unreachable!(), }; let bits = member.components[0]; - let element_type = format_ident!("{}{}", prefix, bits); + let component_type = format_ident!("{}{}", prefix, bits); - let elements = if member.components[1] == 2 * bits { + let component_count = if member.components[1] == 2 * bits { // 422 format with repeated G component 4 } else { @@ -725,12 +829,23 @@ fn formats_members( .count() }; - member.rust_type = Some(if elements > 1 { - let elements = Literal::usize_unsuffixed(elements); - quote! { [#element_type; #elements] } + if component_count > 1 { + let elements = Literal::usize_unsuffixed(component_count); + member.type_std_array = Some(quote! { [#component_type; #elements] }); + + // cgmath only has 1, 2, 3 and 4-component vector types. + // Fall back to arrays for anything else. + if matches!(component_count, 1 | 2 | 3 | 4) { + let ty = format_ident!("{}", format!("Vector{}", component_count)); + member.type_cgmath = Some(quote! { cgmath::#ty<#component_type> }); + } + + member.type_nalgebra = Some(quote! { + nalgebra::base::SVector<#component_type, #component_count> + }); } else { - quote! { #element_type } - }); + member.type_std_array = Some(quote! { #component_type }); + } } } diff --git a/vulkano/src/pipeline/graphics/vertex_input/impl_vertex.rs b/vulkano/src/pipeline/graphics/vertex_input/impl_vertex.rs index 24739305..8a8d8104 100644 --- a/vulkano/src/pipeline/graphics/vertex_input/impl_vertex.rs +++ b/vulkano/src/pipeline/graphics/vertex_input/impl_vertex.rs @@ -162,6 +162,81 @@ where } } +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Vector1 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + ::format() + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Vector2 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + let (ty, sz) = ::format(); + (ty, sz * 2) + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Vector3 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + let (ty, sz) = ::format(); + (ty, sz * 3) + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Vector4 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + let (ty, sz) = ::format(); + (ty, sz * 4) + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Point1 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + ::format() + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Point2 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + let (ty, sz) = ::format(); + (ty, sz * 2) + } +} + +#[cfg(feature = "cgmath")] +unsafe impl VertexMember for cgmath::Point3 +where + T: VertexMember, +{ + fn format() -> (VertexMemberTy, usize) { + let (ty, sz) = ::format(); + (ty, sz * 3) + } +} + #[cfg(feature = "nalgebra")] unsafe impl VertexMember for nalgebra::Vector1 where