mirror of
https://github.com/vulkano-rs/vulkano.git
synced 2024-11-25 00:04:15 +00:00
Sharing generated Rust types between shaders (#1694)
* Vulkano-shader mutual compilation mode * Tests fixes * Shared constants option for multi-shader * Cargo fmt * Fix doc typos
This commit is contained in:
parent
588f91dc59
commit
ddbde190ff
273
examples/src/bin/shader-types-sharing.rs
Normal file
273
examples/src/bin/shader-types-sharing.rs
Normal file
@ -0,0 +1,273 @@
|
||||
// Copyright (c) 2017 The vulkano developers
|
||||
// Licensed under the Apache License, Version 2.0
|
||||
// <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT
|
||||
// license <LICENSE-MIT or https://opensource.org/licenses/MIT>,
|
||||
// at your option. All files in the project carrying such
|
||||
// notice may not be copied, modified, or distributed except
|
||||
// according to those terms.
|
||||
|
||||
// This example demonstrates how to compile several shaders together using vulkano-shaders macro,
|
||||
// such that the macro generates unique Shader types per each compiled shader, but generates common
|
||||
// shareable set of Rust structs representing corresponding structs in the source glsl code.
|
||||
//
|
||||
// Normally, each vulkano-shaders macro invocation among other things generates a `ty` submodule
|
||||
// containing all Rust types per each "struct" declaration of glsl code. Using this submodule
|
||||
// the user can organize type-safe interoperability between Rust code and the shader interface
|
||||
// input/output data tied to these structs. However, if the user compiles several shaders in
|
||||
// independent Rust modules, each of these modules would contain independent `ty` submodule with
|
||||
// each own set of Rust types. So, even if both shaders contain the same(or partially intersecting)
|
||||
// glsl structs they will be duplicated in each generated `ty` submodule and treated by Rust as
|
||||
// independent types. As such it would be tricky to organize interoperability between shader
|
||||
// interfaces in Rust.
|
||||
//
|
||||
// To solve this problem the user can use "shared" generation mode of the macro. In this mode the
|
||||
// user declares all shaders that possibly share common layout interfaces in a single macro
|
||||
// invocation. The macro will check that there is no inconsistency between declared glsl structs
|
||||
// with the same names, and it will put all generated Rust structs for all shaders in just a single
|
||||
// `ty` submodule.
|
||||
|
||||
use std::sync::Arc;
|
||||
use vulkano::buffer::{BufferUsage, CpuAccessibleBuffer};
|
||||
use vulkano::command_buffer::{AutoCommandBufferBuilder, CommandBufferUsage};
|
||||
use vulkano::descriptor_set::PersistentDescriptorSet;
|
||||
use vulkano::device::physical::{PhysicalDevice, PhysicalDeviceType};
|
||||
use vulkano::device::{Device, DeviceExtensions, Features, Queue};
|
||||
use vulkano::instance::{Instance, InstanceExtensions};
|
||||
use vulkano::pipeline::{ComputePipeline, PipelineBindPoint};
|
||||
use vulkano::sync;
|
||||
use vulkano::sync::GpuFuture;
|
||||
use vulkano::Version;
|
||||
|
||||
fn main() {
|
||||
let instance = Instance::new(None, Version::V1_1, &InstanceExtensions::none(), None).unwrap();
|
||||
|
||||
let device_extensions = DeviceExtensions {
|
||||
khr_storage_buffer_storage_class: true,
|
||||
..DeviceExtensions::none()
|
||||
};
|
||||
let (physical_device, queue_family) = PhysicalDevice::enumerate(&instance)
|
||||
.filter(|&p| p.supported_extensions().is_superset_of(&device_extensions))
|
||||
.filter_map(|p| {
|
||||
p.queue_families()
|
||||
.find(|&q| q.supports_compute())
|
||||
.map(|q| (p, q))
|
||||
})
|
||||
.min_by_key(|(p, _)| match p.properties().device_type {
|
||||
PhysicalDeviceType::DiscreteGpu => 0,
|
||||
PhysicalDeviceType::IntegratedGpu => 1,
|
||||
PhysicalDeviceType::VirtualGpu => 2,
|
||||
PhysicalDeviceType::Cpu => 3,
|
||||
PhysicalDeviceType::Other => 4,
|
||||
})
|
||||
.unwrap();
|
||||
|
||||
println!(
|
||||
"Using device: {} (type: {:?})",
|
||||
physical_device.properties().device_name,
|
||||
physical_device.properties().device_type
|
||||
);
|
||||
|
||||
let (device, mut queues) = Device::new(
|
||||
physical_device,
|
||||
&Features::none(),
|
||||
&physical_device
|
||||
.required_extensions()
|
||||
.union(&device_extensions),
|
||||
[(queue_family, 0.5)].iter().cloned(),
|
||||
)
|
||||
.unwrap();
|
||||
let queue = queues.next().unwrap();
|
||||
|
||||
mod shaders {
|
||||
vulkano_shaders::shader! {
|
||||
// We declaring two simple compute shaders with push and specialization constants in
|
||||
// their layout interfaces.
|
||||
//
|
||||
// First one is just multiplying each value from the input array of ints to provided
|
||||
// value in push constants struct. And the second one in turn adds this value instead of
|
||||
// multiplying.
|
||||
//
|
||||
// However both shaders declare glsl struct `Parameters` for push constants in each
|
||||
// shader. Since each of the struct has exactly the same interface, they will be
|
||||
// treated by the macro as "shared".
|
||||
//
|
||||
// Also, note that glsl code duplications between shader sources is not necessary too.
|
||||
// In more complex system the user may want to declare independent glsl file with
|
||||
// such types, and include it in each shader entry-point files using "#include"
|
||||
// directive.
|
||||
shaders: {
|
||||
// Generate single unique `SpecializationConstants` struct for all shaders since
|
||||
// their specialization interfaces are the same. This option is turned off
|
||||
// by default and the macro by default producing unique
|
||||
// structs(`MultSpecializationConstants`, `AddSpecializationConstants`)
|
||||
shared_constants: true,
|
||||
Mult: {
|
||||
ty: "compute",
|
||||
src: "
|
||||
#version 450
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const bool enabled = true;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
int value;
|
||||
} pc;
|
||||
|
||||
layout(set = 0, binding = 0) buffer Data {
|
||||
uint data[];
|
||||
} data;
|
||||
|
||||
void main() {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
uint idx = gl_GlobalInvocationID.x;
|
||||
data.data[idx] *= pc.value;
|
||||
}
|
||||
"
|
||||
},
|
||||
Add: {
|
||||
ty: "compute",
|
||||
src: "
|
||||
#version 450
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
layout(constant_id = 0) const bool enabled = true;
|
||||
|
||||
layout(push_constant) uniform Parameters {
|
||||
int value;
|
||||
} pc;
|
||||
|
||||
layout(set = 0, binding = 0) buffer Data {
|
||||
uint data[];
|
||||
} data;
|
||||
|
||||
void main() {
|
||||
if (!enabled) {
|
||||
return;
|
||||
}
|
||||
uint idx = gl_GlobalInvocationID.x;
|
||||
data.data[idx] += pc.value;
|
||||
}
|
||||
"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// The macro will create the following things in this module:
|
||||
// - `ShaderMult` for the first shader loader/entry-point.
|
||||
// - `ShaderAdd` for the second shader loader/entry-point.
|
||||
// `SpecializationConstants` Rust struct for both shader's specialization constants.
|
||||
// `ty` submodule with `Parameters` Rust struct common for both shaders.
|
||||
}
|
||||
|
||||
// We introducing generic function responsible for running any of the shaders above with
|
||||
// provided Push Constants parameter.
|
||||
// Note that shader's interface `parameter` here is shader-independent.
|
||||
fn run_shader(
|
||||
pipeline: Arc<ComputePipeline>,
|
||||
queue: Arc<Queue>,
|
||||
data_buffer: Arc<CpuAccessibleBuffer<[u32]>>,
|
||||
parameters: shaders::ty::Parameters,
|
||||
) {
|
||||
let layout = pipeline.layout().descriptor_set_layouts().get(0).unwrap();
|
||||
let mut set_builder = PersistentDescriptorSet::start(layout.clone());
|
||||
|
||||
set_builder.add_buffer(data_buffer.clone()).unwrap();
|
||||
|
||||
let set = Arc::new(set_builder.build().unwrap());
|
||||
|
||||
let mut builder = AutoCommandBufferBuilder::primary(
|
||||
queue.device().clone(),
|
||||
queue.family(),
|
||||
CommandBufferUsage::OneTimeSubmit,
|
||||
)
|
||||
.unwrap();
|
||||
builder
|
||||
.bind_pipeline_compute(pipeline.clone())
|
||||
.bind_descriptor_sets(
|
||||
PipelineBindPoint::Compute,
|
||||
pipeline.layout().clone(),
|
||||
0,
|
||||
set.clone(),
|
||||
)
|
||||
.push_constants(pipeline.layout().clone(), 0, parameters)
|
||||
.dispatch([1024, 1, 1])
|
||||
.unwrap();
|
||||
let command_buffer = builder.build().unwrap();
|
||||
|
||||
let future = sync::now(queue.device().clone())
|
||||
.then_execute(queue.clone(), command_buffer)
|
||||
.unwrap()
|
||||
.then_signal_fence_and_flush()
|
||||
.unwrap();
|
||||
|
||||
future.wait(None).unwrap();
|
||||
}
|
||||
|
||||
// Preparing test data array `[0, 1, 2, 3....]`
|
||||
let data_buffer = {
|
||||
let data_iter = (0..65536u32).map(|n| n);
|
||||
CpuAccessibleBuffer::from_iter(device.clone(), BufferUsage::all(), false, data_iter)
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
// Loading the first shader, and creating a Pipeline for the shader
|
||||
let mult_pipeline = Arc::new(
|
||||
ComputePipeline::new(
|
||||
device.clone(),
|
||||
&shaders::MultShader::load(device.clone())
|
||||
.unwrap()
|
||||
.main_entry_point(),
|
||||
&shaders::SpecializationConstants { enabled: 1 },
|
||||
None,
|
||||
|_| {},
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Loading the second shader, and creating a Pipeline for the shader
|
||||
let add_pipeline = Arc::new(
|
||||
ComputePipeline::new(
|
||||
device.clone(),
|
||||
&shaders::AddShader::load(device.clone())
|
||||
.unwrap()
|
||||
.main_entry_point(),
|
||||
&shaders::SpecializationConstants { enabled: 1 },
|
||||
None,
|
||||
|_| {},
|
||||
)
|
||||
.unwrap(),
|
||||
);
|
||||
|
||||
// Multiply each value by 2
|
||||
run_shader(
|
||||
mult_pipeline.clone(),
|
||||
queue.clone(),
|
||||
data_buffer.clone(),
|
||||
shaders::ty::Parameters { value: 2 },
|
||||
);
|
||||
|
||||
// Then add 1 to each value
|
||||
run_shader(
|
||||
add_pipeline,
|
||||
queue.clone(),
|
||||
data_buffer.clone(),
|
||||
shaders::ty::Parameters { value: 1 },
|
||||
);
|
||||
|
||||
// Then multiply each value by 3
|
||||
run_shader(
|
||||
mult_pipeline,
|
||||
queue.clone(),
|
||||
data_buffer.clone(),
|
||||
shaders::ty::Parameters { value: 3 },
|
||||
);
|
||||
|
||||
let data_buffer_content = data_buffer.read().unwrap();
|
||||
for n in 0..65536u32 {
|
||||
assert_eq!(data_buffer_content[n as usize], (n * 2 + 1) * 3);
|
||||
}
|
||||
println!("Success");
|
||||
}
|
@ -14,10 +14,12 @@ pub use crate::parse::ParseError;
|
||||
use crate::read_file_to_string;
|
||||
use crate::spec_consts;
|
||||
use crate::structs;
|
||||
use crate::RegisteredType;
|
||||
use crate::TypesMeta;
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
pub use shaderc::{CompilationArtifact, IncludeType, ResolvedInclude, ShaderKind};
|
||||
use shaderc::{CompileOptions, Compiler, EnvVersion, SpirvVersion, TargetEnv};
|
||||
use std::collections::HashMap;
|
||||
use spirv::{Capability, StorageClass};
|
||||
use std::iter::Iterator;
|
||||
use std::path::Path;
|
||||
@ -205,17 +207,18 @@ pub fn compile(
|
||||
}
|
||||
|
||||
pub(super) fn reflect<'a, I>(
|
||||
name: &str,
|
||||
prefix: &'a str,
|
||||
spirv: &[u32],
|
||||
types_meta: TypesMeta,
|
||||
types_meta: &TypesMeta,
|
||||
input_paths: I,
|
||||
exact_entrypoint_interface: bool,
|
||||
dump: bool,
|
||||
) -> Result<TokenStream, Error>
|
||||
shared_constants: bool,
|
||||
types_registry: &'a mut HashMap<String, RegisteredType>,
|
||||
) -> Result<(TokenStream, TokenStream), Error>
|
||||
where
|
||||
I: Iterator<Item = &'a str>,
|
||||
{
|
||||
let struct_name = Ident::new(&name, Span::call_site());
|
||||
let struct_name = Ident::new(&format!("{}Shader", prefix), Span::call_site());
|
||||
let doc = parse::parse_spirv(spirv)?;
|
||||
|
||||
// checking whether each required capability is enabled in the Vulkan device
|
||||
@ -310,10 +313,12 @@ where
|
||||
for instruction in doc.instructions.iter() {
|
||||
if let &Instruction::EntryPoint { .. } = instruction {
|
||||
let entry_point = entry_point::write_entry_point(
|
||||
prefix,
|
||||
&doc,
|
||||
instruction,
|
||||
&types_meta,
|
||||
types_meta,
|
||||
exact_entrypoint_interface,
|
||||
shared_constants,
|
||||
);
|
||||
entry_points_inside_impl.push(entry_point);
|
||||
}
|
||||
@ -327,46 +332,15 @@ where
|
||||
}
|
||||
});
|
||||
|
||||
let structs = structs::write_structs(&doc, &types_meta);
|
||||
let specialization_constants = spec_consts::write_specialization_constants(&doc, &types_meta);
|
||||
let uses = &types_meta.uses;
|
||||
let ast = quote! {
|
||||
#[allow(unused_imports)]
|
||||
use std::sync::Arc;
|
||||
#[allow(unused_imports)]
|
||||
use std::vec::IntoIter as VecIntoIter;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::device::Device;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDesc;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDescTy;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDescImage;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorSetDesc;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorSetLayout;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::DescriptorSet;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::format::Format;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::image::view::ImageViewType;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::layout::PipelineLayout;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::layout::PipelineLayoutPcRange;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::ShaderStages;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::SpecializationConstants as SpecConstsTrait;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::SpecializationMapEntry;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::Version;
|
||||
|
||||
let structs = structs::write_structs(prefix, &doc, types_meta, types_registry);
|
||||
let specialization_constants = spec_consts::write_specialization_constants(
|
||||
prefix,
|
||||
&doc,
|
||||
types_meta,
|
||||
shared_constants,
|
||||
types_registry,
|
||||
);
|
||||
let shader_code = quote! {
|
||||
pub struct #struct_name {
|
||||
shader: ::std::sync::Arc<::vulkano::pipeline::shader::ShaderModule>,
|
||||
}
|
||||
@ -400,20 +374,10 @@ where
|
||||
#( #entry_points_inside_impl )*
|
||||
}
|
||||
|
||||
pub mod ty {
|
||||
#( #uses )*
|
||||
#structs
|
||||
}
|
||||
|
||||
#specialization_constants
|
||||
};
|
||||
|
||||
if dump {
|
||||
println!("{}", ast.to_string());
|
||||
panic!("`shader!` rust codegen dumped") // TODO: use span from dump
|
||||
}
|
||||
|
||||
Ok(ast)
|
||||
Ok((shader_code, structs))
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
@ -841,7 +805,9 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
let doc = parse::parse_spirv(comp.as_binary()).unwrap();
|
||||
let res = std::panic::catch_unwind(|| structs::write_structs(&doc, &TypesMeta::default()));
|
||||
let res = std::panic::catch_unwind(|| {
|
||||
structs::write_structs("", &doc, &TypesMeta::default(), &mut HashMap::new())
|
||||
});
|
||||
assert!(res.is_err());
|
||||
}
|
||||
#[test]
|
||||
@ -869,7 +835,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
let doc = parse::parse_spirv(comp.as_binary()).unwrap();
|
||||
structs::write_structs(&doc, &TypesMeta::default());
|
||||
structs::write_structs("", &doc, &TypesMeta::default(), &mut HashMap::new());
|
||||
}
|
||||
#[test]
|
||||
fn test_wrap_alignment() {
|
||||
@ -901,7 +867,7 @@ mod tests {
|
||||
)
|
||||
.unwrap();
|
||||
let doc = parse::parse_spirv(comp.as_binary()).unwrap();
|
||||
structs::write_structs(&doc, &TypesMeta::default());
|
||||
structs::write_structs("", &doc, &TypesMeta::default(), &mut HashMap::new());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
@ -88,6 +88,7 @@ pub(super) fn write_descriptor_set_layout_descs(
|
||||
}
|
||||
|
||||
pub(super) fn write_push_constant_ranges(
|
||||
shader: &str,
|
||||
doc: &Spirv,
|
||||
stage: &TokenStream,
|
||||
types_meta: &TypesMeta,
|
||||
@ -106,7 +107,7 @@ pub(super) fn write_push_constant_ranges(
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let (_, size, _) = crate::structs::type_from_id(doc, type_id, types_meta);
|
||||
let (_, _, size, _) = crate::structs::type_from_id(shader, doc, type_id, types_meta);
|
||||
let size = size.expect("Found runtime-sized push constants") as u32;
|
||||
push_constants_size = cmp::max(push_constants_size, size);
|
||||
}
|
||||
@ -562,7 +563,7 @@ fn descriptor_infos(
|
||||
.next()
|
||||
.expect("failed to find array length");
|
||||
let len = len.iter().rev().fold(0, |a, &b| (a << 32) | b as u64);
|
||||
|
||||
|
||||
Some((desc, mutable, len, false))
|
||||
}
|
||||
|
||||
|
@ -15,10 +15,12 @@ use spirv::{Decoration, ExecutionMode, ExecutionModel, StorageClass};
|
||||
use syn::Ident;
|
||||
|
||||
pub(super) fn write_entry_point(
|
||||
shader: &str,
|
||||
doc: &Spirv,
|
||||
instruction: &Instruction,
|
||||
types_meta: &TypesMeta,
|
||||
exact_entrypoint_interface: bool,
|
||||
shared_constants: bool,
|
||||
) -> TokenStream {
|
||||
let (execution, id, ep_name, interface) = match instruction {
|
||||
&Instruction::EntryPoint {
|
||||
@ -83,10 +85,17 @@ pub(super) fn write_entry_point(
|
||||
|
||||
let descriptor_set_layout_descs =
|
||||
write_descriptor_set_layout_descs(&doc, id, interface, exact_entrypoint_interface, &stage);
|
||||
let push_constant_ranges = write_push_constant_ranges(&doc, &stage, &types_meta);
|
||||
let push_constant_ranges = write_push_constant_ranges(shader, &doc, &stage, &types_meta);
|
||||
|
||||
let spec_consts_struct = if crate::spec_consts::has_specialization_constants(doc) {
|
||||
quote! { SpecializationConstants }
|
||||
let spec_consts_struct_name = Ident::new(
|
||||
&format!(
|
||||
"{}SpecializationConstants",
|
||||
if shared_constants { "" } else { shader }
|
||||
),
|
||||
Span::call_site(),
|
||||
);
|
||||
quote! { #spec_consts_struct_name }
|
||||
} else {
|
||||
quote! { () }
|
||||
};
|
||||
|
@ -51,7 +51,7 @@
|
||||
//! return the various entry point structs that can be found in the
|
||||
//! [vulkano::pipeline::shader][pipeline::shader] module.
|
||||
//! * A Rust struct translated from each struct contained in the shader data.
|
||||
//! By default each structure has a `Clone` and a `Copy` implemenetations. This
|
||||
//! By default each structure has a `Clone` and a `Copy` implementations. This
|
||||
//! behavior could be customized through the `types_meta` macro option(see below
|
||||
//! for details).
|
||||
//! * The `SpecializationConstants` struct. This contains a field for every
|
||||
@ -59,9 +59,8 @@
|
||||
//! `Default` and [`SpecializationConstants`][SpecializationConstants] are also
|
||||
//! generated for the struct.
|
||||
//!
|
||||
//! All of these generated items will be accessed through the module specified
|
||||
//! by `mod_name: foo` If you wanted to store the `Shader` in a struct of your own,
|
||||
//! you could do something like this:
|
||||
//! All of these generated items will be accessed through the module when the macro was invoked.
|
||||
//! If you wanted to store the `Shader` in a struct of your own, you could do something like this:
|
||||
//!
|
||||
//! ```
|
||||
//! # fn main() {}
|
||||
@ -134,6 +133,22 @@
|
||||
//! **Note**: If your shader contains multiple entrypoints with different
|
||||
//! descriptor sets, you may also need to enable `exact_entrypoint_interface`.
|
||||
//!
|
||||
//! ## `shaders: { First: {src: "...", ty: "..."}, ... }`
|
||||
//!
|
||||
//! With these options the user can compile several shaders at a single macro invocation.
|
||||
//! Each entry key is a prefix that will be put in front of generated `Shader`
|
||||
//! struct(`FirstShader` in this case), and `SpecializationConstants`
|
||||
//! struct(`FirstSpecializationConstants` in this case). However all other Rust structs
|
||||
//! translated from the shader source will be shared between shaders. The macro checks that the
|
||||
//! source structs with the same names between different shaders have the same declaration
|
||||
//! signature, and throws a compile-time error if they don't.
|
||||
//!
|
||||
//! Each entry values expecting `src`, `path`, `bytes`, and `ty` pairs same as above.
|
||||
//!
|
||||
//! Also `SpecializationConstants` can all be shared between shaders by specifying
|
||||
//! `shared_constants: true,` entry-flag of the `shaders` map. This feature is turned-off by
|
||||
//! default.
|
||||
//!
|
||||
//! ## `include: ["...", "...", ..., "..."]`
|
||||
//!
|
||||
//! Specifies the standard include directories to be searched through when using the
|
||||
@ -219,6 +234,8 @@ extern crate proc_macro;
|
||||
|
||||
use crate::codegen::ShaderKind;
|
||||
use shaderc::{EnvVersion, SpirvVersion};
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::fs;
|
||||
use std::fs::File;
|
||||
use std::io::{Read, Result as IoResult};
|
||||
@ -290,13 +307,73 @@ impl TypesMeta {
|
||||
}
|
||||
}
|
||||
|
||||
struct RegisteredType {
|
||||
shader: String,
|
||||
signature: Vec<(String, Cow<'static, str>)>,
|
||||
}
|
||||
|
||||
impl RegisteredType {
|
||||
#[inline]
|
||||
fn assert_signatures(&self, type_name: &str, target_type: &Self) {
|
||||
if self.signature.len() > target_type.signature.len() {
|
||||
panic!(
|
||||
"Shaders {shader_a:} and {shader_b:} declare structs with the \
|
||||
same name \"`{type_name:}\", but the struct from {shader_a:} shader \
|
||||
contains extra field \"{field:}\"",
|
||||
shader_a = self.shader,
|
||||
shader_b = target_type.shader,
|
||||
type_name = type_name,
|
||||
field = self.signature[target_type.signature.len()].0
|
||||
);
|
||||
}
|
||||
|
||||
if self.signature.len() < target_type.signature.len() {
|
||||
panic!(
|
||||
"Shaders {shader_a:} and {shader_b:} declare structs with the \
|
||||
same name \"{type_name:}\", but the struct from {shader_b:} shader \
|
||||
contains extra field \"{field:}\"",
|
||||
shader_a = self.shader,
|
||||
shader_b = target_type.shader,
|
||||
type_name = type_name,
|
||||
field = target_type.signature[self.signature.len()].0
|
||||
);
|
||||
}
|
||||
|
||||
let comparison = self
|
||||
.signature
|
||||
.iter()
|
||||
.zip(target_type.signature.iter())
|
||||
.enumerate();
|
||||
|
||||
for (index, ((a_name, a_type), (b_name, b_type))) in comparison {
|
||||
if a_name != b_name || a_type != b_type {
|
||||
panic!(
|
||||
"Shaders {shader_a:} and {shader_b:} declare structs with the \
|
||||
same name \"{type_name:}\", but the struct from {shader_a:} shader \
|
||||
contains field \"{a_name:}\" of type \"{a_type:}\" in position {index:}, \
|
||||
whereas the same struct from {shader_b:} contains field \"{b_name:}\" \
|
||||
of type \"{b_type:}\" in the same position",
|
||||
shader_a = self.shader,
|
||||
shader_b = target_type.shader,
|
||||
type_name = type_name,
|
||||
index = index,
|
||||
a_name = a_name,
|
||||
a_type = a_type,
|
||||
b_name = b_name,
|
||||
b_type = b_type,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct MacroInput {
|
||||
dump: bool,
|
||||
exact_entrypoint_interface: bool,
|
||||
include_directories: Vec<String>,
|
||||
macro_defines: Vec<(String, String)>,
|
||||
shader_kind: ShaderKind,
|
||||
source_kind: SourceKind,
|
||||
shared_constants: bool,
|
||||
shaders: HashMap<String, (ShaderKind, SourceKind)>,
|
||||
spirv_version: Option<SpirvVersion>,
|
||||
types_meta: TypesMeta,
|
||||
vulkan_version: Option<EnvVersion>,
|
||||
@ -308,24 +385,168 @@ impl Parse for MacroInput {
|
||||
let mut exact_entrypoint_interface = None;
|
||||
let mut include_directories = Vec::new();
|
||||
let mut macro_defines = Vec::new();
|
||||
let mut shader_kind = None;
|
||||
let mut source_kind = None;
|
||||
let mut shared_constants = None;
|
||||
let mut shaders = HashMap::new();
|
||||
let mut spirv_version = None;
|
||||
let mut types_meta = None;
|
||||
let mut vulkan_version = None;
|
||||
|
||||
while !input.is_empty() {
|
||||
let name: Ident = input.parse()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
fn parse_shader_fields<'k>(
|
||||
output: &mut (Option<ShaderKind>, Option<SourceKind>),
|
||||
name: &'k str,
|
||||
input: ParseStream,
|
||||
) -> Result<()> {
|
||||
match name {
|
||||
"ty" => {
|
||||
if output.0.is_some() {
|
||||
panic!("Only one `ty` can be defined")
|
||||
}
|
||||
|
||||
let ty: LitStr = input.parse()?;
|
||||
let ty = match ty.value().as_ref() {
|
||||
"vertex" => ShaderKind::Vertex,
|
||||
"fragment" => ShaderKind::Fragment,
|
||||
"geometry" => ShaderKind::Geometry,
|
||||
"tess_ctrl" => ShaderKind::TessControl,
|
||||
"tess_eval" => ShaderKind::TessEvaluation,
|
||||
"compute" => ShaderKind::Compute,
|
||||
_ => panic!("Unexpected shader type, valid values: vertex, fragment, geometry, tess_ctrl, tess_eval, compute")
|
||||
};
|
||||
|
||||
output.0 = Some(ty);
|
||||
}
|
||||
|
||||
match name.to_string().as_ref() {
|
||||
"bytes" => {
|
||||
if source_kind.is_some() {
|
||||
panic!("Only one of `src`, `path`, or `bytes` can be defined")
|
||||
if output.1.is_some() {
|
||||
panic!(
|
||||
"Only one of `src`, `path`, or `bytes` can be defined per Shader entry"
|
||||
)
|
||||
}
|
||||
|
||||
let path: LitStr = input.parse()?;
|
||||
source_kind = Some(SourceKind::Bytes(path.value()));
|
||||
output.1 = Some(SourceKind::Bytes(path.value()));
|
||||
}
|
||||
|
||||
"path" => {
|
||||
if output.1.is_some() {
|
||||
panic!(
|
||||
"Only one of `src`, `path`, or `bytes` can be defined per Shader entry"
|
||||
)
|
||||
}
|
||||
|
||||
let path: LitStr = input.parse()?;
|
||||
output.1 = Some(SourceKind::Path(path.value()));
|
||||
}
|
||||
|
||||
"src" => {
|
||||
if output.1.is_some() {
|
||||
panic!("Only one of `src`, `path`, `bytes` can be defined per Shader entry")
|
||||
}
|
||||
|
||||
let src: LitStr = input.parse()?;
|
||||
output.1 = Some(SourceKind::Src(src.value()));
|
||||
}
|
||||
|
||||
other => unreachable!("Unexpected entry key {:?}", other),
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
while !input.is_empty() {
|
||||
let name: Ident = input.parse()?;
|
||||
input.parse::<Token![:]>()?;
|
||||
let name = name.to_string();
|
||||
|
||||
match name.as_str() {
|
||||
"bytes" | "src" | "path" | "ty" => {
|
||||
if shaders.len() > 1 || (shaders.len() == 1 && !shaders.contains_key("")) {
|
||||
panic!("Only one of `shaders`, `src`, `path`, or `bytes` can be defined");
|
||||
}
|
||||
|
||||
parse_shader_fields(
|
||||
shaders
|
||||
.entry("".to_string())
|
||||
.or_insert_with(Default::default),
|
||||
name.as_str(),
|
||||
input,
|
||||
)?;
|
||||
}
|
||||
"shaders" => {
|
||||
if !shaders.is_empty() {
|
||||
panic!("Only one of `shaders`, `src`, `path`, or `bytes` can be defined");
|
||||
}
|
||||
|
||||
let in_braces;
|
||||
braced!(in_braces in input);
|
||||
|
||||
while !in_braces.is_empty() {
|
||||
let prefix: Ident = in_braces.parse()?;
|
||||
let prefix = prefix.to_string();
|
||||
|
||||
if prefix.to_string().as_str() == "shared_constants" {
|
||||
in_braces.parse::<Token![:]>()?;
|
||||
|
||||
if shared_constants.is_some() {
|
||||
panic!("Only one `shared_constants` can be defined")
|
||||
}
|
||||
let independent_constants_lit: LitBool = in_braces.parse()?;
|
||||
shared_constants = Some(independent_constants_lit.value);
|
||||
|
||||
if !in_braces.is_empty() {
|
||||
in_braces.parse::<Token![,]>()?;
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
if shaders.contains_key(&prefix) {
|
||||
panic!("Shader entry {:?} already defined", prefix);
|
||||
}
|
||||
|
||||
in_braces.parse::<Token![:]>()?;
|
||||
|
||||
let in_shader_definition;
|
||||
braced!(in_shader_definition in in_braces);
|
||||
|
||||
while !in_shader_definition.is_empty() {
|
||||
let name: Ident = in_shader_definition.parse()?;
|
||||
in_shader_definition.parse::<Token![:]>()?;
|
||||
let name = name.to_string();
|
||||
|
||||
match name.as_ref() {
|
||||
"bytes" | "src" | "path" | "ty" => {
|
||||
parse_shader_fields(
|
||||
shaders
|
||||
.entry(prefix.clone())
|
||||
.or_insert_with(Default::default),
|
||||
name.as_str(),
|
||||
&in_shader_definition,
|
||||
)?;
|
||||
}
|
||||
|
||||
name => panic!("Unknown Shader definition field {:?}", name),
|
||||
}
|
||||
|
||||
if !in_shader_definition.is_empty() {
|
||||
in_shader_definition.parse::<Token![,]>()?;
|
||||
}
|
||||
}
|
||||
|
||||
if !in_braces.is_empty() {
|
||||
in_braces.parse::<Token![,]>()?;
|
||||
}
|
||||
|
||||
match shaders.get(&prefix).unwrap() {
|
||||
(None, _) => panic!("Please specify shader's {} type e.g. `ty: \"vertex\"`", prefix),
|
||||
(_, None) => panic!("Please specify shader's {} source e.g. `path: \"entry_point.glsl\"`", prefix),
|
||||
_ => ()
|
||||
}
|
||||
}
|
||||
|
||||
if shaders.is_empty() {
|
||||
panic!("At least one Shader entry must be defined");
|
||||
}
|
||||
}
|
||||
"define" => {
|
||||
let array_input;
|
||||
@ -373,14 +594,6 @@ impl Parse for MacroInput {
|
||||
}
|
||||
}
|
||||
}
|
||||
"path" => {
|
||||
if source_kind.is_some() {
|
||||
panic!("Only one of `src`, `path`, or `bytes` can be defined")
|
||||
}
|
||||
|
||||
let path: LitStr = input.parse()?;
|
||||
source_kind = Some(SourceKind::Path(path.value()));
|
||||
}
|
||||
"spirv_version" => {
|
||||
let version: LitStr = input.parse()?;
|
||||
spirv_version = Some(match version.value().as_ref() {
|
||||
@ -393,31 +606,6 @@ impl Parse for MacroInput {
|
||||
_ => panic!("Unknown SPIR-V version: {}", version.value()),
|
||||
});
|
||||
}
|
||||
"src" => {
|
||||
if source_kind.is_some() {
|
||||
panic!("Only one of `src`, `path`, or `bytes` can be defined")
|
||||
}
|
||||
|
||||
let src: LitStr = input.parse()?;
|
||||
source_kind = Some(SourceKind::Src(src.value()));
|
||||
}
|
||||
"ty" => {
|
||||
if shader_kind.is_some() {
|
||||
panic!("Only one `ty` can be defined")
|
||||
}
|
||||
|
||||
let ty: LitStr = input.parse()?;
|
||||
let ty = match ty.value().as_ref() {
|
||||
"vertex" => ShaderKind::Vertex,
|
||||
"fragment" => ShaderKind::Fragment,
|
||||
"geometry" => ShaderKind::Geometry,
|
||||
"tess_ctrl" => ShaderKind::TessControl,
|
||||
"tess_eval" => ShaderKind::TessEvaluation,
|
||||
"compute" => ShaderKind::Compute,
|
||||
_ => panic!("Unexpected shader type, valid values: vertex, fragment, geometry, tess_ctrl, tess_eval, compute")
|
||||
};
|
||||
shader_kind = Some(ty);
|
||||
}
|
||||
"types_meta" => {
|
||||
let in_braces;
|
||||
braced!(in_braces in input);
|
||||
@ -565,7 +753,7 @@ impl Parse for MacroInput {
|
||||
_ => panic!("Unknown Vulkan version: {}", version.value()),
|
||||
});
|
||||
}
|
||||
name => panic!("Unknown field name: {}", name),
|
||||
name => panic!("Unknown field {:?}", name),
|
||||
}
|
||||
|
||||
if !input.is_empty() {
|
||||
@ -573,25 +761,30 @@ impl Parse for MacroInput {
|
||||
}
|
||||
}
|
||||
|
||||
let shader_kind = match shader_kind {
|
||||
Some(shader_kind) => shader_kind,
|
||||
None => panic!("Please provide a shader type e.g. `ty: \"vertex\"`"),
|
||||
};
|
||||
if shaders.is_empty() {
|
||||
panic!("Please specify at least one shader e.g. `ty: \"vertex\", src: \"glsl source code\"`");
|
||||
}
|
||||
|
||||
let source_kind = match source_kind {
|
||||
Some(source_kind) => source_kind,
|
||||
None => panic!("Please provide a source e.g. `path: \"foo.glsl\"` or `src: \"glsl source code here ...\"`")
|
||||
};
|
||||
|
||||
let dump = dump.unwrap_or(false);
|
||||
match shaders.get("") {
|
||||
Some((None, _)) => panic!("Please specify shader's type e.g. `ty: \"vertex\"`"),
|
||||
Some((_, None)) => {
|
||||
panic!("Please specify shader's source e.g. `src: \"glsl source code\"`")
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
dump,
|
||||
dump: dump.unwrap_or(false),
|
||||
exact_entrypoint_interface: exact_entrypoint_interface.unwrap_or(false),
|
||||
include_directories,
|
||||
macro_defines,
|
||||
shader_kind,
|
||||
source_kind,
|
||||
shared_constants: shared_constants.unwrap_or(false),
|
||||
shaders: shaders
|
||||
.into_iter()
|
||||
.map(|(key, (shader_kind, shader_source))| {
|
||||
(key, (shader_kind.unwrap(), shader_source.unwrap()))
|
||||
})
|
||||
.collect(),
|
||||
spirv_version,
|
||||
types_meta: types_meta.unwrap_or_else(|| TypesMeta::default()),
|
||||
vulkan_version,
|
||||
@ -609,92 +802,169 @@ pub(self) fn read_file_to_string(full_path: &Path) -> IoResult<String> {
|
||||
pub fn shader(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
|
||||
let input = parse_macro_input!(input as MacroInput);
|
||||
|
||||
let is_single = input.shaders.len() == 1;
|
||||
let root = env::var("CARGO_MANIFEST_DIR").unwrap_or(".".into());
|
||||
let root_path = Path::new(&root);
|
||||
|
||||
if let SourceKind::Bytes(path) = input.source_kind {
|
||||
let full_path = root_path.join(&path);
|
||||
let mut shaders_code = Vec::with_capacity(input.shaders.len());
|
||||
let mut types_code = Vec::with_capacity(input.shaders.len());
|
||||
let mut types_registry = HashMap::new();
|
||||
|
||||
let bytes = if full_path.is_file() {
|
||||
fs::read(full_path).expect(&format!("Error reading source from {:?}", path))
|
||||
for (prefix, (shader_kind, shader_source)) in input.shaders {
|
||||
let (code, types) = if let SourceKind::Bytes(path) = shader_source {
|
||||
let full_path = root_path.join(&path);
|
||||
|
||||
let bytes = if full_path.is_file() {
|
||||
fs::read(full_path).expect(&format!("Error reading source from {:?}", path))
|
||||
} else {
|
||||
panic!(
|
||||
"File {:?} was not found; note that the path must be relative to your Cargo.toml",
|
||||
path
|
||||
);
|
||||
};
|
||||
|
||||
// The SPIR-V specification essentially guarantees that
|
||||
// a shader will always be an integer number of words
|
||||
assert_eq!(0, bytes.len() % 4);
|
||||
codegen::reflect(
|
||||
prefix.as_str(),
|
||||
unsafe { from_raw_parts(bytes.as_slice().as_ptr() as *const u32, bytes.len() / 4) },
|
||||
&input.types_meta,
|
||||
empty(),
|
||||
input.exact_entrypoint_interface,
|
||||
input.shared_constants,
|
||||
&mut types_registry,
|
||||
)
|
||||
.unwrap()
|
||||
.into()
|
||||
} else {
|
||||
panic!(
|
||||
"File {:?} was not found ; note that the path must be relative to your Cargo.toml",
|
||||
path
|
||||
);
|
||||
};
|
||||
let (path, full_path, source_code) = match shader_source {
|
||||
SourceKind::Src(source) => (None, None, source),
|
||||
SourceKind::Path(path) => {
|
||||
let full_path = root_path.join(&path);
|
||||
let source_code = read_file_to_string(&full_path)
|
||||
.expect(&format!("Error reading source from {:?}", path));
|
||||
|
||||
// The SPIR-V specification essentially guarantees that
|
||||
// a shader will always be an integer number of words
|
||||
assert_eq!(0, bytes.len() % 4);
|
||||
codegen::reflect(
|
||||
"Shader",
|
||||
unsafe { from_raw_parts(bytes.as_slice().as_ptr() as *const u32, bytes.len() / 4) },
|
||||
input.types_meta,
|
||||
empty(),
|
||||
input.exact_entrypoint_interface,
|
||||
input.dump,
|
||||
)
|
||||
.unwrap()
|
||||
.into()
|
||||
} else {
|
||||
let (path, full_path, source_code) = match input.source_kind {
|
||||
SourceKind::Src(source) => (None, None, source),
|
||||
SourceKind::Path(path) => {
|
||||
let full_path = root_path.join(&path);
|
||||
let source_code = read_file_to_string(&full_path)
|
||||
.expect(&format!("Error reading source from {:?}", path));
|
||||
|
||||
if full_path.is_file() {
|
||||
(Some(path.clone()), Some(full_path), source_code)
|
||||
} else {
|
||||
panic!("File {:?} was not found ; note that the path must be relative to your Cargo.toml", path);
|
||||
if full_path.is_file() {
|
||||
(Some(path.clone()), Some(full_path), source_code)
|
||||
} else {
|
||||
panic!("File {:?} was not found; note that the path must be relative to your Cargo.toml", path);
|
||||
}
|
||||
}
|
||||
}
|
||||
SourceKind::Bytes(_) => unreachable!(),
|
||||
SourceKind::Bytes(_) => unreachable!(),
|
||||
};
|
||||
|
||||
let include_paths = input
|
||||
.include_directories
|
||||
.iter()
|
||||
.map(|include_directory| {
|
||||
let include_path = Path::new(include_directory);
|
||||
let mut full_include_path = root_path.to_owned();
|
||||
full_include_path.push(include_path);
|
||||
full_include_path
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let (content, includes) = match codegen::compile(
|
||||
path,
|
||||
&root_path,
|
||||
&source_code,
|
||||
shader_kind,
|
||||
&include_paths,
|
||||
&input.macro_defines,
|
||||
input.vulkan_version,
|
||||
input.spirv_version,
|
||||
) {
|
||||
Ok(ok) => ok,
|
||||
Err(e) => {
|
||||
if is_single {
|
||||
panic!("{}", e.replace("(s): ", "(s):\n"))
|
||||
} else {
|
||||
panic!("Shader {:?} {}", prefix, e.replace("(s): ", "(s):\n"))
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
let input_paths = includes.iter().map(|s| s.as_ref()).chain(
|
||||
full_path
|
||||
.as_ref()
|
||||
.map(|p| p.as_path())
|
||||
.map(codegen::path_to_str),
|
||||
);
|
||||
|
||||
codegen::reflect(
|
||||
prefix.as_str(),
|
||||
content.as_binary(),
|
||||
&input.types_meta,
|
||||
input_paths,
|
||||
input.exact_entrypoint_interface,
|
||||
input.shared_constants,
|
||||
&mut types_registry,
|
||||
)
|
||||
.unwrap()
|
||||
.into()
|
||||
};
|
||||
|
||||
let include_paths = input
|
||||
.include_directories
|
||||
.iter()
|
||||
.map(|include_directory| {
|
||||
let include_path = Path::new(include_directory);
|
||||
let mut full_include_path = root_path.to_owned();
|
||||
full_include_path.push(include_path);
|
||||
full_include_path
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let (content, includes) = match codegen::compile(
|
||||
path,
|
||||
&root_path,
|
||||
&source_code,
|
||||
input.shader_kind,
|
||||
&include_paths,
|
||||
&input.macro_defines,
|
||||
input.vulkan_version,
|
||||
input.spirv_version,
|
||||
) {
|
||||
Ok(ok) => ok,
|
||||
Err(e) => panic!("{}", e.replace("(s): ", "(s):\n")),
|
||||
};
|
||||
|
||||
let input_paths = includes.iter().map(|s| s.as_ref()).chain(
|
||||
full_path
|
||||
.as_ref()
|
||||
.map(|p| p.as_path())
|
||||
.map(codegen::path_to_str),
|
||||
);
|
||||
|
||||
codegen::reflect(
|
||||
"Shader",
|
||||
content.as_binary(),
|
||||
input.types_meta,
|
||||
input_paths,
|
||||
input.exact_entrypoint_interface,
|
||||
input.dump,
|
||||
)
|
||||
.unwrap()
|
||||
.into()
|
||||
shaders_code.push(code);
|
||||
types_code.push(types);
|
||||
}
|
||||
|
||||
let uses = &input.types_meta.uses;
|
||||
|
||||
let result = quote! {
|
||||
#[allow(unused_imports)]
|
||||
use std::sync::Arc;
|
||||
#[allow(unused_imports)]
|
||||
use std::vec::IntoIter as VecIntoIter;
|
||||
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::device::Device;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDesc;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDescTy;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorDescImage;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorSetDesc;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::layout::DescriptorSetLayout;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::descriptor_set::DescriptorSet;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::format::Format;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::image::view::ImageViewType;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::layout::PipelineLayout;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::layout::PipelineLayoutPcRange;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::ShaderStages;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::SpecializationConstants as SpecConstsTrait;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::pipeline::shader::SpecializationMapEntry;
|
||||
#[allow(unused_imports)]
|
||||
use vulkano::Version;
|
||||
|
||||
#(
|
||||
#shaders_code
|
||||
)*
|
||||
|
||||
pub mod ty {
|
||||
#( #uses )*
|
||||
|
||||
#(
|
||||
#types_code
|
||||
)*
|
||||
}
|
||||
};
|
||||
|
||||
if input.dump {
|
||||
println!("{}", result.to_string());
|
||||
panic!("`shader!` rust codegen dumped") // TODO: use span from dump
|
||||
}
|
||||
|
||||
proc_macro::TokenStream::from(result)
|
||||
}
|
||||
|
@ -8,10 +8,12 @@
|
||||
// according to those terms.
|
||||
|
||||
use crate::parse::{Instruction, Spirv};
|
||||
use crate::structs;
|
||||
use crate::{spirv_search, TypesMeta};
|
||||
use crate::{structs, RegisteredType};
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use spirv::Decoration;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
use syn::Ident;
|
||||
|
||||
@ -32,11 +34,18 @@ pub fn has_specialization_constants(doc: &Spirv) -> bool {
|
||||
|
||||
/// Writes the `SpecializationConstants` struct that contains the specialization constants and
|
||||
/// implements the `Default` and the `vulkano::pipeline::shader::SpecializationConstants` traits.
|
||||
pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta) -> TokenStream {
|
||||
pub(super) fn write_specialization_constants<'a>(
|
||||
shader: &'a str,
|
||||
doc: &Spirv,
|
||||
types_meta: &TypesMeta,
|
||||
shared_constants: bool,
|
||||
types_registry: &'a mut HashMap<String, RegisteredType>,
|
||||
) -> TokenStream {
|
||||
struct SpecConst {
|
||||
name: String,
|
||||
constant_id: u32,
|
||||
rust_ty: TokenStream,
|
||||
rust_signature: Cow<'static, str>,
|
||||
rust_size: usize,
|
||||
rust_alignment: u32,
|
||||
default_value: TokenStream,
|
||||
@ -79,8 +88,8 @@ pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let (rust_ty, rust_size, rust_alignment) =
|
||||
spec_const_type_from_id(doc, type_id, types_meta);
|
||||
let (rust_ty, rust_signature, rust_size, rust_alignment) =
|
||||
spec_const_type_from_id(shader, doc, type_id, types_meta);
|
||||
let rust_size = rust_size.expect("Found runtime-sized specialization constant");
|
||||
|
||||
let constant_id = doc.get_decoration_params(result_id, Decoration::SpecId);
|
||||
@ -98,6 +107,7 @@ pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta
|
||||
name,
|
||||
constant_id,
|
||||
rust_ty,
|
||||
rust_signature,
|
||||
rust_size,
|
||||
rust_alignment: rust_alignment as u32,
|
||||
default_value,
|
||||
@ -105,6 +115,38 @@ pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta
|
||||
}
|
||||
}
|
||||
|
||||
let struct_name = Ident::new(
|
||||
&format!(
|
||||
"{}SpecializationConstants",
|
||||
if shared_constants { "" } else { shader }
|
||||
),
|
||||
Span::call_site(),
|
||||
);
|
||||
|
||||
// For multi-constants mode registration mechanism skipped
|
||||
if shared_constants {
|
||||
let target_type = RegisteredType {
|
||||
shader: shader.to_string(),
|
||||
signature: spec_consts
|
||||
.iter()
|
||||
.map(|member| (member.name.to_string(), member.rust_signature.clone()))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let name = struct_name.to_string();
|
||||
|
||||
// Checking with Registry if this struct already registered by another shader, and if their
|
||||
// signatures match.
|
||||
if let Some(registered) = types_registry.get(name.as_str()) {
|
||||
registered.assert_signatures(name.as_str(), &target_type);
|
||||
|
||||
// If the struct already registered and matches this one, skip duplicate.
|
||||
return quote! {};
|
||||
}
|
||||
|
||||
debug_assert!(types_registry.insert(name, target_type).is_none());
|
||||
}
|
||||
|
||||
let map_entries = {
|
||||
let mut map_entries = Vec::new();
|
||||
let mut curr_offset = 0;
|
||||
@ -143,19 +185,19 @@ pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta
|
||||
#[derive(Debug, Copy, Clone)]
|
||||
#[allow(non_snake_case)]
|
||||
#[repr(C)]
|
||||
pub struct SpecializationConstants {
|
||||
pub struct #struct_name {
|
||||
#( #struct_members ),*
|
||||
}
|
||||
|
||||
impl Default for SpecializationConstants {
|
||||
fn default() -> SpecializationConstants {
|
||||
SpecializationConstants {
|
||||
impl Default for #struct_name {
|
||||
fn default() -> #struct_name {
|
||||
#struct_name {
|
||||
#( #struct_member_defaults ),*
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
unsafe impl SpecConstsTrait for SpecializationConstants {
|
||||
unsafe impl SpecConstsTrait for #struct_name {
|
||||
fn descriptors() -> &'static [SpecializationMapEntry] {
|
||||
static DESCRIPTORS: [SpecializationMapEntry; #num_map_entries] = [
|
||||
#( #map_entries ),*
|
||||
@ -168,15 +210,17 @@ pub(super) fn write_specialization_constants(doc: &Spirv, types_meta: &TypesMeta
|
||||
|
||||
// Wrapper around `type_from_id` that also handles booleans.
|
||||
fn spec_const_type_from_id(
|
||||
shader: &str,
|
||||
doc: &Spirv,
|
||||
searched: u32,
|
||||
types_meta: &TypesMeta,
|
||||
) -> (TokenStream, Option<usize>, usize) {
|
||||
) -> (TokenStream, Cow<'static, str>, Option<usize>, usize) {
|
||||
for instruction in doc.instructions.iter() {
|
||||
match instruction {
|
||||
&Instruction::TypeBool { result_id } if result_id == searched => {
|
||||
return (
|
||||
quote! {u32},
|
||||
Cow::from("u32"),
|
||||
Some(mem::size_of::<u32>()),
|
||||
mem::align_of::<u32>(),
|
||||
);
|
||||
@ -185,5 +229,5 @@ fn spec_const_type_from_id(
|
||||
}
|
||||
}
|
||||
|
||||
structs::type_from_id(doc, searched, types_meta)
|
||||
structs::type_from_id(shader, doc, searched, types_meta)
|
||||
}
|
||||
|
@ -8,22 +8,39 @@
|
||||
// according to those terms.
|
||||
|
||||
use crate::parse::{Instruction, Spirv};
|
||||
use crate::{spirv_search, TypesMeta};
|
||||
use crate::{spirv_search, RegisteredType, TypesMeta};
|
||||
use proc_macro2::{Span, TokenStream};
|
||||
use spirv::Decoration;
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::mem;
|
||||
use syn::Ident;
|
||||
use syn::LitStr;
|
||||
|
||||
/// Translates all the structs that are contained in the SPIR-V document as Rust structs.
|
||||
pub(super) fn write_structs(doc: &Spirv, types_meta: &TypesMeta) -> TokenStream {
|
||||
pub(super) fn write_structs<'a>(
|
||||
shader: &'a str,
|
||||
doc: &Spirv,
|
||||
types_meta: &TypesMeta,
|
||||
types_registry: &'a mut HashMap<String, RegisteredType>,
|
||||
) -> TokenStream {
|
||||
let mut structs = vec![];
|
||||
for instruction in &doc.instructions {
|
||||
match *instruction {
|
||||
Instruction::TypeStruct {
|
||||
result_id,
|
||||
ref member_types,
|
||||
} => structs.push(write_struct(doc, result_id, member_types, types_meta).0),
|
||||
} => structs.push(
|
||||
write_struct(
|
||||
shader,
|
||||
doc,
|
||||
result_id,
|
||||
member_types,
|
||||
types_meta,
|
||||
Some(types_registry),
|
||||
)
|
||||
.0,
|
||||
),
|
||||
_ => (),
|
||||
}
|
||||
}
|
||||
@ -34,11 +51,13 @@ pub(super) fn write_structs(doc: &Spirv, types_meta: &TypesMeta) -> TokenStream
|
||||
}
|
||||
|
||||
/// Analyzes a single struct, returns a string containing its Rust definition, plus its size.
|
||||
fn write_struct(
|
||||
fn write_struct<'a>(
|
||||
shader: &'a str,
|
||||
doc: &Spirv,
|
||||
struct_id: u32,
|
||||
members: &[u32],
|
||||
types_meta: &TypesMeta,
|
||||
types_registry: Option<&'a mut HashMap<String, RegisteredType>>,
|
||||
) -> (TokenStream, Option<usize>) {
|
||||
let name = Ident::new(
|
||||
&spirv_search::name_from_id(doc, struct_id),
|
||||
@ -47,9 +66,10 @@ fn write_struct(
|
||||
|
||||
// The members of this struct.
|
||||
struct Member {
|
||||
pub name: Ident,
|
||||
pub dummy: bool,
|
||||
pub ty: TokenStream,
|
||||
name: Ident,
|
||||
dummy: bool,
|
||||
ty: TokenStream,
|
||||
signature: Cow<'static, str>,
|
||||
}
|
||||
let mut rust_members = Vec::with_capacity(members.len());
|
||||
|
||||
@ -62,7 +82,7 @@ fn write_struct(
|
||||
|
||||
for (num, &member) in members.iter().enumerate() {
|
||||
// Compute infos about the member.
|
||||
let (ty, rust_size, rust_align) = type_from_id(doc, member, types_meta);
|
||||
let (ty, signature, rust_size, rust_align) = type_from_id(shader, doc, member, types_meta);
|
||||
let member_name = spirv_search::member_name_from_id(doc, struct_id, num as u32);
|
||||
|
||||
// Ignore the whole struct is a member is built in, which includes
|
||||
@ -107,6 +127,7 @@ fn write_struct(
|
||||
name: Ident::new(&format!("_dummy{}", padding_num), Span::call_site()),
|
||||
dummy: true,
|
||||
ty: quote! { [u8; #diff] },
|
||||
signature: Cow::from(format!("[u8; {}]", diff)),
|
||||
});
|
||||
*current_rust_offset += diff;
|
||||
}
|
||||
@ -123,6 +144,7 @@ fn write_struct(
|
||||
name: Ident::new(&member_name, Span::call_site()),
|
||||
dummy: false,
|
||||
ty,
|
||||
signature,
|
||||
});
|
||||
}
|
||||
|
||||
@ -156,10 +178,39 @@ fn write_struct(
|
||||
name: Ident::new(&format!("_dummy{}", next_padding_num), Span::call_site()),
|
||||
dummy: true,
|
||||
ty: quote! { [u8; #diff as usize] },
|
||||
signature: Cow::from(format!("[u8; {}]", diff)),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let total_size = spirv_req_total_size
|
||||
.map(|sz| sz as usize)
|
||||
.or(current_rust_offset);
|
||||
|
||||
// For single shader-mode registration mechanism skipped.
|
||||
if let Some(types_registry) = types_registry {
|
||||
let target_type = RegisteredType {
|
||||
shader: shader.to_string(),
|
||||
signature: rust_members
|
||||
.iter()
|
||||
.map(|member| (member.name.to_string(), member.signature.clone()))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
let name = name.to_string();
|
||||
|
||||
// Checking with Registry if this struct already registered by another shader, and if their
|
||||
// signatures match.
|
||||
if let Some(registered) = types_registry.get(name.as_str()) {
|
||||
registered.assert_signatures(name.as_str(), &target_type);
|
||||
|
||||
// If the struct already registered and matches this one, skip duplicate.
|
||||
return (quote! {}, total_size);
|
||||
}
|
||||
|
||||
debug_assert!(types_registry.insert(name, target_type).is_none());
|
||||
}
|
||||
|
||||
// We can only implement Clone if there's no unsized member in the struct.
|
||||
let (clone_impl, copy_derive) =
|
||||
if current_rust_offset.is_some() && (types_meta.clone || types_meta.copy) {
|
||||
@ -325,22 +376,18 @@ fn write_struct(
|
||||
#custom_impls
|
||||
};
|
||||
|
||||
(
|
||||
ast,
|
||||
spirv_req_total_size
|
||||
.map(|sz| sz as usize)
|
||||
.or(current_rust_offset),
|
||||
)
|
||||
(ast, total_size)
|
||||
}
|
||||
|
||||
/// Returns the type name to put in the Rust struct, and its size and alignment.
|
||||
///
|
||||
/// The size can be `None` if it's only known at runtime.
|
||||
pub(super) fn type_from_id(
|
||||
shader: &str,
|
||||
doc: &Spirv,
|
||||
searched: u32,
|
||||
types_meta: &TypesMeta,
|
||||
) -> (TokenStream, Option<usize>, usize) {
|
||||
) -> (TokenStream, Cow<'static, str>, Option<usize>, usize) {
|
||||
for instruction in doc.instructions.iter() {
|
||||
match instruction {
|
||||
&Instruction::TypeBool { result_id } if result_id == searched => {
|
||||
@ -359,6 +406,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {i8},
|
||||
Cow::from("i8"),
|
||||
Some(std::mem::size_of::<i8>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -371,6 +419,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {u8},
|
||||
Cow::from("u8"),
|
||||
Some(std::mem::size_of::<u8>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -383,6 +432,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {i16},
|
||||
Cow::from("i16"),
|
||||
Some(std::mem::size_of::<i16>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -395,6 +445,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {u16},
|
||||
Cow::from("u16"),
|
||||
Some(std::mem::size_of::<u16>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -407,6 +458,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {i32},
|
||||
Cow::from("i32"),
|
||||
Some(std::mem::size_of::<i32>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -419,6 +471,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {u32},
|
||||
Cow::from("u32"),
|
||||
Some(std::mem::size_of::<u32>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -431,6 +484,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {i64},
|
||||
Cow::from("i64"),
|
||||
Some(std::mem::size_of::<i64>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -443,6 +497,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {u64},
|
||||
Cow::from("u64"),
|
||||
Some(std::mem::size_of::<u64>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -458,6 +513,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {f32},
|
||||
Cow::from("f32"),
|
||||
Some(std::mem::size_of::<f32>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -470,6 +526,7 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
return (
|
||||
quote! {f64},
|
||||
Cow::from("f64"),
|
||||
Some(std::mem::size_of::<f64>()),
|
||||
mem::align_of::<Foo>(),
|
||||
);
|
||||
@ -482,10 +539,16 @@ pub(super) fn type_from_id(
|
||||
count,
|
||||
} if result_id == searched => {
|
||||
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
|
||||
let (ty, t_size, t_align) = type_from_id(doc, component_id, types_meta);
|
||||
let (ty, item, t_size, t_align) =
|
||||
type_from_id(shader, doc, component_id, types_meta);
|
||||
let array_length = count as usize;
|
||||
let size = t_size.map(|s| s * count as usize);
|
||||
return (quote! { [#ty; #array_length] }, size, t_align);
|
||||
return (
|
||||
quote! { [#ty; #array_length] },
|
||||
Cow::from(format!("[{}; {}]", item, array_length)),
|
||||
size,
|
||||
t_align,
|
||||
);
|
||||
}
|
||||
&Instruction::TypeMatrix {
|
||||
result_id,
|
||||
@ -494,10 +557,16 @@ pub(super) fn type_from_id(
|
||||
} if result_id == searched => {
|
||||
// FIXME: row-major or column-major
|
||||
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
|
||||
let (ty, t_size, t_align) = type_from_id(doc, column_type_id, types_meta);
|
||||
let (ty, item, t_size, t_align) =
|
||||
type_from_id(shader, doc, column_type_id, types_meta);
|
||||
let array_length = column_count as usize;
|
||||
let size = t_size.map(|s| s * column_count as usize);
|
||||
return (quote! { [#ty; #array_length] }, size, t_align);
|
||||
return (
|
||||
quote! { [#ty; #array_length] },
|
||||
Cow::from(format!("[{}; {}]", item, array_length)),
|
||||
size,
|
||||
t_align,
|
||||
);
|
||||
}
|
||||
&Instruction::TypeArray {
|
||||
result_id,
|
||||
@ -505,7 +574,7 @@ pub(super) fn type_from_id(
|
||||
length_id,
|
||||
} if result_id == searched => {
|
||||
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
|
||||
let (ty, t_size, t_align) = type_from_id(doc, type_id, types_meta);
|
||||
let (ty, item, t_size, t_align) = type_from_id(shader, doc, type_id, types_meta);
|
||||
let t_size = t_size.expect("array components must be sized");
|
||||
let len = doc
|
||||
.instructions
|
||||
@ -532,30 +601,39 @@ pub(super) fn type_from_id(
|
||||
}
|
||||
let array_length = len as usize;
|
||||
let size = Some(t_size * len as usize);
|
||||
return (quote! { [#ty; #array_length] }, size, t_align);
|
||||
return (
|
||||
quote! { [#ty; #array_length] },
|
||||
Cow::from(format!("[{}; {}]", item, array_length)),
|
||||
size,
|
||||
t_align,
|
||||
);
|
||||
}
|
||||
&Instruction::TypeRuntimeArray { result_id, type_id } if result_id == searched => {
|
||||
debug_assert_eq!(mem::align_of::<[u32; 3]>(), mem::align_of::<u32>());
|
||||
let (ty, _, t_align) = type_from_id(doc, type_id, types_meta);
|
||||
return (quote! { [#ty] }, None, t_align);
|
||||
let (ty, name, _, t_align) = type_from_id(shader, doc, type_id, types_meta);
|
||||
return (
|
||||
quote! { [#ty] },
|
||||
Cow::from(format!("[{}]", name)),
|
||||
None,
|
||||
t_align,
|
||||
);
|
||||
}
|
||||
&Instruction::TypeStruct {
|
||||
result_id,
|
||||
ref member_types,
|
||||
} if result_id == searched => {
|
||||
// TODO: take the Offset member decorate into account?
|
||||
let name = Ident::new(
|
||||
&spirv_search::name_from_id(doc, result_id),
|
||||
Span::call_site(),
|
||||
);
|
||||
let name_string = spirv_search::name_from_id(doc, result_id);
|
||||
let name = Ident::new(&name_string, Span::call_site());
|
||||
let ty = quote! { #name };
|
||||
let (_, size) = write_struct(doc, result_id, member_types, types_meta);
|
||||
let (_, size) =
|
||||
write_struct(shader, doc, result_id, member_types, types_meta, None);
|
||||
let align = member_types
|
||||
.iter()
|
||||
.map(|&t| type_from_id(doc, t, types_meta).2)
|
||||
.map(|&t| type_from_id(shader, doc, t, types_meta).3)
|
||||
.max()
|
||||
.unwrap_or(1);
|
||||
return (ty, size, align);
|
||||
return (ty, Cow::from(name_string), size, align);
|
||||
}
|
||||
_ => (),
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user