mirror of
https://github.com/vulkano-rs/vulkano.git
synced 2024-11-21 22:34:43 +00:00
Shader improvements (#1753)
* add Eq, PartialEq to ShaderExecution and thus GeometryShaderExecution
* use (String, ExecutionModel) to describe each entry point instead
* update tests
* Revert "update tests"
This reverts commit 2bd07d1ef4
.
* keep old entry_point interface but introduce new entry_point_with_execution for fine grained selection
* update tests
* move traits to autogen
* oops
This commit is contained in:
parent
124305a191
commit
9041858430
@ -242,7 +242,7 @@ where
|
||||
});
|
||||
let spirv_extensions = reflect::spirv_extensions(&spirv);
|
||||
let entry_points = reflect::entry_points(&spirv, exact_entrypoint_interface)
|
||||
.map(|(name, info)| entry_point::write_entry_point(&name, &info));
|
||||
.map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info));
|
||||
|
||||
let specialization_constants = structs::write_specialization_constants(
|
||||
prefix,
|
||||
@ -708,7 +708,7 @@ mod tests {
|
||||
let spirv = Spirv::new(&instructions).unwrap();
|
||||
|
||||
let mut descriptors = Vec::new();
|
||||
for (_, info) in reflect::entry_points(&spirv, true) {
|
||||
for (_, _, info) in reflect::entry_points(&spirv, true) {
|
||||
descriptors.push(info.descriptor_requirements);
|
||||
}
|
||||
|
||||
@ -779,7 +779,7 @@ mod tests {
|
||||
.unwrap();
|
||||
let spirv = Spirv::new(comp.as_binary()).unwrap();
|
||||
|
||||
for (_, info) in reflect::entry_points(&spirv, true) {
|
||||
for (_, _, info) in reflect::entry_points(&spirv, true) {
|
||||
let mut bindings = Vec::new();
|
||||
for (loc, _reqs) in info.descriptor_requirements {
|
||||
bindings.push(loc);
|
||||
|
@ -15,9 +15,11 @@ use vulkano::shader::{
|
||||
SpecializationConstantRequirements,
|
||||
};
|
||||
use vulkano::shader::{EntryPointInfo, ShaderInterface, ShaderStages};
|
||||
use vulkano::shader::spirv::ExecutionModel;
|
||||
|
||||
pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStream {
|
||||
pub(super) fn write_entry_point(name: &str, model: ExecutionModel, info: &EntryPointInfo) -> TokenStream {
|
||||
let execution = write_shader_execution(&info.execution);
|
||||
let model = syn::parse_str::<syn::Path>(&format!("vulkano::shader::spirv::ExecutionModel::{:?}", model)).unwrap();
|
||||
let descriptor_requirements = write_descriptor_requirements(&info.descriptor_requirements);
|
||||
let push_constant_requirements =
|
||||
write_push_constant_requirements(&info.push_constant_requirements);
|
||||
@ -29,6 +31,7 @@ pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStrea
|
||||
quote! {
|
||||
(
|
||||
#name.to_owned(),
|
||||
#model,
|
||||
EntryPointInfo {
|
||||
execution: #execution,
|
||||
descriptor_requirements: std::array::IntoIter::new(#descriptor_requirements).collect(),
|
||||
|
@ -530,8 +530,13 @@ fn value_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream {
|
||||
);
|
||||
let name_string = name.to_string();
|
||||
|
||||
let derives = match name_string.as_str() {
|
||||
"ExecutionModel" => quote! { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] },
|
||||
_ => quote! { #[derive(Clone, Debug, PartialEq)] },
|
||||
};
|
||||
|
||||
quote! {
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#derives
|
||||
#[allow(non_camel_case_types)]
|
||||
pub enum #name {
|
||||
#(#members_items)*
|
||||
|
@ -32,7 +32,7 @@ use crate::Version;
|
||||
use crate::VulkanObject;
|
||||
use fnv::FnvHashMap;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::error;
|
||||
use std::error::Error;
|
||||
use std::ffi::CStr;
|
||||
@ -49,6 +49,8 @@ use std::sync::Arc;
|
||||
pub mod reflect;
|
||||
pub mod spirv;
|
||||
|
||||
use spirv::ExecutionModel;
|
||||
|
||||
// Generated by build.rs
|
||||
include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
|
||||
|
||||
@ -57,7 +59,7 @@ include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
|
||||
pub struct ShaderModule {
|
||||
handle: ash::vk::ShaderModule,
|
||||
device: Arc<Device>,
|
||||
entry_points: HashMap<String, EntryPointInfo>,
|
||||
entry_points: HashMap<String, HashMap<ExecutionModel, EntryPointInfo>>,
|
||||
}
|
||||
|
||||
impl ShaderModule {
|
||||
@ -116,7 +118,7 @@ impl ShaderModule {
|
||||
spirv_version: Version,
|
||||
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
|
||||
spirv_extensions: impl IntoIterator<Item = &'a str>,
|
||||
entry_points: impl IntoIterator<Item = (String, EntryPointInfo)>,
|
||||
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
|
||||
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
|
||||
if let Err(reason) = check_spirv_version(&device, spirv_version) {
|
||||
return Err(ShaderCreationError::SpirvVersionNotSupported {
|
||||
@ -162,10 +164,29 @@ impl ShaderModule {
|
||||
output.assume_init()
|
||||
};
|
||||
|
||||
let entries = entry_points.into_iter().collect::<Vec<_>>();
|
||||
let entry_points = entries
|
||||
.iter()
|
||||
.filter_map(|(name, _, _)| Some(name))
|
||||
.collect::<HashSet<_>>()
|
||||
.iter()
|
||||
.map(|name| {
|
||||
((*name).clone(),
|
||||
entries.iter().filter_map(|(entry_name, entry_model, info)| {
|
||||
if &entry_name == name {
|
||||
Some((*entry_model, info.clone()))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}).collect::<HashMap<_, _>>()
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(Arc::new(ShaderModule {
|
||||
handle,
|
||||
device,
|
||||
entry_points: entry_points.into_iter().collect(),
|
||||
entry_points,
|
||||
}))
|
||||
}
|
||||
|
||||
@ -180,7 +201,7 @@ impl ShaderModule {
|
||||
spirv_version: Version,
|
||||
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
|
||||
spirv_extensions: impl IntoIterator<Item = &'a str>,
|
||||
entry_points: impl IntoIterator<Item = (String, EntryPointInfo)>,
|
||||
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
|
||||
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
|
||||
assert!((bytes.len() % 4) == 0);
|
||||
Self::from_words_with_data(
|
||||
@ -197,12 +218,31 @@ impl ShaderModule {
|
||||
}
|
||||
|
||||
/// Returns information about the entry point with the provided name. Returns `None` if no entry
|
||||
/// point with that name exists in the shader module.
|
||||
/// point with that name exists in the shader module or if multiple entry points with the same
|
||||
/// name exist.
|
||||
pub fn entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>> {
|
||||
self.entry_points.get(name).map(|info| EntryPoint {
|
||||
module: self,
|
||||
name: CString::new(name).unwrap(),
|
||||
info,
|
||||
self.entry_points.get(name).and_then(|infos| {
|
||||
if infos.len() == 1 {
|
||||
infos.iter().next().map(|(_, info)| EntryPoint {
|
||||
module: self,
|
||||
name: CString::new(name).unwrap(),
|
||||
info,
|
||||
})
|
||||
} else {
|
||||
None
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns information about the entry point with the provided name and execution model. Returns
|
||||
/// `None` if no entry and execution model exists in the shader module.
|
||||
pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option<EntryPoint<'a>> {
|
||||
self.entry_points.get(name).and_then(|infos| {
|
||||
infos.get(&execution).map(|info| EntryPoint {
|
||||
module: self,
|
||||
name: CString::new(name).unwrap(),
|
||||
info,
|
||||
})
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -400,7 +440,7 @@ impl<'a> EntryPoint<'a> {
|
||||
|
||||
/// The mode in which a shader executes. This includes both information about the shader type/stage,
|
||||
/// and additional data relevant to particular shader types.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub enum ShaderExecution {
|
||||
Vertex,
|
||||
TessellationControl,
|
||||
@ -425,7 +465,7 @@ pub enum TessellationShaderSubdivision {
|
||||
}*/
|
||||
|
||||
/// The mode in which a geometry shader executes.
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
|
||||
pub struct GeometryShaderExecution {
|
||||
pub input: GeometryShaderInput,
|
||||
/*pub max_output_vertices: u32,
|
||||
|
@ -53,7 +53,7 @@ pub fn spirv_extensions<'a>(spirv: &'a Spirv) -> impl Iterator<Item = &'a str> {
|
||||
pub fn entry_points<'a>(
|
||||
spirv: &'a Spirv,
|
||||
exact_interface: bool,
|
||||
) -> impl Iterator<Item = (String, EntryPointInfo)> + 'a {
|
||||
) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + 'a {
|
||||
spirv.iter_entry_point().filter_map(move |instruction| {
|
||||
let (execution_model, function_id, entry_point_name, interface) = match instruction {
|
||||
&Instruction::EntryPoint {
|
||||
@ -92,6 +92,7 @@ pub fn entry_points<'a>(
|
||||
|
||||
Some((
|
||||
entry_point_name.clone(),
|
||||
*execution_model,
|
||||
EntryPointInfo {
|
||||
execution,
|
||||
descriptor_requirements,
|
||||
|
Loading…
Reference in New Issue
Block a user