mirror of
https://github.com/gfx-rs/wgpu.git
synced 2025-02-14 16:02:47 +00:00
Add optional SPIR-V shader validation
# Conflicts: # Cargo.lock
This commit is contained in:
parent
54c6f6751b
commit
f70f32af87
24
Cargo.lock
generated
24
Cargo.lock
generated
@ -744,6 +744,18 @@ dependencies = [
|
||||
"ws2_32-sys",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "naga"
|
||||
version = "0.1.0"
|
||||
source = "git+https://github.com/gfx-rs/naga?rev=bce6358eb1026c13d2f1c6d365af37afe8869a86#bce6358eb1026c13d2f1c6d365af37afe8869a86"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"fxhash",
|
||||
"log",
|
||||
"num-traits",
|
||||
"spirv_headers",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "net2"
|
||||
version = "0.2.33"
|
||||
@ -1133,6 +1145,16 @@ dependencies = [
|
||||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "spirv_headers"
|
||||
version = "1.4.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3f1418983d16481227ffa3ab3cf44ef92eebc9a76c092fbcd4c51a64ff032622"
|
||||
dependencies = [
|
||||
"bitflags",
|
||||
"num-traits",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "stb_truetype"
|
||||
version = "0.3.1"
|
||||
@ -1376,12 +1398,14 @@ dependencies = [
|
||||
"gfx-memory",
|
||||
"log",
|
||||
"loom",
|
||||
"naga",
|
||||
"parking_lot",
|
||||
"peek-poke",
|
||||
"raw-window-handle",
|
||||
"ron",
|
||||
"serde",
|
||||
"smallvec",
|
||||
"spirv_headers",
|
||||
"vec_map",
|
||||
"wgpu-types",
|
||||
]
|
||||
|
@ -38,8 +38,13 @@ raw-window-handle = { version = "0.3", optional = true }
|
||||
ron = { version = "0.5", optional = true }
|
||||
serde = { version = "1.0", features = ["serde_derive"], optional = true }
|
||||
smallvec = "1"
|
||||
spirv_headers = { version = "1.4.2" }
|
||||
vec_map = "0.8"
|
||||
|
||||
[dependencies.naga]
|
||||
git = "https://github.com/gfx-rs/naga"
|
||||
rev = "bce6358eb1026c13d2f1c6d365af37afe8869a86"
|
||||
|
||||
[dependencies.wgt]
|
||||
path = "../wgpu-types"
|
||||
package = "wgpu-types"
|
||||
|
@ -30,6 +30,8 @@ use std::{
|
||||
sync::atomic::Ordering,
|
||||
};
|
||||
|
||||
use spirv_headers::ExecutionModel;
|
||||
|
||||
mod life;
|
||||
#[cfg(any(feature = "trace", feature = "replay"))]
|
||||
pub mod trace;
|
||||
@ -1511,12 +1513,26 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
|
||||
let spv = unsafe { slice::from_raw_parts(desc.code.bytes, desc.code.length) };
|
||||
let raw = unsafe { device.raw.create_shader_module(spv).unwrap() };
|
||||
|
||||
let module = {
|
||||
// Parse the given shader code and store its representation.
|
||||
let spv_iter = spv.into_iter().cloned();
|
||||
let mut parser = naga::front::spirv::Parser::new(spv_iter);
|
||||
parser
|
||||
.parse()
|
||||
.map_err(|err| {
|
||||
log::warn!("Failed to parse shader SPIR-V code: {:?}", err);
|
||||
log::warn!("Shader module will not be validated");
|
||||
})
|
||||
.ok()
|
||||
};
|
||||
let shader = pipeline::ShaderModule {
|
||||
raw,
|
||||
device_id: Stored {
|
||||
value: device_id,
|
||||
ref_count: device.life_guard.add_ref(),
|
||||
},
|
||||
module,
|
||||
};
|
||||
|
||||
let id = hub
|
||||
@ -2015,23 +2031,55 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
}
|
||||
};
|
||||
|
||||
let vertex = hal::pso::EntryPoint::<B> {
|
||||
entry: unsafe { ffi::CStr::from_ptr(desc.vertex_stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap(), // TODO
|
||||
module: &shader_module_guard[desc.vertex_stage.module].raw,
|
||||
specialization: hal::pso::Specialization::EMPTY,
|
||||
};
|
||||
let fragment =
|
||||
unsafe { desc.fragment_stage.as_ref() }.map(|stage| hal::pso::EntryPoint::<B> {
|
||||
entry: unsafe { ffi::CStr::from_ptr(stage.entry_point) }
|
||||
let vertex = {
|
||||
let entry_point_name =
|
||||
unsafe { ffi::CStr::from_ptr(desc.vertex_stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap(), // TODO
|
||||
module: &shader_module_guard[stage.module].raw,
|
||||
.unwrap();
|
||||
|
||||
let shader_module = &shader_module_guard[desc.vertex_stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) =
|
||||
validate_shader(module, entry_point_name, ExecutionModel::Vertex)
|
||||
{
|
||||
log::error!("Failed validating vertex shader module: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
hal::pso::EntryPoint::<B> {
|
||||
entry: entry_point_name, // TODO
|
||||
module: &shader_module.raw,
|
||||
specialization: hal::pso::Specialization::EMPTY,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let fragment = {
|
||||
let fragment_stage = unsafe { desc.fragment_stage.as_ref() };
|
||||
fragment_stage.map(|stage| {
|
||||
let entry_point_name = unsafe { ffi::CStr::from_ptr(stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap();
|
||||
|
||||
let shader_module = &shader_module_guard[stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) =
|
||||
validate_shader(module, entry_point_name, ExecutionModel::Fragment)
|
||||
{
|
||||
log::error!("Failed validating fragment shader module: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
hal::pso::EntryPoint::<B> {
|
||||
entry: entry_point_name, // TODO
|
||||
module: &shader_module.raw,
|
||||
specialization: hal::pso::Specialization::EMPTY,
|
||||
}
|
||||
})
|
||||
};
|
||||
|
||||
let shaders = hal::pso::GraphicsShaderSet {
|
||||
vertex,
|
||||
@ -2196,12 +2244,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
let pipeline_stage = &desc.compute_stage;
|
||||
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
|
||||
|
||||
let entry_point_name = unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap();
|
||||
|
||||
let shader_module = &shader_module_guard[pipeline_stage.module];
|
||||
|
||||
if let Some(ref module) = shader_module.module {
|
||||
if let Err(e) = validate_shader(module, entry_point_name, ExecutionModel::GLCompute)
|
||||
{
|
||||
log::error!("Failed validating compute shader module: {:?}", e);
|
||||
}
|
||||
}
|
||||
|
||||
let shader = hal::pso::EntryPoint::<B> {
|
||||
entry: unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
|
||||
.to_str()
|
||||
.to_owned()
|
||||
.unwrap(), // TODO
|
||||
module: &shader_module_guard[pipeline_stage.module].raw,
|
||||
entry: entry_point_name, // TODO
|
||||
module: &shader_module.raw,
|
||||
specialization: hal::pso::Specialization::EMPTY,
|
||||
};
|
||||
|
||||
@ -2576,3 +2635,27 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
buffer.map_state = resource::BufferMapState::Idle;
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors produced when validating the shader modules of a pipeline.
|
||||
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
|
||||
enum ShaderValidationError {
|
||||
/// Unable to find an entry point matching the specified execution model.
|
||||
MissingEntryPoint(ExecutionModel),
|
||||
}
|
||||
|
||||
fn validate_shader(
|
||||
module: &naga::Module,
|
||||
entry_point_name: &str,
|
||||
execution_model: ExecutionModel,
|
||||
) -> Result<(), ShaderValidationError> {
|
||||
// Since a shader module can have multiple entry points with the same name,
|
||||
// we need to look for one with the right execution model.
|
||||
let entry_point = module.entry_points.iter().find(|entry_point| {
|
||||
entry_point.name == entry_point_name && entry_point.exec_model == execution_model
|
||||
});
|
||||
|
||||
match entry_point {
|
||||
Some(_) => Ok(()),
|
||||
None => Err(ShaderValidationError::MissingEntryPoint(execution_model)),
|
||||
}
|
||||
}
|
||||
|
@ -40,6 +40,7 @@ pub struct ShaderModuleDescriptor {
|
||||
pub struct ShaderModule<B: hal::Backend> {
|
||||
pub(crate) raw: B::ShaderModule,
|
||||
pub(crate) device_id: Stored<DeviceId>,
|
||||
pub(crate) module: Option<naga::Module>,
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
|
Loading…
Reference in New Issue
Block a user