[metal] Metal compute shader passthrough (#7326)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
Sylvain Benner 2025-04-09 10:25:41 -04:00 committed by GitHub
parent 3ba97bc3b7
commit 14690470bb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
24 changed files with 369 additions and 123 deletions

View File

@ -181,6 +181,32 @@ layout(location = 0, index = 1) out vec4 output1;
By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
#### Unify interface for SpirV shader passthrough
Replace device `create_shader_module_spirv` function with a generic `create_shader_module_passthrough` function
taking a `ShaderModuleDescriptorPassthrough` enum as parameter.
Update your calls to `create_shader_module_spirv` and use `create_shader_module_passthrough` instead:
```diff
- device.create_shader_module_spirv(
- wgpu::ShaderModuleDescriptorSpirV {
- label: Some(&name),
- source: Cow::Borrowed(&source),
- }
- )
+ device.create_shader_module_passthrough(
+ wgpu::ShaderModuleDescriptorPassthrough::SpirV(
+ wgpu::ShaderModuleDescriptorSpirV {
+ label: Some(&name),
+ source: Cow::Borrowed(&source),
+ },
+ ),
+ )
```
By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
### New Features
- Added mesh shader support to `wgpu_hal`. By @SupaMaggie70Incorporated in [#7089](https://github.com/gfx-rs/wgpu/pull/7089)
@ -203,6 +229,9 @@ By @wumpf in [#7144](https://github.com/gfx-rs/wgpu/pull/7144)
- Support getting vertices of the hit triangle when raytracing. By @Vecvec in [#7183](https://github.com/gfx-rs/wgpu/pull/7183) .
- Add Metal compute shader passthrough. Use `create_shader_module_passthrough` on device. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
- new `Features::MSL_SHADER_PASSTHROUGH` run-time feature allows providing pass-through MSL Metal shaders. By @syl20bnr in [#7326](https://github.com/gfx-rs/wgpu/pull/7326).
#### Naga

View File

@ -417,6 +417,8 @@ pub enum GPUFeatureName {
VertexWritableStorage,
#[webidl(rename = "clear-texture")]
ClearTexture,
#[webidl(rename = "msl-shader-passthrough")]
MslShaderPassthrough,
#[webidl(rename = "spirv-shader-passthrough")]
SpirvShaderPassthrough,
#[webidl(rename = "multiview")]
@ -477,6 +479,7 @@ pub fn feature_names_to_features(names: Vec<GPUFeatureName>) -> wgpu_types::Feat
GPUFeatureName::ConservativeRasterization => Features::CONSERVATIVE_RASTERIZATION,
GPUFeatureName::VertexWritableStorage => Features::VERTEX_WRITABLE_STORAGE,
GPUFeatureName::ClearTexture => Features::CLEAR_TEXTURE,
GPUFeatureName::MslShaderPassthrough => Features::MSL_SHADER_PASSTHROUGH,
GPUFeatureName::SpirvShaderPassthrough => Features::SPIRV_SHADER_PASSTHROUGH,
GPUFeatureName::Multiview => Features::MULTIVIEW,
GPUFeatureName::VertexAttribute64Bit => Features::VERTEX_ATTRIBUTE_64BIT,

View File

@ -126,9 +126,9 @@ impl DeviceInterface for CustomDevice {
DispatchShaderModule::custom(CustomShaderModule(self.0.clone()))
}
unsafe fn create_shader_module_spirv(
unsafe fn create_shader_module_passthrough(
&self,
_desc: &wgpu::ShaderModuleDescriptorSpirV<'_>,
_desc: &wgpu::ShaderModuleDescriptorPassthrough<'_>,
) -> DispatchShaderModule {
unimplemented!()
}

View File

@ -237,7 +237,7 @@ bitflags::bitflags! {
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "serialize", derive(serde::Serialize))]
#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))]
pub struct ModuleInfo {

View File

@ -511,12 +511,14 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
fail(
&ctx.device,
|| unsafe {
let _ = ctx
.device
.create_shader_module_spirv(&wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: std::borrow::Cow::Borrowed(&[]),
});
let _ = ctx.device.create_shader_module_passthrough(
wgpu::ShaderModuleDescriptorPassthrough::SpirV(
wgpu::ShaderModuleDescriptorSpirV {
label: None,
source: std::borrow::Cow::Borrowed(&[]),
},
),
);
},
Some("device with '' label is invalid"),
);

View File

@ -938,23 +938,21 @@ impl Global {
(id, Some(error))
}
// Unsafe-ness of internal calls has little to do with unsafe-ness of this.
#[allow(unused_unsafe)]
/// # Safety
///
/// This function passes SPIR-V binary to the backend as-is and can potentially result in a
/// This function passes source code or binary to the backend as-is and can potentially result in a
/// driver crash.
pub unsafe fn device_create_shader_module_spirv(
pub unsafe fn device_create_shader_module_passthrough(
&self,
device_id: DeviceId,
desc: &pipeline::ShaderModuleDescriptor,
source: Cow<[u32]>,
desc: &pipeline::ShaderModuleDescriptorPassthrough<'_>,
id_in: Option<id::ShaderModuleId>,
) -> (
id::ShaderModuleId,
Option<pipeline::CreateShaderModuleError>,
) {
profiling::scope!("Device::create_shader_module");
profiling::scope!("Device::create_shader_module_passthrough");
let hub = &self.hub;
let fid = hub.shader_modules.prepare(id_in);
@ -964,15 +962,30 @@ impl Global {
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
let data = trace.make_binary("spv", bytemuck::cast_slice(&source));
let data = trace.make_binary(desc.trace_binary_ext(), desc.trace_data());
trace.add(trace::Action::CreateShaderModule {
id: fid.id(),
desc: desc.clone(),
desc: match desc {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
pipeline::ShaderModuleDescriptor {
label: inner.label.clone(),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
}
}
},
data,
});
};
let shader = match unsafe { device.create_shader_module_spirv(desc, &source) } {
let result = unsafe { device.create_shader_module_passthrough(desc) };
let shader = match result {
Ok(shader) => shader,
Err(e) => break 'error e,
};
@ -981,7 +994,7 @@ impl Global {
return (id, None);
};
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label.to_string())));
let id = fid.assign(Fallible::Invalid(Arc::new(desc.label().to_string())));
(id, Some(error))
}

View File

@ -1750,19 +1750,31 @@ impl Device {
}
#[allow(unused_unsafe)]
pub(crate) unsafe fn create_shader_module_spirv<'a>(
pub(crate) unsafe fn create_shader_module_passthrough<'a>(
self: &Arc<Self>,
desc: &pipeline::ShaderModuleDescriptor<'a>,
source: &'a [u32],
descriptor: &pipeline::ShaderModuleDescriptorPassthrough<'a>,
) -> Result<Arc<pipeline::ShaderModule>, pipeline::CreateShaderModuleError> {
self.check_is_valid()?;
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
let hal_desc = hal::ShaderModuleDescriptor {
label: desc.label.to_hal(self.instance_flags),
runtime_checks: desc.runtime_checks,
let hal_shader = match descriptor {
pipeline::ShaderModuleDescriptorPassthrough::SpirV(inner) => {
self.require_features(wgt::Features::SPIRV_SHADER_PASSTHROUGH)?;
hal::ShaderInput::SpirV(&inner.source)
}
pipeline::ShaderModuleDescriptorPassthrough::Msl(inner) => {
self.require_features(wgt::Features::MSL_SHADER_PASSTHROUGH)?;
hal::ShaderInput::Msl {
shader: inner.source.to_string(),
entry_point: inner.entry_point.to_string(),
num_workgroups: inner.num_workgroups,
}
}
};
let hal_shader = hal::ShaderInput::SpirV(source);
let hal_desc = hal::ShaderModuleDescriptor {
label: descriptor.label().to_hal(self.instance_flags),
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
let raw = match unsafe { self.raw().create_shader_module(&hal_desc, hal_shader) } {
Ok(raw) => raw,
Err(error) => {
@ -1782,12 +1794,10 @@ impl Device {
raw: ManuallyDrop::new(raw),
device: self.clone(),
interface: None,
label: desc.label.to_string(),
label: descriptor.label().to_string(),
};
let module = Arc::new(module);
Ok(module)
Ok(Arc::new(module))
}
pub(crate) fn create_command_encoder(

View File

@ -53,6 +53,9 @@ pub struct ShaderModuleDescriptor<'a> {
pub runtime_checks: wgt::ShaderRuntimeChecks,
}
pub type ShaderModuleDescriptorPassthrough<'a> =
wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>;
#[derive(Debug)]
pub struct ShaderModule {
pub(crate) raw: ManuallyDrop<Box<dyn hal::DynShaderModule>>,

View File

@ -1641,6 +1641,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
}
}
unsafe fn destroy_shader_module(&self, _module: super::ShaderModule) {

View File

@ -1329,6 +1329,9 @@ impl crate::Device for super::Device {
crate::ShaderInput::SpirV(_) => {
panic!("`Features::SPIRV_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Msl { .. } => {
panic!("`Features::MSL_SHADER_PASSTHROUGH` is not enabled")
}
crate::ShaderInput::Naga(naga) => naga,
},
label: desc.label.map(|str| str.to_string()),

View File

@ -2069,6 +2069,7 @@ pub struct CommandEncoderDescriptor<'a, Q: DynQueue + ?Sized> {
}
/// Naga shader module.
#[derive(Default)]
pub struct NagaShader {
/// Shader module IR.
pub module: Cow<'static, naga::Module>,
@ -2090,6 +2091,11 @@ impl fmt::Debug for NagaShader {
#[allow(clippy::large_enum_variant)]
pub enum ShaderInput<'a> {
Naga(NagaShader),
Msl {
shader: String,
entry_point: String,
num_workgroups: (u32, u32, u32),
},
SpirV(&'a [u32]),
}

View File

@ -904,6 +904,7 @@ impl super::PrivateCapabilities {
use wgt::Features as F;
let mut features = F::empty()
| F::MSL_SHADER_PASSTHROUGH
| F::MAPPABLE_PRIMARY_BUFFERS
| F::VERTEX_WRITABLE_STORAGE
| F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES

View File

@ -4,8 +4,9 @@ use std::{thread, time};
use parking_lot::Mutex;
use super::conv;
use super::{conv, PassthroughShader};
use crate::auxil::map_naga_stage;
use crate::metal::ShaderModuleSource;
use crate::TlasInstance;
use metal::foreign_types::ForeignType;
@ -122,11 +123,15 @@ impl super::Device {
primitive_class: metal::MTLPrimitiveTopologyClass,
naga_stage: naga::ShaderStage,
) -> Result<CompiledShader, crate::PipelineError> {
let naga_shader = if let ShaderModuleSource::Naga(naga) = &stage.module.source {
naga
} else {
panic!("load_shader required a naga shader");
};
let stage_bit = map_naga_stage(naga_stage);
let (module, module_info) = naga::back::pipeline_constants::process_overrides(
&stage.module.naga.module,
&stage.module.naga.info,
&naga_shader.module,
&naga_shader.info,
stage.constants,
)
.map_err(|e| crate::PipelineError::PipelineConstants(stage_bit, format!("MSL: {:?}", e)))?;
@ -989,9 +994,37 @@ impl crate::Device for super::Device {
match shader {
crate::ShaderInput::Naga(naga) => Ok(super::ShaderModule {
naga,
source: ShaderModuleSource::Naga(naga),
bounds_checks: desc.runtime_checks,
}),
crate::ShaderInput::Msl {
shader: source,
entry_point,
num_workgroups,
} => {
let options = metal::CompileOptions::new();
// Obtain the locked device from shared
let device = self.shared.device.lock();
let library = device
.new_library_with_source(&source, &options)
.map_err(|e| crate::ShaderError::Compilation(format!("MSL: {:?}", e)))?;
let function = library.get_function(&entry_point, None).map_err(|_| {
crate::ShaderError::Compilation(format!(
"Entry point '{}' not found",
entry_point
))
})?;
Ok(super::ShaderModule {
source: ShaderModuleSource::Passthrough(PassthroughShader {
library,
function,
entry_point,
num_workgroups,
}),
bounds_checks: desc.runtime_checks,
})
}
crate::ShaderInput::SpirV(_) => {
panic!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
}
@ -1299,13 +1332,30 @@ impl crate::Device for super::Device {
objc::rc::autoreleasepool(|| {
let descriptor = metal::ComputePipelineDescriptor::new();
let cs = self.load_shader(
&desc.stage,
&[],
desc.layout,
metal::MTLPrimitiveTopologyClass::Unspecified,
naga::ShaderStage::Compute,
)?;
let module = desc.stage.module;
let cs = if let ShaderModuleSource::Passthrough(desc) = &module.source {
CompiledShader {
library: desc.library.clone(),
function: desc.function.clone(),
wg_size: metal::MTLSize::new(
desc.num_workgroups.0 as u64,
desc.num_workgroups.1 as u64,
desc.num_workgroups.2 as u64,
),
wg_memory_sizes: vec![],
sized_bindings: vec![],
immutable_buffer_mask: 0,
}
} else {
self.load_shader(
&desc.stage,
&[],
desc.layout,
metal::MTLPrimitiveTopologyClass::Unspecified,
naga::ShaderStage::Compute,
)?
};
descriptor.set_compute_function(Some(&cs.function));
if self.shared.private_caps.supports_mutability {

View File

@ -781,9 +781,23 @@ impl crate::DynBindGroup for BindGroup {}
unsafe impl Send for BindGroup {}
unsafe impl Sync for BindGroup {}
#[derive(Debug)]
pub enum ShaderModuleSource {
Naga(crate::NagaShader),
Passthrough(PassthroughShader),
}
#[derive(Debug)]
pub struct PassthroughShader {
pub library: metal::Library,
pub function: metal::Function,
pub entry_point: String,
pub num_workgroups: (u32, u32, u32),
}
#[derive(Debug)]
pub struct ShaderModule {
naga: crate::NagaShader,
source: ShaderModuleSource,
bounds_checks: wgt::ShaderRuntimeChecks,
}

View File

@ -1862,6 +1862,9 @@ impl crate::Device for super::Device {
.map_err(|e| crate::ShaderError::Compilation(format!("{e}")))?,
)
}
crate::ShaderInput::Msl { .. } => {
panic!("MSL_SHADER_PASSTHROUGH is not enabled for this backend")
}
crate::ShaderInput::SpirV(spv) => Cow::Borrowed(spv),
};

View File

@ -920,6 +920,15 @@ bitflags_array! {
///
/// This is a native only feature.
const CLEAR_TEXTURE = 1 << 23;
/// Enables creating shader modules from Metal MSL computer shaders (unsafe).
///
/// Metal data is not parsed or interpreted in any way
///
/// Supported platforms:
/// - Metal
///
/// This is a native only feature.
const MSL_SHADER_PASSTHROUGH = 1 << 24;
/// Enables creating shader modules from SPIR-V binary data (unsafe).
///
/// SPIR-V data is not parsed or interpreted in any way; you can use
@ -933,7 +942,7 @@ bitflags_array! {
/// This is a native only feature.
///
/// [`wgpu::make_spirv_raw!`]: https://docs.rs/wgpu/latest/wgpu/macro.include_spirv_raw.html
const SPIRV_SHADER_PASSTHROUGH = 1 << 24;
const SPIRV_SHADER_PASSTHROUGH = 1 << 25;
/// Enables multiview render passes and `builtin(view_index)` in vertex shaders.
///
/// Supported platforms:
@ -941,7 +950,7 @@ bitflags_array! {
/// - OpenGL (web only)
///
/// This is a native only feature.
const MULTIVIEW = 1 << 25;
const MULTIVIEW = 1 << 26;
/// Enables using 64-bit types for vertex attributes.
///
/// Requires SHADER_FLOAT64.
@ -949,7 +958,7 @@ bitflags_array! {
/// Supported Platforms: N/A
///
/// This is a native only feature.
const VERTEX_ATTRIBUTE_64BIT = 1 << 26;
const VERTEX_ATTRIBUTE_64BIT = 1 << 27;
/// Enables image atomic fetch add, and, xor, or, min, and max for R32Uint and R32Sint textures.
///
/// Supported platforms:
@ -958,7 +967,7 @@ bitflags_array! {
/// - Metal (with MSL 3.1+)
///
/// This is a native only feature.
const TEXTURE_ATOMIC = 1 << 27;
const TEXTURE_ATOMIC = 1 << 28;
/// Allows for creation of textures of format [`TextureFormat::NV12`]
///
/// Supported platforms:
@ -968,7 +977,7 @@ bitflags_array! {
/// This is a native only feature.
///
/// [`TextureFormat::NV12`]: super::TextureFormat::NV12
const TEXTURE_FORMAT_NV12 = 1 << 28;
const TEXTURE_FORMAT_NV12 = 1 << 29;
/// ***THIS IS EXPERIMENTAL:*** Features enabled by this may have
/// major bugs in them and are expected to be subject to breaking changes, suggestions
/// for the API exposed by this should be posted on [the ray-tracing issue](https://github.com/gfx-rs/wgpu/issues/1040)
@ -980,7 +989,7 @@ bitflags_array! {
/// - Vulkan
///
/// This is a native-only feature.
const EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE = 1 << 29;
const EXPERIMENTAL_RAY_TRACING_ACCELERATION_STRUCTURE = 1 << 30;
// Shader:
@ -994,7 +1003,7 @@ bitflags_array! {
/// - Vulkan
///
/// This is a native-only feature.
const EXPERIMENTAL_RAY_QUERY = 1 << 30;
const EXPERIMENTAL_RAY_QUERY = 1 << 31;
/// Enables 64-bit floating point types in SPIR-V shaders.
///
/// Note: even when supported by GPU hardware, 64-bit floating point operations are
@ -1004,14 +1013,14 @@ bitflags_array! {
/// - Vulkan
///
/// This is a native only feature.
const SHADER_F64 = 1 << 31;
const SHADER_F64 = 1 << 32;
/// Allows shaders to use i16. Not currently supported in `naga`, only available through `spirv-passthrough`.
///
/// Supported platforms:
/// - Vulkan
///
/// This is a native only feature.
const SHADER_I16 = 1 << 32;
const SHADER_I16 = 1 << 33;
/// Enables `builtin(primitive_index)` in fragment shaders.
///
/// Note: enables geometry processing for pipelines using the builtin.
@ -1025,14 +1034,14 @@ bitflags_array! {
/// - OpenGL (some)
///
/// This is a native only feature.
const SHADER_PRIMITIVE_INDEX = 1 << 33;
const SHADER_PRIMITIVE_INDEX = 1 << 34;
/// Allows shaders to use the `early_depth_test` attribute.
///
/// Supported platforms:
/// - GLES 3.1+
///
/// This is a native only feature.
const SHADER_EARLY_DEPTH_TEST = 1 << 34;
const SHADER_EARLY_DEPTH_TEST = 1 << 35;
/// Allows shaders to use i64 and u64.
///
/// Supported platforms:
@ -1041,7 +1050,7 @@ bitflags_array! {
/// - Metal (with MSL 2.3+)
///
/// This is a native only feature.
const SHADER_INT64 = 1 << 35;
const SHADER_INT64 = 1 << 36;
/// Allows compute and fragment shaders to use the subgroup operation built-ins
///
/// Supported Platforms:
@ -1050,14 +1059,14 @@ bitflags_array! {
/// - Metal
///
/// This is a native only feature.
const SUBGROUP = 1 << 36;
const SUBGROUP = 1 << 37;
/// Allows vertex shaders to use the subgroup operation built-ins
///
/// Supported Platforms:
/// - Vulkan
///
/// This is a native only feature.
const SUBGROUP_VERTEX = 1 << 37;
const SUBGROUP_VERTEX = 1 << 38;
/// Allows shaders to use the subgroup barrier
///
/// Supported Platforms:
@ -1065,7 +1074,7 @@ bitflags_array! {
/// - Metal
///
/// This is a native only feature.
const SUBGROUP_BARRIER = 1 << 38;
const SUBGROUP_BARRIER = 1 << 39;
/// Allows the use of pipeline cache objects
///
/// Supported platforms:
@ -1074,7 +1083,7 @@ bitflags_array! {
/// Unimplemented Platforms:
/// - DX12
/// - Metal
const PIPELINE_CACHE = 1 << 39;
const PIPELINE_CACHE = 1 << 40;
/// Allows shaders to use i64 and u64 atomic min and max.
///
/// Supported platforms:
@ -1083,7 +1092,7 @@ bitflags_array! {
/// - Metal (with MSL 2.4+)
///
/// This is a native only feature.
const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 40;
const SHADER_INT64_ATOMIC_MIN_MAX = 1 << 41;
/// Allows shaders to use all i64 and u64 atomic operations.
///
/// Supported platforms:
@ -1091,7 +1100,7 @@ bitflags_array! {
/// - DX12 (with SM 6.6+)
///
/// This is a native only feature.
const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 41;
const SHADER_INT64_ATOMIC_ALL_OPS = 1 << 42;
/// Allows using the [VK_GOOGLE_display_timing] Vulkan extension.
///
/// This is used for frame pacing to reduce latency, and is generally only available on Android.
@ -1107,7 +1116,7 @@ bitflags_array! {
///
/// [VK_GOOGLE_display_timing]: https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VK_GOOGLE_display_timing.html
/// [`Surface::as_hal()`]: https://docs.rs/wgpu/latest/wgpu/struct.Surface.html#method.as_hal
const VULKAN_GOOGLE_DISPLAY_TIMING = 1 << 42;
const VULKAN_GOOGLE_DISPLAY_TIMING = 1 << 43;
/// Allows using the [VK_KHR_external_memory_win32] Vulkan extension.
///
@ -1117,7 +1126,7 @@ bitflags_array! {
/// This is a native only feature.
///
/// [VK_KHR_external_memory_win32]: https://registry.khronos.org/vulkan/specs/latest/man/html/VK_KHR_external_memory_win32.html
const VULKAN_EXTERNAL_MEMORY_WIN32 = 1 << 43;
const VULKAN_EXTERNAL_MEMORY_WIN32 = 1 << 44;
/// Enables R64Uint image atomic min and max.
///
@ -1127,7 +1136,7 @@ bitflags_array! {
/// - Metal (with MSL 3.1+)
///
/// This is a native only feature.
const TEXTURE_INT64_ATOMIC = 1 << 44;
const TEXTURE_INT64_ATOMIC = 1 << 45;
/// Allows uniform buffers to be bound as binding arrays.
///
@ -1144,7 +1153,7 @@ bitflags_array! {
/// - Vulkan 1.2+ (or VK_EXT_descriptor_indexing)'s `shaderUniformBufferArrayNonUniformIndexing` feature)
///
/// This is a native only feature.
const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 45;
const UNIFORM_BUFFER_BINDING_ARRAYS = 1 << 46;
/// Enables mesh shaders and task shaders in mesh shader pipelines.
///
@ -1156,7 +1165,7 @@ bitflags_array! {
/// - Metal
///
/// This is a native only feature.
const EXPERIMENTAL_MESH_SHADER = 1 << 46;
const EXPERIMENTAL_MESH_SHADER = 1 << 47;
/// ***THIS IS EXPERIMENTAL:*** Features enabled by this may have
/// major bugs in them and are expected to be subject to breaking changes, suggestions
@ -1171,7 +1180,7 @@ bitflags_array! {
/// This is a native only feature
///
/// [`AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN`]: super::AccelerationStructureFlags::ALLOW_RAY_HIT_VERTEX_RETURN
const EXPERIMENTAL_RAY_HIT_VERTEX_RETURN = 1 << 47;
const EXPERIMENTAL_RAY_HIT_VERTEX_RETURN = 1 << 48;
/// Enables multiview in mesh shader pipelines
///
@ -1183,7 +1192,7 @@ bitflags_array! {
/// - Metal
///
/// This is a native only feature.
const EXPERIMENTAL_MESH_SHADER_MULTIVIEW = 1 << 48;
const EXPERIMENTAL_MESH_SHADER_MULTIVIEW = 1 << 49;
}
/// Features that are not guaranteed to be supported.

View File

@ -14,6 +14,7 @@ extern crate std;
extern crate alloc;
use alloc::borrow::Cow;
use alloc::{string::String, vec, vec::Vec};
use core::{
fmt,
@ -7654,3 +7655,98 @@ pub enum DeviceLostReason {
/// After `Device::destroy`
Destroyed = 1,
}
/// Descriptor for creating a shader module.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub enum CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Passthrough for SPIR-V binaries.
SpirV(ShaderModuleDescriptorSpirV<'a, L>),
/// Passthrough for MSL source code.
Msl(ShaderModuleDescriptorMsl<'a, L>),
}
impl<'a, L> CreateShaderModuleDescriptorPassthrough<'a, L> {
/// Takes a closure and maps the label of the shader module descriptor into another.
pub fn map_label<K>(
&self,
fun: impl FnOnce(&L) -> K,
) -> CreateShaderModuleDescriptorPassthrough<'_, K> {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::SpirV(
ShaderModuleDescriptorSpirV {
label: fun(&inner.label),
source: inner.source.clone(),
},
)
}
CreateShaderModuleDescriptorPassthrough::Msl(inner) => {
CreateShaderModuleDescriptorPassthrough::<'_, K>::Msl(ShaderModuleDescriptorMsl {
entry_point: inner.entry_point.clone(),
label: fun(&inner.label),
num_workgroups: inner.num_workgroups,
source: inner.source.clone(),
})
}
}
}
/// Returns the label of shader module passthrough descriptor.
pub fn label(&'a self) -> &'a L {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => &inner.label,
CreateShaderModuleDescriptorPassthrough::Msl(inner) => &inner.label,
}
}
#[cfg(feature = "trace")]
/// Returns the source data for tracing purpose.
pub fn trace_data(&self) -> &[u8] {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(inner) => {
bytemuck::cast_slice(&inner.source)
}
CreateShaderModuleDescriptorPassthrough::Msl(inner) => inner.source.as_bytes(),
}
}
#[cfg(feature = "trace")]
/// Returns the binary file extension for tracing purpose.
pub fn trace_binary_ext(&self) -> &'static str {
match self {
CreateShaderModuleDescriptorPassthrough::SpirV(..) => "spv",
CreateShaderModuleDescriptorPassthrough::Msl(..) => "msl",
}
}
}
/// Descriptor for a shader module given by Metal MSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorMsl<'a, L> {
/// Entrypoint.
pub entry_point: String,
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Number of workgroups in each dimension x, y and z.
pub num_workgroups: (u32, u32, u32),
/// Shader MSL source.
pub source: Cow<'a, str>,
}
/// Descriptor for a shader module given by SPIR-V binary.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug, Clone)]
pub struct ShaderModuleDescriptorSpirV<'a, L> {
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: L,
/// Binary SPIR-V data, in 4-byte words.
pub source: Cow<'a, [u32]>,
}

View File

@ -172,20 +172,18 @@ impl Device {
ShaderModule { inner: module }
}
/// Creates a shader module from SPIR-V binary directly.
/// Creates a shader module which will bypass wgpu's shader tooling and validation and be used directly by the backend.
///
/// # Safety
///
/// This function passes binary data to the backend as-is and can potentially result in a
/// driver crash or bogus behaviour. No attempt is made to ensure that data is valid SPIR-V.
///
/// See also [`include_spirv_raw!`] and [`util::make_spirv_raw`].
/// This function passes data to the backend as-is and can potentially result in a
/// driver crash or bogus behaviour. No attempt is made to ensure that data is valid.
#[must_use]
pub unsafe fn create_shader_module_spirv(
pub unsafe fn create_shader_module_passthrough(
&self,
desc: &ShaderModuleDescriptorSpirV<'_>,
desc: ShaderModuleDescriptorPassthrough<'_>,
) -> ShaderModule {
let module = unsafe { self.inner.create_shader_module_spirv(desc) };
let module = unsafe { self.inner.create_shader_module_passthrough(&desc) };
ShaderModule { inner: module }
}

View File

@ -1,5 +1,4 @@
use alloc::{
borrow::Cow,
string::{String, ToString as _},
vec,
vec::Vec,
@ -11,9 +10,9 @@ use crate::*;
/// Handle to a compiled shader module.
///
/// A `ShaderModule` represents a compiled shader module on the GPU. It can be created by passing
/// source code to [`Device::create_shader_module`] or valid SPIR-V binary to
/// [`Device::create_shader_module_spirv`]. Shader modules are used to define programmable stages
/// of a pipeline.
/// source code to [`Device::create_shader_module`]. MSL shader or SPIR-V binary can also be passed
/// directly using [`Device::create_shader_module_passthrough`]. Shader modules are used to define
/// programmable stages of a pipeline.
///
/// Corresponds to [WebGPU `GPUShaderModule`](https://gpuweb.github.io/gpuweb/#shader-module).
#[derive(Debug, Clone)]
@ -182,14 +181,14 @@ pub enum ShaderSource<'a> {
///
/// See also: [`util::make_spirv`], [`include_spirv`]
#[cfg(feature = "spirv")]
SpirV(Cow<'a, [u32]>),
SpirV(alloc::borrow::Cow<'a, [u32]>),
/// GLSL module as a string slice.
///
/// Note: GLSL is not yet fully supported and must be a specific ShaderStage.
#[cfg(feature = "glsl")]
Glsl {
/// The source code of the shader.
shader: Cow<'a, str>,
shader: alloc::borrow::Cow<'a, str>,
/// The shader stage that the shader targets. For example, `naga::ShaderStage::Vertex`
stage: naga::ShaderStage,
/// Key-value pairs to represent defines sent to the glsl preprocessor.
@ -199,10 +198,10 @@ pub enum ShaderSource<'a> {
},
/// WGSL module as a string slice.
#[cfg(feature = "wgsl")]
Wgsl(Cow<'a, str>),
Wgsl(alloc::borrow::Cow<'a, str>),
/// Naga module.
#[cfg(feature = "naga-ir")]
Naga(Cow<'static, naga::Module>),
Naga(alloc::borrow::Cow<'static, naga::Module>),
/// Dummy variant because `Naga` doesn't have a lifetime and without enough active features it
/// could be the last one active.
#[doc(hidden)]
@ -223,16 +222,22 @@ pub struct ShaderModuleDescriptor<'a> {
}
static_assertions::assert_impl_all!(ShaderModuleDescriptor<'_>: Send, Sync);
/// Descriptor for a shader module given by SPIR-V binary, for use with
/// [`Device::create_shader_module_spirv`].
/// Descriptor for a shader module that will bypass wgpu's shader tooling, for use with
/// [`Device::create_shader_module_passthrough`].
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
#[derive(Debug)]
pub struct ShaderModuleDescriptorSpirV<'a> {
/// Debug label of the shader module. This will show up in graphics debuggers for easy identification.
pub label: Label<'a>,
/// Binary SPIR-V data, in 4-byte words.
pub source: Cow<'a, [u32]>,
}
static_assertions::assert_impl_all!(ShaderModuleDescriptorSpirV<'_>: Send, Sync);
pub type ShaderModuleDescriptorPassthrough<'a> =
wgt::CreateShaderModuleDescriptorPassthrough<'a, Label<'a>>;
/// Descriptor for a shader module given by Metal MSL source.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorMsl<'a> = wgt::ShaderModuleDescriptorMsl<'a, Label<'a>>;
/// Descriptor for a shader module given by SPIR-V binary.
///
/// This type is unique to the Rust API of `wgpu`. In the WebGPU specification,
/// only WGSL source code strings are accepted.
pub type ShaderModuleDescriptorSpirV<'a> = wgt::ShaderModuleDescriptorSpirV<'a, Label<'a>>;

View File

@ -1843,11 +1843,11 @@ impl dispatch::DeviceInterface for WebDevice {
.into()
}
unsafe fn create_shader_module_spirv(
unsafe fn create_shader_module_passthrough(
&self,
_desc: &crate::ShaderModuleDescriptorSpirV<'_>,
_desc: &crate::ShaderModuleDescriptorPassthrough<'_>,
) -> dispatch::DispatchShaderModule {
unreachable!("SPIRV_SHADER_PASSTHROUGH is not enabled for this backend")
unreachable!("No XXX_SHADER_PASSTHROUGH feature enabled for this backend")
}
fn create_bind_group_layout(

View File

@ -1,5 +1,5 @@
use alloc::{
borrow::Cow::Borrowed,
borrow::Cow::{self, Borrowed},
boxed::Box,
format,
string::{String, ToString as _},
@ -1043,36 +1043,30 @@ impl dispatch::DeviceInterface for CoreDevice {
.into()
}
unsafe fn create_shader_module_spirv(
unsafe fn create_shader_module_passthrough(
&self,
desc: &crate::ShaderModuleDescriptorSpirV<'_>,
desc: &crate::ShaderModuleDescriptorPassthrough<'_>,
) -> dispatch::DispatchShaderModule {
let descriptor = wgc::pipeline::ShaderModuleDescriptor {
label: desc.label.map(Borrowed),
// Doesn't matter the value since spirv shaders aren't mutated to include
// runtime checks
runtime_checks: wgt::ShaderRuntimeChecks::unchecked(),
};
let desc = desc.map_label(|l| l.map(Cow::from));
let (id, error) = unsafe {
self.context.0.device_create_shader_module_spirv(
self.id,
&descriptor,
Borrowed(&desc.source),
None,
)
self.context
.0
.device_create_shader_module_passthrough(self.id, &desc, None)
};
let compilation_info = match error {
Some(cause) => {
self.context.handle_error(
&self.error_sink,
cause.clone(),
desc.label,
"Device::create_shader_module_spirv",
desc.label().as_deref(),
"Device::create_shader_module_passthrough",
);
CompilationInfo::from(cause)
}
None => CompilationInfo { messages: vec![] },
};
CoreShaderModule {
context: self.context.clone(),
id,

View File

@ -111,10 +111,12 @@ pub trait DeviceInterface: CommonTraits {
desc: crate::ShaderModuleDescriptor<'_>,
shader_bound_checks: crate::ShaderRuntimeChecks,
) -> DispatchShaderModule;
unsafe fn create_shader_module_spirv(
unsafe fn create_shader_module_passthrough(
&self,
desc: &crate::ShaderModuleDescriptorSpirV<'_>,
desc: &crate::ShaderModuleDescriptorPassthrough<'_>,
) -> DispatchShaderModule;
fn create_bind_group_layout(
&self,
desc: &crate::BindGroupLayoutDescriptor<'_>,

View File

@ -106,9 +106,11 @@ macro_rules! include_spirv_raw {
($($token:tt)*) => {
{
//log::info!("including '{}'", $($token)*);
$crate::ShaderModuleDescriptorSpirV {
label: $crate::__macro_helpers::Some($($token)*),
source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)),
$crate::ShaderModuleDescriptorPassthrough::SpirV {
$crate::ShaderModuleDescriptorSpirV {
label: $crate::__macro_helpers::Some($($token)*),
source: $crate::util::make_spirv_raw($crate::__macro_helpers::include_bytes!($($token)*)),
}
}
}
};

View File

@ -39,10 +39,10 @@ pub fn make_spirv(data: &[u8]) -> super::ShaderSource<'_> {
super::ShaderSource::SpirV(make_spirv_raw(data))
}
/// Version of `make_spirv` intended for use with [`Device::create_shader_module_spirv`].
/// Version of `make_spirv` intended for use with [`Device::create_shader_module_passthrough`].
/// Returns a raw slice instead of [`ShaderSource`](super::ShaderSource).
///
/// [`Device::create_shader_module_spirv`]: crate::Device::create_shader_module_spirv
/// [`Device::create_shader_module_passthrough`]: crate::Device::create_shader_module_passthrough
pub fn make_spirv_raw(data: &[u8]) -> Cow<'_, [u32]> {
const MAGIC_NUMBER: u32 = 0x0723_0203;
assert_eq!(