Shader improvements (#1753)

* add Eq, PartialEq to ShaderExecution and thus GeometryShaderExecution

* use (String, ExecutionModel) to describe each entry point instead

* update tests

* Revert "update tests"

This reverts commit 2bd07d1ef4.

* keep old entry_point interface but introduce new entry_point_with_execution for fine grained selection

* update tests

* move traits to autogen

* oops
This commit is contained in:
Will Song 2021-11-24 08:03:10 -05:00 committed by GitHub
parent 124305a191
commit 9041858430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 67 additions and 18 deletions

View File

@ -242,7 +242,7 @@ where
});
let spirv_extensions = reflect::spirv_extensions(&spirv);
let entry_points = reflect::entry_points(&spirv, exact_entrypoint_interface)
.map(|(name, info)| entry_point::write_entry_point(&name, &info));
.map(|(name, model, info)| entry_point::write_entry_point(&name, model, &info));
let specialization_constants = structs::write_specialization_constants(
prefix,
@ -708,7 +708,7 @@ mod tests {
let spirv = Spirv::new(&instructions).unwrap();
let mut descriptors = Vec::new();
for (_, info) in reflect::entry_points(&spirv, true) {
for (_, _, info) in reflect::entry_points(&spirv, true) {
descriptors.push(info.descriptor_requirements);
}
@ -779,7 +779,7 @@ mod tests {
.unwrap();
let spirv = Spirv::new(comp.as_binary()).unwrap();
for (_, info) in reflect::entry_points(&spirv, true) {
for (_, _, info) in reflect::entry_points(&spirv, true) {
let mut bindings = Vec::new();
for (loc, _reqs) in info.descriptor_requirements {
bindings.push(loc);

View File

@ -15,9 +15,11 @@ use vulkano::shader::{
SpecializationConstantRequirements,
};
use vulkano::shader::{EntryPointInfo, ShaderInterface, ShaderStages};
use vulkano::shader::spirv::ExecutionModel;
pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStream {
pub(super) fn write_entry_point(name: &str, model: ExecutionModel, info: &EntryPointInfo) -> TokenStream {
let execution = write_shader_execution(&info.execution);
let model = syn::parse_str::<syn::Path>(&format!("vulkano::shader::spirv::ExecutionModel::{:?}", model)).unwrap();
let descriptor_requirements = write_descriptor_requirements(&info.descriptor_requirements);
let push_constant_requirements =
write_push_constant_requirements(&info.push_constant_requirements);
@ -29,6 +31,7 @@ pub(super) fn write_entry_point(name: &str, info: &EntryPointInfo) -> TokenStrea
quote! {
(
#name.to_owned(),
#model,
EntryPointInfo {
execution: #execution,
descriptor_requirements: std::array::IntoIter::new(#descriptor_requirements).collect(),

View File

@ -530,8 +530,13 @@ fn value_enum_output(enums: &[(Ident, Vec<KindEnumMember>)]) -> TokenStream {
);
let name_string = name.to_string();
let derives = match name_string.as_str() {
"ExecutionModel" => quote! { #[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)] },
_ => quote! { #[derive(Clone, Debug, PartialEq)] },
};
quote! {
#[derive(Clone, Debug, PartialEq)]
#derives
#[allow(non_camel_case_types)]
pub enum #name {
#(#members_items)*

View File

@ -32,7 +32,7 @@ use crate::Version;
use crate::VulkanObject;
use fnv::FnvHashMap;
use std::borrow::Cow;
use std::collections::HashMap;
use std::collections::{HashMap, HashSet};
use std::error;
use std::error::Error;
use std::ffi::CStr;
@ -49,6 +49,8 @@ use std::sync::Arc;
pub mod reflect;
pub mod spirv;
use spirv::ExecutionModel;
// Generated by build.rs
include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
@ -57,7 +59,7 @@ include!(concat!(env!("OUT_DIR"), "/spirv_reqs.rs"));
pub struct ShaderModule {
handle: ash::vk::ShaderModule,
device: Arc<Device>,
entry_points: HashMap<String, EntryPointInfo>,
entry_points: HashMap<String, HashMap<ExecutionModel, EntryPointInfo>>,
}
impl ShaderModule {
@ -116,7 +118,7 @@ impl ShaderModule {
spirv_version: Version,
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
spirv_extensions: impl IntoIterator<Item = &'a str>,
entry_points: impl IntoIterator<Item = (String, EntryPointInfo)>,
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
if let Err(reason) = check_spirv_version(&device, spirv_version) {
return Err(ShaderCreationError::SpirvVersionNotSupported {
@ -162,10 +164,29 @@ impl ShaderModule {
output.assume_init()
};
let entries = entry_points.into_iter().collect::<Vec<_>>();
let entry_points = entries
.iter()
.filter_map(|(name, _, _)| Some(name))
.collect::<HashSet<_>>()
.iter()
.map(|name| {
((*name).clone(),
entries.iter().filter_map(|(entry_name, entry_model, info)| {
if &entry_name == name {
Some((*entry_model, info.clone()))
} else {
None
}
}).collect::<HashMap<_, _>>()
)
})
.collect();
Ok(Arc::new(ShaderModule {
handle,
device,
entry_points: entry_points.into_iter().collect(),
entry_points,
}))
}
@ -180,7 +201,7 @@ impl ShaderModule {
spirv_version: Version,
spirv_capabilities: impl IntoIterator<Item = &'a Capability>,
spirv_extensions: impl IntoIterator<Item = &'a str>,
entry_points: impl IntoIterator<Item = (String, EntryPointInfo)>,
entry_points: impl IntoIterator<Item = (String, ExecutionModel, EntryPointInfo)>,
) -> Result<Arc<ShaderModule>, ShaderCreationError> {
assert!((bytes.len() % 4) == 0);
Self::from_words_with_data(
@ -197,12 +218,31 @@ impl ShaderModule {
}
/// Returns information about the entry point with the provided name. Returns `None` if no entry
/// point with that name exists in the shader module.
/// point with that name exists in the shader module or if multiple entry points with the same
/// name exist.
pub fn entry_point<'a>(&'a self, name: &str) -> Option<EntryPoint<'a>> {
self.entry_points.get(name).map(|info| EntryPoint {
module: self,
name: CString::new(name).unwrap(),
info,
self.entry_points.get(name).and_then(|infos| {
if infos.len() == 1 {
infos.iter().next().map(|(_, info)| EntryPoint {
module: self,
name: CString::new(name).unwrap(),
info,
})
} else {
None
}
})
}
/// Returns information about the entry point with the provided name and execution model. Returns
/// `None` if no entry and execution model exists in the shader module.
pub fn entry_point_with_execution<'a>(&'a self, name: &str, execution: ExecutionModel) -> Option<EntryPoint<'a>> {
self.entry_points.get(name).and_then(|infos| {
infos.get(&execution).map(|info| EntryPoint {
module: self,
name: CString::new(name).unwrap(),
info,
})
})
}
}
@ -400,7 +440,7 @@ impl<'a> EntryPoint<'a> {
/// The mode in which a shader executes. This includes both information about the shader type/stage,
/// and additional data relevant to particular shader types.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub enum ShaderExecution {
Vertex,
TessellationControl,
@ -425,7 +465,7 @@ pub enum TessellationShaderSubdivision {
}*/
/// The mode in which a geometry shader executes.
#[derive(Clone, Copy, Debug)]
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct GeometryShaderExecution {
pub input: GeometryShaderInput,
/*pub max_output_vertices: u32,

View File

@ -53,7 +53,7 @@ pub fn spirv_extensions<'a>(spirv: &'a Spirv) -> impl Iterator<Item = &'a str> {
pub fn entry_points<'a>(
spirv: &'a Spirv,
exact_interface: bool,
) -> impl Iterator<Item = (String, EntryPointInfo)> + 'a {
) -> impl Iterator<Item = (String, ExecutionModel, EntryPointInfo)> + 'a {
spirv.iter_entry_point().filter_map(move |instruction| {
let (execution_model, function_id, entry_point_name, interface) = match instruction {
&Instruction::EntryPoint {
@ -92,6 +92,7 @@ pub fn entry_points<'a>(
Some((
entry_point_name.clone(),
*execution_model,
EntryPointInfo {
execution,
descriptor_requirements,