Add Naga bypass to allow feeding raw SPIR-V shader data to the backend.

While Naga checking is undoubtedly very useful, it currently lags behind
what is possible in SPIR-V and even what is promised by WGPU (ie binding
arrays). This adds an unsafe method to wgpu::Device to allow feeding
raw SPIR-V data into the backend, and adds a feature flag to request a
backend supporting this operation.
This commit is contained in:
Alex S 2021-06-19 22:34:57 +03:00
parent a8be371acf
commit 6d2e6e5a56
13 changed files with 259 additions and 59 deletions

View File

@ -902,7 +902,7 @@ impl<A: HalApi> Device<A> {
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
.validate(&module)?;
let interface = validation::Interface::new(&module, &info);
let hal_shader = hal::NagaShader { module, info };
let hal_shader = hal::ShaderInput::NagaShader(hal::NagaShader { module, info });
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.borrow_option(),
@ -928,7 +928,47 @@ impl<A: HalApi> Device<A> {
value: id::Valid(self_id),
ref_count: self.life_guard.add_ref(),
},
interface,
interface: Some(interface),
#[cfg(debug_assertions)]
label: desc.label.borrow_or_default().to_string(),
})
}
#[allow(unused_unsafe)]
unsafe fn create_shader_module_spirv<'a>(
&self,
self_id: id::DeviceId,
desc: &pipeline::ShaderModuleDescriptor<'a>,
source: Cow<'a, [u32]>,
) -> Result<pipeline::ShaderModule<A>, pipeline::CreateShaderModuleError> {
self.require_features(wgt::Features::SPIR_V_SHADER_MODULES)
.map_err(pipeline::CreateShaderModuleError::MissingFeatures)?;
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.borrow_option(),
};
let hal_shader = hal::ShaderInput::SpirVShader(source);
let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } {
Ok(raw) => raw,
Err(error) => {
return Err(match error {
hal::ShaderError::Device(error) => {
pipeline::CreateShaderModuleError::Device(error.into())
}
hal::ShaderError::Compilation(ref msg) => {
log::error!("Shader error: {}", msg);
pipeline::CreateShaderModuleError::Generation
}
})
}
};
Ok(pipeline::ShaderModule {
raw,
device_id: Stored {
value: id::Valid(self_id),
ref_count: self.life_guard.add_ref(),
},
interface: None,
#[cfg(debug_assertions)]
label: desc.label.borrow_or_default().to_string(),
})
@ -1748,13 +1788,15 @@ impl<A: HalApi> Device<A> {
None
}
};
let _ = shader_module.interface.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&desc.stage.entry_point,
flag,
io,
)?;
if let Some(ref interface) = shader_module.interface {
let _ = interface.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&desc.stage.entry_point,
flag,
io,
)?;
}
}
let pipeline_layout_id = match desc.layout {
@ -2033,20 +2075,21 @@ impl<A: HalApi> Device<A> {
None => None,
};
io = shader_module
.interface
.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&stage.entry_point,
flag,
io,
)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage: flag,
error,
})?;
validated_stages |= flag;
if let Some(ref interface) = shader_module.interface {
io = interface
.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&stage.entry_point,
flag,
io,
)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage: flag,
error,
})?;
validated_stages |= flag;
}
hal::ProgrammableStage {
module: &shader_module.raw,
@ -2077,20 +2120,21 @@ impl<A: HalApi> Device<A> {
};
if validated_stages == wgt::ShaderStage::VERTEX {
io = shader_module
.interface
.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&fragment.stage.entry_point,
flag,
io,
)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage: flag,
error,
})?;
validated_stages |= flag;
if let Some(ref interface) = shader_module.interface {
io = interface
.check_stage(
provided_layouts.as_ref().map(|p| p.as_slice()),
&mut derived_group_layouts,
&fragment.stage.entry_point,
flag,
io,
)
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
stage: flag,
error,
})?;
validated_stages |= flag;
}
}
Some(hal::ProgrammableStage {
@ -3463,6 +3507,58 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
(id, Some(error))
}
#[allow(unused_unsafe)] // Unsafe-ness of internal calls has little to do with unsafe-ness of this.
/// # Safety
///
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
/// driver crash.
pub unsafe fn device_create_shader_module_spirv<A: HalApi>(
&self,
device_id: id::DeviceId,
desc: &pipeline::ShaderModuleDescriptor,
source: Cow<[u32]>,
id_in: Input<G, id::ShaderModuleId>,
) -> (
id::ShaderModuleId,
Option<pipeline::CreateShaderModuleError>,
) {
profiling::scope!("create_shader_module", "Device");
let hub = A::hub(self);
let mut token = Token::root();
let fid = hub.shader_modules.prepare(id_in);
let (device_guard, mut token) = hub.devices.read(&mut token);
let error = loop {
let device = match device_guard.get(device_id) {
Ok(device) => device,
Err(_) => break DeviceError::Invalid.into(),
};
#[cfg(feature = "trace")]
if let Some(ref trace) = device.trace {
let mut trace = trace.lock();
let data = trace.make_binary("spv", unsafe {
std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * 4)
});
trace.add(trace::Action::CreateShaderModule {
id: fid.id(),
desc: desc.clone(),
data,
});
};
let shader = match device.create_shader_module_spirv(device_id, desc, source) {
Ok(shader) => shader,
Err(e) => break e,
};
let id = fid.assign(shader, &mut token);
return (id.0, None);
};
let id = fid.assign_error(desc.label.borrow_or_default(), &mut token);
(id, Some(error))
}
pub fn shader_module_label<A: HalApi>(&self, id: id::ShaderModuleId) -> String {
A::hub(self).shader_modules.label_for_resource(id)
}

View File

@ -29,7 +29,7 @@ pub struct ShaderModuleDescriptor<'a> {
pub struct ShaderModule<A: hal::Api> {
pub(crate) raw: A::ShaderModule,
pub(crate) device_id: Stored<DeviceId>,
pub(crate) interface: validation::Interface,
pub(crate) interface: Option<validation::Interface>,
#[cfg(debug_assertions)]
pub(crate) label: String,
}

View File

@ -138,7 +138,7 @@ impl<A: hal::Api> Example<A> {
let shader_desc = hal::ShaderModuleDescriptor { label: None };
let shader = unsafe {
device
.create_shader_module(&shader_desc, naga_shader)
.create_shader_module(&shader_desc, hal::ShaderInput::NagaShader(naga_shader))
.unwrap()
};

View File

@ -174,7 +174,7 @@ impl crate::Device<Api> for Context {
unsafe fn create_shader_module(
&self,
desc: &crate::ShaderModuleDescriptor,
shader: crate::NagaShader,
shader: crate::ShaderInput,
) -> Result<Resource, crate::ShaderError> {
Ok(Resource)
}

View File

@ -64,7 +64,7 @@ pub mod api {
}
use std::{
borrow::Borrow,
borrow::{Borrow, Cow},
fmt,
num::NonZeroU8,
ops::{Range, RangeInclusive},
@ -257,7 +257,7 @@ pub trait Device<A: Api>: Send + Sync {
unsafe fn create_shader_module(
&self,
desc: &ShaderModuleDescriptor,
shader: NagaShader,
shader: ShaderInput,
) -> Result<A::ShaderModule, ShaderError>;
unsafe fn destroy_shader_module(&self, module: A::ShaderModule);
unsafe fn create_render_pipeline(
@ -839,6 +839,12 @@ pub struct NagaShader {
pub info: naga::valid::ModuleInfo,
}
/// Shader input.
pub enum ShaderInput<'a> {
NagaShader(NagaShader),
SpirVShader(Cow<'a, [u32]>),
}
pub struct ShaderModuleDescriptor<'a> {
pub label: Label<'a>,
}

View File

@ -630,9 +630,14 @@ impl crate::Device<super::Api> for super::Device {
unsafe fn create_shader_module(
&self,
_desc: &crate::ShaderModuleDescriptor,
shader: crate::NagaShader,
shader: crate::ShaderInput,
) -> Result<super::ShaderModule, crate::ShaderError> {
Ok(super::ShaderModule { raw: shader })
match shader {
crate::ShaderInput::NagaShader(raw) => Ok(super::ShaderModule { raw }),
crate::ShaderInput::SpirVShader(_) => Err(crate::ShaderError::Compilation(
"SPIR-V shaders are not supported for Metal".to_string(),
)),
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {}

View File

@ -5,7 +5,7 @@ use ash::{extensions::khr, version::DeviceV1_0, vk};
use inplace_it::inplace_or_alloc_from_iter;
use parking_lot::Mutex;
use std::{cmp, collections::hash_map::Entry, ffi::CString, ptr, sync::Arc};
use std::{borrow::Cow, cmp, collections::hash_map::Entry, ffi::CString, ptr, sync::Arc};
impl super::DeviceShared {
pub(super) unsafe fn set_object_name(
@ -1059,10 +1059,19 @@ impl crate::Device<super::Api> for super::Device {
unsafe fn create_shader_module(
&self,
desc: &crate::ShaderModuleDescriptor,
shader: crate::NagaShader,
shader: crate::ShaderInput,
) -> Result<super::ShaderModule, crate::ShaderError> {
let spv = naga::back::spv::write_vec(&shader.module, &shader.info, &self.naga_options)
.map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?;
let spv = match shader {
crate::ShaderInput::NagaShader(naga_shader) => Cow::Owned(
naga::back::spv::write_vec(
&naga_shader.module,
&naga_shader.info,
&self.naga_options,
)
.map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?,
),
crate::ShaderInput::SpirVShader(spv) => spv,
};
let vk_info = vk::ShaderModuleCreateInfo::builder()
.flags(vk::ShaderModuleCreateFlags::empty())

View File

@ -594,6 +594,14 @@ bitflags::bitflags! {
///
/// This is a native only feature.
const CLEAR_COMMANDS = 0x0000_0001_0000_0000;
/// Enables creating shader modules from SPIR-V binary data (unsafe).
///
/// Supported platforms:
/// - Vulkan, in case shader's requested capabilities and extensions agree with
/// Vulkan implementation.
///
/// This is a native only feature.
const SPIR_V_SHADER_MODULES = 0x0000_0002_0000_0000;
/// Features which are part of the upstream WebGPU standard.
const ALL_WEBGPU = 0x0000_0000_0000_FFFF;

View File

@ -77,7 +77,7 @@ impl framework::Example for Example {
| wgpu::Features::PUSH_CONSTANTS
}
fn required_features() -> wgpu::Features {
wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY
wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY | wgpu::Features::SPIR_V_SHADER_MODULES
}
fn required_limits() -> wgpu::Limits {
wgpu::Limits {
@ -94,22 +94,22 @@ impl framework::Example for Example {
let mut uniform_workaround = false;
let vs_module = device.create_shader_module(&wgpu::include_spirv!("shader.vert.spv"));
let fs_source = match device.features() {
f if f.contains(wgpu::Features::UNSIZED_BINDING_ARRAY) => {
wgpu::include_spirv!("unsized-non-uniform.frag.spv")
}
//f if f.contains(wgpu::Features::UNSIZED_BINDING_ARRAY) => {
// wgpu::include_spirv_raw!("unsized-non-uniform.frag.spv")
//}
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_ARRAY_NON_UNIFORM_INDEXING) => {
wgpu::include_spirv!("non-uniform.frag.spv")
wgpu::include_spirv_raw!("non-uniform.frag.spv")
}
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_ARRAY_DYNAMIC_INDEXING) => {
uniform_workaround = true;
wgpu::include_spirv!("uniform.frag.spv")
wgpu::include_spirv_raw!("uniform.frag.spv")
}
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY) => {
wgpu::include_spirv!("constant.frag.spv")
wgpu::include_spirv_raw!("constant.frag.spv")
}
_ => unreachable!(),
};
let fs_module = device.create_shader_module(&fs_source);
let fs_module = unsafe { device.create_shader_module_spirv(&fs_source) };
let vertex_size = std::mem::size_of::<Vertex>();
let vertex_data = create_vertices();

View File

@ -4,8 +4,8 @@ use crate::{
CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor,
DownlevelCapabilities, Features, Label, Limits, LoadOp, MapMode, Operations,
PipelineLayoutDescriptor, RenderBundleEncoderDescriptor, RenderPipelineDescriptor,
SamplerDescriptor, ShaderModuleDescriptor, ShaderSource, SwapChainStatus, TextureDescriptor,
TextureFormat, TextureViewDescriptor,
SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, ShaderSource,
SwapChainStatus, TextureDescriptor, TextureFormat, TextureViewDescriptor,
};
use arrayvec::ArrayVec;
@ -828,6 +828,30 @@ impl crate::Context for Context {
id
}
unsafe fn device_create_shader_module_spirv(
&self,
device: &Self::DeviceId,
desc: &ShaderModuleDescriptorSpirV,
) -> Self::ShaderModuleId {
let global = &self.0;
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
label: desc.label.map(Borrowed),
};
let (id, error) = wgc::gfx_select!(
device.id => global.device_create_shader_module_spirv(device.id, &descriptor, Borrowed(&desc.source), PhantomData)
);
if let Some(cause) = error {
self.handle_error(
&device.error_sink,
cause,
LABEL,
desc.label,
"Device::create_shader_module_spirv",
);
}
id
}
fn device_create_bind_group_layout(
&self,
device: &Self::DeviceId,

View File

@ -228,6 +228,11 @@ trait Context: Debug + Send + Sized + Sync {
device: &Self::DeviceId,
desc: &ShaderModuleDescriptor,
) -> Self::ShaderModuleId;
unsafe fn device_create_shader_module_spirv(
&self,
device: &Self::DeviceId,
desc: &ShaderModuleDescriptorSpirV,
) -> Self::ShaderModuleId;
fn device_create_bind_group_layout(
&self,
device: &Self::DeviceId,
@ -752,6 +757,14 @@ pub struct ShaderModuleDescriptor<'a> {
pub source: ShaderSource<'a>,
}
/// Descriptor for a shader module given by SPIR-V binary.
pub struct ShaderModuleDescriptorSpirV<'a> {
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: Label<'a>,
/// Binary SPIR-V data, in 4-byte words.
pub source: Cow<'a, [u32]>,
}
/// Handle to a pipeline layout.
///
/// A `PipelineLayout` object describes the available binding groups of a pipeline.
@ -1552,6 +1565,22 @@ impl Device {
}
}
/// Creates a shader module from SPIR-V binary directly.
///
/// # Safety
///
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
/// driver crash.
pub unsafe fn create_shader_module_spirv(
&self,
desc: &ShaderModuleDescriptorSpirV,
) -> ShaderModule {
ShaderModule {
context: Arc::clone(&self.context),
id: Context::device_create_shader_module_spirv(&*self.context, &self.id, desc),
}
}
/// Creates an empty [`CommandEncoder`].
pub fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> CommandEncoder {
CommandEncoder {

View File

@ -57,6 +57,22 @@ macro_rules! include_spirv {
};
}
/// Macro to load raw SPIR-V data statically, for use with [`wgpu::Features::SPIR_V_SHADER_MODULES`].
///
/// It ensures the word alignment as well as the magic number.
#[macro_export]
macro_rules! include_spirv_raw {
($($token:tt)*) => {
{
//log::info!("including '{}'", $($token)*);
$crate::ShaderModuleDescriptorSpirV {
label: Some($($token)*),
source: $crate::util::make_spirv_raw(include_bytes!($($token)*)),
}
}
};
}
/// Macro to load a WGSL module statically.
#[macro_export]
macro_rules! include_wgsl {

View File

@ -25,6 +25,12 @@ pub use encoder::RenderEncoder;
/// - Input is longer than [`usize::max_value`]
/// - SPIR-V magic number is missing from beginning of stream
pub fn make_spirv(data: &[u8]) -> super::ShaderSource {
super::ShaderSource::SpirV(make_spirv_raw(data))
}
/// Version of [`make_spirv`] intended for use with [`Device::create_shader_module_spirv`].
/// Returns raw slice instead of ShaderSource.
pub fn make_spirv_raw(data: &[u8]) -> Cow<[u32]> {
const MAGIC_NUMBER: u32 = 0x0723_0203;
assert_eq!(
@ -53,7 +59,8 @@ pub fn make_spirv(data: &[u8]) -> super::ShaderSource {
"wrong magic word {:x}. Make sure you are using a binary SPIRV file.",
words[0]
);
super::ShaderSource::SpirV(words)
words
}
/// CPU accessible buffer used to download data back from the GPU.