diff --git a/CHANGELOG.md b/CHANGELOG.md index f26392b38..4f67e017b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -72,6 +72,10 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410) ### New features +#### Vulkan + +- Added a `PipelineCache` resource to allow using Vulkan pipeline caches. By @DJMcNab in [#5319](https://github.com/gfx-rs/wgpu/pull/5319) + #### General #### Naga diff --git a/benches/benches/renderpass.rs b/benches/benches/renderpass.rs index 30543839a..fcb35c386 100644 --- a/benches/benches/renderpass.rs +++ b/benches/benches/renderpass.rs @@ -207,6 +207,7 @@ impl RenderpassState { compilation_options: wgpu::PipelineCompilationOptions::default(), }), multiview: None, + cache: None, }); let render_target = device_state @@ -304,6 +305,7 @@ impl RenderpassState { compilation_options: wgpu::PipelineCompilationOptions::default(), }), multiview: None, + cache: None, }, )); } diff --git a/deno_webgpu/pipeline.rs b/deno_webgpu/pipeline.rs index fc3a92bfc..c82b6a97c 100644 --- a/deno_webgpu/pipeline.rs +++ b/deno_webgpu/pipeline.rs @@ -115,6 +115,7 @@ pub fn op_webgpu_create_compute_pipeline( constants: Cow::Owned(compute.constants.unwrap_or_default()), zero_initialize_workgroup_memory: true, }, + cache: None, }; let implicit_pipelines = match layout { GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(_) => None, @@ -395,6 +396,7 @@ pub fn op_webgpu_create_render_pipeline( multisample: args.multisample, fragment, multiview: None, + cache: None, }; let implicit_pipelines = match args.layout { diff --git a/examples/src/boids/mod.rs b/examples/src/boids/mod.rs index 6c8bb6e76..7b1b8f0bc 100644 --- a/examples/src/boids/mod.rs +++ b/examples/src/boids/mod.rs @@ -156,6 +156,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); // create compute pipeline @@ -166,6 +167,7 @@ impl crate::framework::Example for Example { module: &compute_shader, entry_point: "main", compilation_options: Default::default(), + cache: None, }); // buffer for the three 2d triangle vertices of each instance diff --git a/examples/src/bunnymark/mod.rs b/examples/src/bunnymark/mod.rs index 679fc5014..b5b33b54d 100644 --- a/examples/src/bunnymark/mod.rs +++ b/examples/src/bunnymark/mod.rs @@ -224,6 +224,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let texture = { diff --git a/examples/src/conservative_raster/mod.rs b/examples/src/conservative_raster/mod.rs index 89500a798..116ed8623 100644 --- a/examples/src/conservative_raster/mod.rs +++ b/examples/src/conservative_raster/mod.rs @@ -113,6 +113,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let pipeline_triangle_regular = @@ -135,6 +136,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let pipeline_lines = if device @@ -165,6 +167,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }), ) } else { @@ -224,6 +227,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }), bind_group_layout, ) diff --git a/examples/src/cube/mod.rs b/examples/src/cube/mod.rs index 934762781..9828157e5 100644 --- a/examples/src/cube/mod.rs +++ b/examples/src/cube/mod.rs @@ -260,6 +260,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let pipeline_wire = if device @@ -301,6 +302,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); Some(pipeline_wire) } else { diff --git a/examples/src/hello_compute/mod.rs b/examples/src/hello_compute/mod.rs index d04aaa430..cdd6d439d 100644 --- a/examples/src/hello_compute/mod.rs +++ b/examples/src/hello_compute/mod.rs @@ -110,6 +110,7 @@ async fn execute_gpu_inner( module: &cs_module, entry_point: "main", compilation_options: Default::default(), + cache: None, }); // Instantiates the bind group, once again specifying the binding of buffers. diff --git a/examples/src/hello_synchronization/mod.rs b/examples/src/hello_synchronization/mod.rs index 0a222fbe5..9b6675289 100644 --- a/examples/src/hello_synchronization/mod.rs +++ b/examples/src/hello_synchronization/mod.rs @@ -104,6 +104,7 @@ async fn execute( module: &shaders_module, entry_point: "patient_main", compilation_options: Default::default(), + cache: None, }); let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { label: None, @@ -111,6 +112,7 @@ async fn execute( module: &shaders_module, entry_point: "hasty_main", compilation_options: Default::default(), + cache: None, }); //---------------------------------------------------------- diff --git a/examples/src/hello_triangle/mod.rs b/examples/src/hello_triangle/mod.rs index 79162a695..e4d42674f 100644 --- a/examples/src/hello_triangle/mod.rs +++ b/examples/src/hello_triangle/mod.rs @@ -72,6 +72,7 @@ async fn run(event_loop: EventLoop<()>, window: Window) { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let mut config = surface diff --git a/examples/src/hello_workgroups/mod.rs b/examples/src/hello_workgroups/mod.rs index 572de36d3..0416451da 100644 --- a/examples/src/hello_workgroups/mod.rs +++ b/examples/src/hello_workgroups/mod.rs @@ -111,6 +111,7 @@ async fn run() { module: &shader, entry_point: "main", compilation_options: Default::default(), + cache: None, }); //---------------------------------------------------------- diff --git a/examples/src/mipmap/mod.rs b/examples/src/mipmap/mod.rs index 0848e94e1..eaed9c82e 100644 --- a/examples/src/mipmap/mod.rs +++ b/examples/src/mipmap/mod.rs @@ -109,6 +109,7 @@ impl Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let bind_group_layout = pipeline.get_bind_group_layout(0); @@ -310,6 +311,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); // Create bind group diff --git a/examples/src/msaa_line/mod.rs b/examples/src/msaa_line/mod.rs index cd22e75bc..46bb743e9 100644 --- a/examples/src/msaa_line/mod.rs +++ b/examples/src/msaa_line/mod.rs @@ -78,6 +78,7 @@ impl Example { ..Default::default() }, multiview: None, + cache: None, }); let mut encoder = device.create_render_bundle_encoder(&wgpu::RenderBundleEncoderDescriptor { diff --git a/examples/src/render_to_texture/mod.rs b/examples/src/render_to_texture/mod.rs index 5e571dc74..caed73674 100644 --- a/examples/src/render_to_texture/mod.rs +++ b/examples/src/render_to_texture/mod.rs @@ -72,6 +72,7 @@ async fn run(_path: Option) { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); log::info!("Wgpu context set up."); diff --git a/examples/src/repeated_compute/mod.rs b/examples/src/repeated_compute/mod.rs index 55e87eed9..72b615251 100644 --- a/examples/src/repeated_compute/mod.rs +++ b/examples/src/repeated_compute/mod.rs @@ -246,6 +246,7 @@ impl WgpuContext { module: &shader, entry_point: "main", compilation_options: Default::default(), + cache: None, }); WgpuContext { diff --git a/examples/src/shadow/mod.rs b/examples/src/shadow/mod.rs index 2cb6d6f3e..b2c27f589 100644 --- a/examples/src/shadow/mod.rs +++ b/examples/src/shadow/mod.rs @@ -526,6 +526,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); Pass { @@ -660,6 +661,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); Pass { diff --git a/examples/src/skybox/mod.rs b/examples/src/skybox/mod.rs index 35a4266d2..e526feeda 100644 --- a/examples/src/skybox/mod.rs +++ b/examples/src/skybox/mod.rs @@ -221,6 +221,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let entity_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { label: Some("Entity"), @@ -254,6 +255,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let sampler = device.create_sampler(&wgpu::SamplerDescriptor { diff --git a/examples/src/srgb_blend/mod.rs b/examples/src/srgb_blend/mod.rs index f701aff98..314fc92df 100644 --- a/examples/src/srgb_blend/mod.rs +++ b/examples/src/srgb_blend/mod.rs @@ -151,6 +151,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); // Done diff --git a/examples/src/stencil_triangles/mod.rs b/examples/src/stencil_triangles/mod.rs index e0f495177..8d638d20d 100644 --- a/examples/src/stencil_triangles/mod.rs +++ b/examples/src/stencil_triangles/mod.rs @@ -106,6 +106,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let outer_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { @@ -141,6 +142,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let stencil_buffer = device.create_texture(&wgpu::TextureDescriptor { diff --git a/examples/src/storage_texture/mod.rs b/examples/src/storage_texture/mod.rs index 02900c891..04253e818 100644 --- a/examples/src/storage_texture/mod.rs +++ b/examples/src/storage_texture/mod.rs @@ -101,6 +101,7 @@ async fn run(_path: Option) { module: &shader, entry_point: "main", compilation_options: Default::default(), + cache: None, }); log::info!("Wgpu context set up."); diff --git a/examples/src/texture_arrays/mod.rs b/examples/src/texture_arrays/mod.rs index dd7b4ec89..b0f474b95 100644 --- a/examples/src/texture_arrays/mod.rs +++ b/examples/src/texture_arrays/mod.rs @@ -341,6 +341,7 @@ impl crate::framework::Example for Example { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None }); Self { diff --git a/examples/src/timestamp_queries/mod.rs b/examples/src/timestamp_queries/mod.rs index 7042d60fe..703bafe49 100644 --- a/examples/src/timestamp_queries/mod.rs +++ b/examples/src/timestamp_queries/mod.rs @@ -299,6 +299,7 @@ fn compute_pass( module, entry_point: "main_cs", compilation_options: Default::default(), + cache: None, }); let bind_group_layout = compute_pipeline.get_bind_group_layout(0); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { @@ -366,8 +367,8 @@ fn render_pass( depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); - let render_target = device.create_texture(&wgpu::TextureDescriptor { label: Some("rendertarget"), size: wgpu::Extent3d { diff --git a/examples/src/uniform_values/mod.rs b/examples/src/uniform_values/mod.rs index 932c7aaee..c53a18972 100644 --- a/examples/src/uniform_values/mod.rs +++ b/examples/src/uniform_values/mod.rs @@ -192,8 +192,8 @@ impl WgpuContext { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); - let surface_config = surface .get_default_config(&adapter, size.width, size.height) .unwrap(); diff --git a/examples/src/water/mod.rs b/examples/src/water/mod.rs index 94f12895a..b21ec70c4 100644 --- a/examples/src/water/mod.rs +++ b/examples/src/water/mod.rs @@ -574,6 +574,8 @@ impl crate::framework::Example for Example { // No multisampling is used. multisample: wgpu::MultisampleState::default(), multiview: None, + // No pipeline caching is used + cache: None, }); // Same idea as the water pipeline. @@ -610,6 +612,7 @@ impl crate::framework::Example for Example { }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None }); // A render bundle to draw the terrain. diff --git a/player/src/lib.rs b/player/src/lib.rs index c67c605e5..930fef151 100644 --- a/player/src/lib.rs +++ b/player/src/lib.rs @@ -302,6 +302,12 @@ impl GlobalPlay for wgc::global::Global { Action::DestroyRenderPipeline(id) => { self.render_pipeline_drop::(id); } + Action::CreatePipelineCache { id, desc } => { + let _ = unsafe { self.device_create_pipeline_cache::(device, &desc, Some(id)) }; + } + Action::DestroyPipelineCache(id) => { + self.pipeline_cache_drop::(id); + } Action::CreateRenderBundle { id, desc, base } => { let bundle = wgc::command::RenderBundleEncoder::new(&desc, device, Some(base)).unwrap(); diff --git a/tests/src/image.rs b/tests/src/image.rs index 8996f361c..19bbc1a91 100644 --- a/tests/src/image.rs +++ b/tests/src/image.rs @@ -370,6 +370,7 @@ fn copy_via_compute( module: &sm, entry_point: "copy_texture_to_buffer", compilation_options: Default::default(), + cache: None, }); { diff --git a/tests/tests/bgra8unorm_storage.rs b/tests/tests/bgra8unorm_storage.rs index 17082a9ed..7bc117f09 100644 --- a/tests/tests/bgra8unorm_storage.rs +++ b/tests/tests/bgra8unorm_storage.rs @@ -98,6 +98,7 @@ static BGRA8_UNORM_STORAGE: GpuTestConfiguration = GpuTestConfiguration::new() entry_point: "main", compilation_options: Default::default(), module: &module, + cache: None, }); let mut encoder = diff --git a/tests/tests/bind_group_layout_dedup.rs b/tests/tests/bind_group_layout_dedup.rs index 3466e1e24..3d74e62cb 100644 --- a/tests/tests/bind_group_layout_dedup.rs +++ b/tests/tests/bind_group_layout_dedup.rs @@ -91,6 +91,7 @@ async fn bgl_dedupe(ctx: TestingContext) { module: &module, entry_point: "no_resources", compilation_options: Default::default(), + cache: None, }; let pipeline = ctx.device.create_compute_pipeline(&desc); @@ -220,6 +221,7 @@ fn bgl_dedupe_with_dropped_user_handle(ctx: TestingContext) { module: &module, entry_point: "no_resources", compilation_options: Default::default(), + cache: None, }); let mut encoder = ctx.device.create_command_encoder(&Default::default()); @@ -266,6 +268,7 @@ fn bgl_dedupe_derived(ctx: TestingContext) { module: &module, entry_point: "resources", compilation_options: Default::default(), + cache: None, }); // We create two bind groups, pulling the bind_group_layout from the pipeline each time. @@ -337,6 +340,7 @@ fn separate_programs_have_incompatible_derived_bgls(ctx: TestingContext) { module: &module, entry_point: "resources", compilation_options: Default::default(), + cache: None, }; // Create two pipelines, creating a BG from the second. let pipeline1 = ctx.device.create_compute_pipeline(&desc); @@ -399,6 +403,7 @@ fn derived_bgls_incompatible_with_regular_bgls(ctx: TestingContext) { module: &module, entry_point: "resources", compilation_options: Default::default(), + cache: None, }); // Create a matching BGL diff --git a/tests/tests/buffer.rs b/tests/tests/buffer.rs index 0693877d0..241026731 100644 --- a/tests/tests/buffer.rs +++ b/tests/tests/buffer.rs @@ -225,6 +225,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_LAYOUT: GpuTestConfiguration = GpuTestConfigu module: &shader_module, entry_point: "main", compilation_options: Default::default(), + cache: None, }); }); }); @@ -294,6 +295,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_DISPATCH: GpuTestConfiguration = GpuTestConfi module: &shader_module, entry_point: "main", compilation_options: Default::default(), + cache: None, }); let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { diff --git a/tests/tests/compute_pass_resource_ownership.rs b/tests/tests/compute_pass_resource_ownership.rs index 6612ad006..4d48c2ad9 100644 --- a/tests/tests/compute_pass_resource_ownership.rs +++ b/tests/tests/compute_pass_resource_ownership.rs @@ -161,6 +161,7 @@ fn resource_setup(ctx: &TestingContext) -> ResourceSetup { module: &sm, entry_point: "main", compilation_options: Default::default(), + cache: None, }); ResourceSetup { diff --git a/tests/tests/device.rs b/tests/tests/device.rs index 649a850fa..3e7829329 100644 --- a/tests/tests/device.rs +++ b/tests/tests/device.rs @@ -488,6 +488,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne multisample: wgpu::MultisampleState::default(), fragment: None, multiview: None, + cache: None, }); }); @@ -500,6 +501,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne module: &shader_module, entry_point: "", compilation_options: Default::default(), + cache: None, }); }); @@ -757,6 +759,7 @@ fn vs_main() -> @builtin(position) vec4 { depth_stencil: None, multisample: wgt::MultisampleState::default(), multiview: None, + cache: None }); // fail(&ctx.device, || { diff --git a/tests/tests/mem_leaks.rs b/tests/tests/mem_leaks.rs index 7002ebabe..3c59aec03 100644 --- a/tests/tests/mem_leaks.rs +++ b/tests/tests/mem_leaks.rs @@ -113,6 +113,7 @@ async fn draw_test_with_reports( })], }), multiview: None, + cache: None, }); let global_report = ctx.instance.generate_report().unwrap(); diff --git a/tests/tests/nv12_texture/mod.rs b/tests/tests/nv12_texture/mod.rs index 70ee84983..fa386f865 100644 --- a/tests/tests/nv12_texture/mod.rs +++ b/tests/tests/nv12_texture/mod.rs @@ -41,6 +41,7 @@ static NV12_TEXTURE_CREATION_SAMPLING: GpuTestConfiguration = GpuTestConfigurati depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let tex = ctx.device.create_texture(&wgpu::TextureDescriptor { diff --git a/tests/tests/occlusion_query/mod.rs b/tests/tests/occlusion_query/mod.rs index 1a68ecf79..a888320e2 100644 --- a/tests/tests/occlusion_query/mod.rs +++ b/tests/tests/occlusion_query/mod.rs @@ -51,6 +51,7 @@ static OCCLUSION_QUERY: GpuTestConfiguration = GpuTestConfiguration::new() }), multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); // Create occlusion query set diff --git a/tests/tests/partially_bounded_arrays/mod.rs b/tests/tests/partially_bounded_arrays/mod.rs index 11eee5b20..83f9cee38 100644 --- a/tests/tests/partially_bounded_arrays/mod.rs +++ b/tests/tests/partially_bounded_arrays/mod.rs @@ -70,6 +70,7 @@ static PARTIALLY_BOUNDED_ARRAY: GpuTestConfiguration = GpuTestConfiguration::new module: &cs_module, entry_point: "main", compilation_options: Default::default(), + cache: None, }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { diff --git a/tests/tests/pipeline.rs b/tests/tests/pipeline.rs index a07e158a5..0d725b8f4 100644 --- a/tests/tests/pipeline.rs +++ b/tests/tests/pipeline.rs @@ -29,6 +29,7 @@ static PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration = GpuTestConfigu module: &module, entry_point: "doesn't exist", compilation_options: Default::default(), + cache: None, }); pipeline.get_bind_group_layout(0); diff --git a/tests/tests/pipeline_cache.rs b/tests/tests/pipeline_cache.rs new file mode 100644 index 000000000..58dae4694 --- /dev/null +++ b/tests/tests/pipeline_cache.rs @@ -0,0 +1,192 @@ +use std::{fmt::Write, num::NonZeroU64}; + +use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext}; + +/// We want to test that using a pipeline cache doesn't cause failure +/// +/// It would be nice if we could also assert that reusing a pipeline cache would make compilation +/// be faster however, some drivers use a fallback pipeline cache, which makes this inconsistent +/// (both intra- and inter-run). +#[gpu_test] +static PIPELINE_CACHE: GpuTestConfiguration = GpuTestConfiguration::new() + .parameters( + TestParameters::default() + .test_features_limits() + .features(wgpu::Features::PIPELINE_CACHE), + ) + .run_async(pipeline_cache_test); + +/// Set to a higher value if adding a timing based assertion. This is otherwise fast to compile +const ARRAY_SIZE: u64 = 256; + +/// Create a shader which should be slow-ish to compile +fn shader() -> String { + let mut body = String::new(); + for idx in 0..ARRAY_SIZE { + // "Safety": There will only be a single workgroup, and a single thread in that workgroup + writeln!(body, " output[{idx}] = {idx}u;") + .expect("`u64::fmt` and `String::write_fmt` are infallible"); + } + + format!( + r#" + @group(0) @binding(0) + var output: array; + + @compute @workgroup_size(1) + fn main() {{ + {body} + }} + "#, + ) +} + +async fn pipeline_cache_test(ctx: TestingContext) { + let shader = shader(); + let sm = ctx + .device + .create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("shader"), + source: wgpu::ShaderSource::Wgsl(shader.into()), + }); + + let bgl = ctx + .device + .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: Some("bind_group_layout"), + entries: &[wgpu::BindGroupLayoutEntry { + binding: 0, + visibility: wgpu::ShaderStages::COMPUTE, + ty: wgpu::BindingType::Buffer { + ty: wgpu::BufferBindingType::Storage { read_only: false }, + has_dynamic_offset: false, + min_binding_size: NonZeroU64::new(ARRAY_SIZE * 4), + }, + count: None, + }], + }); + + let gpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("gpu_buffer"), + size: ARRAY_SIZE * 4, + usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC, + mapped_at_creation: false, + }); + + let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { + label: Some("cpu_buffer"), + size: ARRAY_SIZE * 4, + usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ, + mapped_at_creation: false, + }); + + let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { + label: Some("bind_group"), + layout: &bgl, + entries: &[wgpu::BindGroupEntry { + binding: 0, + resource: gpu_buffer.as_entire_binding(), + }], + }); + + let pipeline_layout = ctx + .device + .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor { + label: Some("pipeline_layout"), + bind_group_layouts: &[&bgl], + push_constant_ranges: &[], + }); + + let first_cache_data; + { + let first_cache = unsafe { + ctx.device + .create_pipeline_cache(&wgpu::PipelineCacheDescriptor { + label: Some("pipeline_cache"), + data: None, + fallback: false, + }) + }; + let first_pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("pipeline"), + layout: Some(&pipeline_layout), + module: &sm, + entry_point: "main", + compilation_options: Default::default(), + cache: Some(&first_cache), + }); + validate_pipeline(&ctx, first_pipeline, &bind_group, &gpu_buffer, &cpu_buffer).await; + first_cache_data = first_cache.get_data(); + } + assert!(first_cache_data.is_some()); + + let second_cache = unsafe { + ctx.device + .create_pipeline_cache(&wgpu::PipelineCacheDescriptor { + label: Some("pipeline_cache"), + data: first_cache_data.as_deref(), + fallback: false, + }) + }; + let first_pipeline = ctx + .device + .create_compute_pipeline(&wgpu::ComputePipelineDescriptor { + label: Some("pipeline"), + layout: Some(&pipeline_layout), + module: &sm, + entry_point: "main", + compilation_options: Default::default(), + cache: Some(&second_cache), + }); + validate_pipeline(&ctx, first_pipeline, &bind_group, &gpu_buffer, &cpu_buffer).await; + + // Ideally, we could assert here that the second compilation was faster than the first + // However, that doesn't actually work, because drivers have their own internal caches. + // This does work on my machine if I set `MESA_DISABLE_PIPELINE_CACHE=1` + // before running the test; but of course that is not a realistic scenario +} + +async fn validate_pipeline( + ctx: &TestingContext, + pipeline: wgpu::ComputePipeline, + bind_group: &wgpu::BindGroup, + gpu_buffer: &wgpu::Buffer, + cpu_buffer: &wgpu::Buffer, +) { + let mut encoder = ctx + .device + .create_command_encoder(&wgpu::CommandEncoderDescriptor { + label: Some("encoder"), + }); + + { + let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor { + label: Some("compute_pass"), + timestamp_writes: None, + }); + cpass.set_pipeline(&pipeline); + cpass.set_bind_group(0, bind_group, &[]); + + cpass.dispatch_workgroups(1, 1, 1); + } + + encoder.copy_buffer_to_buffer(gpu_buffer, 0, cpu_buffer, 0, ARRAY_SIZE * 4); + ctx.queue.submit([encoder.finish()]); + cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ()); + ctx.async_poll(wgpu::Maintain::wait()) + .await + .panic_on_timeout(); + + let data = cpu_buffer.slice(..).get_mapped_range(); + + let arrays: &[u32] = bytemuck::cast_slice(&data); + + assert_eq!(arrays.len(), ARRAY_SIZE as usize); + for (idx, value) in arrays.iter().copied().enumerate() { + assert_eq!(value as usize, idx); + } + drop(data); + cpu_buffer.unmap(); +} diff --git a/tests/tests/push_constants.rs b/tests/tests/push_constants.rs index 04d9a00f7..a18207bef 100644 --- a/tests/tests/push_constants.rs +++ b/tests/tests/push_constants.rs @@ -104,6 +104,7 @@ async fn partial_update_test(ctx: TestingContext) { module: &sm, entry_point: "main", compilation_options: Default::default(), + cache: None, }); let mut encoder = ctx diff --git a/tests/tests/regression/issue_3349.rs b/tests/tests/regression/issue_3349.rs index 74c466b45..35d35e5bd 100644 --- a/tests/tests/regression/issue_3349.rs +++ b/tests/tests/regression/issue_3349.rs @@ -119,6 +119,7 @@ async fn multi_stage_data_binding_test(ctx: TestingContext) { depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let texture = ctx.device.create_texture(&wgpu::TextureDescriptor { diff --git a/tests/tests/regression/issue_3457.rs b/tests/tests/regression/issue_3457.rs index f18d681ae..f0f7e6463 100644 --- a/tests/tests/regression/issue_3457.rs +++ b/tests/tests/regression/issue_3457.rs @@ -80,6 +80,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = })], }), multiview: None, + cache: None, }); let single_pipeline = ctx @@ -111,6 +112,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration = })], }), multiview: None, + cache: None, }); let view = ctx diff --git a/tests/tests/root.rs b/tests/tests/root.rs index ba5e02079..29f894ede 100644 --- a/tests/tests/root.rs +++ b/tests/tests/root.rs @@ -24,6 +24,7 @@ mod nv12_texture; mod occlusion_query; mod partially_bounded_arrays; mod pipeline; +mod pipeline_cache; mod poll; mod push_constants; mod query_set; diff --git a/tests/tests/scissor_tests/mod.rs b/tests/tests/scissor_tests/mod.rs index 15c35644e..3f1e7df13 100644 --- a/tests/tests/scissor_tests/mod.rs +++ b/tests/tests/scissor_tests/mod.rs @@ -61,6 +61,7 @@ async fn scissor_test_impl( })], }), multiview: None, + cache: None, }); let readback_buffer = image::ReadbackBuffers::new(&ctx.device, &texture); diff --git a/tests/tests/shader/mod.rs b/tests/tests/shader/mod.rs index 248b9c23e..2716caabd 100644 --- a/tests/tests/shader/mod.rs +++ b/tests/tests/shader/mod.rs @@ -310,6 +310,7 @@ async fn shader_input_output_test( module: &sm, entry_point: "cs_main", compilation_options: Default::default(), + cache: None, }); // -- Initializing data -- diff --git a/tests/tests/shader/zero_init_workgroup_mem.rs b/tests/tests/shader/zero_init_workgroup_mem.rs index cb9f341ee..0dcb81959 100644 --- a/tests/tests/shader/zero_init_workgroup_mem.rs +++ b/tests/tests/shader/zero_init_workgroup_mem.rs @@ -88,6 +88,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: module: &sm, entry_point: "read", compilation_options: Default::default(), + cache: None, }); let pipeline_write = ctx @@ -98,6 +99,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration:: module: &sm, entry_point: "write", compilation_options: Default::default(), + cache: None, }); // -- Initializing data -- diff --git a/tests/tests/shader_primitive_index/mod.rs b/tests/tests/shader_primitive_index/mod.rs index fb4339783..9972f81aa 100644 --- a/tests/tests/shader_primitive_index/mod.rs +++ b/tests/tests/shader_primitive_index/mod.rs @@ -147,6 +147,7 @@ async fn pulling_common( })], }), multiview: None, + cache: None, }); let width = 2; diff --git a/tests/tests/shader_view_format/mod.rs b/tests/tests/shader_view_format/mod.rs index 53c642bf7..d34b8d851 100644 --- a/tests/tests/shader_view_format/mod.rs +++ b/tests/tests/shader_view_format/mod.rs @@ -109,6 +109,7 @@ async fn reinterpret( depth_stencil: None, multisample: wgpu::MultisampleState::default(), multiview: None, + cache: None, }); let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { layout: &pipeline.get_bind_group_layout(0), diff --git a/tests/tests/subgroup_operations/mod.rs b/tests/tests/subgroup_operations/mod.rs index 2c518a9d9..7d0aec824 100644 --- a/tests/tests/subgroup_operations/mod.rs +++ b/tests/tests/subgroup_operations/mod.rs @@ -75,6 +75,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new() module: &cs_module, entry_point: "main", compilation_options: Default::default(), + cache: None, }); let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { diff --git a/tests/tests/vertex_indices/mod.rs b/tests/tests/vertex_indices/mod.rs index cad7e731d..7bd172d85 100644 --- a/tests/tests/vertex_indices/mod.rs +++ b/tests/tests/vertex_indices/mod.rs @@ -295,6 +295,7 @@ async fn vertex_index_common(ctx: TestingContext) { })], }), multiview: None, + cache: None, }; let builtin_pipeline = ctx.device.create_render_pipeline(&pipeline_desc); pipeline_desc.vertex.entry_point = "vs_main_buffers"; diff --git a/wgpu-core/src/device/global.rs b/wgpu-core/src/device/global.rs index 660f90279..a5c51b269 100644 --- a/wgpu-core/src/device/global.rs +++ b/wgpu-core/src/device/global.rs @@ -13,8 +13,10 @@ use crate::{ instance::{self, Adapter, Surface}, lock::{rank, RwLock}, pipeline, present, - resource::{self, BufferAccessResult}, - resource::{BufferAccessError, BufferMapOperation, CreateBufferError, Resource}, + resource::{ + self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError, + Resource, + }, validation::check_buffer_usage, Label, LabelHelpers as _, }; @@ -1823,6 +1825,66 @@ impl Global { } } + /// # Safety + /// The `data` argument of `desc` must have been returned by + /// [Self::pipeline_cache_get_data] for the same adapter + pub unsafe fn device_create_pipeline_cache( + &self, + device_id: DeviceId, + desc: &pipeline::PipelineCacheDescriptor<'_>, + id_in: Option, + ) -> ( + id::PipelineCacheId, + Option, + ) { + profiling::scope!("Device::create_pipeline_cache"); + + let hub = A::hub(self); + + let fid = hub.pipeline_caches.prepare(id_in); + let error: pipeline::CreatePipelineCacheError = 'error: { + let device = match hub.devices.get(device_id) { + Ok(device) => device, + // TODO: Handle error properly + Err(crate::storage::InvalidId) => break 'error DeviceError::Invalid.into(), + }; + if !device.is_valid() { + break 'error DeviceError::Lost.into(); + } + #[cfg(feature = "trace")] + if let Some(ref mut trace) = *device.trace.lock() { + trace.add(trace::Action::CreatePipelineCache { + id: fid.id(), + desc: desc.clone(), + }); + } + let cache = unsafe { device.create_pipeline_cache(desc) }; + match cache { + Ok(cache) => { + let (id, _) = fid.assign(Arc::new(cache)); + api_log!("Device::create_pipeline_cache -> {id:?}"); + return (id, None); + } + Err(e) => break 'error e, + } + }; + + let id = fid.assign_error(desc.label.borrow_or_default()); + + (id, Some(error)) + } + + pub fn pipeline_cache_drop(&self, pipeline_cache_id: id::PipelineCacheId) { + profiling::scope!("PipelineCache::drop"); + api_log!("PipelineCache::drop {pipeline_cache_id:?}"); + + let hub = A::hub(self); + + if let Some(cache) = hub.pipeline_caches.unregister(pipeline_cache_id) { + drop(cache) + } + } + pub fn surface_configure( &self, surface_id: SurfaceId, @@ -2270,6 +2332,37 @@ impl Global { .force_replace_with_error(device_id, "Made invalid."); } + pub fn pipeline_cache_get_data(&self, id: id::PipelineCacheId) -> Option> { + use crate::pipeline_cache; + api_log!("PipelineCache::get_data"); + let hub = A::hub(self); + + if let Ok(cache) = hub.pipeline_caches.get(id) { + // TODO: Is this check needed? + if !cache.device.is_valid() { + return None; + } + if let Some(raw_cache) = cache.raw.as_ref() { + let mut vec = unsafe { cache.device.raw().pipeline_cache_get_data(raw_cache) }?; + let validation_key = cache.device.raw().pipeline_cache_validation_key()?; + + let mut header_contents = [0; pipeline_cache::HEADER_LENGTH]; + pipeline_cache::add_cache_header( + &mut header_contents, + &vec, + &cache.device.adapter.raw.info, + validation_key, + ); + + let deleted = vec.splice(..0, header_contents).collect::>(); + debug_assert!(deleted.is_empty()); + + return Some(vec); + } + } + None + } + pub fn device_drop(&self, device_id: DeviceId) { profiling::scope!("Device::drop"); api_log!("Device::drop {device_id:?}"); diff --git a/wgpu-core/src/device/resource.rs b/wgpu-core/src/device/resource.rs index 2f1eecccb..7ac3878ef 100644 --- a/wgpu-core/src/device/resource.rs +++ b/wgpu-core/src/device/resource.rs @@ -2817,6 +2817,20 @@ impl Device { let late_sized_buffer_groups = Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); + let cache = 'cache: { + let Some(cache) = desc.cache else { + break 'cache None; + }; + let Ok(cache) = hub.pipeline_caches.get(cache) else { + break 'cache None; + }; + + if cache.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); + } + Some(cache) + }; + let pipeline_desc = hal::ComputePipelineDescriptor { label: desc.label.to_hal(self.instance_flags), layout: pipeline_layout.raw(), @@ -2826,6 +2840,7 @@ impl Device { constants: desc.stage.constants.as_ref(), zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory, }, + cache: cache.as_ref().and_then(|it| it.raw.as_ref()), }; let raw = unsafe { @@ -3199,6 +3214,7 @@ impl Device { let vertex_shader_module; let vertex_entry_point_name; + let vertex_stage = { let stage_desc = &desc.vertex.stage; let stage = wgt::ShaderStages::VERTEX; @@ -3393,6 +3409,20 @@ impl Device { let late_sized_buffer_groups = Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); + let pipeline_cache = 'cache: { + let Some(cache) = desc.cache else { + break 'cache None; + }; + let Ok(cache) = hub.pipeline_caches.get(cache) else { + break 'cache None; + }; + + if cache.device.as_info().id() != self.as_info().id() { + return Err(DeviceError::WrongDevice.into()); + } + Some(cache) + }; + let pipeline_desc = hal::RenderPipelineDescriptor { label: desc.label.to_hal(self.instance_flags), layout: pipeline_layout.raw(), @@ -3404,6 +3434,7 @@ impl Device { fragment_stage, color_targets, multiview: desc.multiview, + cache: pipeline_cache.as_ref().and_then(|it| it.raw.as_ref()), }; let raw = unsafe { self.raw @@ -3484,6 +3515,53 @@ impl Device { Ok(pipeline) } + /// # Safety + /// The `data` field on `desc` must have previously been returned from [`crate::global::Global::pipeline_cache_get_data`] + pub unsafe fn create_pipeline_cache( + self: &Arc, + desc: &pipeline::PipelineCacheDescriptor, + ) -> Result, pipeline::CreatePipelineCacheError> { + use crate::pipeline_cache; + self.require_features(wgt::Features::PIPELINE_CACHE)?; + let data = if let Some((data, validation_key)) = desc + .data + .as_ref() + .zip(self.raw().pipeline_cache_validation_key()) + { + let data = pipeline_cache::validate_pipeline_cache( + data, + &self.adapter.raw.info, + validation_key, + ); + match data { + Ok(data) => Some(data), + Err(e) if e.was_avoidable() || !desc.fallback => return Err(e.into()), + // If the error was unavoidable and we are asked to fallback, do so + Err(_) => None, + } + } else { + None + }; + let cache_desc = hal::PipelineCacheDescriptor { + data, + label: desc.label.to_hal(self.instance_flags), + }; + let raw = match unsafe { self.raw().create_pipeline_cache(&cache_desc) } { + Ok(raw) => raw, + Err(e) => return Err(e.into()), + }; + let cache = pipeline::PipelineCache { + device: self.clone(), + info: ResourceInfo::new( + desc.label.borrow_or_default(), + Some(self.tracker_indices.pipeline_caches.clone()), + ), + // This would be none in the error condition, which we don't implement yet + raw: Some(raw), + }; + Ok(cache) + } + pub(crate) fn get_texture_format_features( &self, adapter: &Adapter, diff --git a/wgpu-core/src/device/trace.rs b/wgpu-core/src/device/trace.rs index 0802b610d..24790103a 100644 --- a/wgpu-core/src/device/trace.rs +++ b/wgpu-core/src/device/trace.rs @@ -98,6 +98,11 @@ pub enum Action<'a> { implicit_context: Option, }, DestroyRenderPipeline(id::RenderPipelineId), + CreatePipelineCache { + id: id::PipelineCacheId, + desc: crate::pipeline::PipelineCacheDescriptor<'a>, + }, + DestroyPipelineCache(id::PipelineCacheId), CreateRenderBundle { id: id::RenderBundleId, desc: crate::command::RenderBundleEncoderDescriptor<'a>, diff --git a/wgpu-core/src/hub.rs b/wgpu-core/src/hub.rs index eb57411d9..a318f91fc 100644 --- a/wgpu-core/src/hub.rs +++ b/wgpu-core/src/hub.rs @@ -110,7 +110,7 @@ use crate::{ device::{queue::Queue, Device}, hal_api::HalApi, instance::{Adapter, Surface}, - pipeline::{ComputePipeline, RenderPipeline, ShaderModule}, + pipeline::{ComputePipeline, PipelineCache, RenderPipeline, ShaderModule}, registry::{Registry, RegistryReport}, resource::{Buffer, QuerySet, Sampler, StagingBuffer, Texture, TextureView}, storage::{Element, Storage}, @@ -130,6 +130,7 @@ pub struct HubReport { pub render_bundles: RegistryReport, pub render_pipelines: RegistryReport, pub compute_pipelines: RegistryReport, + pub pipeline_caches: RegistryReport, pub query_sets: RegistryReport, pub buffers: RegistryReport, pub textures: RegistryReport, @@ -180,6 +181,7 @@ pub struct Hub { pub(crate) render_bundles: Registry>, pub(crate) render_pipelines: Registry>, pub(crate) compute_pipelines: Registry>, + pub(crate) pipeline_caches: Registry>, pub(crate) query_sets: Registry>, pub(crate) buffers: Registry>, pub(crate) staging_buffers: Registry>, @@ -202,6 +204,7 @@ impl Hub { render_bundles: Registry::new(A::VARIANT), render_pipelines: Registry::new(A::VARIANT), compute_pipelines: Registry::new(A::VARIANT), + pipeline_caches: Registry::new(A::VARIANT), query_sets: Registry::new(A::VARIANT), buffers: Registry::new(A::VARIANT), staging_buffers: Registry::new(A::VARIANT), @@ -235,6 +238,7 @@ impl Hub { self.pipeline_layouts.write().map.clear(); self.compute_pipelines.write().map.clear(); self.render_pipelines.write().map.clear(); + self.pipeline_caches.write().map.clear(); self.query_sets.write().map.clear(); for element in surface_guard.map.iter() { @@ -280,6 +284,7 @@ impl Hub { render_bundles: self.render_bundles.generate_report(), render_pipelines: self.render_pipelines.generate_report(), compute_pipelines: self.compute_pipelines.generate_report(), + pipeline_caches: self.pipeline_caches.generate_report(), query_sets: self.query_sets.generate_report(), buffers: self.buffers.generate_report(), textures: self.textures.generate_report(), diff --git a/wgpu-core/src/id.rs b/wgpu-core/src/id.rs index e999ef33c..5bc86b377 100644 --- a/wgpu-core/src/id.rs +++ b/wgpu-core/src/id.rs @@ -313,6 +313,7 @@ ids! { pub type ShaderModuleId ShaderModule; pub type RenderPipelineId RenderPipeline; pub type ComputePipelineId ComputePipeline; + pub type PipelineCacheId PipelineCache; pub type CommandEncoderId CommandEncoder; pub type CommandBufferId CommandBuffer; pub type RenderPassEncoderId RenderPassEncoder; diff --git a/wgpu-core/src/lib.rs b/wgpu-core/src/lib.rs index 032d85a4b..ebf80091c 100644 --- a/wgpu-core/src/lib.rs +++ b/wgpu-core/src/lib.rs @@ -65,6 +65,7 @@ mod init_tracker; pub mod instance; mod lock; pub mod pipeline; +mod pipeline_cache; mod pool; pub mod present; pub mod registry; diff --git a/wgpu-core/src/pipeline.rs b/wgpu-core/src/pipeline.rs index 3c80929e6..bfb2c331d 100644 --- a/wgpu-core/src/pipeline.rs +++ b/wgpu-core/src/pipeline.rs @@ -1,11 +1,12 @@ #[cfg(feature = "trace")] use crate::device::trace; +pub use crate::pipeline_cache::PipelineCacheValidationError; use crate::{ binding_model::{CreateBindGroupLayoutError, CreatePipelineLayoutError, PipelineLayout}, command::ColorAttachmentError, device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures, RenderPassContext}, hal_api::HalApi, - id::{PipelineLayoutId, ShaderModuleId}, + id::{PipelineCacheId, PipelineLayoutId, ShaderModuleId}, resource::{Resource, ResourceInfo, ResourceType}, resource_log, validation, Label, }; @@ -192,6 +193,8 @@ pub struct ComputePipelineDescriptor<'a> { pub layout: Option, /// The compiled compute stage and its entry point. pub stage: ProgrammableStageDescriptor<'a>, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, } #[derive(Clone, Debug, Error)] @@ -259,6 +262,68 @@ impl ComputePipeline { } } +#[derive(Clone, Debug, Error)] +#[non_exhaustive] +pub enum CreatePipelineCacheError { + #[error(transparent)] + Device(#[from] DeviceError), + #[error("Pipeline cache validation failed")] + Validation(#[from] PipelineCacheValidationError), + #[error(transparent)] + MissingFeatures(#[from] MissingFeatures), + #[error("Internal error: {0}")] + Internal(String), +} + +impl From for CreatePipelineCacheError { + fn from(value: hal::PipelineCacheError) -> Self { + match value { + hal::PipelineCacheError::Device(device) => { + CreatePipelineCacheError::Device(device.into()) + } + } + } +} + +#[derive(Debug)] +pub struct PipelineCache { + pub(crate) raw: Option, + pub(crate) device: Arc>, + pub(crate) info: ResourceInfo>, +} + +impl Drop for PipelineCache { + fn drop(&mut self) { + if let Some(raw) = self.raw.take() { + resource_log!("Destroy raw PipelineCache {:?}", self.info.label()); + + #[cfg(feature = "trace")] + if let Some(t) = self.device.trace.lock().as_mut() { + t.add(trace::Action::DestroyPipelineCache(self.info.id())); + } + + unsafe { + use hal::Device; + self.device.raw().destroy_pipeline_cache(raw); + } + } + } +} + +impl Resource for PipelineCache { + const TYPE: ResourceType = "PipelineCache"; + + type Marker = crate::id::markers::PipelineCache; + + fn as_info(&self) -> &ResourceInfo { + &self.info + } + + fn as_info_mut(&mut self) -> &mut ResourceInfo { + &mut self.info + } +} + /// Describes how the vertex buffer is interpreted. #[derive(Clone, Debug)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] @@ -315,6 +380,16 @@ pub struct RenderPipelineDescriptor<'a> { /// If the pipeline will be used with a multiview render pass, this indicates how many array /// layers the attachments will have. pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option, +} + +#[derive(Clone, Debug)] +#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] +pub struct PipelineCacheDescriptor<'a> { + pub label: Label<'a>, + pub data: Option>, + pub fallback: bool, } #[derive(Clone, Debug, Error)] diff --git a/wgpu-core/src/pipeline_cache.rs b/wgpu-core/src/pipeline_cache.rs new file mode 100644 index 000000000..d098cdafc --- /dev/null +++ b/wgpu-core/src/pipeline_cache.rs @@ -0,0 +1,530 @@ +use thiserror::Error; +use wgt::AdapterInfo; + +pub const HEADER_LENGTH: usize = std::mem::size_of::(); + +#[derive(Debug, PartialEq, Eq, Clone, Error)] +#[non_exhaustive] +pub enum PipelineCacheValidationError { + #[error("The pipeline cache data was truncated")] + Truncated, + #[error("The pipeline cache data was longer than recorded")] + // TODO: Is it plausible that this would happen + Extended, + #[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")] + Corrupted, + #[error("The pipeline cacha data was out of date and so cannot be safely used")] + Outdated, + #[error("The cache data was created for a different device")] + WrongDevice, + #[error("Pipeline cacha data was created for a future version of wgpu")] + Unsupported, +} + +impl PipelineCacheValidationError { + /// Could the error have been avoided? + /// That is, is there a mistake in user code interacting with the cache + pub fn was_avoidable(&self) -> bool { + match self { + PipelineCacheValidationError::WrongDevice => true, + PipelineCacheValidationError::Truncated + | PipelineCacheValidationError::Unsupported + | PipelineCacheValidationError::Extended + // It's unusual, but not implausible, to be downgrading wgpu + | PipelineCacheValidationError::Outdated + | PipelineCacheValidationError::Corrupted => false, + } + } +} + +/// Validate the data in a pipeline cache +pub fn validate_pipeline_cache<'d>( + cache_data: &'d [u8], + adapter: &AdapterInfo, + validation_key: [u8; 16], +) -> Result<&'d [u8], PipelineCacheValidationError> { + let adapter_key = adapter_key(adapter)?; + let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else { + return Err(PipelineCacheValidationError::Truncated); + }; + if header.magic != MAGIC { + return Err(PipelineCacheValidationError::Corrupted); + } + if header.header_version != HEADER_VERSION { + return Err(PipelineCacheValidationError::Outdated); + } + if header.cache_abi != ABI { + return Err(PipelineCacheValidationError::Outdated); + } + if header.backend != adapter.backend as u8 { + return Err(PipelineCacheValidationError::WrongDevice); + } + if header.adapter_key != adapter_key { + return Err(PipelineCacheValidationError::WrongDevice); + } + if header.validation_key != validation_key { + // If the validation key is wrong, that means that this device has changed + // in a way where the cache won't be compatible since the cache was made, + // so it is outdated + return Err(PipelineCacheValidationError::Outdated); + } + let data_size: usize = header + .data_size + .try_into() + // If the data was previously more than 4GiB, and we're still on a 32 bit system (ABI check, above) + // Then the data must be corrupted + .map_err(|_| PipelineCacheValidationError::Corrupted)?; + if remaining_data.len() < data_size { + return Err(PipelineCacheValidationError::Truncated); + } + if remaining_data.len() > data_size { + return Err(PipelineCacheValidationError::Extended); + } + if header.hash_space != HASH_SPACE_VALUE { + return Err(PipelineCacheValidationError::Corrupted); + } + Ok(remaining_data) +} + +pub fn add_cache_header( + in_region: &mut [u8], + data: &[u8], + adapter: &AdapterInfo, + validation_key: [u8; 16], +) { + assert_eq!(in_region.len(), HEADER_LENGTH); + let header = PipelineCacheHeader { + adapter_key: adapter_key(adapter) + .expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"), + backend: adapter.backend as u8, + cache_abi: ABI, + magic: MAGIC, + header_version: HEADER_VERSION, + validation_key, + hash_space: HASH_SPACE_VALUE, + data_size: data + .len() + .try_into() + .expect("Cache larger than u64::MAX bytes"), + }; + header.write(in_region); +} + +const MAGIC: [u8; 8] = *b"WGPUPLCH"; +const HEADER_VERSION: u32 = 1; +const ABI: u32 = std::mem::size_of::<*const ()>() as u32; + +/// The value used to fill [`PipelineCacheHeader::hash_space`] +/// +/// If we receive reports of pipeline cache data corruption which is not otherwise caught +/// on a real device, it would be worth modifying this +/// +/// Note that wgpu does not protect against malicious writes to e.g. a file used +/// to store a pipeline cache. +/// That is the resonsibility of the end application, such as by using a +/// private space. +const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210; + +#[repr(C)] +#[derive(PartialEq, Eq)] +struct PipelineCacheHeader { + /// The magic header to ensure that we have the right file format + /// Has a value of MAGIC, as above + magic: [u8; 8], + // /// The total size of this header, in bytes + // header_size: u32, + /// The version of this wgpu header + /// Should be equal to HEADER_VERSION above + /// + /// This must always be the second item, after the value above + header_version: u32, + /// The number of bytes in the pointers of this ABI, because some drivers + /// have previously not distinguished between their 32 bit and 64 bit drivers + /// leading to Vulkan data corruption + cache_abi: u32, + /// The id for the backend in use, from [wgt::Backend] + backend: u8, + /// The key which identifiers the device/adapter. + /// This is used to validate that this pipeline cache (probably) was produced for + /// the expected device. + /// On Vulkan: it is a combination of vendor ID and device ID + adapter_key: [u8; 15], + /// A key used to validate that this device is still compatible with the cache + /// + /// This should e.g. contain driver version and/or intermediate compiler versions + validation_key: [u8; 16], + /// The length of the data which is sent to/recieved from the backend + data_size: u64, + /// Space reserved for a hash of the data in future + /// + /// We assume that your cache storage system will be relatively robust, and so + /// do not validate this hash + /// + /// Therefore, this will always have a value of [`HASH_SPACE_VALUE`] + hash_space: u64, +} + +impl PipelineCacheHeader { + fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> { + let mut reader = Reader { + data, + total_read: 0, + }; + let magic = reader.read_array()?; + let header_version = reader.read_u32()?; + let cache_abi = reader.read_u32()?; + let backend = reader.read_byte()?; + let adapter_key = reader.read_array()?; + let validation_key = reader.read_array()?; + let data_size = reader.read_u64()?; + let data_hash = reader.read_u64()?; + + assert_eq!( + reader.total_read, + std::mem::size_of::() + ); + + Some(( + PipelineCacheHeader { + magic, + header_version, + cache_abi, + backend, + adapter_key, + validation_key, + data_size, + hash_space: data_hash, + }, + reader.data, + )) + } + + fn write(&self, into: &mut [u8]) -> Option<()> { + let mut writer = Writer { data: into }; + writer.write_array(&self.magic)?; + writer.write_u32(self.header_version)?; + writer.write_u32(self.cache_abi)?; + writer.write_byte(self.backend)?; + writer.write_array(&self.adapter_key)?; + writer.write_array(&self.validation_key)?; + writer.write_u64(self.data_size)?; + writer.write_u64(self.hash_space)?; + + assert_eq!(writer.data.len(), 0); + Some(()) + } +} + +fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> { + match adapter.backend { + wgt::Backend::Vulkan => { + // If these change size, the header format needs to change + // We set the type explicitly so this won't compile in that case + let v: [u8; 4] = adapter.vendor.to_be_bytes(); + let d: [u8; 4] = adapter.device.to_be_bytes(); + let adapter = [ + 255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255, + ]; + Ok(adapter) + } + _ => Err(PipelineCacheValidationError::Unsupported), + } +} + +struct Reader<'a> { + data: &'a [u8], + total_read: usize, +} + +impl<'a> Reader<'a> { + fn read_byte(&mut self) -> Option { + let res = *self.data.first()?; + self.total_read += 1; + self.data = &self.data[1..]; + Some(res) + } + fn read_array(&mut self) -> Option<[u8; N]> { + // Only greater than because we're indexing fenceposts, not items + if N > self.data.len() { + return None; + } + let (start, data) = self.data.split_at(N); + self.total_read += N; + self.data = data; + Some(start.try_into().expect("off-by-one-error in array size")) + } + + // fn read_u16(&mut self) -> Option { + // self.read_array().map(u16::from_be_bytes) + // } + fn read_u32(&mut self) -> Option { + self.read_array().map(u32::from_be_bytes) + } + fn read_u64(&mut self) -> Option { + self.read_array().map(u64::from_be_bytes) + } +} + +struct Writer<'a> { + data: &'a mut [u8], +} + +impl<'a> Writer<'a> { + fn write_byte(&mut self, byte: u8) -> Option<()> { + self.write_array(&[byte]) + } + fn write_array(&mut self, array: &[u8; N]) -> Option<()> { + // Only greater than because we're indexing fenceposts, not items + if N > self.data.len() { + return None; + } + let data = std::mem::take(&mut self.data); + let (start, data) = data.split_at_mut(N); + self.data = data; + start.copy_from_slice(array); + Some(()) + } + + // fn write_u16(&mut self, value: u16) -> Option<()> { + // self.write_array(&value.to_be_bytes()) + // } + fn write_u32(&mut self, value: u32) -> Option<()> { + self.write_array(&value.to_be_bytes()) + } + fn write_u64(&mut self, value: u64) -> Option<()> { + self.write_array(&value.to_be_bytes()) + } +} + +#[cfg(test)] +mod tests { + use wgt::AdapterInfo; + + use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH}; + + use super::ABI; + + // Assert the correct size + const _: [(); HEADER_LENGTH] = [(); 64]; + + const ADAPTER: AdapterInfo = AdapterInfo { + name: String::new(), + vendor: 0x0002_FEED, + device: 0xFEFE_FEFE, + device_type: wgt::DeviceType::Other, + driver: String::new(), + driver_info: String::new(), + backend: wgt::Backend::Vulkan, + }; + + // IMPORTANT: If these tests fail, then you MUST increment HEADER_VERSION + const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888); + #[test] + fn written_header() { + let mut result = [0; HEADER_LENGTH]; + super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY); + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let expected = cache.into_iter().flatten().collect::>(); + + assert_eq!(result.as_slice(), expected.as_slice()); + } + + #[test] + fn valid_data() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let expected: &[u8] = &[]; + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Ok(expected)); + } + #[test] + fn invalid_magic() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"NOT_WGPU", // (Wrong) MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Corrupted)); + } + + #[test] + fn wrong_version() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 2, 0, 0, 0, ABI as u8], // (wrong) Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Outdated)); + } + #[test] + fn wrong_abi() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + // a 14 bit ABI is improbable + [0, 0, 0, 1, 0, 0, 0, 14], // Version and (wrong) ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Header + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Outdated)); + } + + #[test] + fn wrong_backend() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [2, 255, 255, 255, 0, 2, 0xFE, 0xED], // (wrong) Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::WrongDevice)); + } + #[test] + fn wrong_adapter() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0x00], // Backend and (wrong) Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::WrongDevice)); + } + #[test] + fn wrong_validation() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_00000000u64.to_be_bytes(), // (wrong) Validation key + 0x0u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Outdated)); + } + #[test] + fn too_little_data() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x064u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Truncated)); + } + #[test] + fn not_no_data() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 100u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache + .into_iter() + .flatten() + .chain(std::iter::repeat(0u8).take(100)) + .collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + let expected: &[u8] = &[0; 100]; + assert_eq!(validation_result, Ok(expected)); + } + #[test] + fn too_much_data() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x064u64.to_be_bytes(), // Data size + 0xFEDCBA9_876543210u64.to_be_bytes(), // Hash + ]; + let cache = cache + .into_iter() + .flatten() + .chain(std::iter::repeat(0u8).take(200)) + .collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Extended)); + } + #[test] + fn wrong_hash() { + let cache: [[u8; 8]; HEADER_LENGTH / 8] = [ + *b"WGPUPLCH", // MAGIC + [0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI + [1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key + [0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key + 0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key + 0x88888888_88888888u64.to_be_bytes(), // Validation key + 0x0u64.to_be_bytes(), // Data size + 0x00000000_00000000u64.to_be_bytes(), // Hash + ]; + let cache = cache.into_iter().flatten().collect::>(); + let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY); + assert_eq!(validation_result, Err(E::Corrupted)); + } +} diff --git a/wgpu-core/src/track/mod.rs b/wgpu-core/src/track/mod.rs index ff40a36b9..dba34aa38 100644 --- a/wgpu-core/src/track/mod.rs +++ b/wgpu-core/src/track/mod.rs @@ -228,6 +228,7 @@ pub(crate) struct TrackerIndexAllocators { pub pipeline_layouts: Arc, pub bundles: Arc, pub query_sets: Arc, + pub pipeline_caches: Arc, } impl TrackerIndexAllocators { @@ -245,6 +246,7 @@ impl TrackerIndexAllocators { pipeline_layouts: Arc::new(SharedTrackerIndexAllocator::new()), bundles: Arc::new(SharedTrackerIndexAllocator::new()), query_sets: Arc::new(SharedTrackerIndexAllocator::new()), + pipeline_caches: Arc::new(SharedTrackerIndexAllocator::new()), } } } diff --git a/wgpu-hal/examples/halmark/main.rs b/wgpu-hal/examples/halmark/main.rs index aef6919c8..ee59fa259 100644 --- a/wgpu-hal/examples/halmark/main.rs +++ b/wgpu-hal/examples/halmark/main.rs @@ -274,6 +274,7 @@ impl Example { write_mask: wgt::ColorWrites::default(), })], multiview: None, + cache: None, }; let pipeline = unsafe { device.create_render_pipeline(&pipeline_desc).unwrap() }; diff --git a/wgpu-hal/examples/ray-traced-triangle/main.rs b/wgpu-hal/examples/ray-traced-triangle/main.rs index 3985cd60a..8f404dc4d 100644 --- a/wgpu-hal/examples/ray-traced-triangle/main.rs +++ b/wgpu-hal/examples/ray-traced-triangle/main.rs @@ -374,6 +374,7 @@ impl Example { constants: &Default::default(), zero_initialize_workgroup_memory: true, }, + cache: None, }) } .unwrap(); diff --git a/wgpu-hal/src/dx12/device.rs b/wgpu-hal/src/dx12/device.rs index d4d27ca3f..5625dfca3 100644 --- a/wgpu-hal/src/dx12/device.rs +++ b/wgpu-hal/src/dx12/device.rs @@ -1513,6 +1513,14 @@ impl crate::Device for super::Device { } unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {} + unsafe fn create_pipeline_cache( + &self, + _desc: &crate::PipelineCacheDescriptor<'_>, + ) -> Result<(), crate::PipelineCacheError> { + Ok(()) + } + unsafe fn destroy_pipeline_cache(&self, (): ()) {} + unsafe fn create_query_set( &self, desc: &wgt::QuerySetDescriptor, diff --git a/wgpu-hal/src/dx12/mod.rs b/wgpu-hal/src/dx12/mod.rs index 95a31d189..99800e87c 100644 --- a/wgpu-hal/src/dx12/mod.rs +++ b/wgpu-hal/src/dx12/mod.rs @@ -82,6 +82,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type PipelineCache = (); type AccelerationStructure = AccelerationStructure; } diff --git a/wgpu-hal/src/empty.rs b/wgpu-hal/src/empty.rs index ad00da1b7..f1986f770 100644 --- a/wgpu-hal/src/empty.rs +++ b/wgpu-hal/src/empty.rs @@ -30,6 +30,7 @@ impl crate::Api for Api { type QuerySet = Resource; type Fence = Resource; type AccelerationStructure = Resource; + type PipelineCache = Resource; type BindGroupLayout = Resource; type BindGroup = Resource; @@ -220,6 +221,13 @@ impl crate::Device for Context { Ok(Resource) } unsafe fn destroy_compute_pipeline(&self, pipeline: Resource) {} + unsafe fn create_pipeline_cache( + &self, + desc: &crate::PipelineCacheDescriptor<'_>, + ) -> Result { + Ok(Resource) + } + unsafe fn destroy_pipeline_cache(&self, cache: Resource) {} unsafe fn create_query_set( &self, diff --git a/wgpu-hal/src/gles/device.rs b/wgpu-hal/src/gles/device.rs index ae9b401a0..afdc6ad7c 100644 --- a/wgpu-hal/src/gles/device.rs +++ b/wgpu-hal/src/gles/device.rs @@ -1406,6 +1406,16 @@ impl crate::Device for super::Device { } } + unsafe fn create_pipeline_cache( + &self, + _: &crate::PipelineCacheDescriptor<'_>, + ) -> Result<(), crate::PipelineCacheError> { + // Even though the cache doesn't do anything, we still return something here + // as the least bad option + Ok(()) + } + unsafe fn destroy_pipeline_cache(&self, (): ()) {} + #[cfg_attr(target_arch = "wasm32", allow(unused))] unsafe fn create_query_set( &self, diff --git a/wgpu-hal/src/gles/mod.rs b/wgpu-hal/src/gles/mod.rs index 0fcb09be4..058bdcf6f 100644 --- a/wgpu-hal/src/gles/mod.rs +++ b/wgpu-hal/src/gles/mod.rs @@ -154,6 +154,7 @@ impl crate::Api for Api { type QuerySet = QuerySet; type Fence = Fence; type AccelerationStructure = (); + type PipelineCache = (); type BindGroupLayout = BindGroupLayout; type BindGroup = BindGroup; diff --git a/wgpu-hal/src/lib.rs b/wgpu-hal/src/lib.rs index d300ca30c..16cc5fe21 100644 --- a/wgpu-hal/src/lib.rs +++ b/wgpu-hal/src/lib.rs @@ -332,6 +332,12 @@ pub enum PipelineError { Device(#[from] DeviceError), } +#[derive(Clone, Debug, Eq, PartialEq, Error)] +pub enum PipelineCacheError { + #[error(transparent)] + Device(#[from] DeviceError), +} + #[derive(Clone, Debug, Eq, PartialEq, Error)] pub enum SurfaceError { #[error("Surface is lost")] @@ -432,6 +438,7 @@ pub trait Api: Clone + fmt::Debug + Sized { type ShaderModule: fmt::Debug + WasmNotSendSync; type RenderPipeline: fmt::Debug + WasmNotSendSync; type ComputePipeline: fmt::Debug + WasmNotSendSync; + type PipelineCache: fmt::Debug + WasmNotSendSync; type AccelerationStructure: fmt::Debug + WasmNotSendSync + 'static; } @@ -611,6 +618,14 @@ pub trait Device: WasmNotSendSync { desc: &ComputePipelineDescriptor, ) -> Result<::ComputePipeline, PipelineError>; unsafe fn destroy_compute_pipeline(&self, pipeline: ::ComputePipeline); + unsafe fn create_pipeline_cache( + &self, + desc: &PipelineCacheDescriptor<'_>, + ) -> Result<::PipelineCache, PipelineCacheError>; + fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> { + None + } + unsafe fn destroy_pipeline_cache(&self, cache: ::PipelineCache); unsafe fn create_query_set( &self, @@ -652,6 +667,14 @@ pub trait Device: WasmNotSendSync { unsafe fn start_capture(&self) -> bool; unsafe fn stop_capture(&self); + #[allow(unused_variables)] + unsafe fn pipeline_cache_get_data( + &self, + cache: &::PipelineCache, + ) -> Option> { + None + } + unsafe fn create_acceleration_structure( &self, desc: &AccelerationStructureDescriptor, @@ -1636,6 +1659,13 @@ pub struct ComputePipelineDescriptor<'a, A: Api> { pub layout: &'a A::PipelineLayout, /// The compiled compute stage and its entry point. pub stage: ProgrammableStage<'a, A>, + /// The cache which will be used and filled when compiling this pipeline + pub cache: Option<&'a A::PipelineCache>, +} + +pub struct PipelineCacheDescriptor<'a> { + pub label: Label<'a>, + pub data: Option<&'a [u8]>, } /// Describes how the vertex buffer is interpreted. @@ -1672,6 +1702,8 @@ pub struct RenderPipelineDescriptor<'a, A: Api> { /// If the pipeline will be used with a multiview render pass, this indicates how many array /// layers the attachments will have. pub multiview: Option, + /// The cache which will be used and filled when compiling this pipeline + pub cache: Option<&'a A::PipelineCache>, } #[derive(Debug, Clone)] diff --git a/wgpu-hal/src/metal/device.rs b/wgpu-hal/src/metal/device.rs index 2c8f5a2bf..81ab5dbdb 100644 --- a/wgpu-hal/src/metal/device.rs +++ b/wgpu-hal/src/metal/device.rs @@ -1099,6 +1099,14 @@ impl crate::Device for super::Device { } unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {} + unsafe fn create_pipeline_cache( + &self, + _desc: &crate::PipelineCacheDescriptor<'_>, + ) -> Result<(), crate::PipelineCacheError> { + Ok(()) + } + unsafe fn destroy_pipeline_cache(&self, (): ()) {} + unsafe fn create_query_set( &self, desc: &wgt::QuerySetDescriptor, diff --git a/wgpu-hal/src/metal/mod.rs b/wgpu-hal/src/metal/mod.rs index 7d547cfe3..a5ea63b03 100644 --- a/wgpu-hal/src/metal/mod.rs +++ b/wgpu-hal/src/metal/mod.rs @@ -66,6 +66,7 @@ impl crate::Api for Api { type ShaderModule = ShaderModule; type RenderPipeline = RenderPipeline; type ComputePipeline = ComputePipeline; + type PipelineCache = (); type AccelerationStructure = AccelerationStructure; } diff --git a/wgpu-hal/src/vulkan/adapter.rs b/wgpu-hal/src/vulkan/adapter.rs index 82a30617f..6641a24a7 100644 --- a/wgpu-hal/src/vulkan/adapter.rs +++ b/wgpu-hal/src/vulkan/adapter.rs @@ -462,7 +462,8 @@ impl PhysicalDeviceFeatures { | F::TIMESTAMP_QUERY_INSIDE_ENCODERS | F::TIMESTAMP_QUERY_INSIDE_PASSES | F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES - | F::CLEAR_TEXTURE; + | F::CLEAR_TEXTURE + | F::PIPELINE_CACHE; let mut dl_flags = Df::COMPUTE_SHADERS | Df::BASE_VERTEX @@ -1745,6 +1746,19 @@ impl super::Adapter { unsafe { raw_device.get_device_queue(family_index, queue_index) } }; + let driver_version = self + .phd_capabilities + .properties + .driver_version + .to_be_bytes(); + #[rustfmt::skip] + let pipeline_cache_validation_key = [ + driver_version[0], driver_version[1], driver_version[2], driver_version[3], + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + ]; + let shared = Arc::new(super::DeviceShared { raw: raw_device, family_index, @@ -1760,6 +1774,7 @@ impl super::Adapter { timeline_semaphore: timeline_semaphore_fn, ray_tracing: ray_tracing_fns, }, + pipeline_cache_validation_key, vendor_id: self.phd_capabilities.properties.vendor_id, timestamp_period: self.phd_capabilities.properties.limits.timestamp_period, private_caps: self.private_caps.clone(), diff --git a/wgpu-hal/src/vulkan/device.rs b/wgpu-hal/src/vulkan/device.rs index 0ac83b3fa..1ea627897 100644 --- a/wgpu-hal/src/vulkan/device.rs +++ b/wgpu-hal/src/vulkan/device.rs @@ -1,4 +1,4 @@ -use super::conv; +use super::{conv, PipelineCache}; use arrayvec::ArrayVec; use ash::{khr, vk}; @@ -1867,12 +1867,17 @@ impl crate::Device for super::Device { .render_pass(raw_pass) }]; + let pipeline_cache = desc + .cache + .map(|it| it.raw) + .unwrap_or(vk::PipelineCache::null()); + let mut raw_vec = { profiling::scope!("vkCreateGraphicsPipelines"); unsafe { self.shared .raw - .create_graphics_pipelines(vk::PipelineCache::null(), &vk_infos, None) + .create_graphics_pipelines(pipeline_cache, &vk_infos, None) .map_err(|(_, e)| crate::DeviceError::from(e)) }? }; @@ -1915,12 +1920,17 @@ impl crate::Device for super::Device { .stage(compiled.create_info) }]; + let pipeline_cache = desc + .cache + .map(|it| it.raw) + .unwrap_or(vk::PipelineCache::null()); + let mut raw_vec = { profiling::scope!("vkCreateComputePipelines"); unsafe { self.shared .raw - .create_compute_pipelines(vk::PipelineCache::null(), &vk_infos, None) + .create_compute_pipelines(pipeline_cache, &vk_infos, None) .map_err(|(_, e)| crate::DeviceError::from(e)) }? }; @@ -1940,6 +1950,26 @@ impl crate::Device for super::Device { unsafe { self.shared.raw.destroy_pipeline(pipeline.raw, None) }; } + unsafe fn create_pipeline_cache( + &self, + desc: &crate::PipelineCacheDescriptor<'_>, + ) -> Result { + let mut info = vk::PipelineCacheCreateInfo::default(); + if let Some(data) = desc.data { + info = info.initial_data(data) + } + profiling::scope!("vkCreatePipelineCache"); + let raw = unsafe { self.shared.raw.create_pipeline_cache(&info, None) } + .map_err(crate::DeviceError::from)?; + + Ok(PipelineCache { raw }) + } + fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> { + Some(self.shared.pipeline_cache_validation_key) + } + unsafe fn destroy_pipeline_cache(&self, cache: PipelineCache) { + unsafe { self.shared.raw.destroy_pipeline_cache(cache.raw, None) } + } unsafe fn create_query_set( &self, desc: &wgt::QuerySetDescriptor, @@ -2105,6 +2135,11 @@ impl crate::Device for super::Device { } } + unsafe fn pipeline_cache_get_data(&self, cache: &PipelineCache) -> Option> { + let data = unsafe { self.raw_device().get_pipeline_cache_data(cache.raw) }; + data.ok() + } + unsafe fn get_acceleration_structure_build_sizes<'a>( &self, desc: &crate::GetAccelerationStructureBuildSizesDescriptor<'a, super::Api>, diff --git a/wgpu-hal/src/vulkan/mod.rs b/wgpu-hal/src/vulkan/mod.rs index 9f244ff98..1716ee920 100644 --- a/wgpu-hal/src/vulkan/mod.rs +++ b/wgpu-hal/src/vulkan/mod.rs @@ -70,6 +70,7 @@ impl crate::Api for Api { type QuerySet = QuerySet; type Fence = Fence; type AccelerationStructure = AccelerationStructure; + type PipelineCache = PipelineCache; type BindGroupLayout = BindGroupLayout; type BindGroup = BindGroup; @@ -338,6 +339,7 @@ struct DeviceShared { enabled_extensions: Vec<&'static CStr>, extension_fns: DeviceExtensionFunctions, vendor_id: u32, + pipeline_cache_validation_key: [u8; 16], timestamp_period: f32, private_caps: PrivateCapabilities, workarounds: Workarounds, @@ -549,6 +551,11 @@ pub struct ComputePipeline { raw: vk::Pipeline, } +#[derive(Debug)] +pub struct PipelineCache { + raw: vk::PipelineCache, +} + #[derive(Debug)] pub struct QuerySet { raw: vk::QueryPool, diff --git a/wgpu-types/src/lib.rs b/wgpu-types/src/lib.rs index 92d5b68d4..4eac7c4c1 100644 --- a/wgpu-types/src/lib.rs +++ b/wgpu-types/src/lib.rs @@ -914,6 +914,15 @@ bitflags::bitflags! { /// /// This is a native only feature. const SUBGROUP_BARRIER = 1 << 58; + /// Allows the use of pipeline cache objects + /// + /// Supported platforms: + /// - Vulkan + /// + /// Unimplemented Platforms: + /// - DX12 + /// - Metal + const PIPELINE_CACHE = 1 << 59; } } diff --git a/wgpu/src/backend/webgpu.rs b/wgpu/src/backend/webgpu.rs index 2185d5b8b..fa2896dfc 100644 --- a/wgpu/src/backend/webgpu.rs +++ b/wgpu/src/backend/webgpu.rs @@ -1159,6 +1159,8 @@ impl crate::context::Context for ContextWebGpu { type SurfaceOutputDetail = SurfaceOutputDetail; type SubmissionIndex = Unused; type SubmissionIndexData = (); + type PipelineCacheId = Unused; + type PipelineCacheData = (); type RequestAdapterFuture = MakeSendFuture< wasm_bindgen_futures::JsFuture, @@ -1995,6 +1997,16 @@ impl crate::context::Context for ContextWebGpu { create_identified(device_data.0.create_compute_pipeline(&mapped_desc)) } + unsafe fn device_create_pipeline_cache( + &self, + _: &Self::DeviceId, + _: &Self::DeviceData, + _: &crate::PipelineCacheDescriptor<'_>, + ) -> (Self::PipelineCacheId, Self::PipelineCacheData) { + (Unused, ()) + } + fn pipeline_cache_drop(&self, _: &Self::PipelineCacheId, _: &Self::PipelineCacheData) {} + fn device_create_buffer( &self, _device: &Self::DeviceId, @@ -2981,6 +2993,14 @@ impl crate::context::Context for ContextWebGpu { fn device_start_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {} fn device_stop_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {} + fn pipeline_cache_get_data( + &self, + _: &Self::PipelineCacheId, + _: &Self::PipelineCacheData, + ) -> Option> { + None + } + fn compute_pass_set_pipeline( &self, _pass: &mut Self::ComputePassId, diff --git a/wgpu/src/backend/wgpu_core.rs b/wgpu/src/backend/wgpu_core.rs index f03e4a569..70a7bef11 100644 --- a/wgpu/src/backend/wgpu_core.rs +++ b/wgpu/src/backend/wgpu_core.rs @@ -4,10 +4,10 @@ use crate::{ BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, CompilationMessage, CompilationMessageType, ComputePassDescriptor, ComputePipelineDescriptor, DownlevelCapabilities, Features, Label, Limits, LoadOp, MapMode, Operations, - PipelineLayoutDescriptor, RenderBundleEncoderDescriptor, RenderPipelineDescriptor, - SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, ShaderSource, StoreOp, - SurfaceStatus, SurfaceTargetUnsafe, TextureDescriptor, TextureViewDescriptor, - UncapturedErrorHandler, + PipelineCacheDescriptor, PipelineLayoutDescriptor, RenderBundleEncoderDescriptor, + RenderPipelineDescriptor, SamplerDescriptor, ShaderModuleDescriptor, + ShaderModuleDescriptorSpirV, ShaderSource, StoreOp, SurfaceStatus, SurfaceTargetUnsafe, + TextureDescriptor, TextureViewDescriptor, UncapturedErrorHandler, }; use arrayvec::ArrayVec; @@ -519,6 +519,8 @@ impl crate::Context for ContextWgpuCore { type RenderPipelineData = (); type ComputePipelineId = wgc::id::ComputePipelineId; type ComputePipelineData = (); + type PipelineCacheId = wgc::id::PipelineCacheId; + type PipelineCacheData = (); type CommandEncoderId = wgc::id::CommandEncoderId; type CommandEncoderData = CommandEncoder; type ComputePassId = Unused; @@ -1191,6 +1193,7 @@ impl crate::Context for ContextWgpuCore { targets: Borrowed(frag.targets), }), multiview: desc.multiview, + cache: desc.cache.map(|c| c.id.into()), }; let (id, error) = wgc::gfx_select!(device => self.0.device_create_render_pipeline( @@ -1240,6 +1243,7 @@ impl crate::Context for ContextWgpuCore { .compilation_options .zero_initialize_workgroup_memory, }, + cache: desc.cache.map(|c| c.id.into()), }; let (id, error) = wgc::gfx_select!(device => self.0.device_create_compute_pipeline( @@ -1267,6 +1271,37 @@ impl crate::Context for ContextWgpuCore { } (id, ()) } + + unsafe fn device_create_pipeline_cache( + &self, + device: &Self::DeviceId, + device_data: &Self::DeviceData, + desc: &PipelineCacheDescriptor<'_>, + ) -> (Self::PipelineCacheId, Self::PipelineCacheData) { + use wgc::pipeline as pipe; + + let descriptor = pipe::PipelineCacheDescriptor { + label: desc.label.map(Borrowed), + data: desc.data.map(Borrowed), + fallback: desc.fallback, + }; + let (id, error) = wgc::gfx_select!(device => self.0.device_create_pipeline_cache( + *device, + &descriptor, + None + )); + if let Some(cause) = error { + self.handle_error( + &device_data.error_sink, + cause, + LABEL, + desc.label, + "Device::device_create_pipeline_cache_init", + ); + } + (id, ()) + } + fn device_create_buffer( &self, device: &Self::DeviceId, @@ -1726,6 +1761,14 @@ impl crate::Context for ContextWgpuCore { wgc::gfx_select!(*pipeline => self.0.render_pipeline_drop(*pipeline)) } + fn pipeline_cache_drop( + &self, + cache: &Self::PipelineCacheId, + _cache_data: &Self::PipelineCacheData, + ) { + wgc::gfx_select!(*cache => self.0.pipeline_cache_drop(*cache)) + } + fn compute_pipeline_get_bind_group_layout( &self, pipeline: &Self::ComputePipelineId, @@ -2336,6 +2379,15 @@ impl crate::Context for ContextWgpuCore { wgc::gfx_select!(device => self.0.device_stop_capture(*device)); } + fn pipeline_cache_get_data( + &self, + cache: &Self::PipelineCacheId, + // TODO: Used for error handling? + _cache_data: &Self::PipelineCacheData, + ) -> Option> { + wgc::gfx_select!(cache => self.0.pipeline_cache_get_data(*cache)) + } + fn compute_pass_set_pipeline( &self, _pass: &mut Self::ComputePassId, diff --git a/wgpu/src/context.rs b/wgpu/src/context.rs index 12ea5cc90..c29d88c2b 100644 --- a/wgpu/src/context.rs +++ b/wgpu/src/context.rs @@ -11,11 +11,12 @@ use crate::{ AnyWasmNotSendSync, BindGroupDescriptor, BindGroupLayoutDescriptor, Buffer, BufferAsyncError, BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, ComputePassDescriptor, ComputePipelineDescriptor, DeviceDescriptor, Error, ErrorFilter, ImageCopyBuffer, - ImageCopyTexture, Maintain, MaintainResult, MapMode, PipelineLayoutDescriptor, - QuerySetDescriptor, RenderBundleDescriptor, RenderBundleEncoderDescriptor, - RenderPassDescriptor, RenderPipelineDescriptor, RequestAdapterOptions, RequestDeviceError, - SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, SurfaceTargetUnsafe, - Texture, TextureDescriptor, TextureViewDescriptor, UncapturedErrorHandler, + ImageCopyTexture, Maintain, MaintainResult, MapMode, PipelineCacheDescriptor, + PipelineLayoutDescriptor, QuerySetDescriptor, RenderBundleDescriptor, + RenderBundleEncoderDescriptor, RenderPassDescriptor, RenderPipelineDescriptor, + RequestAdapterOptions, RequestDeviceError, SamplerDescriptor, ShaderModuleDescriptor, + ShaderModuleDescriptorSpirV, SurfaceTargetUnsafe, Texture, TextureDescriptor, + TextureViewDescriptor, UncapturedErrorHandler, }; /// Meta trait for an id tracked by a context. @@ -59,6 +60,8 @@ pub trait Context: Debug + WasmNotSendSync + Sized { type RenderPipelineData: ContextData; type ComputePipelineId: ContextId + WasmNotSendSync; type ComputePipelineData: ContextData; + type PipelineCacheId: ContextId + WasmNotSendSync; + type PipelineCacheData: ContextData; type CommandEncoderId: ContextId + WasmNotSendSync; type CommandEncoderData: ContextData; type ComputePassId: ContextId; @@ -233,6 +236,12 @@ pub trait Context: Debug + WasmNotSendSync + Sized { device_data: &Self::DeviceData, desc: &ComputePipelineDescriptor<'_>, ) -> (Self::ComputePipelineId, Self::ComputePipelineData); + unsafe fn device_create_pipeline_cache( + &self, + device: &Self::DeviceId, + device_data: &Self::DeviceData, + desc: &PipelineCacheDescriptor<'_>, + ) -> (Self::PipelineCacheId, Self::PipelineCacheData); fn device_create_buffer( &self, device: &Self::DeviceId, @@ -395,6 +404,11 @@ pub trait Context: Debug + WasmNotSendSync + Sized { pipeline: &Self::RenderPipelineId, pipeline_data: &Self::RenderPipelineData, ); + fn pipeline_cache_drop( + &self, + cache: &Self::PipelineCacheId, + cache_data: &Self::PipelineCacheData, + ); fn compute_pipeline_get_bind_group_layout( &self, @@ -613,6 +627,12 @@ pub trait Context: Debug + WasmNotSendSync + Sized { fn device_start_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData); fn device_stop_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData); + fn pipeline_cache_get_data( + &self, + cache: &Self::PipelineCacheId, + cache_data: &Self::PipelineCacheData, + ) -> Option>; + fn compute_pass_set_pipeline( &self, pass: &mut Self::ComputePassId, @@ -1271,6 +1291,12 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync { device_data: &crate::Data, desc: &ComputePipelineDescriptor<'_>, ) -> (ObjectId, Box); + unsafe fn device_create_pipeline_cache( + &self, + device: &ObjectId, + device_data: &crate::Data, + desc: &PipelineCacheDescriptor<'_>, + ) -> (ObjectId, Box); fn device_create_buffer( &self, device: &ObjectId, @@ -1391,6 +1417,7 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync { fn render_bundle_drop(&self, render_bundle: &ObjectId, render_bundle_data: &crate::Data); fn compute_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data); fn render_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data); + fn pipeline_cache_drop(&self, cache: &ObjectId, _cache_data: &crate::Data); fn compute_pipeline_get_bind_group_layout( &self, @@ -1601,6 +1628,12 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync { fn device_start_capture(&self, device: &ObjectId, data: &crate::Data); fn device_stop_capture(&self, device: &ObjectId, data: &crate::Data); + fn pipeline_cache_get_data( + &self, + cache: &ObjectId, + cache_data: &crate::Data, + ) -> Option>; + fn compute_pass_set_pipeline( &self, pass: &mut ObjectId, @@ -2297,6 +2330,19 @@ where (compute_pipeline.into(), Box::new(data) as _) } + unsafe fn device_create_pipeline_cache( + &self, + device: &ObjectId, + device_data: &crate::Data, + desc: &PipelineCacheDescriptor<'_>, + ) -> (ObjectId, Box) { + let device = ::from(*device); + let device_data = downcast_ref(device_data); + let (pipeline_cache, data) = + unsafe { Context::device_create_pipeline_cache(self, &device, device_data, desc) }; + (pipeline_cache.into(), Box::new(data) as _) + } + fn device_create_buffer( &self, device: &ObjectId, @@ -2621,6 +2667,12 @@ where Context::render_pipeline_drop(self, &pipeline, pipeline_data) } + fn pipeline_cache_drop(&self, cache: &ObjectId, cache_data: &crate::Data) { + let cache = ::from(*cache); + let cache_data = downcast_ref(cache_data); + Context::pipeline_cache_drop(self, &cache, cache_data) + } + fn compute_pipeline_get_bind_group_layout( &self, pipeline: &ObjectId, @@ -3083,6 +3135,16 @@ where Context::device_stop_capture(self, &device, device_data) } + fn pipeline_cache_get_data( + &self, + cache: &ObjectId, + cache_data: &crate::Data, + ) -> Option> { + let cache = ::from(*cache); + let cache_data = downcast_ref::(cache_data); + Context::pipeline_cache_get_data(self, &cache, cache_data) + } + fn compute_pass_set_pipeline( &self, pass: &mut ObjectId, diff --git a/wgpu/src/lib.rs b/wgpu/src/lib.rs index ed5694173..0d2bd504f 100644 --- a/wgpu/src/lib.rs +++ b/wgpu/src/lib.rs @@ -1105,6 +1105,100 @@ impl ComputePipeline { } } +/// Handle to a pipeline cache, which is used to accelerate +/// creating [`RenderPipeline`]s and [`ComputePipeline`]s +/// in subsequent executions +/// +/// This reuse is only applicable for the same or similar devices. +/// See [`util::pipeline_cache_key`] for some details. +/// +/// # Background +/// +/// In most GPU drivers, shader code must be converted into a machine code +/// which can be executed on the GPU. +/// Generating this machine code can require a lot of computation. +/// Pipeline caches allow this computation to be reused between executions +/// of the program. +/// This can be very useful for reducing program startup time. +/// +/// Note that most desktop GPU drivers will manage their own caches, +/// meaning that little advantage can be gained from this on those platforms. +/// However, on some platforms, especially Android, drivers leave this to the +/// application to implement. +/// +/// Unfortunately, drivers do not expose whether they manage their own caches. +/// Some reasonable policies for applications to use are: +/// - Manage their own pipeline cache on all platforms +/// - Only manage pipeline caches on Android +/// +/// # Usage +/// +/// It is valid to use this resource when creating multiple pipelines, in +/// which case it will likely cache each of those pipelines. +/// It is also valid to create a new cache for each pipeline. +/// +/// This resource is most useful when the data produced from it (using +/// [`PipelineCache::get_data`]) is persisted. +/// Care should be taken that pipeline caches are only used for the same device, +/// as pipeline caches from compatible devices are unlikely to provide any advantage. +/// `util::pipeline_cache_key` can be used as a file/directory name to help ensure that. +/// +/// It is recommended to store pipeline caches atomically. If persisting to disk, +/// this can usually be achieved by creating a temporary file, then moving/[renaming] +/// the temporary file over the existing cache +/// +/// # Storage Usage +/// +/// There is not currently an API available to reduce the size of a cache. +/// This is due to limitations in the underlying graphics APIs used. +/// This is especially impactful if your application is being updated, so +/// previous caches are no longer being used. +/// +/// One option to work around this is to regenerate the cache. +/// That is, creating the pipelines which your program runs using +/// with the stored cached data, then recreating the *same* pipelines +/// using a new cache, which your application then store. +/// +/// # Implementations +/// +/// This resource currently only works on the following backends: +/// - Vulkan +/// +/// This type is unique to the Rust API of `wgpu`. +/// +/// [renaming]: std::fs::rename +#[derive(Debug)] +pub struct PipelineCache { + context: Arc, + id: ObjectId, + data: Box, +} + +#[cfg(send_sync)] +static_assertions::assert_impl_all!(PipelineCache: Send, Sync); + +impl PipelineCache { + /// Get the data associated with this pipeline cache. + /// The data format is an implementation detail of `wgpu`. + /// The only defined operation on this data setting it as the `data` field + /// on [`PipelineCacheDescriptor`], then to [`Device::create_pipeline_cache`]. + /// + /// This function is unique to the Rust API of `wgpu`. + pub fn get_data(&self) -> Option> { + self.context + .pipeline_cache_get_data(&self.id, self.data.as_ref()) + } +} + +impl Drop for PipelineCache { + fn drop(&mut self) { + if !thread::panicking() { + self.context + .pipeline_cache_drop(&self.id, self.data.as_ref()); + } + } +} + /// Handle to a command buffer on the GPU. /// /// A `CommandBuffer` represents a complete sequence of commands that may be submitted to a command @@ -1832,6 +1926,8 @@ pub struct RenderPipelineDescriptor<'a> { /// If the pipeline will be used with a multiview render pass, this indicates how many array /// layers the attachments will have. pub multiview: Option, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option<&'a PipelineCache>, } #[cfg(send_sync)] static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); @@ -1929,10 +2025,38 @@ pub struct ComputePipelineDescriptor<'a> { /// /// This implements `Default`, and for most users can be set to `Default::default()` pub compilation_options: PipelineCompilationOptions<'a>, + /// The pipeline cache to use when creating this pipeline. + pub cache: Option<&'a PipelineCache>, } #[cfg(send_sync)] static_assertions::assert_impl_all!(ComputePipelineDescriptor<'_>: Send, Sync); +/// Describes a pipeline cache, which allows reusing compilation work +/// between program runs. +/// +/// For use with [`Device::create_pipeline_cache`] +/// +/// This type is unique to the Rust API of `wgpu`. +#[derive(Clone, Debug)] +pub struct PipelineCacheDescriptor<'a> { + /// Debug label of the pipeline cache. This might show up in some logs from `wgpu` + pub label: Label<'a>, + /// The data used to initialise the cache initialise + /// + /// # Safety + /// + /// This data must have been provided from a previous call to + /// [`PipelineCache::get_data`], if not `None` + pub data: Option<&'a [u8]>, + /// Whether to create a cache without data when the provided data + /// is invalid. + /// + /// Recommended to set to true + pub fallback: bool, +} +#[cfg(send_sync)] +static_assertions::assert_impl_all!(PipelineCacheDescriptor<'_>: Send, Sync); + pub use wgt::ImageCopyBuffer as ImageCopyBufferBase; /// View of a buffer which can be used to copy to/from a texture. /// @@ -3080,6 +3204,62 @@ impl Device { pub fn make_invalid(&self) { DynContext::device_make_invalid(&*self.context, &self.id, self.data.as_ref()) } + + /// Create a [`PipelineCache`] with initial data + /// + /// This can be passed to [`Device::create_compute_pipeline`] + /// and [`Device::create_render_pipeline`] to either accelerate these + /// or add the cache results from those. + /// + /// # Safety + /// + /// If the `data` field of `desc` is set, it must have previously been returned from a call + /// to [`PipelineCache::get_data`][^saving]. This `data` will only be used if it came + /// from an adapter with the same [`util::pipeline_cache_key`]. + /// This *is* compatible across wgpu versions, as any data format change will + /// be accounted for. + /// + /// It is *not* supported to bring caches from previous direct uses of backend APIs + /// into this method. + /// + /// # Errors + /// + /// Returns an error value if: + /// * the [`PIPELINE_CACHE`](wgt::Features::PIPELINE_CACHE) feature is not enabled + /// * this device is invalid; or + /// * the device is out of memory + /// + /// This method also returns an error value if: + /// * The `fallback` field on `desc` is false; and + /// * the `data` provided would not be used[^data_not_used] + /// + /// If an error value is used in subsequent calls, default caching will be used. + /// + /// [^saving]: We do recognise that saving this data to disk means this condition + /// is impossible to fully prove. Consider the risks for your own application in this case. + /// + /// [^data_not_used]: This data may be not used if: the data was produced by a prior + /// version of wgpu; or was created for an incompatible adapter, or there was a GPU driver + /// update. In some cases, the data might not be used and a real value is returned, + /// this is left to the discretion of GPU drivers. + pub unsafe fn create_pipeline_cache( + &self, + desc: &PipelineCacheDescriptor<'_>, + ) -> PipelineCache { + let (id, data) = unsafe { + DynContext::device_create_pipeline_cache( + &*self.context, + &self.id, + self.data.as_ref(), + desc, + ) + }; + PipelineCache { + context: Arc::clone(&self.context), + id, + data, + } + } } impl Drop for Device { diff --git a/wgpu/src/util/mod.rs b/wgpu/src/util/mod.rs index 3ab6639cf..ce5af6fb6 100644 --- a/wgpu/src/util/mod.rs +++ b/wgpu/src/util/mod.rs @@ -140,3 +140,52 @@ impl std::ops::Deref for DownloadBuffer { self.1.slice() } } + +/// A recommended key for storing [`PipelineCache`]s for the adapter +/// associated with the given [`AdapterInfo`](wgt::AdapterInfo) +/// This key will define a class of adapters for which the same cache +/// might be valid. +/// +/// If this returns `None`, the adapter doesn't support [`PipelineCache`]. +/// This may be because the API doesn't support application managed caches +/// (such as browser WebGPU), or that `wgpu` hasn't implemented it for +/// that API yet. +/// +/// This key could be used as a filename, as seen in the example below. +/// +/// # Examples +/// +/// ``` no_run +/// # use std::path::PathBuf; +/// # let adapter_info = todo!(); +/// let cache_dir: PathBuf = PathBuf::new(); +/// let filename = wgpu::util::pipeline_cache_key(&adapter_info); +/// if let Some(filename) = filename { +/// let cache_file = cache_dir.join(&filename); +/// let cache_data = std::fs::read(&cache_file); +/// let pipeline_cache: wgpu::PipelineCache = todo!("Use data (if present) to create a pipeline cache"); +/// +/// let data = pipeline_cache.get_data(); +/// if let Some(data) = data { +/// let temp_file = cache_file.with_extension("temp"); +/// std::fs::write(&temp_file, &data)?; +/// std::fs::rename(&temp_file, &cache_file)?; +/// } +/// } +/// # Ok::<(), std::io::Error>(()) +/// ``` +/// +/// [`PipelineCache`]: super::PipelineCache +pub fn pipeline_cache_key(adapter_info: &wgt::AdapterInfo) -> Option { + match adapter_info.backend { + wgt::Backend::Vulkan => Some(format!( + // The vendor/device should uniquely define a driver + // We/the driver will also later validate that the vendor/device and driver + // version match, which may lead to clearing an outdated + // cache for the same device. + "wgpu_pipeline_cache_vulkan_{}_{}", + adapter_info.vendor, adapter_info.device + )), + _ => None, + } +}