From 14690470bba29a04fc6bffb0f9c0a45ca68f9ba2 Mon Sep 17 00:00:00 2001 From: Sylvain Benner Date: Wed, 9 Apr 2025 10:25:41 -0400 Subject: [PATCH] [metal] Metal compute shader passthrough (#7326) Co-authored-by: Connor Fitzgerald --- CHANGELOG.md | 29 ++++++ deno_webgpu/webidl.rs | 3 + .../standalone/custom_backend/src/custom.rs | 4 +- naga/src/valid/mod.rs | 2 +- tests/tests/wgpu-gpu/device.rs | 14 +-- wgpu-core/src/device/global.rs | 33 +++++-- wgpu-core/src/device/resource.rs | 36 ++++--- wgpu-core/src/pipeline.rs | 3 + wgpu-hal/src/dx12/device.rs | 3 + wgpu-hal/src/gles/device.rs | 3 + wgpu-hal/src/lib.rs | 6 ++ wgpu-hal/src/metal/adapter.rs | 1 + wgpu-hal/src/metal/device.rs | 74 +++++++++++--- wgpu-hal/src/metal/mod.rs | 16 +++- wgpu-hal/src/vulkan/device.rs | 3 + wgpu-types/src/features.rs | 59 +++++++----- wgpu-types/src/lib.rs | 96 +++++++++++++++++++ wgpu/src/api/device.rs | 14 ++- wgpu/src/api/shader_module.rs | 41 ++++---- wgpu/src/backend/webgpu.rs | 6 +- wgpu/src/backend/wgpu_core.rs | 28 +++--- wgpu/src/dispatch.rs | 6 +- wgpu/src/macros.rs | 8 +- wgpu/src/util/mod.rs | 4 +- 24 files changed, 369 insertions(+), 123 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cdb29a35b..1537ad433 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -181,6 +181,32 @@ layout(location = 0, index = 1) out vec4 output1; By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144) +#### Unify interface for SpirV shader passthrough + +Replace device `create_shader_module_spirv` function with a generic `create_shader_module_passthrough` function +taking a `ShaderModuleDescriptorPassthrough` enum as parameter. + +Update your calls to `create_shader_module_spirv` and use `create_shader_module_passthrough` instead: + +```diff +- device.create_shader_module_spirv( +- wgpu::ShaderModuleDescriptorSpirV { +- label: Some(&name), +- source: Cow::Borrowed(&source), +- } +- ) ++ device.create_shader_module_passthrough( ++ wgpu::ShaderModuleDescriptorPassthrough::SpirV( ++ wgpu::ShaderModuleDescriptorSpirV { ++ label: Some(&name), ++ source: Cow::Borrowed(&source), ++ }, ++ ), ++ ) +``` + +By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326). + ### New Features - Added mesh shader support to `wgpu_hal`. By @SupaMaggie70Incorporated in [#7089](https://github.com/gfx-rs/wgpu/pull/7089) @@ -203,6 +229,9 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144) - Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) . +- Add Metal compute shader passthrough. Use `create_shader_module_passthrough` on device. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326). + +- new `Features::MSL_SHADER_PASSTHROUGH` run-time feature allows providing pass-through MSL Metal shaders. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326). #### Naga diff --git a/deno_webgpu/webidl.rs b/deno_webgpu/webidl.rs index 60ab54706..97476a7ac 100644 --- a/deno_webgpu/webidl.rs +++ b/deno_webgpu/webidl.rs @@ -417,6 +417,8 @@ pub enum GPUFeatureName { VertexWritableStorage, #[webidl(rename = "clear-texture")] ClearTexture, + #[webidl(rename = "msl-shader-passthrough")] + MslShaderPassthrough, #[webidl(rename = "spirv-shader-passthrough")] SpirvShaderPassthrough, #[webidl(rename = "multiview")] @@ -477,6 +479,7 @@ pub fn feature_names_to_features(names: Vec) -> wgpu_types::Feat GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION, GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE, GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE, + GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH, GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH, GPUFeatureName::Multiview => Features::MULTIVIEW, GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT, diff --git a/examples/standalone/custom_backend/src/custom.rs b/examples/standalone/custom_backend/src/custom.rs index 1c04ab0f4..c3815fbba 100644 --- a/examples/standalone/custom_backend/src/custom.rs +++ b/examples/standalone/custom_backend/src/custom.rs @@ -126,9 +126,9 @@ impl DeviceInterface for CustomDevice { DispatchShaderModule::custom(CustomShaderModule(self.0.clone())) } - unsafe fn create_shader_module_spirv( + unsafe fn create_shader_module_passthrough( &self, - _desc: &wgpu::ShaderModuleDescriptorSpirV<'_>, + _desc: &wgpu::ShaderModuleDescriptorPassthrough<'_>, ) -> DispatchShaderModule { unimplemented!() } diff --git a/naga/src/valid/mod.rs b/naga/src/valid/mod.rs index 1580e7964..3d5039036 100644 --- a/naga/src/valid/mod.rs +++ b/naga/src/valid/mod.rs @@ -237,7 +237,7 @@ bitflags::bitflags! { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Default)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] #[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] pub struct ModuleInfo { diff --git a/tests/tests/wgpu-gpu/device.rs b/tests/tests/wgpu-gpu/device.rs index 6fc8e04f0..54e3adc8e 100644 --- a/tests/tests/wgpu-gpu/device.rs +++ b/tests/tests/wgpu-gpu/device.rs @@ -511,12 +511,14 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne fail( &ctx.device, || unsafe { - let _ = ctx - .device - .create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV { - label: None, - source: std::borrow::Cow::Borrowed(&[]), - }); + let _ = ctx.device.create_shader_module_passthrough( + wgpu::ShaderModuleDescriptorPassthrough::SpirV( + wgpu::ShaderModuleDescriptorSpirV { + label: None, + source: std::borrow::Cow::Borrowed(&[]), + }, + ), + ); }, Some("device with '' label is invalid"), ); diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 5d19bd362..e9db7d569 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -938,23 +938,21 @@ impl Global { (id, Some(error)) } - // Unsafe-ness of internal calls has little to do with unsafe-ness of this. #[allow(unused_unsafe)] /// # Safety /// - /// This function passes SPIR-V binary to the backend as-is and can potentially result in a + /// This function passes source code or binary to the backend as-is and can potentially result in a /// driver crash. - pub unsafe fn device_create_shader_module_spirv( + pub unsafe fn device_create_shader_module_passthrough( &self, device_id: DeviceId, - desc: &pipeline::ShaderModuleDescriptor, - source: Cow<[u32]>, + desc: &pipeline::ShaderModuleDescriptorPassthrough<'_>, id_in: Option, ) -> ( id::ShaderModuleId, Option, ) { - profiling::scope!("Device::create_shader_module"); + profiling::scope!("Device::create_shader_module_passthrough"); let hub = &self.hub; let fid = hub.shader_modules.prepare(id_in); @@ -964,15 +962,30 @@ impl Global { #[cfg(feature = "trace")] if let Some(ref mut trace) = *device.trace.lock() { - let data = trace.make_binary("spv", bytemuck::cast_slice(&source)); + let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data()); trace.add(trace::Action::CreateShaderModule { id: fid.id(), - desc: desc.clone(), + desc: match desc { + pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => { + pipeline::ShaderModuleDescriptor { + label: inner.label.clone(), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + } + } + pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => { + pipeline::ShaderModuleDescriptor { + label: inner.label.clone(), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + } + } + }, data, }); }; - let shader = match unsafe { device.create_shader_module_spirv(desc, &source) } { + let result = unsafe { device.create_shader_module_passthrough(desc) }; + + let shader = match result { Ok(shader) => shader, Err(e) => break 'error e, }; @@ -981,7 +994,7 @@ impl Global { return (id, None); }; - let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string()))); + let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string()))); (id, Some(error)) } diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index bdc4a2775..66898538a 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -1750,19 +1750,31 @@ impl Device { } #[allow(unused_unsafe)] - pub(crate) unsafe fn create_shader_module_spirv<'a>( + pub(crate) unsafe fn create_shader_module_passthrough<'a>( self: &Arc, - desc: &pipeline::ShaderModuleDescriptor<'a>, - source: &'a [u32], + descriptor: &pipeline::ShaderModuleDescriptorPassthrough<'a>, ) -> Result, pipeline::CreateShaderModuleError> { self.check_is_valid()?; - - self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?; - let hal_desc = hal::ShaderModuleDescriptor { - label: desc.label.to_hal(self.instance_flags), - runtime_checks: desc.runtime_checks, + let hal_shader = match descriptor { + pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => { + self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?; + hal::ShaderInput::SpirV(&inner.source) + } + pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => { + self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?; + hal::ShaderInput::Msl { + shader: inner.source.to_string(), + entry_point: inner.entry_point.to_string(), + num_workgroups: inner.num_workgroups, + } + } }; - let hal_shader = hal::ShaderInput::SpirV(source); + + let hal_desc = hal::ShaderModuleDescriptor { + label: descriptor.label().to_hal(self.instance_flags), + runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), + }; + let raw = match unsafe { self.raw().create_shader_module(&hal_desc, hal_shader) } { Ok(raw) => raw, Err(error) => { @@ -1782,12 +1794,10 @@ impl Device { raw: ManuallyDrop::new(raw), device: self.clone(), interface: None, - label: desc.label.to_string(), + label: descriptor.label().to_string(), }; - let module = Arc::new(module); - - Ok(module) + Ok(Arc::new(module)) } pub(crate) fn create_command_encoder( diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 44171c370..09b3ba462 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -53,6 +53,9 @@ pub struct ShaderModuleDescriptor<'a> { pub runtime_checks: wgt::ShaderRuntimeChecks, } +pub type ShaderModuleDescriptorPassthrough<'a> = + wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>; + #[derive(Debug)] pub struct ShaderModule { pub(crate) raw: ManuallyDrop>, diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index b9663f5ee..7de69e368 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1641,6 +1641,9 @@ impl crate::Device for super::Device { crate::ShaderInput::SpirV(_) => { panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") } + crate::ShaderInput::Msl { .. } => { + panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") + } } } unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) { diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index 45c7b1b92..cc23b981f 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1329,6 +1329,9 @@ impl crate::Device for super::Device { crate::ShaderInput::SpirV(_) => { panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled") } + crate::ShaderInput::Msl { .. } => { + panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled") + } crate::ShaderInput::Naga(naga) => naga, }, label: desc.label.map(|str| str.to_string()), diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 2a37c2507..444d89680 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -2069,6 +2069,7 @@ pub struct CommandEncoderDescriptor<'a, Q: DynQueue + ?Sized> { } /// Naga shader module. +#[derive(Default)] pub struct NagaShader { /// Shader module IR. pub module: Cow<'static, naga::Module>, @@ -2090,6 +2091,11 @@ impl fmt::Debug for NagaShader { #[allow(clippy::large_enum_variant)] pub enum ShaderInput<'a> { Naga(NagaShader), + Msl { + shader: String, + entry_point: String, + num_workgroups: (u32, u32, u32), + }, SpirV(&'a [u32]), } diff --git a/wgpu-hal/src/metal/adapter.rs b/wgpu-hal/src/metal/adapter.rs index f9e02cbde..19719b7a9 100644 --- a/wgpu-hal/src/metal/adapter.rs +++ b/wgpu-hal/src/metal/adapter.rs @@ -904,6 +904,7 @@ impl super::PrivateCapabilities { use wgt::Features as F; let mut features = F::empty() + | F::MSL_SHADER_PASSTHROUGH | F::MAPPABLE_PRIMARY_BUFFERS | F::VERTEX_WRITABLE_STORAGE | F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 3d4e42105..8bc359560 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -4,8 +4,9 @@ use std::{thread, time}; use parking_lot::Mutex; -use super::conv; +use super::{conv, PassthroughShader}; use crate::auxil::map_naga_stage; +use crate::metal::ShaderModuleSource; use crate::TlasInstance; use metal::foreign_types::ForeignType; @@ -122,11 +123,15 @@ impl super::Device { primitive_class: metal::MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, ) -> Result { + let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source { + naga + } else { + panic!("load_shader required a naga shader"); + }; let stage_bit = map_naga_stage(naga_stage); - let (module, module_info) = naga::back::pipeline_constants::process_overrides( - &stage.module.naga.module, - &stage.module.naga.info, + &naga_shader.module, + &naga_shader.info, stage.constants, ) .map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {:?}", e)))?; @@ -989,9 +994,37 @@ impl crate::Device for super::Device { match shader { crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule { - naga, + source: ShaderModuleSource::Naga(naga), bounds_checks: desc.runtime_checks, }), + crate::ShaderInput::Msl { + shader: source, + entry_point, + num_workgroups, + } => { + let options = metal::CompileOptions::new(); + // Obtain the locked device from shared + let device = self.shared.device.lock(); + let library = device + .new_library_with_source(&source, &options) + .map_err(|e| crate::ShaderError::Compilation(format!("MSL: {:?}", e)))?; + let function = library.get_function(&entry_point, None).map_err(|_| { + crate::ShaderError::Compilation(format!( + "Entry point '{}' not found", + entry_point + )) + })?; + + Ok(super::ShaderModule { + source: ShaderModuleSource::Passthrough(PassthroughShader { + library, + function, + entry_point, + num_workgroups, + }), + bounds_checks: desc.runtime_checks, + }) + } crate::ShaderInput::SpirV(_) => { panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") } @@ -1299,13 +1332,30 @@ impl crate::Device for super::Device { objc::rc::autoreleasepool(|| { let descriptor = metal::ComputePipelineDescriptor::new(); - let cs = self.load_shader( - &desc.stage, - &[], - desc.layout, - metal::MTLPrimitiveTopologyClass::Unspecified, - naga::ShaderStage::Compute, - )?; + let module = desc.stage.module; + let cs = if let ShaderModuleSource::Passthrough(desc) = &module.source { + CompiledShader { + library: desc.library.clone(), + function: desc.function.clone(), + wg_size: metal::MTLSize::new( + desc.num_workgroups.0 as u64, + desc.num_workgroups.1 as u64, + desc.num_workgroups.2 as u64, + ), + wg_memory_sizes: vec![], + sized_bindings: vec![], + immutable_buffer_mask: 0, + } + } else { + self.load_shader( + &desc.stage, + &[], + desc.layout, + metal::MTLPrimitiveTopologyClass::Unspecified, + naga::ShaderStage::Compute, + )? + }; + descriptor.set_compute_function(Some(&cs.function)); if self.shared.private_caps.supports_mutability { diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index bcec1c95e..08581736e 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -781,9 +781,23 @@ impl crate::DynBindGroup for BindGroup {} unsafe impl Send for BindGroup {} unsafe impl Sync for BindGroup {} +#[derive(Debug)] +pub enum ShaderModuleSource { + Naga(crate::NagaShader), + Passthrough(PassthroughShader), +} + +#[derive(Debug)] +pub struct PassthroughShader { + pub library: metal::Library, + pub function: metal::Function, + pub entry_point: String, + pub num_workgroups: (u32, u32, u32), +} + #[derive(Debug)] pub struct ShaderModule { - naga: crate::NagaShader, + source: ShaderModuleSource, bounds_checks: wgt::ShaderRuntimeChecks, } diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index b2734b763..9cd0fa36c 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1862,6 +1862,9 @@ impl crate::Device for super::Device { .map_err(|e| crate::ShaderError::Compilation(format!("{e}")))?, ) } + crate::ShaderInput::Msl { .. } => { + panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend") + } crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv), }; diff --git a/wgpu-types/src/features.rs b/wgpu-types/src/features.rs index b375b3c26..a3bd0498e 100644 --- a/wgpu-types/src/features.rs +++ b/wgpu-types/src/features.rs @@ -920,6 +920,15 @@ bitflags_array! { /// /// This is a native only feature. const CLEAR_TEXTURE = 1 << 23; + /// Enables creating shader modules from Metal MSL computer shaders (unsafe). + /// + /// Metal data is not parsed or interpreted in any way + /// + /// Supported platforms: + /// - Metal + /// + /// This is a native only feature. + const MSL_SHADER_PASSTHROUGH = 1 << 24; /// Enables creating shader modules from SPIR-V binary data (unsafe). /// /// SPIR-V data is not parsed or interpreted in any way; you can use @@ -933,7 +942,7 @@ bitflags_array! { /// This is a native only feature. /// /// [`wgpu::make_spirv_raw!`]: https://docs.rs/wgpu/latest/wgpu/macro.include_spirv_raw.html - const SPIRV_SHADER_PASSTHROUGH = 1 << 24; + const SPIRV_SHADER_PASSTHROUGH = 1 << 25; /// Enables multiview render passes and `builtin(view_index)` in vertex shaders. /// /// Supported platforms: @@ -941,7 +950,7 @@ bitflags_array! { /// - OpenGL (web only) /// /// This is a native only feature. - const MULTIVIEW = 1 << 25; + const MULTIVIEW = 1 << 26; /// Enables using 64-bit types for vertex attributes. /// /// Requires SHADER_FLOAT64. @@ -949,7 +958,7 @@ bitflags_array! { /// Supported Platforms: N/A /// /// This is a native only feature. - const VERTEX_ATTRIBUTE_64BIT = 1 << 26; + const VERTEX_ATTRIBUTE_64BIT = 1 << 27; /// Enables image atomic fetch add, and, xor, or, min, and max for R32Uint and R32Sint textures. /// /// Supported platforms: @@ -958,7 +967,7 @@ bitflags_array! { /// - Metal (with MSL 3.1+) /// /// This is a native only feature. - const TEXTURE_ATOMIC = 1 << 27; + const TEXTURE_ATOMIC = 1 << 28; /// Allows for creation of textures of format [`TextureFormat::NV12`] /// /// Supported platforms: @@ -968,7 +977,7 @@ bitflags_array! { /// This is a native only feature. /// /// [`TextureFormat::NV12`]: super::TextureFormat::NV12 - const TEXTURE_FORMAT_NV12 = 1 << 28; + const TEXTURE_FORMAT_NV12 = 1 << 29; /// ***THIS IS EXPERIMENTAL:*** Features enabled by this may have /// major bugs in them and are expected to be subject to breaking changes, suggestions /// for the API exposed by this should be posted on [the ray-tracing issue](https://github.com/gfx-rs/wgpu/issues/1040) @@ -980,7 +989,7 @@ bitflags_array! { /// - Vulkan /// /// This is a native-only feature. - const EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE = 1 << 29; + const EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE = 1 << 30; // Shader: @@ -994,7 +1003,7 @@ bitflags_array! { /// - Vulkan /// /// This is a native-only feature. - const EXPERIMENTAL_RAY_QUERY = 1 << 30; + const EXPERIMENTAL_RAY_QUERY = 1 << 31; /// Enables 64-bit floating point types in SPIR-V shaders. /// /// Note: even when supported by GPU hardware, 64-bit floating point operations are @@ -1004,14 +1013,14 @@ bitflags_array! { /// - Vulkan /// /// This is a native only feature. - const SHADER_F64 = 1 << 31; + const SHADER_F64 = 1 << 32; /// Allows shaders to use i16. Not currently supported in `naga`, only available through `spirv-passthrough`. /// /// Supported platforms: /// - Vulkan /// /// This is a native only feature. - const SHADER_I16 = 1 << 32; + const SHADER_I16 = 1 << 33; /// Enables `builtin(primitive_index)` in fragment shaders. /// /// Note: enables geometry processing for pipelines using the builtin. @@ -1025,14 +1034,14 @@ bitflags_array! { /// - OpenGL (some) /// /// This is a native only feature. - const SHADER_PRIMITIVE_INDEX = 1 << 33; + const SHADER_PRIMITIVE_INDEX = 1 << 34; /// Allows shaders to use the `early_depth_test` attribute. /// /// Supported platforms: /// - GLES 3.1+ /// /// This is a native only feature. - const SHADER_EARLY_DEPTH_TEST = 1 << 34; + const SHADER_EARLY_DEPTH_TEST = 1 << 35; /// Allows shaders to use i64 and u64. /// /// Supported platforms: @@ -1041,7 +1050,7 @@ bitflags_array! { /// - Metal (with MSL 2.3+) /// /// This is a native only feature. - const SHADER_INT64 = 1 << 35; + const SHADER_INT64 = 1 << 36; /// Allows compute and fragment shaders to use the subgroup operation built-ins /// /// Supported Platforms: @@ -1050,14 +1059,14 @@ bitflags_array! { /// - Metal /// /// This is a native only feature. - const SUBGROUP = 1 << 36; + const SUBGROUP = 1 << 37; /// Allows vertex shaders to use the subgroup operation built-ins /// /// Supported Platforms: /// - Vulkan /// /// This is a native only feature. - const SUBGROUP_VERTEX = 1 << 37; + const SUBGROUP_VERTEX = 1 << 38; /// Allows shaders to use the subgroup barrier /// /// Supported Platforms: @@ -1065,7 +1074,7 @@ bitflags_array! { /// - Metal /// /// This is a native only feature. - const SUBGROUP_BARRIER = 1 << 38; + const SUBGROUP_BARRIER = 1 << 39; /// Allows the use of pipeline cache objects /// /// Supported platforms: @@ -1074,7 +1083,7 @@ bitflags_array! { /// Unimplemented Platforms: /// - DX12 /// - Metal - const PIPELINE_CACHE = 1 << 39; + const PIPELINE_CACHE = 1 << 40; /// Allows shaders to use i64 and u64 atomic min and max. /// /// Supported platforms: @@ -1083,7 +1092,7 @@ bitflags_array! { /// - Metal (with MSL 2.4+) /// /// This is a native only feature. - const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 40; + const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 41; /// Allows shaders to use all i64 and u64 atomic operations. /// /// Supported platforms: @@ -1091,7 +1100,7 @@ bitflags_array! { /// - DX12 (with SM 6.6+) /// /// This is a native only feature. - const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 41; + const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 42; /// Allows using the [VK_GOOGLE_display_timing] Vulkan extension. /// /// This is used for frame pacing to reduce latency, and is generally only available on Android. @@ -1107,7 +1116,7 @@ bitflags_array! { /// /// [VK_GOOGLE_display_timing]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html /// [`Surface::as_hal()`]: https://docs.rs/wgpu/latest/wgpu/struct.Surface.html#method.as_hal - const VULKAN_GOOGLE_DISPLAY_TIMING = 1 << 42; + const VULKAN_GOOGLE_DISPLAY_TIMING = 1 << 43; /// Allows using the [VK_KHR_external_memory_win32] Vulkan extension. /// @@ -1117,7 +1126,7 @@ bitflags_array! { /// This is a native only feature. /// /// [VK_KHR_external_memory_win32]: https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_external_memory_win32.html - const VULKAN_EXTERNAL_MEMORY_WIN32 = 1 << 43; + const VULKAN_EXTERNAL_MEMORY_WIN32 = 1 << 44; /// Enables R64Uint image atomic min and max. /// @@ -1127,7 +1136,7 @@ bitflags_array! { /// - Metal (with MSL 3.1+) /// /// This is a native only feature. - const TEXTURE_INT64_ATOMIC = 1 << 44; + const TEXTURE_INT64_ATOMIC = 1 << 45; /// Allows uniform buffers to be bound as binding arrays. /// @@ -1144,7 +1153,7 @@ bitflags_array! { /// - Vulkan 1.2+ (or VK_EXT_descriptor_indexing)'s `shaderUniformBufferArrayNonUniformIndexing` feature) /// /// This is a native only feature. - const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 45; + const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 46; /// Enables mesh shaders and task shaders in mesh shader pipelines. /// @@ -1156,7 +1165,7 @@ bitflags_array! { /// - Metal /// /// This is a native only feature. - const EXPERIMENTAL_MESH_SHADER = 1 << 46; + const EXPERIMENTAL_MESH_SHADER = 1 << 47; /// ***THIS IS EXPERIMENTAL:*** Features enabled by this may have /// major bugs in them and are expected to be subject to breaking changes, suggestions @@ -1171,7 +1180,7 @@ bitflags_array! { /// This is a native only feature /// /// [`AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN`]: super::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN - const EXPERIMENTAL_RAY_HIT_VERTEX_RETURN = 1 << 47; + const EXPERIMENTAL_RAY_HIT_VERTEX_RETURN = 1 << 48; /// Enables multiview in mesh shader pipelines /// @@ -1183,7 +1192,7 @@ bitflags_array! { /// - Metal /// /// This is a native only feature. - const EXPERIMENTAL_MESH_SHADER_MULTIVIEW = 1 << 48; + const EXPERIMENTAL_MESH_SHADER_MULTIVIEW = 1 << 49; } /// Features that are not guaranteed to be supported. diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index f3d8b6a2f..34b7a26ca 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -14,6 +14,7 @@ extern crate std; extern crate alloc; +use alloc::borrow::Cow; use alloc::{string::String, vec, vec::Vec}; use core::{ fmt, @@ -7654,3 +7655,98 @@ pub enum DeviceLostReason { /// After `Device::destroy` Destroyed = 1, } + +/// Descriptor for creating a shader module. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +#[derive(Debug, Clone)] +pub enum CreateShaderModuleDescriptorPassthrough<'a, L> { + /// Passthrough for SPIR-V binaries. + SpirV(ShaderModuleDescriptorSpirV<'a, L>), + /// Passthrough for MSL source code. + Msl(ShaderModuleDescriptorMsl<'a, L>), +} + +impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> { + /// Takes a closure and maps the label of the shader module descriptor into another. + pub fn map_label( + &self, + fun: impl FnOnce(&L) -> K, + ) -> CreateShaderModuleDescriptorPassthrough<'_, K> { + match self { + CreateShaderModuleDescriptorPassthrough::SpirV(inner) => { + CreateShaderModuleDescriptorPassthrough::<'_, K>::SpirV( + ShaderModuleDescriptorSpirV { + label: fun(&inner.label), + source: inner.source.clone(), + }, + ) + } + CreateShaderModuleDescriptorPassthrough::Msl(inner) => { + CreateShaderModuleDescriptorPassthrough::<'_, K>::Msl(ShaderModuleDescriptorMsl { + entry_point: inner.entry_point.clone(), + label: fun(&inner.label), + num_workgroups: inner.num_workgroups, + source: inner.source.clone(), + }) + } + } + } + + /// Returns the label of shader module passthrough descriptor. + pub fn label(&'a self) -> &'a L { + match self { + CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label, + CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label, + } + } + + #[cfg(feature = "trace")] + /// Returns the source data for tracing purpose. + pub fn trace_data(&self) -> &[u8] { + match self { + CreateShaderModuleDescriptorPassthrough::SpirV(inner) => { + bytemuck::cast_slice(&inner.source) + } + CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(), + } + } + + #[cfg(feature = "trace")] + /// Returns the binary file extension for tracing purpose. + pub fn trace_binary_ext(&self) -> &'static str { + match self { + CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv", + CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl", + } + } +} + +/// Descriptor for a shader module given by Metal MSL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +#[derive(Debug, Clone)] +pub struct ShaderModuleDescriptorMsl<'a, L> { + /// Entrypoint. + pub entry_point: String, + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Number of workgroups in each dimension x, y and z. + pub num_workgroups: (u32, u32, u32), + /// Shader MSL source. + pub source: Cow<'a, str>, +} + +/// Descriptor for a shader module given by SPIR-V binary. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +#[derive(Debug, Clone)] +pub struct ShaderModuleDescriptorSpirV<'a, L> { + /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. + pub label: L, + /// Binary SPIR-V data, in 4-byte words. + pub source: Cow<'a, [u32]>, +} diff --git a/wgpu/src/api/device.rs b/wgpu/src/api/device.rs index 05ca2020e..44fb837a2 100644 --- a/wgpu/src/api/device.rs +++ b/wgpu/src/api/device.rs @@ -172,20 +172,18 @@ impl Device { ShaderModule { inner: module } } - /// Creates a shader module from SPIR-V binary directly. + /// Creates a shader module which will bypass wgpu's shader tooling and validation and be used directly by the backend. /// /// # Safety /// - /// This function passes binary data to the backend as-is and can potentially result in a - /// driver crash or bogus behaviour. No attempt is made to ensure that data is valid SPIR-V. - /// - /// See also [`include_spirv_raw!`] and [`util::make_spirv_raw`]. + /// This function passes data to the backend as-is and can potentially result in a + /// driver crash or bogus behaviour. No attempt is made to ensure that data is valid. #[must_use] - pub unsafe fn create_shader_module_spirv( + pub unsafe fn create_shader_module_passthrough( &self, - desc: &ShaderModuleDescriptorSpirV<'_>, + desc: ShaderModuleDescriptorPassthrough<'_>, ) -> ShaderModule { - let module = unsafe { self.inner.create_shader_module_spirv(desc) }; + let module = unsafe { self.inner.create_shader_module_passthrough(&desc) }; ShaderModule { inner: module } } diff --git a/wgpu/src/api/shader_module.rs b/wgpu/src/api/shader_module.rs index 241d19b09..940d63632 100644 --- a/wgpu/src/api/shader_module.rs +++ b/wgpu/src/api/shader_module.rs @@ -1,5 +1,4 @@ use alloc::{ - borrow::Cow, string::{String, ToString as _}, vec, vec::Vec, @@ -11,9 +10,9 @@ use crate::*; /// Handle to a compiled shader module. /// /// A `ShaderModule` represents a compiled shader module on the GPU. It can be created by passing -/// source code to [`Device::create_shader_module`] or valid SPIR-V binary to -/// [`Device::create_shader_module_spirv`]. Shader modules are used to define programmable stages -/// of a pipeline. +/// source code to [`Device::create_shader_module`]. MSL shader or SPIR-V binary can also be passed +/// directly using [`Device::create_shader_module_passthrough`]. Shader modules are used to define +/// programmable stages of a pipeline. /// /// Corresponds to [WebGPU `GPUShaderModule`](https://gpuweb.github.io/gpuweb/#shader-module). #[derive(Debug, Clone)] @@ -182,14 +181,14 @@ pub enum ShaderSource<'a> { /// /// See also: [`util::make_spirv`], [`include_spirv`] #[cfg(feature = "spirv")] - SpirV(Cow<'a, [u32]>), + SpirV(alloc::borrow::Cow<'a, [u32]>), /// GLSL module as a string slice. /// /// Note: GLSL is not yet fully supported and must be a specific ShaderStage. #[cfg(feature = "glsl")] Glsl { /// The source code of the shader. - shader: Cow<'a, str>, + shader: alloc::borrow::Cow<'a, str>, /// The shader stage that the shader targets. For example, `naga::ShaderStage::Vertex` stage: naga::ShaderStage, /// Key-value pairs to represent defines sent to the glsl preprocessor. @@ -199,10 +198,10 @@ pub enum ShaderSource<'a> { }, /// WGSL module as a string slice. #[cfg(feature = "wgsl")] - Wgsl(Cow<'a, str>), + Wgsl(alloc::borrow::Cow<'a, str>), /// Naga module. #[cfg(feature = "naga-ir")] - Naga(Cow<'static, naga::Module>), + Naga(alloc::borrow::Cow<'static, naga::Module>), /// Dummy variant because `Naga` doesn't have a lifetime and without enough active features it /// could be the last one active. #[doc(hidden)] @@ -223,16 +222,22 @@ pub struct ShaderModuleDescriptor<'a> { } static_assertions::assert_impl_all!(ShaderModuleDescriptor<'_>: Send, Sync); -/// Descriptor for a shader module given by SPIR-V binary, for use with -/// [`Device::create_shader_module_spirv`]. +/// Descriptor for a shader module that will bypass wgpu's shader tooling, for use with +/// [`Device::create_shader_module_passthrough`]. /// /// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, /// only WGSL source code strings are accepted. -#[derive(Debug)] -pub struct ShaderModuleDescriptorSpirV<'a> { - /// Debug label of the shader module. This will show up in graphics debuggers for easy identification. - pub label: Label<'a>, - /// Binary SPIR-V data, in 4-byte words. - pub source: Cow<'a, [u32]>, -} -static_assertions::assert_impl_all!(ShaderModuleDescriptorSpirV<'_>: Send, Sync); +pub type ShaderModuleDescriptorPassthrough<'a> = + wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>; + +/// Descriptor for a shader module given by Metal MSL source. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Label<'a>>; + +/// Descriptor for a shader module given by SPIR-V binary. +/// +/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification, +/// only WGSL source code strings are accepted. +pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>; diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index eb97554dd..86d1eaca5 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -1843,11 +1843,11 @@ impl dispatch::DeviceInterface for WebDevice { .into() } - unsafe fn create_shader_module_spirv( + unsafe fn create_shader_module_passthrough( &self, - _desc: &crate::ShaderModuleDescriptorSpirV<'_>, + _desc: &crate::ShaderModuleDescriptorPassthrough<'_>, ) -> dispatch::DispatchShaderModule { - unreachable!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend") + unreachable!("No XXX_SHADER_PASSTHROUGH feature enabled for this backend") } fn create_bind_group_layout( diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 98c1ba8d3..7a69a24ab 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1,5 +1,5 @@ use alloc::{ - borrow::Cow::Borrowed, + borrow::Cow::{self, Borrowed}, boxed::Box, format, string::{String, ToString as _}, @@ -1043,36 +1043,30 @@ impl dispatch::DeviceInterface for CoreDevice { .into() } - unsafe fn create_shader_module_spirv( + unsafe fn create_shader_module_passthrough( &self, - desc: &crate::ShaderModuleDescriptorSpirV<'_>, + desc: &crate::ShaderModuleDescriptorPassthrough<'_>, ) -> dispatch::DispatchShaderModule { - let descriptor = wgc::pipeline::ShaderModuleDescriptor { - label: desc.label.map(Borrowed), - // Doesn't matter the value since spirv shaders aren't mutated to include - // runtime checks - runtime_checks: wgt::ShaderRuntimeChecks::unchecked(), - }; + let desc = desc.map_label(|l| l.map(Cow::from)); let (id, error) = unsafe { - self.context.0.device_create_shader_module_spirv( - self.id, - &descriptor, - Borrowed(&desc.source), - None, - ) + self.context + .0 + .device_create_shader_module_passthrough(self.id, &desc, None) }; + let compilation_info = match error { Some(cause) => { self.context.handle_error( &self.error_sink, cause.clone(), - desc.label, - "Device::create_shader_module_spirv", + desc.label().as_deref(), + "Device::create_shader_module_passthrough", ); CompilationInfo::from(cause) } None => CompilationInfo { messages: vec![] }, }; + CoreShaderModule { context: self.context.clone(), id, diff --git a/wgpu/src/dispatch.rs b/wgpu/src/dispatch.rs index 0e4bbc70a..a088a57ca 100644 --- a/wgpu/src/dispatch.rs +++ b/wgpu/src/dispatch.rs @@ -111,10 +111,12 @@ pub trait DeviceInterface: CommonTraits { desc: crate::ShaderModuleDescriptor<'_>, shader_bound_checks: crate::ShaderRuntimeChecks, ) -> DispatchShaderModule; - unsafe fn create_shader_module_spirv( + + unsafe fn create_shader_module_passthrough( &self, - desc: &crate::ShaderModuleDescriptorSpirV<'_>, + desc: &crate::ShaderModuleDescriptorPassthrough<'_>, ) -> DispatchShaderModule; + fn create_bind_group_layout( &self, desc: &crate::BindGroupLayoutDescriptor<'_>, diff --git a/wgpu/src/macros.rs b/wgpu/src/macros.rs index 85371ec57..1afb099da 100644 --- a/wgpu/src/macros.rs +++ b/wgpu/src/macros.rs @@ -106,9 +106,11 @@ macro_rules! include_spirv_raw { ($($token:tt)*) => { { //log::info!("including '{}'", $($token)*); - $crate::ShaderModuleDescriptorSpirV { - label: $crate::__macro_helpers::Some($($token)*), - source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)), + $crate::ShaderModuleDescriptorPassthrough::SpirV { + $crate::ShaderModuleDescriptorSpirV { + label: $crate::__macro_helpers::Some($($token)*), + source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)), + } } } }; diff --git a/wgpu/src/util/mod.rs b/wgpu/src/util/mod.rs index 7f7908053..358563c70 100644 --- a/wgpu/src/util/mod.rs +++ b/wgpu/src/util/mod.rs @@ -39,10 +39,10 @@ pub fn make_spirv(data: &[u8]) -> super::ShaderSource<'_> { super::ShaderSource::SpirV(make_spirv_raw(data)) } -/// Version of `make_spirv` intended for use with [`Device::create_shader_module_spirv`]. +/// Version of `make_spirv` intended for use with [`Device::create_shader_module_passthrough`]. /// Returns a raw slice instead of [`ShaderSource`](super::ShaderSource). /// -/// [`Device::create_shader_module_spirv`]: crate::Device::create_shader_module_spirv +/// [`Device::create_shader_module_passthrough`]: crate::Device::create_shader_module_passthrough pub fn make_spirv_raw(data: &[u8]) -> Cow<'_, [u32]> { const MAGIC_NUMBER: u32 = 0x0723_0203; assert_eq!(