Add optional SPIR-V shader validation

# Conflicts:
#	Cargo.lock
This commit is contained in:
Gabriel Majeri 2020-05-09 19:55:10 +03:00
parent 54c6f6751b
commit f70f32af87
4 changed files with 132 additions and 19 deletions

24
Cargo.lock generated
View File

@ -744,6 +744,18 @@ dependencies = [
"ws2_32-sys",
]
[[package]]
name = "naga"
version = "0.1.0"
source = "git+https://github.com/gfx-rs/naga?rev=bce6358eb1026c13d2f1c6d365af37afe8869a86#bce6358eb1026c13d2f1c6d365af37afe8869a86"
dependencies = [
"bitflags",
"fxhash",
"log",
"num-traits",
"spirv_headers",
]
[[package]]
name = "net2"
version = "0.2.33"
@ -1133,6 +1145,16 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "spirv_headers"
version = "1.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3f1418983d16481227ffa3ab3cf44ef92eebc9a76c092fbcd4c51a64ff032622"
dependencies = [
"bitflags",
"num-traits",
]
[[package]]
name = "stb_truetype"
version = "0.3.1"
@ -1376,12 +1398,14 @@ dependencies = [
"gfx-memory",
"log",
"loom",
"naga",
"parking_lot",
"peek-poke",
"raw-window-handle",
"ron",
"serde",
"smallvec",
"spirv_headers",
"vec_map",
"wgpu-types",
]

View File

@ -38,8 +38,13 @@ raw-window-handle = { version = "0.3", optional = true }
ron = { version = "0.5", optional = true }
serde = { version = "1.0", features = ["serde_derive"], optional = true }
smallvec = "1"
spirv_headers = { version = "1.4.2" }
vec_map = "0.8"
[dependencies.naga]
git = "https://github.com/gfx-rs/naga"
rev = "bce6358eb1026c13d2f1c6d365af37afe8869a86"
[dependencies.wgt]
path = "../wgpu-types"
package = "wgpu-types"

View File

@ -30,6 +30,8 @@ use std::{
sync::atomic::Ordering,
};
use spirv_headers::ExecutionModel;
mod life;
#[cfg(any(feature = "trace", feature = "replay"))]
pub mod trace;
@ -1511,12 +1513,26 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let spv = unsafe { slice::from_raw_parts(desc.code.bytes, desc.code.length) };
let raw = unsafe { device.raw.create_shader_module(spv).unwrap() };
let module = {
// Parse the given shader code and store its representation.
let spv_iter = spv.into_iter().cloned();
let mut parser = naga::front::spirv::Parser::new(spv_iter);
parser
.parse()
.map_err(|err| {
log::warn!("Failed to parse shader SPIR-V code: {:?}", err);
log::warn!("Shader module will not be validated");
})
.ok()
};
let shader = pipeline::ShaderModule {
raw,
device_id: Stored {
value: device_id,
ref_count: device.life_guard.add_ref(),
},
module,
};
let id = hub
@ -2015,23 +2031,55 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
}
};
let vertex = hal::pso::EntryPoint::<B> {
entry: unsafe { ffi::CStr::from_ptr(desc.vertex_stage.entry_point) }
.to_str()
.to_owned()
.unwrap(), // TODO
module: &shader_module_guard[desc.vertex_stage.module].raw,
specialization: hal::pso::Specialization::EMPTY,
};
let fragment =
unsafe { desc.fragment_stage.as_ref() }.map(|stage| hal::pso::EntryPoint::<B> {
entry: unsafe { ffi::CStr::from_ptr(stage.entry_point) }
let vertex = {
let entry_point_name =
unsafe { ffi::CStr::from_ptr(desc.vertex_stage.entry_point) }
.to_str()
.to_owned()
.unwrap(), // TODO
module: &shader_module_guard[stage.module].raw,
.unwrap();
let shader_module = &shader_module_guard[desc.vertex_stage.module];
if let Some(ref module) = shader_module.module {
if let Err(e) =
validate_shader(module, entry_point_name, ExecutionModel::Vertex)
{
log::error!("Failed validating vertex shader module: {:?}", e);
}
}
hal::pso::EntryPoint::<B> {
entry: entry_point_name, // TODO
module: &shader_module.raw,
specialization: hal::pso::Specialization::EMPTY,
});
}
};
let fragment = {
let fragment_stage = unsafe { desc.fragment_stage.as_ref() };
fragment_stage.map(|stage| {
let entry_point_name = unsafe { ffi::CStr::from_ptr(stage.entry_point) }
.to_str()
.to_owned()
.unwrap();
let shader_module = &shader_module_guard[stage.module];
if let Some(ref module) = shader_module.module {
if let Err(e) =
validate_shader(module, entry_point_name, ExecutionModel::Fragment)
{
log::error!("Failed validating fragment shader module: {:?}", e);
}
}
hal::pso::EntryPoint::<B> {
entry: entry_point_name, // TODO
module: &shader_module.raw,
specialization: hal::pso::Specialization::EMPTY,
}
})
};
let shaders = hal::pso::GraphicsShaderSet {
vertex,
@ -2196,12 +2244,23 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
let pipeline_stage = &desc.compute_stage;
let (shader_module_guard, _) = hub.shader_modules.read(&mut token);
let entry_point_name = unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
.to_str()
.to_owned()
.unwrap();
let shader_module = &shader_module_guard[pipeline_stage.module];
if let Some(ref module) = shader_module.module {
if let Err(e) = validate_shader(module, entry_point_name, ExecutionModel::GLCompute)
{
log::error!("Failed validating compute shader module: {:?}", e);
}
}
let shader = hal::pso::EntryPoint::<B> {
entry: unsafe { ffi::CStr::from_ptr(pipeline_stage.entry_point) }
.to_str()
.to_owned()
.unwrap(), // TODO
module: &shader_module_guard[pipeline_stage.module].raw,
entry: entry_point_name, // TODO
module: &shader_module.raw,
specialization: hal::pso::Specialization::EMPTY,
};
@ -2576,3 +2635,27 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
buffer.map_state = resource::BufferMapState::Idle;
}
}
/// Errors produced when validating the shader modules of a pipeline.
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum ShaderValidationError {
/// Unable to find an entry point matching the specified execution model.
MissingEntryPoint(ExecutionModel),
}
fn validate_shader(
module: &naga::Module,
entry_point_name: &str,
execution_model: ExecutionModel,
) -> Result<(), ShaderValidationError> {
// Since a shader module can have multiple entry points with the same name,
// we need to look for one with the right execution model.
let entry_point = module.entry_points.iter().find(|entry_point| {
entry_point.name == entry_point_name && entry_point.exec_model == execution_model
});
match entry_point {
Some(_) => Ok(()),
None => Err(ShaderValidationError::MissingEntryPoint(execution_model)),
}
}

View File

@ -40,6 +40,7 @@ pub struct ShaderModuleDescriptor {
pub struct ShaderModule<B: hal::Backend> {
pub(crate) raw: B::ShaderModule,
pub(crate) device_id: Stored<DeviceId>,
pub(crate) module: Option<naga::Module>,
}
#[repr(C)]