From 9b7a9656670d74d364c359b918b062f814cb5f01 Mon Sep 17 00:00:00 2001 From: Brad Werth Date: Tue, 6 Feb 2024 16:35:17 -0800 Subject: [PATCH] Add an experimental vertex pulling flag to Metal pipelines. This proves a flag in msl::PipelineOptions that attempts to write all Metal vertex entry points to use a vertex pulling technique. It does this by: 1) Forcing the _buffer_sizes structure to be generated for all vertex entry points. The structure has additional buffer_size members that contain the byte sizes of the vertex buffers. 2) Adding new args to vertex entry points for the vertex id and/or the instance id and for the bound buffers. If there is an existing @builtin(vertex_index) or @builtin(instance_index) param, then no duplicate arg is created. 3) Adding code at the beginning of the function for vertex entry points to compare the vertex id or instance id against the lengths of all the bound buffers, and force an early-exit if the bounds are violated. 4) Extracting the raw bytes from the vertex buffer(s) and unpacking those bytes into the bound attributes with the expected types. 5) Replacing the varyings input and instead using the unpacked attributes to fill any structs-as-args that are rebuilt in the entry point. A new naga test is added which exercises this flag and demonstrates the effect of the transform. The msl generated by this test passes validation. Eventually this transformation will be the default, always-on behavior for Metal pipelines, though the flag may remain so that naga translation tests can be run with and without the tranformation. --- deno_webgpu/pipeline.rs | 3 + naga/CHANGELOG.md | 1 + naga/src/back/msl/mod.rs | 118 ++ naga/src/back/msl/writer.rs | 1131 ++++++++++++++++- naga/tests/in/interface.param.ron | 2 + .../in/vertex-pulling-transform.param.ron | 31 + naga/tests/in/vertex-pulling-transform.wgsl | 32 + .../out/msl/vertex-pulling-transform.msl | 76 ++ naga/tests/snapshots.rs | 1 + player/tests/data/bind-group.ron | 1 + .../tests/data/pipeline-statistics-query.ron | 1 + player/tests/data/quad.ron | 2 + player/tests/data/zero-init-buffer.ron | 1 + .../tests/data/zero-init-texture-binding.ron | 1 + tests/tests/vertex_indices/mod.rs | 53 +- wgpu-core/src/device/resource.rs | 3 + wgpu-core/src/pipeline.rs | 2 + wgpu-hal/examples/halmark/main.rs | 2 + wgpu-hal/examples/ray-traced-triangle/main.rs | 1 + wgpu-hal/src/lib.rs | 3 + wgpu-hal/src/metal/command.rs | 32 + wgpu-hal/src/metal/device.rs | 87 +- wgpu-hal/src/metal/mod.rs | 18 + wgpu/src/backend/wgpu_core.rs | 6 + wgpu/src/lib.rs | 3 + 25 files changed, 1540 insertions(+), 71 deletions(-) create mode 100644 naga/tests/in/vertex-pulling-transform.param.ron create mode 100644 naga/tests/in/vertex-pulling-transform.wgsl create mode 100644 naga/tests/out/msl/vertex-pulling-transform.msl diff --git a/deno_webgpu/pipeline.rs b/deno_webgpu/pipeline.rs index c82b6a97c..992365245 100644 --- a/deno_webgpu/pipeline.rs +++ b/deno_webgpu/pipeline.rs @@ -114,6 +114,7 @@ pub fn op_webgpu_create_compute_pipeline( entry_point: compute.entry_point.map(Cow::from), constants: Cow::Owned(compute.constants.unwrap_or_default()), zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }, cache: None, }; @@ -363,6 +364,7 @@ pub fn op_webgpu_create_render_pipeline( constants: Cow::Owned(fragment.constants.unwrap_or_default()), // Required to be true for WebGPU zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }, targets: Cow::Owned(fragment.targets), }) @@ -388,6 +390,7 @@ pub fn op_webgpu_create_render_pipeline( constants: Cow::Owned(args.vertex.constants.unwrap_or_default()), // Required to be true for WebGPU zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }, buffers: Cow::Owned(vertex_buffers), }, diff --git a/naga/CHANGELOG.md b/naga/CHANGELOG.md index a92d0c4f9..d2e0515eb 100644 --- a/naga/CHANGELOG.md +++ b/naga/CHANGELOG.md @@ -79,6 +79,7 @@ For changelogs after v0.14, see [the wgpu changelog](../CHANGELOG.md). - Add and fix minimum Metal version checks for optional functionality. ([#2486](https://github.com/gfx-rs/naga/pull/2486)) **@teoxoy** - Make varyings' struct members unique. ([#2521](https://github.com/gfx-rs/naga/pull/2521)) **@evahop** +- Add experimental vertex pulling transform flag. ([#5254](https://github.com/gfx-rs/wgpu/pull/5254)) **@bradwerth** #### GLSL-OUT diff --git a/naga/src/back/msl/mod.rs b/naga/src/back/msl/mod.rs index d7a06d774..d80d012ad 100644 --- a/naga/src/back/msl/mod.rs +++ b/naga/src/back/msl/mod.rs @@ -222,6 +222,113 @@ impl Default for Options { } } +/// Corresponds to [WebGPU `GPUVertexFormat`]( +/// https://gpuweb.github.io/gpuweb/#enumdef-gpuvertexformat). +#[repr(u32)] +#[derive(Copy, Clone, Debug, Hash, Eq, PartialEq)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub enum VertexFormat { + /// Two unsigned bytes (u8). `vec2` in shaders. + Uint8x2 = 0, + /// Four unsigned bytes (u8). `vec4` in shaders. + Uint8x4 = 1, + /// Two signed bytes (i8). `vec2` in shaders. + Sint8x2 = 2, + /// Four signed bytes (i8). `vec4` in shaders. + Sint8x4 = 3, + /// Two unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec2` in shaders. + Unorm8x2 = 4, + /// Four unsigned bytes (u8). [0, 255] converted to float [0, 1] `vec4` in shaders. + Unorm8x4 = 5, + /// Two signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec2` in shaders. + Snorm8x2 = 6, + /// Four signed bytes (i8). [-127, 127] converted to float [-1, 1] `vec4` in shaders. + Snorm8x4 = 7, + /// Two unsigned shorts (u16). `vec2` in shaders. + Uint16x2 = 8, + /// Four unsigned shorts (u16). `vec4` in shaders. + Uint16x4 = 9, + /// Two signed shorts (i16). `vec2` in shaders. + Sint16x2 = 10, + /// Four signed shorts (i16). `vec4` in shaders. + Sint16x4 = 11, + /// Two unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec2` in shaders. + Unorm16x2 = 12, + /// Four unsigned shorts (u16). [0, 65535] converted to float [0, 1] `vec4` in shaders. + Unorm16x4 = 13, + /// Two signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec2` in shaders. + Snorm16x2 = 14, + /// Four signed shorts (i16). [-32767, 32767] converted to float [-1, 1] `vec4` in shaders. + Snorm16x4 = 15, + /// Two half-precision floats (no Rust equiv). `vec2` in shaders. + Float16x2 = 16, + /// Four half-precision floats (no Rust equiv). `vec4` in shaders. + Float16x4 = 17, + /// One single-precision float (f32). `f32` in shaders. + Float32 = 18, + /// Two single-precision floats (f32). `vec2` in shaders. + Float32x2 = 19, + /// Three single-precision floats (f32). `vec3` in shaders. + Float32x3 = 20, + /// Four single-precision floats (f32). `vec4` in shaders. + Float32x4 = 21, + /// One unsigned int (u32). `u32` in shaders. + Uint32 = 22, + /// Two unsigned ints (u32). `vec2` in shaders. + Uint32x2 = 23, + /// Three unsigned ints (u32). `vec3` in shaders. + Uint32x3 = 24, + /// Four unsigned ints (u32). `vec4` in shaders. + Uint32x4 = 25, + /// One signed int (i32). `i32` in shaders. + Sint32 = 26, + /// Two signed ints (i32). `vec2` in shaders. + Sint32x2 = 27, + /// Three signed ints (i32). `vec3` in shaders. + Sint32x3 = 28, + /// Four signed ints (i32). `vec4` in shaders. + Sint32x4 = 29, + /// Three unsigned 10-bit integers and one 2-bit integer, packed into a 32-bit integer (u32). [0, 1024] converted to float [0, 1] `vec4` in shaders. + #[cfg_attr(feature = "serde", serde(rename = "unorm10-10-10-2"))] + Unorm10_10_10_2 = 34, +} + +/// A mapping of vertex buffers and their attributes to shader +/// locations. +#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct AttributeMapping { + /// Shader location associated with this attribute + pub shader_location: u32, + /// Offset in bytes from start of vertex buffer structure + pub offset: u32, + /// Format code to help us unpack the attribute into the type + /// used by the shader. Codes correspond to a 0-based index of + /// . + /// The conversion process is described by + /// . + pub format: VertexFormat, +} + +/// A description of a vertex buffer with all the information we +/// need to address the attributes within it. +#[derive(Debug, Default, Clone, PartialEq, Eq, Hash)] +#[cfg_attr(feature = "serialize", derive(serde::Serialize))] +#[cfg_attr(feature = "deserialize", derive(serde::Deserialize))] +pub struct VertexBufferMapping { + /// Shader location associated with this buffer + pub id: u32, + /// Size of the structure in bytes + pub stride: u32, + /// True if the buffer is indexed by vertex, false if indexed + /// by instance. + pub indexed_by_vertex: bool, + /// Vec of the attributes within the structure + pub attributes: Vec, +} + /// A subset of options that are meant to be changed per pipeline. #[derive(Debug, Default, Clone)] #[cfg_attr(feature = "serialize", derive(serde::Serialize))] @@ -234,6 +341,17 @@ pub struct PipelineOptions { /// /// Enable this for vertex shaders with point primitive topologies. pub allow_and_force_point_size: bool, + + /// If set, when generating the Metal vertex shader, transform it + /// to receive the vertex buffers, lengths, and vertex id as args, + /// and bounds-check the vertex id and use the index into the + /// vertex buffers to access attributes, rather than using Metal's + /// [[stage-in]] assembled attribute data. + pub vertex_pulling_transform: bool, + + /// vertex_buffer_mappings are used during shader translation to + /// support vertex pulling. + pub vertex_buffer_mappings: Vec, } impl Options { diff --git a/naga/src/back/msl/writer.rs b/naga/src/back/msl/writer.rs index 389785992..ba2876713 100644 --- a/naga/src/back/msl/writer.rs +++ b/naga/src/back/msl/writer.rs @@ -75,6 +75,14 @@ fn put_numeric_type( } } +const fn scalar_is_int(scalar: crate::Scalar) -> bool { + use crate::ScalarKind::*; + match scalar.kind { + Sint | Uint | AbstractInt | Bool => true, + Float | AbstractFloat => false, + } +} + /// Prefix for cached clamped level-of-detail values for `ImageLoad` expressions. const CLAMPED_LOD_LOAD_PREFIX: &str = "clamped_lod_e"; @@ -87,6 +95,22 @@ struct TypeContext<'a> { first_time: bool, } +impl<'a> TypeContext<'a> { + fn scalar(&self) -> Option { + let ty = &self.gctx.types[self.handle]; + ty.inner.scalar() + } + + fn vertex_input_dimension(&self) -> u32 { + let ty = &self.gctx.types[self.handle]; + match ty.inner { + crate::TypeInner::Scalar(_) => 1, + crate::TypeInner::Vector { size, .. } => size as u32, + _ => unreachable!(), + } + } +} + impl<'a> Display for TypeContext<'a> { fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> { let ty = &self.gctx.types[self.handle]; @@ -3001,7 +3025,7 @@ impl Writer { // follow-up with any global resources used let mut separate = !arguments.is_empty(); let fun_info = &context.expression.mod_info[function]; - let mut supports_array_length = false; + let mut needs_buffer_sizes = false; for (handle, var) in context.expression.module.global_variables.iter() { if fun_info[handle].is_empty() { continue; @@ -3015,10 +3039,10 @@ impl Writer { } write!(self.out, "{name}")?; } - supports_array_length |= + needs_buffer_sizes |= needs_array_length(var.ty, &context.expression.module.types); } - if supports_array_length { + if needs_buffer_sizes { if separate { write!(self.out, ", ")?; } @@ -3417,13 +3441,22 @@ impl Writer { } } - if !indices.is_empty() { + let mut buffer_indices = vec![]; + for vbm in &pipeline_options.vertex_buffer_mappings { + buffer_indices.push(vbm.id); + } + + if !indices.is_empty() || !buffer_indices.is_empty() { writeln!(self.out, "struct _mslBufferSizes {{")?; for idx in indices { writeln!(self.out, "{}uint size{};", back::INDENT, idx)?; } + for idx in buffer_indices { + writeln!(self.out, "{}uint buffer_size{};", back::INDENT, idx)?; + } + writeln!(self.out, "}};")?; writeln!(self.out)?; } @@ -3764,6 +3797,672 @@ impl Writer { Ok(()) } + fn write_unpacking_function( + &mut self, + format: back::msl::VertexFormat, + ) -> Result<(String, u32, u32), Error> { + use back::msl::VertexFormat::*; + match format { + Uint8x2 => { + let name = self.namer.call("unpackUint8x2"); + writeln!( + self.out, + "metal::uint2 {name}(metal::uchar b0, \ + metal::uchar b1) {{" + )?; + writeln!(self.out, "{}return metal::uint2(b0, b1);", back::INDENT)?; + writeln!(self.out, "}}")?; + Ok((name, 2, 2)) + } + Uint8x4 => { + let name = self.namer.call("unpackUint8x4"); + writeln!( + self.out, + "metal::uint4 {name}(metal::uchar b0, \ + metal::uchar b1, \ + metal::uchar b2, \ + metal::uchar b3) {{" + )?; + writeln!( + self.out, + "{}return metal::uint4(b0, b1, b2, b3);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 4)) + } + Sint8x2 => { + let name = self.namer.call("unpackSint8x2"); + writeln!( + self.out, + "metal::int2 {name}(metal::uchar b0, \ + metal::uchar b1) {{" + )?; + writeln!( + self.out, + "{}return metal::int2(as_type(b0), \ + as_type(b1));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 2, 2)) + } + Sint8x4 => { + let name = self.namer.call("unpackSint8x4"); + writeln!( + self.out, + "metal::int4 {name}(metal::uchar b0, \ + metal::uchar b1, \ + metal::uchar b2, \ + metal::uchar b3) {{" + )?; + writeln!( + self.out, + "{}return metal::int4(as_type(b0), \ + as_type(b1), \ + as_type(b2), \ + as_type(b3));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 4)) + } + Unorm8x2 => { + let name = self.namer.call("unpackUnorm8x2"); + writeln!( + self.out, + "metal::float2 {name}(metal::uchar b0, \ + metal::uchar b1) {{" + )?; + writeln!( + self.out, + "{}return metal::float2(float(b0) / 255.0f, \ + float(b1) / 255.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 2, 2)) + } + Unorm8x4 => { + let name = self.namer.call("unpackUnorm8x4"); + writeln!( + self.out, + "metal::float4 {name}(metal::uchar b0, \ + metal::uchar b1, \ + metal::uchar b2, \ + metal::uchar b3) {{" + )?; + writeln!( + self.out, + "{}return metal::float4(float(b0) / 255.0f, \ + float(b1) / 255.0f, \ + float(b2) / 255.0f, \ + float(b3) / 255.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 4)) + } + Snorm8x2 => { + let name = self.namer.call("unpackSnorm8x2"); + writeln!( + self.out, + "metal::float2 {name}(metal::uchar b0, \ + metal::uchar b1) {{" + )?; + writeln!( + self.out, + "{}return metal::float2((float(b0) - 128.0f) / 255.0f, \ + (float(b1) - 128.0f) / 255.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 2, 2)) + } + Snorm8x4 => { + let name = self.namer.call("unpackSnorm8x4"); + writeln!( + self.out, + "metal::float4 {name}(metal::uchar b0, \ + metal::uchar b1, \ + metal::uchar b2, \ + metal::uchar b3) {{" + )?; + writeln!( + self.out, + "{}return metal::float4((float(b0) - 128.0f) / 255.0f, \ + (float(b1) - 128.0f) / 255.0f, \ + (float(b2) - 128.0f) / 255.0f, \ + (float(b3) - 128.0f) / 255.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 4)) + } + Uint16x2 => { + let name = self.namer.call("unpackUint16x2"); + writeln!( + self.out, + "metal::uint2 {name}(metal::uint b0, \ + metal::uint b1, \ + metal::uint b2, \ + metal::uint b3) {{" + )?; + writeln!( + self.out, + "{}return metal::uint2(b1 << 8 | b0, \ + b3 << 8 | b2);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 2)) + } + Uint16x4 => { + let name = self.namer.call("unpackUint16x4"); + writeln!( + self.out, + "metal::uint4 {name}(metal::uint b0, \ + metal::uint b1, \ + metal::uint b2, \ + metal::uint b3, \ + metal::uint b4, \ + metal::uint b5, \ + metal::uint b6, \ + metal::uint b7) {{" + )?; + writeln!( + self.out, + "{}return metal::uint4(b1 << 8 | b0, \ + b3 << 8 | b2, \ + b5 << 8 | b4, \ + b7 << 8 | b6);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 4)) + } + Sint16x2 => { + let name = self.namer.call("unpackSint16x2"); + writeln!( + self.out, + "metal::int2 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3) {{" + )?; + writeln!( + self.out, + "{}return metal::int2(as_type(b1 << 8 | b0), \ + as_type(b3 << 8 | b2));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 2)) + } + Sint16x4 => { + let name = self.namer.call("unpackSint16x4"); + writeln!( + self.out, + "metal::int4 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3, \ + metal::ushort b4, \ + metal::ushort b5, \ + metal::ushort b6, \ + metal::ushort b7) {{" + )?; + writeln!( + self.out, + "{}return metal::int4(as_type(b1 << 8 | b0), \ + as_type(b3 << 8 | b2), \ + as_type(b5 << 8 | b4), \ + as_type(b7 << 8 | b6));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 4)) + } + Unorm16x2 => { + let name = self.namer.call("unpackUnorm16x2"); + writeln!( + self.out, + "metal::float2 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3) {{" + )?; + writeln!( + self.out, + "{}return metal::float2(float(b1 << 8 | b0) / 65535.0f, \ + float(b3 << 8 | b2) / 65535.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 2)) + } + Unorm16x4 => { + let name = self.namer.call("unpackUnorm16x4"); + writeln!( + self.out, + "metal::float4 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3, \ + metal::ushort b4, \ + metal::ushort b5, \ + metal::ushort b6, \ + metal::ushort b7) {{" + )?; + writeln!( + self.out, + "{}return metal::float4(float(b1 << 8 | b0) / 65535.0f, \ + float(b3 << 8 | b2) / 65535.0f, \ + float(b5 << 8 | b4) / 65535.0f, \ + float(b7 << 8 | b6) / 65535.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 4)) + } + Snorm16x2 => { + let name = self.namer.call("unpackSnorm16x2"); + writeln!( + self.out, + "metal::float2 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3) {{" + )?; + writeln!( + self.out, + "{}return metal::float2((float(b1 << 8 | b0) - 32767.0f) / 65535.0f, \ + (float(b3 << 8 | b2) - 32767.0f) / 65535.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 2)) + } + Snorm16x4 => { + let name = self.namer.call("unpackSnorm16x4"); + writeln!( + self.out, + "metal::float4 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3, \ + metal::ushort b4, \ + metal::ushort b5, \ + metal::ushort b6, \ + metal::ushort b7) {{" + )?; + writeln!( + self.out, + "{}return metal::float4((float(b1 << 8 | b0) - 32767.0f) / 65535.0f, \ + (float(b3 << 8 | b2) - 32767.0f) / 65535.0f, \ + (float(b5 << 8 | b4) - 32767.0f) / 65535.0f, \ + (float(b7 << 8 | b6) - 32767.0f) / 65535.0f);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 4)) + } + Float16x2 => { + let name = self.namer.call("unpackFloat16x2"); + writeln!( + self.out, + "metal::float2 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3) {{" + )?; + writeln!( + self.out, + "{}return metal::float2(as_type(b1 << 8 | b0), \ + as_type(b3 << 8 | b2));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 2)) + } + Float16x4 => { + let name = self.namer.call("unpackFloat16x4"); + writeln!( + self.out, + "metal::int4 {name}(metal::ushort b0, \ + metal::ushort b1, \ + metal::ushort b2, \ + metal::ushort b3, \ + metal::ushort b4, \ + metal::ushort b5, \ + metal::ushort b6, \ + metal::ushort b7) {{" + )?; + writeln!( + self.out, + "{}return metal::int4(as_type(b1 << 8 | b0), \ + as_type(b3 << 8 | b2), \ + as_type(b5 << 8 | b4), \ + as_type(b7 << 8 | b6));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 4)) + } + Float32 => { + let name = self.namer.call("unpackFloat32"); + writeln!( + self.out, + "float {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3) {{" + )?; + writeln!( + self.out, + "{}return as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 1)) + } + Float32x2 => { + let name = self.namer.call("unpackFloat32x2"); + writeln!( + self.out, + "metal::float2 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7) {{" + )?; + writeln!( + self.out, + "{}return metal::float2(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 2)) + } + Float32x3 => { + let name = self.namer.call("unpackFloat32x3"); + writeln!( + self.out, + "metal::float3 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11) {{" + )?; + writeln!( + self.out, + "{}return metal::float3(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 12, 3)) + } + Float32x4 => { + let name = self.namer.call("unpackFloat32x4"); + writeln!( + self.out, + "metal::float4 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11, \ + uint b12, \ + uint b13, \ + uint b14, \ + uint b15) {{" + )?; + writeln!( + self.out, + "{}return metal::float4(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8), \ + as_type(b15 << 24 | b14 << 16 | b13 << 8 | b12));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 16, 4)) + } + Uint32 => { + let name = self.namer.call("unpackUint32"); + writeln!( + self.out, + "uint {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3) {{" + )?; + writeln!( + self.out, + "{}return (b3 << 24 | b2 << 16 | b1 << 8 | b0);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 1)) + } + Uint32x2 => { + let name = self.namer.call("unpackUint32x2"); + writeln!( + self.out, + "uint2 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7) {{" + )?; + writeln!( + self.out, + "{}return uint2((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + (b7 << 24 | b6 << 16 | b5 << 8 | b4));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 2)) + } + Uint32x3 => { + let name = self.namer.call("unpackUint32x3"); + writeln!( + self.out, + "uint3 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11) {{" + )?; + writeln!( + self.out, + "{}return uint3((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + (b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + (b11 << 24 | b10 << 16 | b9 << 8 | b8));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 12, 3)) + } + Uint32x4 => { + let name = self.namer.call("unpackUint32x4"); + writeln!( + self.out, + "uint4 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11, \ + uint b12, \ + uint b13, \ + uint b14, \ + uint b15) {{" + )?; + writeln!( + self.out, + "{}return uint4((b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + (b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + (b11 << 24 | b10 << 16 | b9 << 8 | b8), \ + (b15 << 24 | b14 << 16 | b13 << 8 | b12));", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 16, 4)) + } + Sint32 => { + let name = self.namer.call("unpackSint32"); + writeln!( + self.out, + "metal::int {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3) {{" + )?; + writeln!( + self.out, + "{}return as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 1)) + } + Sint32x2 => { + let name = self.namer.call("unpackSint32x2"); + writeln!( + self.out, + "metal::int2 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7) {{" + )?; + writeln!( + self.out, + "{}return metal::int2(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 8, 2)) + } + Sint32x3 => { + let name = self.namer.call("unpackSint32x3"); + writeln!( + self.out, + "metal::int3 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11) {{" + )?; + writeln!( + self.out, + "{}return metal::int3(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 12, 3)) + } + Sint32x4 => { + let name = self.namer.call("unpackSint32x4"); + writeln!( + self.out, + "metal::int4 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3, \ + uint b4, \ + uint b5, \ + uint b6, \ + uint b7, \ + uint b8, \ + uint b9, \ + uint b10, \ + uint b11, \ + uint b12, \ + uint b13, \ + uint b14, \ + uint b15) {{" + )?; + writeln!( + self.out, + "{}return metal::int4(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), \ + as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), \ + as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8), \ + as_type(b15 << 24 | b14 << 16 | b13 << 8 | b12);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 16, 4)) + } + Unorm10_10_10_2 => { + let name = self.namer.call("unpackUnorm10_10_10_2"); + writeln!( + self.out, + "metal::float4 {name}(uint b0, \ + uint b1, \ + uint b2, \ + uint b3) {{" + )?; + writeln!( + self.out, + "{}return unpack_unorm10a2_to_float(b3 << 24 | b2 << 16 | b1 << 8 | b0);", + back::INDENT + )?; + writeln!(self.out, "}}")?; + Ok((name, 4, 4)) + } + } + } + // Returns the array of mapped entry point names. fn write_functions( &mut self, @@ -3772,6 +4471,101 @@ impl Writer { options: &Options, pipeline_options: &PipelineOptions, ) -> Result { + use back::msl::VertexFormat; + + // Define structs to hold resolved/generated data for vertex buffers and + // their attributes. + struct AttributeMappingResolved { + ty_name: String, + dimension: u32, + ty_is_int: bool, + name: String, + } + let mut am_resolved = FastHashMap::::default(); + + struct VertexBufferMappingResolved<'a> { + id: u32, + stride: u32, + indexed_by_vertex: bool, + ty_name: String, + param_name: String, + elem_name: String, + attributes: &'a Vec, + } + let mut vbm_resolved = Vec::::new(); + + // Define a struct to hold a named reference to a byte-unpacking function. + struct UnpackingFunction { + name: String, + byte_count: u32, + dimension: u32, + } + let mut unpacking_functions = FastHashMap::::default(); + + // Check if we are attempting vertex pulling. If we are, generate some + // names we'll need, and iterate the vertex buffer mappings to output + // all the conversion functions we'll need to unpack the attribute data. + // We can re-use these names for all entry points that need them, since + // those entry points also use self.namer. + let mut needs_vertex_id = false; + let v_id = self.namer.call("v_id"); + + let mut needs_instance_id = false; + let i_id = self.namer.call("i_id"); + if pipeline_options.vertex_pulling_transform { + for vbm in &pipeline_options.vertex_buffer_mappings { + let buffer_id = vbm.id; + let buffer_stride = vbm.stride; + + assert!( + buffer_stride > 0, + "Vertex pulling requires a non-zero buffer stride." + ); + + if vbm.indexed_by_vertex { + needs_vertex_id = true; + } else { + needs_instance_id = true; + } + + let buffer_ty = self.namer.call(format!("vb_{buffer_id}_type").as_str()); + let buffer_param = self.namer.call(format!("vb_{buffer_id}_in").as_str()); + let buffer_elem = self.namer.call(format!("vb_{buffer_id}_elem").as_str()); + + vbm_resolved.push(VertexBufferMappingResolved { + id: buffer_id, + stride: buffer_stride, + indexed_by_vertex: vbm.indexed_by_vertex, + ty_name: buffer_ty, + param_name: buffer_param, + elem_name: buffer_elem, + attributes: &vbm.attributes, + }); + + // Iterate the attributes and generate needed unpacking functions. + for attribute in &vbm.attributes { + if unpacking_functions.contains_key(&attribute.format) { + continue; + } + let (name, byte_count, dimension) = + match self.write_unpacking_function(attribute.format) { + Ok((name, byte_count, dimension)) => (name, byte_count, dimension), + _ => { + continue; + } + }; + unpacking_functions.insert( + attribute.format, + UnpackingFunction { + name, + byte_count, + dimension, + }, + ); + } + } + } + let mut pass_through_globals = Vec::new(); for (fun_handle, fun) in module.functions.iter() { log::trace!( @@ -3782,13 +4576,13 @@ impl Writer { let fun_info = &mod_info[fun_handle]; pass_through_globals.clear(); - let mut supports_array_length = false; + let mut needs_buffer_sizes = false; for (handle, var) in module.global_variables.iter() { if !fun_info[handle].is_empty() { if var.space.needs_pass_through() { pass_through_globals.push(handle); } - supports_array_length |= needs_array_length(var.ty, &module.types); + needs_buffer_sizes |= needs_array_length(var.ty, &module.types); } } @@ -3825,7 +4619,7 @@ impl Writer { let separator = separate( !pass_through_globals.is_empty() || index + 1 != fun.arguments.len() - || supports_array_length, + || needs_buffer_sizes, ); writeln!( self.out, @@ -3846,13 +4640,13 @@ impl Writer { reference: true, }; let separator = - separate(index + 1 != pass_through_globals.len() || supports_array_length); + separate(index + 1 != pass_through_globals.len() || needs_buffer_sizes); write!(self.out, "{}", back::INDENT)?; tyvar.try_fmt(&mut self.out)?; writeln!(self.out, "{separator}")?; } - if supports_array_length { + if needs_buffer_sizes { writeln!( self.out, "{}constant _mslBufferSizes& _buffer_sizes", @@ -3917,18 +4711,51 @@ impl Writer { let fun_info = mod_info.get_entry_point(ep_index); let mut ep_error = None; + // For vertex_id and instance_id arguments, presume that we'll + // use our generated names, but switch to the name of an + // existing @builtin param, if we find one. + let mut v_existing_id = None; + let mut i_existing_id = None; + log::trace!( "entry point {:?}, index {:?}", fun.name.as_deref().unwrap_or("(anonymous)"), ep_index ); + let (em_str, in_mode, out_mode, can_vertex_pull) = match ep.stage { + crate::ShaderStage::Vertex => ( + "vertex", + LocationMode::VertexInput, + LocationMode::VertexOutput, + true, + ), + crate::ShaderStage::Fragment { .. } => ( + "fragment", + LocationMode::FragmentInput, + LocationMode::FragmentOutput, + false, + ), + crate::ShaderStage::Compute { .. } => ( + "kernel", + LocationMode::Uniform, + LocationMode::Uniform, + false, + ), + }; + + // Should this entry point be modified to do vertex pulling? + let do_vertex_pulling = can_vertex_pull + && pipeline_options.vertex_pulling_transform + && !pipeline_options.vertex_buffer_mappings.is_empty(); + // Is any global variable used by this entry point dynamically sized? - let supports_array_length = module - .global_variables - .iter() - .filter(|&(handle, _)| !fun_info[handle].is_empty()) - .any(|(_, var)| needs_array_length(var.ty, &module.types)); + let needs_buffer_sizes = do_vertex_pulling + || module + .global_variables + .iter() + .filter(|&(handle, _)| !fun_info[handle].is_empty()) + .any(|(_, var)| needs_array_length(var.ty, &module.types)); // skip this entry point if any global bindings are missing, // or their types are incompatible. @@ -3986,7 +4813,7 @@ impl Writer { | crate::AddressSpace::WorkGroup => {} } } - if supports_array_length { + if needs_buffer_sizes { if let Err(err) = options.resolve_sizes_buffer(ep) { ep_error = Some(err); } @@ -4002,22 +4829,6 @@ impl Writer { writeln!(self.out)?; - let (em_str, in_mode, out_mode) = match ep.stage { - crate::ShaderStage::Vertex => ( - "vertex", - LocationMode::VertexInput, - LocationMode::VertexOutput, - ), - crate::ShaderStage::Fragment { .. } => ( - "fragment", - LocationMode::FragmentInput, - LocationMode::FragmentOutput, - ), - crate::ShaderStage::Compute { .. } => { - ("kernel", LocationMode::Uniform, LocationMode::Uniform) - } - }; - // Since `Namer.reset` wasn't expecting struct members to be // suddenly injected into another namespace like this, // `self.names` doesn't keep them distinct from other variables. @@ -4045,7 +4856,11 @@ impl Writer { let name_key = NameKey::StructMember(arg.ty, member_index); let name = match member.binding { Some(crate::Binding::Location { .. }) => { - varyings_namer.call(&self.names[&name_key]) + if do_vertex_pulling { + self.namer.call(&self.names[&name_key]) + } else { + varyings_namer.call(&self.names[&name_key]) + } } _ => self.namer.call(&self.names[&name_key]), }; @@ -4060,19 +4875,24 @@ impl Writer { } } - // Identify the varyings among the argument values, and emit a - // struct type named `Input` to hold them. + // Identify the varyings among the argument values, and maybe emit + // a struct type named `Input` to hold them. If we are doing + // vertex pulling, we instead update our attribute mapping to + // note the types, names, and zero values of the attributes. let stage_in_name = format!("{fun_name}Input"); let varyings_member_name = self.namer.call("varyings"); let mut has_varyings = false; if !flattened_arguments.is_empty() { - writeln!(self.out, "struct {stage_in_name} {{")?; + if !do_vertex_pulling { + writeln!(self.out, "struct {stage_in_name} {{")?; + } for &(ref name_key, ty, binding) in flattened_arguments.iter() { - let binding = match binding { - Some(ref binding @ &crate::Binding::Location { .. }) => binding, + let (binding, location) = match binding { + Some(ref binding @ &crate::Binding::Location { location, .. }) => { + (binding, location) + } _ => continue, }; - has_varyings = true; let name = match *name_key { NameKey::StructMember(..) => &flattened_member_names[name_key], _ => &self.names[name_key], @@ -4086,11 +4906,27 @@ impl Writer { first_time: false, }; let resolved = options.resolve_local_binding(binding, in_mode)?; - write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; - resolved.try_fmt(&mut self.out)?; - writeln!(self.out, ";")?; + if do_vertex_pulling { + // Update our attribute mapping. + am_resolved.insert( + location, + AttributeMappingResolved { + ty_name: ty_name.to_string(), + dimension: ty_name.vertex_input_dimension(), + ty_is_int: ty_name.scalar().map_or(false, scalar_is_int), + name: name.to_string(), + }, + ); + } else { + has_varyings = true; + write!(self.out, "{}{} {}", back::INDENT, ty_name, name)?; + resolved.try_fmt(&mut self.out)?; + writeln!(self.out, ";")?; + } + } + if !do_vertex_pulling { + writeln!(self.out, "}};")?; } - writeln!(self.out, "}};")?; } // Define a struct type named for the return value, if any, named @@ -4173,6 +5009,23 @@ impl Writer { None => "void", }; + // If we're doing a vertex pulling transform, define the buffer + // structure types. + if do_vertex_pulling { + for vbm in &vbm_resolved { + let buffer_stride = vbm.stride; + let buffer_ty = &vbm.ty_name; + + // Define a structure of bytes of the appropriate size. + // When we access the attributes, we'll be unpacking these + // bytes at some offset. + writeln!( + self.out, + "struct {buffer_ty} {{ metal::uchar data[{buffer_stride}]; }};" + )?; + } + } + // Write the entry point function's name, and begin its argument list. writeln!(self.out, "{em_str} {result_type_name} {fun_name}(")?; let mut is_first_argument = true; @@ -4213,6 +5066,17 @@ impl Writer { binding: None, first_time: false, }; + + match *binding { + crate::Binding::BuiltIn(crate::BuiltIn::VertexIndex) => { + v_existing_id = Some(name.clone()); + } + crate::Binding::BuiltIn(crate::BuiltIn::InstanceIndex) => { + i_existing_id = Some(name.clone()); + } + _ => {} + }; + let resolved = options.resolve_local_binding(binding, in_mode)?; let separator = if is_first_argument { is_first_argument = false; @@ -4405,16 +5269,45 @@ impl Writer { writeln!(self.out)?; } - // If this entry uses any variable-length arrays, their sizes are - // passed as a final struct-typed argument. - if supports_array_length { - // this is checked earlier - let resolved = options.resolve_sizes_buffer(ep).unwrap(); - let separator = if module.global_variables.is_empty() { + if do_vertex_pulling { + assert!(needs_vertex_id || needs_instance_id); + + let mut separator = if is_first_argument { + is_first_argument = false; ' ' } else { ',' }; + + if needs_vertex_id && v_existing_id.is_none() { + // Write the [[vertex_id]] argument. + writeln!(self.out, "{separator} uint {v_id} [[vertex_id]]")?; + separator = ','; + } + + if needs_instance_id && i_existing_id.is_none() { + writeln!(self.out, "{separator} uint {i_id} [[instance_id]]")?; + } + + // Iterate vbm_resolved, output one argument for every vertex buffer, + // using the names we generated earlier. + for vbm in &vbm_resolved { + let id = &vbm.id; + let ty_name = &vbm.ty_name; + let param_name = &vbm.param_name; + writeln!( + self.out, + ", const device {ty_name}* {param_name} [[buffer({id})]]" + )?; + } + } + + // If this entry uses any variable-length arrays, their sizes are + // passed as a final struct-typed argument. + if needs_buffer_sizes { + // this is checked earlier + let resolved = options.resolve_sizes_buffer(ep).unwrap(); + let separator = if is_first_argument { ' ' } else { ',' }; write!( self.out, "{separator} constant _mslBufferSizes& _buffer_sizes", @@ -4426,6 +5319,126 @@ impl Writer { // end of the entry point argument list writeln!(self.out, ") {{")?; + // Starting the function body. + if do_vertex_pulling { + // Provide zero values for all the attributes, which we will overwrite with + // real data from the vertex attribute buffers, if the indices are in-bounds. + for vbm in &vbm_resolved { + for attribute in vbm.attributes { + let location = attribute.shader_location; + let am_option = am_resolved.get(&location); + if am_option.is_none() { + // This bound attribute isn't used in this entry point, so + // don't bother zero-initializing it. + continue; + } + let am = am_option.unwrap(); + let attribute_ty_name = &am.ty_name; + let attribute_name = &am.name; + + writeln!( + self.out, + "{}{attribute_ty_name} {attribute_name} = {{}};", + back::Level(1) + )?; + } + + // Output a bounds check block that will set real values for the + // attributes, if the bounds are satisfied. + write!(self.out, "{}if (", back::Level(1))?; + + let idx = &vbm.id; + let stride = &vbm.stride; + let index_name = if vbm.indexed_by_vertex { + if let Some(ref name) = v_existing_id { + name + } else { + &v_id + } + } else if let Some(ref name) = i_existing_id { + name + } else { + &i_id + }; + write!( + self.out, + "{index_name} < (_buffer_sizes.buffer_size{idx} / {stride})" + )?; + + writeln!(self.out, ") {{")?; + + // Pull the bytes out of the vertex buffer. + let ty_name = &vbm.ty_name; + let elem_name = &vbm.elem_name; + let param_name = &vbm.param_name; + + writeln!( + self.out, + "{}const {ty_name} {elem_name} = {param_name}[{index_name}];", + back::Level(2), + )?; + + // Now set real values for each of the attributes, by unpacking the data + // from the buffer elements. + for attribute in vbm.attributes { + let location = attribute.shader_location; + let am_option = am_resolved.get(&location); + if am_option.is_none() { + // This bound attribute isn't used in this entry point, so + // don't bother extracting the data. Too bad we emitted the + // unpacking function earlier -- it might not get used. + continue; + } + let am = am_option.unwrap(); + let attribute_name = &am.name; + let attribute_ty_name = &am.ty_name; + + let offset = attribute.offset; + let func = unpacking_functions + .get(&attribute.format) + .expect("Should have generated this unpacking function earlier."); + let func_name = &func.name; + + write!(self.out, "{}{attribute_name} = ", back::Level(2),)?; + + // Check dimensionality of the attribute compared to the unpacking + // function. If attribute dimension is < unpack dimension, then + // we need to explicitly cast down the result. Otherwise, if attribute + // dimension > unpack dimension, we have to pad out the unpack value + // from a vec4(0, 0, 0, 1) of matching scalar type. + + let needs_truncate_or_padding = am.dimension != func.dimension; + if needs_truncate_or_padding { + write!(self.out, "{attribute_ty_name}(")?; + } + + write!(self.out, "{func_name}({elem_name}.data[{offset}]",)?; + for i in (offset + 1)..(offset + func.byte_count) { + write!(self.out, ", {elem_name}.data[{i}]")?; + } + write!(self.out, ")")?; + + if needs_truncate_or_padding { + let zero_value = if am.ty_is_int { "0" } else { "0.0" }; + let one_value = if am.ty_is_int { "1" } else { "1.0" }; + for i in func.dimension..am.dimension { + write!( + self.out, + ", {}", + if i == 3 { one_value } else { zero_value } + )?; + } + write!(self.out, ")")?; + } + + writeln!(self.out, ";")?; + } + + // End the bounds check / attribute setting block. + writeln!(self.out, "{}}}", back::Level(1))?; + } + } + if need_workgroup_variables_initialization { self.write_workgroup_variables_initialization( module, @@ -4518,7 +5531,9 @@ impl Writer { write!(self.out, "{{}}, ")?; } if let Some(crate::Binding::Location { .. }) = member.binding { - write!(self.out, "{varyings_member_name}.")?; + if has_varyings { + write!(self.out, "{varyings_member_name}.")?; + } } write!(self.out, "{name}")?; } @@ -4526,14 +5541,16 @@ impl Writer { } _ => { if let Some(crate::Binding::Location { .. }) = arg.binding { - writeln!( - self.out, - "{}const auto {} = {}.{};", - back::INDENT, - arg_name, - varyings_member_name, - arg_name - )?; + if has_varyings { + writeln!( + self.out, + "{}const auto {} = {}.{};", + back::INDENT, + arg_name, + varyings_member_name, + arg_name + )?; + } } } } diff --git a/naga/tests/in/interface.param.ron b/naga/tests/in/interface.param.ron index 4d8566176..b5dce6b8a 100644 --- a/naga/tests/in/interface.param.ron +++ b/naga/tests/in/interface.param.ron @@ -27,5 +27,7 @@ ), msl_pipeline: ( allow_and_force_point_size: true, + vertex_pulling_transform: false, + vertex_buffer_mappings: [], ), ) diff --git a/naga/tests/in/vertex-pulling-transform.param.ron b/naga/tests/in/vertex-pulling-transform.param.ron new file mode 100644 index 000000000..d05e21256 --- /dev/null +++ b/naga/tests/in/vertex-pulling-transform.param.ron @@ -0,0 +1,31 @@ +( + msl_pipeline: ( + allow_and_force_point_size: false, + vertex_pulling_transform: true, + vertex_buffer_mappings: [( + id: 1, + stride: 20, + indexed_by_vertex: true, + attributes: [( + shader_location: 0, // position + offset: 0, + format: Float32, // too small, inflated to a vec4 + ), + ( + shader_location: 1, // normal + offset: 4, + format: Float32x4, // too big, truncated to a vec3 + )], + ), + ( + id: 2, + stride: 16, + indexed_by_vertex: false, + attributes: [( + shader_location: 2, // texcoord + offset: 0, + format: Float32x2, + )], + )], + ), +) diff --git a/naga/tests/in/vertex-pulling-transform.wgsl b/naga/tests/in/vertex-pulling-transform.wgsl new file mode 100644 index 000000000..87b318ec1 --- /dev/null +++ b/naga/tests/in/vertex-pulling-transform.wgsl @@ -0,0 +1,32 @@ +struct VertexOutput { + @builtin(position) position: vec4, + @location(0) color: vec4, + @location(1) texcoord: vec2, +} + +struct VertexInput { + @location(0) position: vec4, + @location(1) normal: vec3, + @location(2) texcoord: vec2, +} + +@group(0) @binding(0) var mvp_matrix: mat4x4; + +@vertex +fn render_vertex( + v_in: VertexInput, + @builtin(vertex_index) v_existing_id: u32, +) -> VertexOutput +{ + var v_out: VertexOutput; + v_out.position = v_in.position * mvp_matrix; + v_out.color = do_lighting(v_in.position, + v_in.normal); + v_out.texcoord = v_in.texcoord; + return v_out; +} + +fn do_lighting(position: vec4, normal: vec3) -> vec4 { + // blah blah blah + return vec4(0); +} diff --git a/naga/tests/out/msl/vertex-pulling-transform.msl b/naga/tests/out/msl/vertex-pulling-transform.msl new file mode 100644 index 000000000..6481e24c2 --- /dev/null +++ b/naga/tests/out/msl/vertex-pulling-transform.msl @@ -0,0 +1,76 @@ +// language: metal1.0 +#include +#include + +using metal::uint; + +struct _mslBufferSizes { + uint buffer_size1; + uint buffer_size2; +}; + +struct VertexOutput { + metal::float4 position; + metal::float4 color; + metal::float2 texcoord; +}; +struct VertexInput { + metal::float4 position; + metal::float3 normal; + metal::float2 texcoord; +}; +float unpackFloat32_(uint b0, uint b1, uint b2, uint b3) { + return as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0); +} +metal::float4 unpackFloat32x4_(uint b0, uint b1, uint b2, uint b3, uint b4, uint b5, uint b6, uint b7, uint b8, uint b9, uint b10, uint b11, uint b12, uint b13, uint b14, uint b15) { + return metal::float4(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4), as_type(b11 << 24 | b10 << 16 | b9 << 8 | b8), as_type(b15 << 24 | b14 << 16 | b13 << 8 | b12)); +} +metal::float2 unpackFloat32x2_(uint b0, uint b1, uint b2, uint b3, uint b4, uint b5, uint b6, uint b7) { + return metal::float2(as_type(b3 << 24 | b2 << 16 | b1 << 8 | b0), as_type(b7 << 24 | b6 << 16 | b5 << 8 | b4)); +} + +metal::float4 do_lighting( + metal::float4 position, + metal::float3 normal +) { + return metal::float4(0.0); +} + +struct render_vertexOutput { + metal::float4 position [[position]]; + metal::float4 color [[user(loc0), center_perspective]]; + metal::float2 texcoord [[user(loc1), center_perspective]]; +}; +struct vb_1_type { metal::uchar data[20]; }; +struct vb_2_type { metal::uchar data[16]; }; +vertex render_vertexOutput render_vertex( + uint v_existing_id [[vertex_id]] +, constant metal::float4x4& mvp_matrix [[user(fake0)]] +, uint i_id [[instance_id]] +, const device vb_1_type* vb_1_in [[buffer(1)]] +, const device vb_2_type* vb_2_in [[buffer(2)]] +, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]] +) { + metal::float4 position_1 = {}; + metal::float3 normal_1 = {}; + if (v_existing_id < (_buffer_sizes.buffer_size1 / 20)) { + const vb_1_type vb_1_elem = vb_1_in[v_existing_id]; + position_1 = metal::float4(unpackFloat32_(vb_1_elem.data[0], vb_1_elem.data[1], vb_1_elem.data[2], vb_1_elem.data[3]), 0.0, 0.0, 1.0); + normal_1 = metal::float3(unpackFloat32x4_(vb_1_elem.data[4], vb_1_elem.data[5], vb_1_elem.data[6], vb_1_elem.data[7], vb_1_elem.data[8], vb_1_elem.data[9], vb_1_elem.data[10], vb_1_elem.data[11], vb_1_elem.data[12], vb_1_elem.data[13], vb_1_elem.data[14], vb_1_elem.data[15], vb_1_elem.data[16], vb_1_elem.data[17], vb_1_elem.data[18], vb_1_elem.data[19])); + } + metal::float2 texcoord = {}; + if (i_id < (_buffer_sizes.buffer_size2 / 16)) { + const vb_2_type vb_2_elem = vb_2_in[i_id]; + texcoord = unpackFloat32x2_(vb_2_elem.data[0], vb_2_elem.data[1], vb_2_elem.data[2], vb_2_elem.data[3], vb_2_elem.data[4], vb_2_elem.data[5], vb_2_elem.data[6], vb_2_elem.data[7]); + } + const VertexInput v_in = { position_1, normal_1, texcoord }; + VertexOutput v_out = {}; + metal::float4x4 _e6 = mvp_matrix; + v_out.position = v_in.position * _e6; + metal::float4 _e11 = do_lighting(v_in.position, v_in.normal); + v_out.color = _e11; + v_out.texcoord = v_in.texcoord; + VertexOutput _e14 = v_out; + const auto _tmp = _e14; + return render_vertexOutput { _tmp.position, _tmp.color, _tmp.texcoord }; +} diff --git a/naga/tests/snapshots.rs b/naga/tests/snapshots.rs index 5e2441e0d..826337515 100644 --- a/naga/tests/snapshots.rs +++ b/naga/tests/snapshots.rs @@ -890,6 +890,7 @@ fn convert_wgsl() { "overrides-ray-query", Targets::IR | Targets::SPIRV | Targets::METAL, ), + ("vertex-pulling-transform", Targets::METAL), ]; for &(name, targets) in inputs.iter() { diff --git a/player/tests/data/bind-group.ron b/player/tests/data/bind-group.ron index 9da7abe09..a53a77b16 100644 --- a/player/tests/data/bind-group.ron +++ b/player/tests/data/bind-group.ron @@ -59,6 +59,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), ), ), diff --git a/player/tests/data/pipeline-statistics-query.ron b/player/tests/data/pipeline-statistics-query.ron index f0f96d42c..8a6e4239b 100644 --- a/player/tests/data/pipeline-statistics-query.ron +++ b/player/tests/data/pipeline-statistics-query.ron @@ -32,6 +32,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), ), ), diff --git a/player/tests/data/quad.ron b/player/tests/data/quad.ron index 1a8b4028b..aad576c42 100644 --- a/player/tests/data/quad.ron +++ b/player/tests/data/quad.ron @@ -60,6 +60,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), buffers: [], ), @@ -69,6 +70,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), targets: [ Some(( diff --git a/player/tests/data/zero-init-buffer.ron b/player/tests/data/zero-init-buffer.ron index 1ce7924dd..b13786e26 100644 --- a/player/tests/data/zero-init-buffer.ron +++ b/player/tests/data/zero-init-buffer.ron @@ -136,6 +136,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), ), ), diff --git a/player/tests/data/zero-init-texture-binding.ron b/player/tests/data/zero-init-texture-binding.ron index 2aeaf22c7..ba4951c19 100644 --- a/player/tests/data/zero-init-texture-binding.ron +++ b/player/tests/data/zero-init-texture-binding.ron @@ -137,6 +137,7 @@ entry_point: None, constants: {}, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, ), ), ), diff --git a/tests/tests/vertex_indices/mod.rs b/tests/tests/vertex_indices/mod.rs index 7bd172d85..b85f3274e 100644 --- a/tests/tests/vertex_indices/mod.rs +++ b/tests/tests/vertex_indices/mod.rs @@ -185,6 +185,7 @@ struct Test { id_source: IdSource, draw_call_kind: DrawCallKind, encoder_kind: EncoderKind, + vertex_pulling_transform: bool, } impl Test { @@ -298,6 +299,16 @@ async fn vertex_index_common(ctx: TestingContext) { cache: None, }; let builtin_pipeline = ctx.device.create_render_pipeline(&pipeline_desc); + pipeline_desc + .vertex + .compilation_options + .vertex_pulling_transform = true; + let builtin_pipeline_vpt = ctx.device.create_render_pipeline(&pipeline_desc); + pipeline_desc + .vertex + .compilation_options + .vertex_pulling_transform = false; + pipeline_desc.vertex.entry_point = "vs_main_buffers"; pipeline_desc.vertex.buffers = &[ wgpu::VertexBufferLayout { @@ -312,6 +323,15 @@ async fn vertex_index_common(ctx: TestingContext) { }, ]; let buffer_pipeline = ctx.device.create_render_pipeline(&pipeline_desc); + pipeline_desc + .vertex + .compilation_options + .vertex_pulling_transform = true; + let buffer_pipeline_vpt = ctx.device.create_render_pipeline(&pipeline_desc); + pipeline_desc + .vertex + .compilation_options + .vertex_pulling_transform = false; let dummy = ctx .device @@ -336,17 +356,20 @@ async fn vertex_index_common(ctx: TestingContext) { ) .create_view(&wgpu::TextureViewDescriptor::default()); - let mut tests = Vec::with_capacity(5 * 2 * 2); + let mut tests = Vec::with_capacity(5 * 2 * 2 * 2); for case in TestCase::ARRAY { for id_source in IdSource::ARRAY { for draw_call_kind in DrawCallKind::ARRAY { for encoder_kind in EncoderKind::ARRAY { - tests.push(Test { - case, - id_source, - draw_call_kind, - encoder_kind, - }) + for vertex_pulling_transform in [false, true] { + tests.push(Test { + case, + id_source, + draw_call_kind, + encoder_kind, + vertex_pulling_transform, + }) + } } } } @@ -357,8 +380,20 @@ async fn vertex_index_common(ctx: TestingContext) { let mut failed = false; for test in tests { let pipeline = match test.id_source { - IdSource::Buffers => &buffer_pipeline, - IdSource::Builtins => &builtin_pipeline, + IdSource::Buffers => { + if test.vertex_pulling_transform { + &buffer_pipeline_vpt + } else { + &buffer_pipeline + } + } + IdSource::Builtins => { + if test.vertex_pulling_transform { + &builtin_pipeline_vpt + } else { + &builtin_pipeline + } + } }; let expected = test.expectation(&ctx); diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index ba51507d1..f9242848c 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2737,6 +2737,7 @@ impl Device { entry_point: final_entry_point_name.as_ref(), constants: desc.stage.constants.as_ref(), zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory, + vertex_pulling_transform: false, }, cache: cache.as_ref().and_then(|it| it.raw.as_ref()), }; @@ -3165,6 +3166,7 @@ impl Device { entry_point: &vertex_entry_point_name, constants: stage_desc.constants.as_ref(), zero_initialize_workgroup_memory: stage_desc.zero_initialize_workgroup_memory, + vertex_pulling_transform: stage_desc.vertex_pulling_transform, } }; @@ -3228,6 +3230,7 @@ impl Device { zero_initialize_workgroup_memory: fragment_state .stage .zero_initialize_workgroup_memory, + vertex_pulling_transform: false, }) } None => None, diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index ee8f8668c..f3e7dbacb 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -166,6 +166,8 @@ pub struct ProgrammableStageDescriptor<'a> { /// This is required by the WebGPU spec, but may have overhead which can be avoided /// for cross-platform applications pub zero_initialize_workgroup_memory: bool, + /// Should the pipeline attempt to transform vertex shaders to use vertex pulling. + pub vertex_pulling_transform: bool, } /// Number of implicit bind groups derived at pipeline creation. diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index ee59fa259..560aa6f8c 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -254,6 +254,7 @@ impl Example { entry_point: "vs_main", constants: &constants, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }, vertex_buffers: &[], fragment_stage: Some(hal::ProgrammableStage { @@ -261,6 +262,7 @@ impl Example { entry_point: "fs_main", constants: &constants, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }), primitive: wgt::PrimitiveState { topology: wgt::PrimitiveTopology::TriangleStrip, diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs index 8f404dc4d..90f0e6fc5 100644 --- a/wgpu-hal/examples/ray-traced-triangle/main.rs +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -373,6 +373,7 @@ impl Example { entry_point: "main", constants: &Default::default(), zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, }, cache: None, }) diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index 35b9ea0d0..da3834bcb 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -1714,6 +1714,8 @@ pub struct ProgrammableStage<'a, A: Api> { /// This is required by the WebGPU spec, but may have overhead which can be avoided /// for cross-platform applications pub zero_initialize_workgroup_memory: bool, + /// Should the pipeline attempt to transform vertex shaders to use vertex pulling. + pub vertex_pulling_transform: bool, } // Rust gets confused about the impl requirements for `A` @@ -1724,6 +1726,7 @@ impl Clone for ProgrammableStage<'_, A> { entry_point: self.entry_point, constants: self.constants, zero_initialize_workgroup_memory: self.zero_initialize_workgroup_memory, + vertex_pulling_transform: self.vertex_pulling_transform, } } } diff --git a/wgpu-hal/src/metal/command.rs b/wgpu-hal/src/metal/command.rs index 341712c32..fb9c7e9c0 100644 --- a/wgpu-hal/src/metal/command.rs +++ b/wgpu-hal/src/metal/command.rs @@ -16,6 +16,7 @@ impl Default for super::CommandState { raw_wg_size: metal::MTLSize::new(0, 0, 0), stage_infos: Default::default(), storage_buffer_length_map: Default::default(), + vertex_buffer_size_map: Default::default(), work_group_memory_sizes: Vec::new(), push_constants: Vec::new(), pending_timer_queries: Vec::new(), @@ -137,6 +138,7 @@ impl super::CommandEncoder { impl super::CommandState { fn reset(&mut self) { self.storage_buffer_length_map.clear(); + self.vertex_buffer_size_map.clear(); self.stage_infos.vs.clear(); self.stage_infos.fs.clear(); self.stage_infos.cs.clear(); @@ -160,6 +162,15 @@ impl super::CommandState { .unwrap_or_default() })); + // Extend with the sizes of the mapped vertex buffers, in the order + // they were added to the map. + result_sizes.extend(stage_info.vertex_buffer_mappings.iter().map(|vbm| { + self.vertex_buffer_size_map + .get(&(vbm.id as u64)) + .map(|size| u32::try_from(size.get()).unwrap_or(u32::MAX)) + .unwrap_or_default() + })); + if !result_sizes.is_empty() { Some((slot as _, result_sizes)) } else { @@ -927,6 +938,27 @@ impl crate::CommandEncoder for super::CommandEncoder { let buffer_index = self.shared.private_caps.max_vertex_buffers as u64 - 1 - index as u64; let encoder = self.state.render.as_ref().unwrap(); encoder.set_vertex_buffer(buffer_index, Some(&binding.buffer.raw), binding.offset); + + let buffer_size = binding.resolve_size(); + if buffer_size > 0 { + self.state.vertex_buffer_size_map.insert( + buffer_index, + std::num::NonZeroU64::new(buffer_size).unwrap(), + ); + } else { + self.state.vertex_buffer_size_map.remove(&buffer_index); + } + + if let Some((index, sizes)) = self + .state + .make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes) + { + encoder.set_vertex_bytes( + index as _, + (sizes.len() * WORD_SIZE) as u64, + sizes.as_ptr() as _, + ); + } } unsafe fn set_viewport(&mut self, rect: &crate::Rect, depth_range: Range) { diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 81ab5dbdb..77ea8a0d8 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -59,10 +59,48 @@ fn create_depth_stencil_desc(state: &wgt::DepthStencilState) -> metal::DepthSten desc } +const fn convert_vertex_format_to_naga(format: wgt::VertexFormat) -> naga::back::msl::VertexFormat { + match format { + wgt::VertexFormat::Uint8x2 => naga::back::msl::VertexFormat::Uint8x2, + wgt::VertexFormat::Uint8x4 => naga::back::msl::VertexFormat::Uint8x4, + wgt::VertexFormat::Sint8x2 => naga::back::msl::VertexFormat::Sint8x2, + wgt::VertexFormat::Sint8x4 => naga::back::msl::VertexFormat::Sint8x4, + wgt::VertexFormat::Unorm8x2 => naga::back::msl::VertexFormat::Unorm8x2, + wgt::VertexFormat::Unorm8x4 => naga::back::msl::VertexFormat::Unorm8x4, + wgt::VertexFormat::Snorm8x2 => naga::back::msl::VertexFormat::Snorm8x2, + wgt::VertexFormat::Snorm8x4 => naga::back::msl::VertexFormat::Snorm8x4, + wgt::VertexFormat::Uint16x2 => naga::back::msl::VertexFormat::Uint16x2, + wgt::VertexFormat::Uint16x4 => naga::back::msl::VertexFormat::Uint16x4, + wgt::VertexFormat::Sint16x2 => naga::back::msl::VertexFormat::Sint16x2, + wgt::VertexFormat::Sint16x4 => naga::back::msl::VertexFormat::Sint16x4, + wgt::VertexFormat::Unorm16x2 => naga::back::msl::VertexFormat::Unorm16x2, + wgt::VertexFormat::Unorm16x4 => naga::back::msl::VertexFormat::Unorm16x4, + wgt::VertexFormat::Snorm16x2 => naga::back::msl::VertexFormat::Snorm16x2, + wgt::VertexFormat::Snorm16x4 => naga::back::msl::VertexFormat::Snorm16x4, + wgt::VertexFormat::Float16x2 => naga::back::msl::VertexFormat::Float16x2, + wgt::VertexFormat::Float16x4 => naga::back::msl::VertexFormat::Float16x4, + wgt::VertexFormat::Float32 => naga::back::msl::VertexFormat::Float32, + wgt::VertexFormat::Float32x2 => naga::back::msl::VertexFormat::Float32x2, + wgt::VertexFormat::Float32x3 => naga::back::msl::VertexFormat::Float32x3, + wgt::VertexFormat::Float32x4 => naga::back::msl::VertexFormat::Float32x4, + wgt::VertexFormat::Uint32 => naga::back::msl::VertexFormat::Uint32, + wgt::VertexFormat::Uint32x2 => naga::back::msl::VertexFormat::Uint32x2, + wgt::VertexFormat::Uint32x3 => naga::back::msl::VertexFormat::Uint32x3, + wgt::VertexFormat::Uint32x4 => naga::back::msl::VertexFormat::Uint32x4, + wgt::VertexFormat::Sint32 => naga::back::msl::VertexFormat::Sint32, + wgt::VertexFormat::Sint32x2 => naga::back::msl::VertexFormat::Sint32x2, + wgt::VertexFormat::Sint32x3 => naga::back::msl::VertexFormat::Sint32x3, + wgt::VertexFormat::Sint32x4 => naga::back::msl::VertexFormat::Sint32x4, + wgt::VertexFormat::Unorm10_10_10_2 => naga::back::msl::VertexFormat::Unorm10_10_10_2, + _ => unimplemented!(), + } +} + impl super::Device { fn load_shader( &self, stage: &crate::ProgrammableStage, + vertex_buffer_mappings: &[naga::back::msl::VertexBufferMapping], layout: &super::PipelineLayout, primitive_class: metal::MTLPrimitiveTopologyClass, naga_stage: naga::ShaderStage, @@ -120,6 +158,8 @@ impl super::Device { metal::MTLPrimitiveTopologyClass::Point => true, _ => false, }, + vertex_pulling_transform: stage.vertex_pulling_transform, + vertex_buffer_mappings: vertex_buffer_mappings.to_vec(), }; let (source, info) = @@ -548,7 +588,7 @@ impl crate::Device for super::Device { pc_buffer: Option, pc_limit: u32, sizes_buffer: Option, - sizes_count: u8, + need_sizes_buffer: bool, resources: naga::back::msl::BindingMap, } @@ -558,7 +598,7 @@ impl crate::Device for super::Device { pc_buffer: None, pc_limit: 0, sizes_buffer: None, - sizes_count: 0, + need_sizes_buffer: false, resources: Default::default(), }); let mut bind_group_infos = arrayvec::ArrayVec::new(); @@ -603,7 +643,7 @@ impl crate::Device for super::Device { { for info in stage_data.iter_mut() { if entry.visibility.contains(map_naga_stage(info.stage)) { - info.sizes_count += 1; + info.need_sizes_buffer = true; } } } @@ -661,11 +701,13 @@ impl crate::Device for super::Device { // Finally, make sure we fit the limits for info in stage_data.iter_mut() { - // handle the sizes buffer assignment and shader overrides - if info.sizes_count != 0 { + if info.need_sizes_buffer || info.stage == naga::ShaderStage::Vertex { + // Set aside space for the sizes_buffer, which is required + // for variable-length buffers, or to support vertex pulling. info.sizes_buffer = Some(info.counters.buffers); info.counters.buffers += 1; } + if info.counters.buffers > self.shared.private_caps.max_buffers_per_stage || info.counters.textures > self.shared.private_caps.max_textures_per_stage || info.counters.samplers > self.shared.private_caps.max_samplers_per_stage @@ -832,8 +874,38 @@ impl crate::Device for super::Device { // Vertex shader let (vs_lib, vs_info) = { + let mut vertex_buffer_mappings = Vec::::new(); + for (i, vbl) in desc.vertex_buffers.iter().enumerate() { + let mut attributes = Vec::::new(); + for attribute in vbl.attributes.iter() { + attributes.push(naga::back::msl::AttributeMapping { + shader_location: attribute.shader_location, + offset: attribute.offset as u32, + format: convert_vertex_format_to_naga(attribute.format), + }); + } + + vertex_buffer_mappings.push(naga::back::msl::VertexBufferMapping { + id: self.shared.private_caps.max_vertex_buffers - 1 - i as u32, + stride: if vbl.array_stride > 0 { + vbl.array_stride.try_into().unwrap() + } else { + vbl.attributes + .iter() + .map(|attribute| attribute.offset + attribute.format.size()) + .max() + .unwrap_or(0) + .try_into() + .unwrap() + }, + indexed_by_vertex: (vbl.step_mode == wgt::VertexStepMode::Vertex {}), + attributes, + }); + } + let vs = self.load_shader( &desc.vertex_stage, + &vertex_buffer_mappings, desc.layout, primitive_class, naga::ShaderStage::Vertex, @@ -851,6 +923,7 @@ impl crate::Device for super::Device { push_constants: desc.layout.push_constants_infos.vs, sizes_slot: desc.layout.per_stage_map.vs.sizes_buffer, sized_bindings: vs.sized_bindings, + vertex_buffer_mappings, }; (vs.library, info) @@ -861,6 +934,7 @@ impl crate::Device for super::Device { Some(ref stage) => { let fs = self.load_shader( stage, + &[], desc.layout, primitive_class, naga::ShaderStage::Fragment, @@ -878,6 +952,7 @@ impl crate::Device for super::Device { push_constants: desc.layout.push_constants_infos.fs, sizes_slot: desc.layout.per_stage_map.fs.sizes_buffer, sized_bindings: fs.sized_bindings, + vertex_buffer_mappings: vec![], }; (Some(fs.library), Some(info)) @@ -1053,6 +1128,7 @@ impl crate::Device for super::Device { let cs = self.load_shader( &desc.stage, + &[], desc.layout, metal::MTLPrimitiveTopologyClass::Unspecified, naga::ShaderStage::Compute, @@ -1070,6 +1146,7 @@ impl crate::Device for super::Device { push_constants: desc.layout.push_constants_infos.cs, sizes_slot: desc.layout.per_stage_map.cs.sizes_buffer, sized_bindings: cs.sized_bindings, + vertex_buffer_mappings: vec![], }; if let Some(name) = desc.label { diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index a5ea63b03..ce8e01592 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -466,6 +466,15 @@ impl Buffer { } } +impl crate::BufferBinding<'_, Api> { + fn resolve_size(&self) -> wgt::BufferAddress { + match self.size { + Some(size) => size.get(), + None => self.buffer.size - self.offset, + } + } +} + #[derive(Debug)] pub struct Texture { raw: metal::Texture, @@ -690,6 +699,9 @@ struct PipelineStageInfo { /// /// See `device::CompiledShader::sized_bindings` for more details. sized_bindings: Vec, + + /// Info on all bound vertex buffers. + vertex_buffer_mappings: Vec, } impl PipelineStageInfo { @@ -697,6 +709,7 @@ impl PipelineStageInfo { self.push_constants = None; self.sizes_slot = None; self.sized_bindings.clear(); + self.vertex_buffer_mappings.clear(); } fn assign_from(&mut self, other: &Self) { @@ -704,6 +717,9 @@ impl PipelineStageInfo { self.sizes_slot = other.sizes_slot; self.sized_bindings.clear(); self.sized_bindings.extend_from_slice(&other.sized_bindings); + self.vertex_buffer_mappings.clear(); + self.vertex_buffer_mappings + .extend_from_slice(&other.vertex_buffer_mappings); } } @@ -821,6 +837,8 @@ struct CommandState { /// [`ResourceBinding`]: naga::ResourceBinding storage_buffer_length_map: rustc_hash::FxHashMap, + vertex_buffer_size_map: rustc_hash::FxHashMap, + work_group_memory_sizes: Vec, push_constants: Vec, diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index 5ed055f2b..d5210900b 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -1189,6 +1189,10 @@ impl crate::Context for ContextWgpuCore { .vertex .compilation_options .zero_initialize_workgroup_memory, + vertex_pulling_transform: desc + .vertex + .compilation_options + .vertex_pulling_transform, }, buffers: Borrowed(&vertex_buffers), }, @@ -1203,6 +1207,7 @@ impl crate::Context for ContextWgpuCore { zero_initialize_workgroup_memory: frag .compilation_options .zero_initialize_workgroup_memory, + vertex_pulling_transform: false, }, targets: Borrowed(frag.targets), }), @@ -1256,6 +1261,7 @@ impl crate::Context for ContextWgpuCore { zero_initialize_workgroup_memory: desc .compilation_options .zero_initialize_workgroup_memory, + vertex_pulling_transform: false, }, cache: desc.cache.map(|c| c.id.into()), }; diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index 00130a99c..e94ae27fe 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -1987,6 +1987,8 @@ pub struct PipelineCompilationOptions<'a> { /// This is required by the WebGPU spec, but may have overhead which can be avoided /// for cross-platform applications pub zero_initialize_workgroup_memory: bool, + /// Should the pipeline attempt to transform vertex shaders to use vertex pulling. + pub vertex_pulling_transform: bool, } impl<'a> Default for PipelineCompilationOptions<'a> { @@ -2000,6 +2002,7 @@ impl<'a> Default for PipelineCompilationOptions<'a> { Self { constants, zero_initialize_workgroup_memory: true, + vertex_pulling_transform: false, } } }