mirror of
https://github.com/gfx-rs/wgpu.git
synced 2024-11-25 16:24:24 +00:00
Add Naga bypass to allow feeding raw SPIR-V shader data to the backend.
While Naga checking is undoubtedly very useful, it currently lags behind what is possible in SPIR-V and even what is promised by WGPU (ie binding arrays). This adds an unsafe method to wgpu::Device to allow feeding raw SPIR-V data into the backend, and adds a feature flag to request a backend supporting this operation.
This commit is contained in:
parent
a8be371acf
commit
6d2e6e5a56
@ -902,7 +902,7 @@ impl<A: HalApi> Device<A> {
|
||||
let info = naga::valid::Validator::new(naga::valid::ValidationFlags::all(), caps)
|
||||
.validate(&module)?;
|
||||
let interface = validation::Interface::new(&module, &info);
|
||||
let hal_shader = hal::NagaShader { module, info };
|
||||
let hal_shader = hal::ShaderInput::NagaShader(hal::NagaShader { module, info });
|
||||
|
||||
let hal_desc = hal::ShaderModuleDescriptor {
|
||||
label: desc.label.borrow_option(),
|
||||
@ -928,7 +928,47 @@ impl<A: HalApi> Device<A> {
|
||||
value: id::Valid(self_id),
|
||||
ref_count: self.life_guard.add_ref(),
|
||||
},
|
||||
interface,
|
||||
interface: Some(interface),
|
||||
#[cfg(debug_assertions)]
|
||||
label: desc.label.borrow_or_default().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
#[allow(unused_unsafe)]
|
||||
unsafe fn create_shader_module_spirv<'a>(
|
||||
&self,
|
||||
self_id: id::DeviceId,
|
||||
desc: &pipeline::ShaderModuleDescriptor<'a>,
|
||||
source: Cow<'a, [u32]>,
|
||||
) -> Result<pipeline::ShaderModule<A>, pipeline::CreateShaderModuleError> {
|
||||
self.require_features(wgt::Features::SPIR_V_SHADER_MODULES)
|
||||
.map_err(pipeline::CreateShaderModuleError::MissingFeatures)?;
|
||||
let hal_desc = hal::ShaderModuleDescriptor {
|
||||
label: desc.label.borrow_option(),
|
||||
};
|
||||
let hal_shader = hal::ShaderInput::SpirVShader(source);
|
||||
let raw = match unsafe { self.raw.create_shader_module(&hal_desc, hal_shader) } {
|
||||
Ok(raw) => raw,
|
||||
Err(error) => {
|
||||
return Err(match error {
|
||||
hal::ShaderError::Device(error) => {
|
||||
pipeline::CreateShaderModuleError::Device(error.into())
|
||||
}
|
||||
hal::ShaderError::Compilation(ref msg) => {
|
||||
log::error!("Shader error: {}", msg);
|
||||
pipeline::CreateShaderModuleError::Generation
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(pipeline::ShaderModule {
|
||||
raw,
|
||||
device_id: Stored {
|
||||
value: id::Valid(self_id),
|
||||
ref_count: self.life_guard.add_ref(),
|
||||
},
|
||||
interface: None,
|
||||
#[cfg(debug_assertions)]
|
||||
label: desc.label.borrow_or_default().to_string(),
|
||||
})
|
||||
@ -1748,13 +1788,15 @@ impl<A: HalApi> Device<A> {
|
||||
None
|
||||
}
|
||||
};
|
||||
let _ = shader_module.interface.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&desc.stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)?;
|
||||
if let Some(ref interface) = shader_module.interface {
|
||||
let _ = interface.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&desc.stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
|
||||
let pipeline_layout_id = match desc.layout {
|
||||
@ -2033,20 +2075,21 @@ impl<A: HalApi> Device<A> {
|
||||
None => None,
|
||||
};
|
||||
|
||||
io = shader_module
|
||||
.interface
|
||||
.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)
|
||||
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
|
||||
stage: flag,
|
||||
error,
|
||||
})?;
|
||||
validated_stages |= flag;
|
||||
if let Some(ref interface) = shader_module.interface {
|
||||
io = interface
|
||||
.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)
|
||||
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
|
||||
stage: flag,
|
||||
error,
|
||||
})?;
|
||||
validated_stages |= flag;
|
||||
}
|
||||
|
||||
hal::ProgrammableStage {
|
||||
module: &shader_module.raw,
|
||||
@ -2077,20 +2120,21 @@ impl<A: HalApi> Device<A> {
|
||||
};
|
||||
|
||||
if validated_stages == wgt::ShaderStage::VERTEX {
|
||||
io = shader_module
|
||||
.interface
|
||||
.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&fragment.stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)
|
||||
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
|
||||
stage: flag,
|
||||
error,
|
||||
})?;
|
||||
validated_stages |= flag;
|
||||
if let Some(ref interface) = shader_module.interface {
|
||||
io = interface
|
||||
.check_stage(
|
||||
provided_layouts.as_ref().map(|p| p.as_slice()),
|
||||
&mut derived_group_layouts,
|
||||
&fragment.stage.entry_point,
|
||||
flag,
|
||||
io,
|
||||
)
|
||||
.map_err(|error| pipeline::CreateRenderPipelineError::Stage {
|
||||
stage: flag,
|
||||
error,
|
||||
})?;
|
||||
validated_stages |= flag;
|
||||
}
|
||||
}
|
||||
|
||||
Some(hal::ProgrammableStage {
|
||||
@ -3463,6 +3507,58 @@ impl<G: GlobalIdentityHandlerFactory> Global<G> {
|
||||
(id, Some(error))
|
||||
}
|
||||
|
||||
#[allow(unused_unsafe)] // Unsafe-ness of internal calls has little to do with unsafe-ness of this.
|
||||
/// # Safety
|
||||
///
|
||||
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
|
||||
/// driver crash.
|
||||
pub unsafe fn device_create_shader_module_spirv<A: HalApi>(
|
||||
&self,
|
||||
device_id: id::DeviceId,
|
||||
desc: &pipeline::ShaderModuleDescriptor,
|
||||
source: Cow<[u32]>,
|
||||
id_in: Input<G, id::ShaderModuleId>,
|
||||
) -> (
|
||||
id::ShaderModuleId,
|
||||
Option<pipeline::CreateShaderModuleError>,
|
||||
) {
|
||||
profiling::scope!("create_shader_module", "Device");
|
||||
|
||||
let hub = A::hub(self);
|
||||
let mut token = Token::root();
|
||||
let fid = hub.shader_modules.prepare(id_in);
|
||||
|
||||
let (device_guard, mut token) = hub.devices.read(&mut token);
|
||||
let error = loop {
|
||||
let device = match device_guard.get(device_id) {
|
||||
Ok(device) => device,
|
||||
Err(_) => break DeviceError::Invalid.into(),
|
||||
};
|
||||
#[cfg(feature = "trace")]
|
||||
if let Some(ref trace) = device.trace {
|
||||
let mut trace = trace.lock();
|
||||
let data = trace.make_binary("spv", unsafe {
|
||||
std::slice::from_raw_parts(source.as_ptr() as *const u8, source.len() * 4)
|
||||
});
|
||||
trace.add(trace::Action::CreateShaderModule {
|
||||
id: fid.id(),
|
||||
desc: desc.clone(),
|
||||
data,
|
||||
});
|
||||
};
|
||||
|
||||
let shader = match device.create_shader_module_spirv(device_id, desc, source) {
|
||||
Ok(shader) => shader,
|
||||
Err(e) => break e,
|
||||
};
|
||||
let id = fid.assign(shader, &mut token);
|
||||
return (id.0, None);
|
||||
};
|
||||
|
||||
let id = fid.assign_error(desc.label.borrow_or_default(), &mut token);
|
||||
(id, Some(error))
|
||||
}
|
||||
|
||||
pub fn shader_module_label<A: HalApi>(&self, id: id::ShaderModuleId) -> String {
|
||||
A::hub(self).shader_modules.label_for_resource(id)
|
||||
}
|
||||
|
@ -29,7 +29,7 @@ pub struct ShaderModuleDescriptor<'a> {
|
||||
pub struct ShaderModule<A: hal::Api> {
|
||||
pub(crate) raw: A::ShaderModule,
|
||||
pub(crate) device_id: Stored<DeviceId>,
|
||||
pub(crate) interface: validation::Interface,
|
||||
pub(crate) interface: Option<validation::Interface>,
|
||||
#[cfg(debug_assertions)]
|
||||
pub(crate) label: String,
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ impl<A: hal::Api> Example<A> {
|
||||
let shader_desc = hal::ShaderModuleDescriptor { label: None };
|
||||
let shader = unsafe {
|
||||
device
|
||||
.create_shader_module(&shader_desc, naga_shader)
|
||||
.create_shader_module(&shader_desc, hal::ShaderInput::NagaShader(naga_shader))
|
||||
.unwrap()
|
||||
};
|
||||
|
||||
|
@ -174,7 +174,7 @@ impl crate::Device<Api> for Context {
|
||||
unsafe fn create_shader_module(
|
||||
&self,
|
||||
desc: &crate::ShaderModuleDescriptor,
|
||||
shader: crate::NagaShader,
|
||||
shader: crate::ShaderInput,
|
||||
) -> Result<Resource, crate::ShaderError> {
|
||||
Ok(Resource)
|
||||
}
|
||||
|
@ -64,7 +64,7 @@ pub mod api {
|
||||
}
|
||||
|
||||
use std::{
|
||||
borrow::Borrow,
|
||||
borrow::{Borrow, Cow},
|
||||
fmt,
|
||||
num::NonZeroU8,
|
||||
ops::{Range, RangeInclusive},
|
||||
@ -257,7 +257,7 @@ pub trait Device<A: Api>: Send + Sync {
|
||||
unsafe fn create_shader_module(
|
||||
&self,
|
||||
desc: &ShaderModuleDescriptor,
|
||||
shader: NagaShader,
|
||||
shader: ShaderInput,
|
||||
) -> Result<A::ShaderModule, ShaderError>;
|
||||
unsafe fn destroy_shader_module(&self, module: A::ShaderModule);
|
||||
unsafe fn create_render_pipeline(
|
||||
@ -839,6 +839,12 @@ pub struct NagaShader {
|
||||
pub info: naga::valid::ModuleInfo,
|
||||
}
|
||||
|
||||
/// Shader input.
|
||||
pub enum ShaderInput<'a> {
|
||||
NagaShader(NagaShader),
|
||||
SpirVShader(Cow<'a, [u32]>),
|
||||
}
|
||||
|
||||
pub struct ShaderModuleDescriptor<'a> {
|
||||
pub label: Label<'a>,
|
||||
}
|
||||
|
@ -630,9 +630,14 @@ impl crate::Device<super::Api> for super::Device {
|
||||
unsafe fn create_shader_module(
|
||||
&self,
|
||||
_desc: &crate::ShaderModuleDescriptor,
|
||||
shader: crate::NagaShader,
|
||||
shader: crate::ShaderInput,
|
||||
) -> Result<super::ShaderModule, crate::ShaderError> {
|
||||
Ok(super::ShaderModule { raw: shader })
|
||||
match shader {
|
||||
crate::ShaderInput::NagaShader(raw) => Ok(super::ShaderModule { raw }),
|
||||
crate::ShaderInput::SpirVShader(_) => Err(crate::ShaderError::Compilation(
|
||||
"SPIR-V shaders are not supported for Metal".to_string(),
|
||||
)),
|
||||
}
|
||||
}
|
||||
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {}
|
||||
|
||||
|
@ -5,7 +5,7 @@ use ash::{extensions::khr, version::DeviceV1_0, vk};
|
||||
use inplace_it::inplace_or_alloc_from_iter;
|
||||
use parking_lot::Mutex;
|
||||
|
||||
use std::{cmp, collections::hash_map::Entry, ffi::CString, ptr, sync::Arc};
|
||||
use std::{borrow::Cow, cmp, collections::hash_map::Entry, ffi::CString, ptr, sync::Arc};
|
||||
|
||||
impl super::DeviceShared {
|
||||
pub(super) unsafe fn set_object_name(
|
||||
@ -1059,10 +1059,19 @@ impl crate::Device<super::Api> for super::Device {
|
||||
unsafe fn create_shader_module(
|
||||
&self,
|
||||
desc: &crate::ShaderModuleDescriptor,
|
||||
shader: crate::NagaShader,
|
||||
shader: crate::ShaderInput,
|
||||
) -> Result<super::ShaderModule, crate::ShaderError> {
|
||||
let spv = naga::back::spv::write_vec(&shader.module, &shader.info, &self.naga_options)
|
||||
.map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?;
|
||||
let spv = match shader {
|
||||
crate::ShaderInput::NagaShader(naga_shader) => Cow::Owned(
|
||||
naga::back::spv::write_vec(
|
||||
&naga_shader.module,
|
||||
&naga_shader.info,
|
||||
&self.naga_options,
|
||||
)
|
||||
.map_err(|e| crate::ShaderError::Compilation(format!("{}", e)))?,
|
||||
),
|
||||
crate::ShaderInput::SpirVShader(spv) => spv,
|
||||
};
|
||||
|
||||
let vk_info = vk::ShaderModuleCreateInfo::builder()
|
||||
.flags(vk::ShaderModuleCreateFlags::empty())
|
||||
|
@ -594,6 +594,14 @@ bitflags::bitflags! {
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const CLEAR_COMMANDS = 0x0000_0001_0000_0000;
|
||||
/// Enables creating shader modules from SPIR-V binary data (unsafe).
|
||||
///
|
||||
/// Supported platforms:
|
||||
/// - Vulkan, in case shader's requested capabilities and extensions agree with
|
||||
/// Vulkan implementation.
|
||||
///
|
||||
/// This is a native only feature.
|
||||
const SPIR_V_SHADER_MODULES = 0x0000_0002_0000_0000;
|
||||
|
||||
/// Features which are part of the upstream WebGPU standard.
|
||||
const ALL_WEBGPU = 0x0000_0000_0000_FFFF;
|
||||
|
@ -77,7 +77,7 @@ impl framework::Example for Example {
|
||||
| wgpu::Features::PUSH_CONSTANTS
|
||||
}
|
||||
fn required_features() -> wgpu::Features {
|
||||
wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY
|
||||
wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY | wgpu::Features::SPIR_V_SHADER_MODULES
|
||||
}
|
||||
fn required_limits() -> wgpu::Limits {
|
||||
wgpu::Limits {
|
||||
@ -94,22 +94,22 @@ impl framework::Example for Example {
|
||||
let mut uniform_workaround = false;
|
||||
let vs_module = device.create_shader_module(&wgpu::include_spirv!("shader.vert.spv"));
|
||||
let fs_source = match device.features() {
|
||||
f if f.contains(wgpu::Features::UNSIZED_BINDING_ARRAY) => {
|
||||
wgpu::include_spirv!("unsized-non-uniform.frag.spv")
|
||||
}
|
||||
//f if f.contains(wgpu::Features::UNSIZED_BINDING_ARRAY) => {
|
||||
// wgpu::include_spirv_raw!("unsized-non-uniform.frag.spv")
|
||||
//}
|
||||
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_ARRAY_NON_UNIFORM_INDEXING) => {
|
||||
wgpu::include_spirv!("non-uniform.frag.spv")
|
||||
wgpu::include_spirv_raw!("non-uniform.frag.spv")
|
||||
}
|
||||
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_ARRAY_DYNAMIC_INDEXING) => {
|
||||
uniform_workaround = true;
|
||||
wgpu::include_spirv!("uniform.frag.spv")
|
||||
wgpu::include_spirv_raw!("uniform.frag.spv")
|
||||
}
|
||||
f if f.contains(wgpu::Features::SAMPLED_TEXTURE_BINDING_ARRAY) => {
|
||||
wgpu::include_spirv!("constant.frag.spv")
|
||||
wgpu::include_spirv_raw!("constant.frag.spv")
|
||||
}
|
||||
_ => unreachable!(),
|
||||
};
|
||||
let fs_module = device.create_shader_module(&fs_source);
|
||||
let fs_module = unsafe { device.create_shader_module_spirv(&fs_source) };
|
||||
|
||||
let vertex_size = std::mem::size_of::<Vertex>();
|
||||
let vertex_data = create_vertices();
|
||||
|
@ -4,8 +4,8 @@ use crate::{
|
||||
CommandEncoderDescriptor, ComputePassDescriptor, ComputePipelineDescriptor,
|
||||
DownlevelCapabilities, Features, Label, Limits, LoadOp, MapMode, Operations,
|
||||
PipelineLayoutDescriptor, RenderBundleEncoderDescriptor, RenderPipelineDescriptor,
|
||||
SamplerDescriptor, ShaderModuleDescriptor, ShaderSource, SwapChainStatus, TextureDescriptor,
|
||||
TextureFormat, TextureViewDescriptor,
|
||||
SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, ShaderSource,
|
||||
SwapChainStatus, TextureDescriptor, TextureFormat, TextureViewDescriptor,
|
||||
};
|
||||
|
||||
use arrayvec::ArrayVec;
|
||||
@ -828,6 +828,30 @@ impl crate::Context for Context {
|
||||
id
|
||||
}
|
||||
|
||||
unsafe fn device_create_shader_module_spirv(
|
||||
&self,
|
||||
device: &Self::DeviceId,
|
||||
desc: &ShaderModuleDescriptorSpirV,
|
||||
) -> Self::ShaderModuleId {
|
||||
let global = &self.0;
|
||||
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
|
||||
label: desc.label.map(Borrowed),
|
||||
};
|
||||
let (id, error) = wgc::gfx_select!(
|
||||
device.id => global.device_create_shader_module_spirv(device.id, &descriptor, Borrowed(&desc.source), PhantomData)
|
||||
);
|
||||
if let Some(cause) = error {
|
||||
self.handle_error(
|
||||
&device.error_sink,
|
||||
cause,
|
||||
LABEL,
|
||||
desc.label,
|
||||
"Device::create_shader_module_spirv",
|
||||
);
|
||||
}
|
||||
id
|
||||
}
|
||||
|
||||
fn device_create_bind_group_layout(
|
||||
&self,
|
||||
device: &Self::DeviceId,
|
||||
|
@ -228,6 +228,11 @@ trait Context: Debug + Send + Sized + Sync {
|
||||
device: &Self::DeviceId,
|
||||
desc: &ShaderModuleDescriptor,
|
||||
) -> Self::ShaderModuleId;
|
||||
unsafe fn device_create_shader_module_spirv(
|
||||
&self,
|
||||
device: &Self::DeviceId,
|
||||
desc: &ShaderModuleDescriptorSpirV,
|
||||
) -> Self::ShaderModuleId;
|
||||
fn device_create_bind_group_layout(
|
||||
&self,
|
||||
device: &Self::DeviceId,
|
||||
@ -752,6 +757,14 @@ pub struct ShaderModuleDescriptor<'a> {
|
||||
pub source: ShaderSource<'a>,
|
||||
}
|
||||
|
||||
/// Descriptor for a shader module given by SPIR-V binary.
|
||||
pub struct ShaderModuleDescriptorSpirV<'a> {
|
||||
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
|
||||
pub label: Label<'a>,
|
||||
/// Binary SPIR-V data, in 4-byte words.
|
||||
pub source: Cow<'a, [u32]>,
|
||||
}
|
||||
|
||||
/// Handle to a pipeline layout.
|
||||
///
|
||||
/// A `PipelineLayout` object describes the available binding groups of a pipeline.
|
||||
@ -1552,6 +1565,22 @@ impl Device {
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates a shader module from SPIR-V binary directly.
|
||||
///
|
||||
/// # Safety
|
||||
///
|
||||
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
|
||||
/// driver crash.
|
||||
pub unsafe fn create_shader_module_spirv(
|
||||
&self,
|
||||
desc: &ShaderModuleDescriptorSpirV,
|
||||
) -> ShaderModule {
|
||||
ShaderModule {
|
||||
context: Arc::clone(&self.context),
|
||||
id: Context::device_create_shader_module_spirv(&*self.context, &self.id, desc),
|
||||
}
|
||||
}
|
||||
|
||||
/// Creates an empty [`CommandEncoder`].
|
||||
pub fn create_command_encoder(&self, desc: &CommandEncoderDescriptor) -> CommandEncoder {
|
||||
CommandEncoder {
|
||||
|
@ -57,6 +57,22 @@ macro_rules! include_spirv {
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro to load raw SPIR-V data statically, for use with [`wgpu::Features::SPIR_V_SHADER_MODULES`].
|
||||
///
|
||||
/// It ensures the word alignment as well as the magic number.
|
||||
#[macro_export]
|
||||
macro_rules! include_spirv_raw {
|
||||
($($token:tt)*) => {
|
||||
{
|
||||
//log::info!("including '{}'", $($token)*);
|
||||
$crate::ShaderModuleDescriptorSpirV {
|
||||
label: Some($($token)*),
|
||||
source: $crate::util::make_spirv_raw(include_bytes!($($token)*)),
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
/// Macro to load a WGSL module statically.
|
||||
#[macro_export]
|
||||
macro_rules! include_wgsl {
|
||||
|
@ -25,6 +25,12 @@ pub use encoder::RenderEncoder;
|
||||
/// - Input is longer than [`usize::max_value`]
|
||||
/// - SPIR-V magic number is missing from beginning of stream
|
||||
pub fn make_spirv(data: &[u8]) -> super::ShaderSource {
|
||||
super::ShaderSource::SpirV(make_spirv_raw(data))
|
||||
}
|
||||
|
||||
/// Version of [`make_spirv`] intended for use with [`Device::create_shader_module_spirv`].
|
||||
/// Returns raw slice instead of ShaderSource.
|
||||
pub fn make_spirv_raw(data: &[u8]) -> Cow<[u32]> {
|
||||
const MAGIC_NUMBER: u32 = 0x0723_0203;
|
||||
|
||||
assert_eq!(
|
||||
@ -53,7 +59,8 @@ pub fn make_spirv(data: &[u8]) -> super::ShaderSource {
|
||||
"wrong magic word {:x}. Make sure you are using a binary SPIRV file.",
|
||||
words[0]
|
||||
);
|
||||
super::ShaderSource::SpirV(words)
|
||||
|
||||
words
|
||||
}
|
||||
|
||||
/// CPU accessible buffer used to download data back from the GPU.
|
||||
|
Loading…
Reference in New Issue
Block a user