Rewrite VertexDefinition (#2487)

This commit is contained in:
Rua 2024-03-04 22:58:27 +01:00 committed by GitHub
parent 960e554ca0
commit b190d6fb1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
38 changed files with 325 additions and 575 deletions

View File

@ -432,9 +432,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = MyVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = MyVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -222,9 +222,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -87,9 +87,7 @@ impl AmbientLightingSystem {
.expect("failed to create shader module")
.entry_point("main")
.expect("shader entry point not found");
let vertex_input_state = LightingVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = LightingVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -88,9 +88,7 @@ impl DirectionalLightingSystem {
.expect("failed to create shader module")
.entry_point("main")
.expect("shader entry point not found");
let vertex_input_state = LightingVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = LightingVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -87,9 +87,7 @@ impl PointLightingSystem {
.expect("failed to create shader module")
.entry_point("main")
.expect("shader entry point not found");
let vertex_input_state = LightingVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = LightingVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -77,9 +77,7 @@ impl TriangleDrawSystem {
.expect("failed to create shader module")
.entry_point("main")
.expect("shader entry point not found");
let vertex_input_state = TriangleVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = TriangleVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -694,9 +694,7 @@ mod linux {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = MyVertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = MyVertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -343,9 +343,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -290,9 +290,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -296,9 +296,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -314,9 +314,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -290,7 +290,7 @@ fn main() -> Result<(), impl Error> {
.entry_point("main")
.unwrap();
let vertex_input_state = [TriangleVertex::per_vertex(), InstanceData::per_instance()]
.definition(&vs.info().input_interface)
.definition(&vs)
.unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),

View File

@ -313,9 +313,7 @@ fn main() {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -259,9 +259,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -260,9 +260,7 @@ fn main() {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -300,9 +300,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -206,9 +206,7 @@ fn main() {
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),

View File

@ -281,9 +281,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -400,9 +400,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -187,9 +187,7 @@ fn main() -> Result<(), impl Error> {
module.entry_point("main").unwrap()
};
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -460,9 +460,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -497,7 +497,7 @@ fn window_size_dependent_setup(
// https://computergraphics.stackexchange.com/questions/5742/vulkan-best-way-of-updating-pipeline-viewport
let pipeline = {
let vertex_input_state = [Position::per_vertex(), Normal::per_vertex()]
.definition(&vs.info().input_interface)
.definition(&vs)
.unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),

View File

@ -340,9 +340,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(tcs),

View File

@ -301,9 +301,7 @@ fn main() -> Result<(), impl Error> {
.unwrap()
.entry_point("main")
.unwrap();
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
let stages = [
PipelineShaderStageCreateInfo::new(vs),
PipelineShaderStageCreateInfo::new(fs),

View File

@ -197,9 +197,7 @@ fn main() -> Result<(), impl Error> {
// Automatically generate a vertex input state from the vertex shader's input interface,
// that takes a single vertex buffer containing `Vertex` structs.
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
// Make a list of the shader stages that the pipeline will have.
let stages = [

View File

@ -385,9 +385,7 @@ fn main() -> Result<(), impl Error> {
// Automatically generate a vertex input state from the vertex shader's input interface,
// that takes a single vertex buffer containing `Vertex` structs.
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
// Make a list of the shader stages that the pipeline will have.
let stages = [

View File

@ -390,9 +390,7 @@ fn main() -> Result<(), impl Error> {
// Automatically generate a vertex input state from the vertex shader's input interface,
// that takes a single vertex buffer containing `Vertex` structs.
let vertex_input_state = Vertex::per_vertex()
.definition(&vs.info().input_interface)
.unwrap();
let vertex_input_state = Vertex::per_vertex().definition(&vs).unwrap();
// Make a list of the shader stages that the pipeline will have.
let stages = [

View File

@ -61,7 +61,7 @@ pub fn derive_vertex(crate_ident: &Ident, ast: syn::DeriveInput) -> Result<Token
let field_size = ::std::mem::size_of::<#field_ty>();
let format = #format;
let format_size = format.block_size() as usize;
let format_size = usize::try_from(format.block_size()).unwrap();
let num_elements = field_size / format_size;
let remainder = field_size % format_size;
::std::assert!(
@ -76,6 +76,7 @@ pub fn derive_vertex(crate_ident: &Ident, ast: syn::DeriveInput) -> Result<Token
offset: offset.try_into().unwrap(),
format,
num_elements: num_elements.try_into().unwrap(),
stride: format_size.try_into().unwrap(),
},
);

View File

@ -143,6 +143,19 @@ fn formats_output(members: &[FormatMember]) -> TokenStream {
}
},
);
let locations_items = members.iter().filter_map(
|FormatMember {
name, components, ..
}| {
if components.starts_with(&[64, 64, 64]) {
Some(quote! {
Self::#name => 2,
})
} else {
None
}
},
);
let compression_items = members.iter().filter_map(
|FormatMember {
name, compression, ..
@ -418,6 +431,15 @@ fn formats_output(members: &[FormatMember]) -> TokenStream {
}
}
/// Returns the number of shader input/output locations that a single element of this
/// format takes up.
pub fn locations(self) -> u32 {
match self {
#(#locations_items)*
_ => 1,
}
}
/// Returns the block compression scheme used for this format, if any. Returns `None` if
/// the format does not use compression.
pub fn compression(self) -> Option<CompressionType> {

View File

@ -1,7 +1,7 @@
use super::VertexBufferDescription;
use super::{definition::VertexDefinition, VertexBufferDescription};
use crate::{
pipeline::graphics::vertex_input::{Vertex, VertexDefinition, VertexInputState},
shader::ShaderInterface,
pipeline::graphics::vertex_input::{Vertex, VertexInputState},
shader::EntryPoint,
ValidationError,
};
@ -55,8 +55,8 @@ unsafe impl VertexDefinition for BuffersDefinition {
#[inline]
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>> {
self.0.definition(interface)
self.0.definition(entry_point)
}
}

View File

@ -1,17 +1,30 @@
use super::{
VertexBufferDescription, VertexInputAttributeDescription, VertexInputBindingDescription,
VertexMemberInfo,
};
use crate::{
pipeline::graphics::vertex_input::VertexInputState, shader::ShaderInterface, DeviceSize,
pipeline::{
graphics::vertex_input::VertexInputState,
inout_interface::{
input_output_map, InputOutputData, InputOutputKey, InputOutputUserKey,
InputOutputVariableBlock,
},
},
shader::{
spirv::{ExecutionModel, Instruction, StorageClass},
EntryPoint,
},
ValidationError,
};
use ahash::HashMap;
use std::{borrow::Cow, collections::hash_map::Entry};
/// Trait for types that can create a [`VertexInputState`] from a [`ShaderInterface`].
/// Trait for types that can create a [`VertexInputState`] from an [`EntryPoint`].
pub unsafe trait VertexDefinition {
/// Builds the `VertexInputState` for the provided `interface`.
/// Builds the `VertexInputState` for the provided `entry_point`.
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>>;
}
@ -19,89 +32,198 @@ unsafe impl VertexDefinition for &[VertexBufferDescription] {
#[inline]
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>> {
let bindings = self.iter().enumerate().map(|(binding, buffer)| {
let spirv = entry_point.module().spirv();
let Some(&Instruction::EntryPoint {
execution_model,
ref interface,
..
}) = spirv.function(entry_point.id()).entry_point()
else {
unreachable!()
};
if execution_model != ExecutionModel::Vertex {
return Err(Box::new(ValidationError {
context: "entry_point".into(),
problem: "is not a vertex shader".into(),
..Default::default()
}));
}
let bindings = self
.iter()
.enumerate()
.map(|(binding, buffer_description)| {
let &VertexBufferDescription {
members: _,
stride,
input_rate,
} = buffer_description;
(
binding as u32,
binding.try_into().unwrap(),
VertexInputBindingDescription {
stride: buffer.stride,
input_rate: buffer.input_rate,
stride,
input_rate,
..Default::default()
},
)
});
let mut attributes: Vec<(u32, VertexInputAttributeDescription)> = Vec::new();
for element in interface.elements() {
let name = element.name.as_ref().unwrap();
let (infos, binding) = self
.iter()
.enumerate()
.find_map(|(binding, buffer)| {
buffer
.members
.get(name.as_ref())
.map(|infos| (infos.clone(), binding as u32))
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: format!(
"the shader interface contains a variable named \"{}\", \
but no such attribute exists in the vertex definition",
name,
)
.into(),
..Default::default()
})
})?;
.collect();
let mut attributes: HashMap<u32, VertexInputAttributeDescription> = HashMap::default();
// TODO: ShaderInterfaceEntryType does not properly support 64bit.
// Once it does the below logic around num_elements and num_locations
// might have to be updated.
if infos.num_components() != element.ty.num_components
|| infos.num_elements != element.ty.num_locations()
{
for variable_id in interface.iter().copied() {
input_output_map(
spirv,
execution_model,
variable_id,
StorageClass::Input,
|key, data| -> Result<(), Box<ValidationError>> {
let InputOutputKey::User(key) = key else {
return Ok(());
};
let InputOutputUserKey {
mut location,
component,
index: _,
} = key;
// TODO: can we make this work somehow?
if component != 0 {
return Err(Box::new(ValidationError {
problem: format!(
"for the variable \"{}\", the number of locations and components \
required by the shader don't match the number of locations and components \
of the type provided in the vertex definition",
name,
"the shader interface contains an input variable (location {}) \
with a non-zero component decoration ({}), which is not yet \
supported by `VertexDefinition` in Vulkano",
location, component,
)
.into(),
..Default::default()
}));
}
let mut offset = infos.offset as DeviceSize;
let block_size = infos.format.block_size();
// Double precision formats can exceed a single location.
// R64B64G64A64_SFLOAT requires two locations, so we need to adapt how we bind
let location_range = if block_size > 16 {
(element.location..element.location + 2 * element.ty.num_locations()).step_by(2)
let InputOutputData {
variable_id,
pointer_type_id: _,
block,
type_id: _,
} = data;
// Find the name of the variable defined in the shader,
// or use a default placeholder.
let names = if let Some(block) = block {
let InputOutputVariableBlock {
type_id,
member_index,
} = block;
spirv.id(type_id).members()[member_index].names()
} else {
(element.location..element.location + element.ty.num_locations()).step_by(1)
spirv.id(variable_id).names()
};
let name = names
.iter()
.find_map(|instruction| match *instruction {
Instruction::Name { ref name, .. }
| Instruction::MemberName { ref name, .. } => {
Some(Cow::Borrowed(name.as_str()))
}
_ => None,
})
.unwrap_or_else(|| Cow::Owned(format!("vertex_input_{}", location)));
for location in location_range {
attributes.push((
location,
VertexInputAttributeDescription {
binding,
format: infos.format,
offset: offset as u32,
// Find a vertex member whose name matches the one in the shader.
let (vertex_member_info, binding) = self
.iter()
.enumerate()
.find_map(|(binding, buffer)| {
buffer
.members
.get(name.as_ref())
.map(|info| (info, binding.try_into().unwrap()))
})
.ok_or_else(|| {
Box::new(ValidationError {
problem: format!(
"the shader interface contains an input variable named \"{}\" \
(location {}, component {}), but no such attribute exists in \
the vertex definition",
name, location, component,
)
.into(),
..Default::default()
},
));
offset += block_size;
})
})?;
let &VertexMemberInfo {
mut offset,
format,
num_elements,
mut stride,
} = vertex_member_info;
let locations_per_element;
if num_elements > 1 {
locations_per_element = format.locations();
if u64::from(stride) < format.block_size() {
return Err(Box::new(ValidationError {
problem: format!(
"in the vertex member named \"{}\" in buffer {}, the `stride` is \
less than the block size of `format`",
name, binding,
)
.into(),
..Default::default()
}));
}
} else {
stride = 0;
locations_per_element = 0;
}
// Add an attribute description for every element in the member.
for _ in 0..num_elements {
match attributes.entry(location) {
Entry::Occupied(_) => {
return Err(Box::new(ValidationError {
problem: format!(
"the vertex definition specifies a variable at \
location {}, but that location is already occupied by \
another variable",
location,
)
.into(),
..Default::default()
}));
}
Entry::Vacant(entry) => {
entry.insert(VertexInputAttributeDescription {
binding,
format,
offset,
..Default::default()
});
}
}
Ok(VertexInputState::new()
.bindings(bindings)
.attributes(attributes))
location = location.checked_add(locations_per_element).unwrap();
offset = offset.checked_add(stride).unwrap();
}
Ok(())
},
)?;
}
Ok(VertexInputState {
bindings,
attributes,
..Default::default()
})
}
}
@ -109,9 +231,9 @@ unsafe impl<const N: usize> VertexDefinition for [VertexBufferDescription; N] {
#[inline]
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>> {
self.as_slice().definition(interface)
self.as_slice().definition(entry_point)
}
}
@ -119,9 +241,9 @@ unsafe impl VertexDefinition for Vec<VertexBufferDescription> {
#[inline]
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>> {
self.as_slice().definition(interface)
self.as_slice().definition(entry_point)
}
}
@ -129,8 +251,8 @@ unsafe impl VertexDefinition for VertexBufferDescription {
#[inline]
fn definition(
&self,
interface: &ShaderInterface,
entry_point: &EntryPoint,
) -> Result<VertexInputState, Box<ValidationError>> {
std::slice::from_ref(self).definition(interface)
std::slice::from_ref(self).definition(entry_point)
}
}

View File

@ -57,9 +57,10 @@ macro_rules! impl_vertex {
let member_ptr = (&dummy.$member) as *const _;
members.insert(stringify!($member).to_string(), VertexMemberInfo {
offset: member_ptr as usize - dummy_ptr as usize,
offset: u32::try_from(member_ptr as usize - dummy_ptr as usize).unwrap(),
format,
num_elements,
stride: format_size,
});
}
)*

View File

@ -304,7 +304,7 @@ impl VertexInputState {
// the location following it needs to be empty.
let unassigned_locations = attributes
.iter()
.filter(|&(_, attribute_desc)| attribute_desc.format.block_size() > 16)
.filter(|&(_, attribute_desc)| attribute_desc.format.locations() == 2)
.map(|(location, _)| location + 1);
for location in unassigned_locations {
@ -342,12 +342,7 @@ impl VertexInputState {
location.checked_sub(1).and_then(|location| {
self.attributes
.get(&location)
.filter(|attribute_desc| {
attribute_desc
.format
.components()
.starts_with(&[64, 64, 64])
})
.filter(|attribute_desc| attribute_desc.format.locations() == 2)
.map(|d| (true, d))
})
})

View File

@ -83,13 +83,19 @@ impl VertexBufferDescription {
/// Information about a member of a vertex struct.
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct VertexMemberInfo {
/// Offset of the member in bytes from the start of the struct.
pub offset: usize,
/// Attribute format of the member. Implicitly provides number of components.
/// The offset of the member in bytes from the start of the struct.
pub offset: u32,
/// The attribute format of the member. Implicitly provides the number of components.
pub format: Format,
/// Number of consecutive array elements or matrix columns using format. The corresponding
/// number of locations might defer depending on the size of the format.
/// The number of consecutive array elements or matrix columns using `format`.
/// The corresponding number of locations might differ depending on the size of the format.
pub num_elements: u32,
/// If `num_elements` is greater than 1, the stride in bytes between the start of consecutive
/// elements.
pub stride: u32,
}
impl VertexMemberInfo {

View File

@ -8,7 +8,7 @@ use crate::{
ValidationError,
};
use ahash::HashMap;
use std::collections::hash_map::Entry;
use std::{collections::hash_map::Entry, convert::Infallible};
pub(crate) fn validate_interfaces_compatible(
out_spirv: &Spirv,
@ -208,7 +208,7 @@ fn get_variables_by_key<'a>(
execution_model,
variable_id,
filter_storage_class,
|key, data| {
|key, data| -> Result<(), Infallible> {
if let InputOutputKey::User(InputOutputUserKey {
location,
component,
@ -241,8 +241,11 @@ fn get_variables_by_key<'a>(
},
);
}
Ok(())
},
);
)
.unwrap();
}
variables_by_key
@ -723,13 +726,16 @@ pub(crate) fn shader_interface_location_info(
execution_model,
variable_id,
filter_storage_class,
|key, data| {
|key, data| -> Result<(), Infallible> {
if let InputOutputKey::User(key) = key {
let InputOutputData { type_id, .. } = data;
shader_interface_analyze_type(spirv, type_id, key, &mut scalar_func);
}
Ok(())
},
);
)
.unwrap();
}
locations
@ -864,13 +870,13 @@ pub(crate) struct InputOutputVariableBlock {
pub(crate) member_index: usize,
}
pub(crate) fn input_output_map(
pub(crate) fn input_output_map<E>(
spirv: &Spirv,
execution_model: ExecutionModel,
variable_id: Id,
filter_storage_class: StorageClass,
mut func: impl FnMut(InputOutputKey, InputOutputData),
) {
mut func: impl FnMut(InputOutputKey, InputOutputData) -> Result<(), E>,
) -> Result<(), E> {
let variable_id_info = spirv.id(variable_id);
let (pointer_type_id, storage_class) = match *variable_id_info.instruction() {
Instruction::Variable {
@ -878,7 +884,7 @@ pub(crate) fn input_output_map(
storage_class,
..
} if storage_class == filter_storage_class => (result_type_id, storage_class),
_ => return,
_ => return Ok(()),
};
let pointer_type_id_info = spirv.id(pointer_type_id);
let type_id = match *pointer_type_id_info.instruction() {
@ -918,7 +924,7 @@ pub(crate) fn input_output_map(
block: None,
type_id,
},
);
)
} else if let Some(built_in) = built_in {
func(
InputOutputKey::BuiltIn(built_in),
@ -928,13 +934,13 @@ pub(crate) fn input_output_map(
block: None,
type_id,
},
);
)
} else {
let block_type_id = type_id;
let block_type_id_info = spirv.id(block_type_id);
let member_types = match block_type_id_info.instruction() {
Instruction::TypeStruct { member_types, .. } => member_types,
_ => return,
_ => return Ok(()),
};
for (member_index, (&type_id, member_info)) in member_types
@ -975,7 +981,7 @@ pub(crate) fn input_output_map(
}),
type_id,
},
);
)?;
} else if let Some(built_in) = built_in {
func(
InputOutputKey::BuiltIn(built_in),
@ -988,9 +994,11 @@ pub(crate) fn input_output_map(
}),
type_id,
},
);
)?;
}
}
Ok(())
}
}

View File

@ -18,7 +18,7 @@ use crate::{
DeviceSize, Requires, RequiresAllOf, RequiresOneOf, ValidationError, Version,
};
use ahash::HashMap;
use std::cmp::max;
use std::{cmp::max, convert::Infallible};
pub(crate) fn validate_runtime(
device: &Device,
@ -1366,7 +1366,7 @@ impl<'a> RuntimeValidator<'a> {
self.execution_model,
result_id,
storage_class,
|key, data| {
|key, data| -> Result<(), Infallible> {
let InputOutputData { type_id, .. } = data;
match key {
@ -1386,8 +1386,11 @@ impl<'a> RuntimeValidator<'a> {
// https://github.com/KhronosGroup/Vulkan-Docs/issues/2293
}
}
Ok(())
},
);
)
.unwrap();
}
if is_in_interface && storage_class == StorageClass::Output {

View File

@ -433,7 +433,6 @@ use half::f16;
use smallvec::SmallVec;
use spirv::ExecutionModel;
use std::{
borrow::Cow,
collections::hash_map::Entry,
mem::{discriminant, size_of_val, MaybeUninit},
num::NonZeroU64,
@ -1150,8 +1149,6 @@ pub struct EntryPointInfo {
pub execution_model: ExecutionModel,
pub descriptor_binding_requirements: HashMap<(u32, u32), DescriptorBindingRequirements>,
pub push_constant_requirements: Option<PushConstantRange>,
pub input_interface: ShaderInterface,
pub output_interface: ShaderInterface,
}
/// Represents a shader entry point in a shader module.
@ -1372,86 +1369,6 @@ impl DescriptorRequirements {
}
}
/// Type that contains the definition of an interface between two shader stages, or between
/// the outside and a shader stage.
#[derive(Clone, Debug)]
pub struct ShaderInterface {
elements: Vec<ShaderInterfaceEntry>,
}
impl ShaderInterface {
/// Constructs a new `ShaderInterface`.
///
/// # Safety
///
/// - Must only provide one entry per location.
/// - The format of each element must not be larger than 128 bits.
// TODO: 4x64 bit formats are possible, but they require special handling.
// TODO: could this be made safe?
#[inline]
pub unsafe fn new_unchecked(elements: Vec<ShaderInterfaceEntry>) -> ShaderInterface {
ShaderInterface { elements }
}
/// Creates a description of an empty shader interface.
#[inline]
pub const fn empty() -> ShaderInterface {
ShaderInterface {
elements: Vec::new(),
}
}
/// Returns a slice containing the elements of the interface.
#[inline]
pub fn elements(&self) -> &[ShaderInterfaceEntry] {
self.elements.as_ref()
}
}
/// Entry of a shader interface definition.
#[derive(Debug, Clone)]
pub struct ShaderInterfaceEntry {
/// The location slot that the variable starts at.
pub location: u32,
/// The index within the location slot that the variable is located.
/// Only meaningful for fragment outputs.
pub index: u32,
/// The component slot that the variable starts at. Must be in the range 0..=3.
pub component: u32,
/// Name of the element, or `None` if the name is unknown.
pub name: Option<Cow<'static, str>>,
/// The type of the variable.
pub ty: ShaderInterfaceEntryType,
}
/// The type of a variable in a shader interface.
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
pub struct ShaderInterfaceEntryType {
/// The base numeric type.
pub base_type: NumericType,
/// The number of vector components. Must be in the range 1..=4.
pub num_components: u32,
/// The number of array elements or matrix columns.
pub num_elements: u32,
/// Whether the base type is 64 bits wide. If true, each item of the base type takes up two
/// component slots instead of one.
pub is_64bit: bool,
}
impl ShaderInterfaceEntryType {
pub(crate) fn num_locations(&self) -> u32 {
assert!(!self.is_64bit); // TODO: implement
self.num_elements
}
}
vulkan_bitflags_enum! {
#[non_exhaustive]

View File

@ -6,17 +6,15 @@ use crate::{
image::view::ImageViewType,
pipeline::layout::PushConstantRange,
shader::{
spirv::{Decoration, Dim, ExecutionModel, Id, Instruction, Spirv, StorageClass},
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderInterface,
ShaderInterfaceEntry, ShaderInterfaceEntryType, ShaderStage, ShaderStages,
SpecializationConstant,
spirv::{Decoration, Dim, Id, Instruction, Spirv, StorageClass},
DescriptorIdentifier, DescriptorRequirements, EntryPointInfo, NumericType, ShaderStage,
ShaderStages, SpecializationConstant,
},
DeviceSize, Version,
};
use ahash::{HashMap, HashSet};
use half::f16;
use smallvec::{smallvec, SmallVec};
use std::borrow::Cow;
/// Returns an iterator over all entry points in `spirv`, with information about the entry point.
#[inline]
@ -24,15 +22,14 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)>
let interface_variables = interface_variables(spirv);
spirv.entry_points().iter().filter_map(move |instruction| {
let (execution_model, function_id, entry_point_name, interface) = match *instruction {
Instruction::EntryPoint {
let &Instruction::EntryPoint {
execution_model,
entry_point,
ref name,
ref interface,
..
} => (execution_model, entry_point, name, interface),
_ => return None,
} = instruction
else {
return None;
};
let stage = ShaderStage::from(execution_model);
@ -41,44 +38,22 @@ pub fn entry_points(spirv: &Spirv) -> impl Iterator<Item = (Id, EntryPointInfo)>
&interface_variables.descriptor_binding,
spirv,
stage,
function_id,
entry_point,
);
let push_constant_requirements = push_constant_requirements(
&interface_variables.push_constant,
spirv,
stage,
function_id,
);
let input_interface = shader_interface(
spirv,
interface,
StorageClass::Input,
matches!(
execution_model,
ExecutionModel::TessellationControl
| ExecutionModel::TessellationEvaluation
| ExecutionModel::Geometry
),
);
let output_interface = shader_interface(
spirv,
interface,
StorageClass::Output,
matches!(
execution_model,
ExecutionModel::TessellationControl | ExecutionModel::MeshEXT
),
entry_point,
);
Some((
function_id,
entry_point,
EntryPointInfo {
name: entry_point_name.clone(),
name: name.clone(),
execution_model,
descriptor_binding_requirements,
push_constant_requirements,
input_interface,
output_interface,
},
))
})
@ -1066,118 +1041,6 @@ pub(super) fn specialization_constants(spirv: &Spirv) -> HashMap<u32, Specializa
.collect()
}
/// Extracts the `ShaderInterface` with the given storage class from `spirv`.
fn shader_interface(
spirv: &Spirv,
interface: &[Id],
filter_storage_class: StorageClass,
ignore_first_array: bool,
) -> ShaderInterface {
let elements: Vec<_> = interface
.iter()
.filter_map(|&id| {
let (result_type_id, result_id) = match *spirv.id(id).instruction() {
Instruction::Variable {
result_type_id,
result_id,
storage_class,
..
} if storage_class == filter_storage_class => (result_type_id, result_id),
_ => return None,
};
if is_builtin(spirv, result_id) {
return None;
}
let id_info = spirv.id(result_id);
let name = id_info
.names()
.iter()
.find_map(|instruction| match *instruction {
Instruction::Name { ref name, .. } => Some(Cow::Owned(name.clone())),
_ => None,
});
let location = id_info
.decorations()
.iter()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
decoration: Decoration::Location { location },
..
} => Some(location),
_ => None,
})
.unwrap_or_else(|| {
panic!(
"Input/output variable with id {} (name {:?}) is missing a location",
result_id, name,
)
});
let component = id_info
.decorations()
.iter()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
decoration: Decoration::Component { component },
..
} => Some(component),
_ => None,
})
.unwrap_or(0);
let index = id_info
.decorations()
.iter()
.find_map(|instruction| match *instruction {
Instruction::Decorate {
decoration: Decoration::Index { index },
..
} => Some(index),
_ => None,
})
.unwrap_or(0);
let ty = shader_interface_type_of(spirv, result_type_id, ignore_first_array);
assert!(ty.num_elements >= 1);
Some(ShaderInterfaceEntry {
location,
index,
component,
ty,
name,
})
})
.collect();
// Checking for overlapping elements.
for (offset, element1) in elements.iter().enumerate() {
for element2 in elements.iter().skip(offset + 1) {
if element1.index == element2.index
&& (element1.location == element2.location
|| (element1.location < element2.location
&& element1.location + element1.ty.num_locations() > element2.location)
|| (element2.location < element1.location
&& element2.location + element2.ty.num_locations() > element1.location))
{
panic!(
"The locations of attributes `{:?}` ({}..{}) and `{:?}` ({}..{}) overlap",
element1.name,
element1.location,
element1.location + element1.ty.num_locations(),
element2.name,
element2.location,
element2.location + element2.ty.num_locations(),
);
}
}
}
ShaderInterface { elements }
}
/// Returns the size of a type, or `None` if its size cannot be determined.
pub(crate) fn size_of_type(spirv: &Spirv, id: Id) -> Option<DeviceSize> {
let id_info = spirv.id(id);
@ -1302,144 +1165,6 @@ fn offset_of_struct(spirv: &Spirv, id: Id) -> u32 {
.unwrap_or(0)
}
/// If `ignore_first_array` is true, the function expects the outermost instruction to be
/// `OpTypeArray`. If it's the case, the OpTypeArray will be ignored. If not, the function will
/// panic.
fn shader_interface_type_of(
spirv: &Spirv,
id: Id,
ignore_first_array: bool,
) -> ShaderInterfaceEntryType {
match *spirv.id(id).instruction() {
Instruction::TypeInt {
width, signedness, ..
} => {
assert!(!ignore_first_array);
ShaderInterfaceEntryType {
base_type: match signedness {
0 => NumericType::Uint,
1 => NumericType::Int,
_ => unreachable!(),
},
num_components: 1,
num_elements: 1,
is_64bit: match width {
8 | 16 | 32 => false,
64 => true,
_ => unimplemented!(),
},
}
}
Instruction::TypeFloat { width, .. } => {
assert!(!ignore_first_array);
ShaderInterfaceEntryType {
base_type: NumericType::Float,
num_components: 1,
num_elements: 1,
is_64bit: match width {
16 | 32 => false,
64 => true,
_ => unimplemented!(),
},
}
}
Instruction::TypeVector {
component_type,
component_count,
..
} => {
assert!(!ignore_first_array);
ShaderInterfaceEntryType {
num_components: component_count,
..shader_interface_type_of(spirv, component_type, false)
}
}
Instruction::TypeMatrix {
column_type,
column_count,
..
} => {
assert!(!ignore_first_array);
ShaderInterfaceEntryType {
num_elements: column_count,
..shader_interface_type_of(spirv, column_type, false)
}
}
Instruction::TypeArray {
element_type,
length,
..
} => {
if ignore_first_array {
shader_interface_type_of(spirv, element_type, false)
} else {
let mut ty = shader_interface_type_of(spirv, element_type, false);
let length = get_constant(spirv, length).expect("failed to find array length");
ty.num_elements *= length as u32;
ty
}
}
Instruction::TypePointer { ty, .. } => {
shader_interface_type_of(spirv, ty, ignore_first_array)
}
Instruction::TypeStruct { .. } => {
panic!("Structs are not yet supported in shader in/out interface!");
}
_ => panic!("Type {} not found or invalid", id),
}
}
/// Returns true if a `BuiltIn` decorator is applied on an id.
fn is_builtin(spirv: &Spirv, id: Id) -> bool {
let id_info = spirv.id(id);
if id_info.decorations().iter().any(|instruction| {
matches!(
instruction,
Instruction::Decorate {
decoration: Decoration::BuiltIn { .. },
..
}
)
}) {
return true;
}
if id_info
.members()
.iter()
.flat_map(|member_info| member_info.decorations())
.any(|instruction| {
matches!(
instruction,
Instruction::MemberDecorate {
decoration: Decoration::BuiltIn { .. },
..
}
)
})
{
return true;
}
match id_info.instruction() {
Instruction::Variable {
result_type_id: ty, ..
}
| Instruction::TypeArray {
element_type: ty, ..
}
| Instruction::TypeRuntimeArray {
element_type: ty, ..
}
| Instruction::TypePointer { ty, .. } => is_builtin(spirv, *ty),
Instruction::TypeStruct { member_types, .. } => {
member_types.iter().any(|ty| is_builtin(spirv, *ty))
}
_ => false,
}
}
pub(crate) fn get_constant(spirv: &Spirv, id: Id) -> Option<u64> {
match spirv.id(id).instruction() {
Instruction::Constant { value, .. } => match value.len() {