Pipeline cache API and implementation for Vulkan (#5319)

Co-authored-by: Connor Fitzgerald <connorwadefitzgerald@gmail.com>
This commit is contained in:
Daniel McNab 2024-05-16 15:52:56 +02:00 committed by GitHub
parent eeb1a9d7b7
commit 4902e470ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
76 changed files with 1578 additions and 19 deletions

View File

@ -72,6 +72,10 @@ By @stefnotch in [#5410](https://github.com/gfx-rs/wgpu/pull/5410)
### New features ### New features
#### Vulkan
- Added a `PipelineCache` resource to allow using Vulkan pipeline caches. By @DJMcNab in [#5319](https://github.com/gfx-rs/wgpu/pull/5319)
#### General #### General
#### Naga #### Naga

View File

@ -207,6 +207,7 @@ impl RenderpassState {
compilation_options: wgpu::PipelineCompilationOptions::default(), compilation_options: wgpu::PipelineCompilationOptions::default(),
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let render_target = device_state let render_target = device_state
@ -304,6 +305,7 @@ impl RenderpassState {
compilation_options: wgpu::PipelineCompilationOptions::default(), compilation_options: wgpu::PipelineCompilationOptions::default(),
}), }),
multiview: None, multiview: None,
cache: None,
}, },
)); ));
} }

View File

@ -115,6 +115,7 @@ pub fn op_webgpu_create_compute_pipeline(
constants: Cow::Owned(compute.constants.unwrap_or_default()), constants: Cow::Owned(compute.constants.unwrap_or_default()),
zero_initialize_workgroup_memory: true, zero_initialize_workgroup_memory: true,
}, },
cache: None,
}; };
let implicit_pipelines = match layout { let implicit_pipelines = match layout {
GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(_) => None, GPUPipelineLayoutOrGPUAutoLayoutMode::Layout(_) => None,
@ -395,6 +396,7 @@ pub fn op_webgpu_create_render_pipeline(
multisample: args.multisample, multisample: args.multisample,
fragment, fragment,
multiview: None, multiview: None,
cache: None,
}; };
let implicit_pipelines = match args.layout { let implicit_pipelines = match args.layout {

View File

@ -156,6 +156,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
// create compute pipeline // create compute pipeline
@ -166,6 +167,7 @@ impl crate::framework::Example for Example {
module: &compute_shader, module: &compute_shader,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// buffer for the three 2d triangle vertices of each instance // buffer for the three 2d triangle vertices of each instance

View File

@ -224,6 +224,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let texture = { let texture = {

View File

@ -113,6 +113,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let pipeline_triangle_regular = let pipeline_triangle_regular =
@ -135,6 +136,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let pipeline_lines = if device let pipeline_lines = if device
@ -165,6 +167,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}), }),
) )
} else { } else {
@ -224,6 +227,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}), }),
bind_group_layout, bind_group_layout,
) )

View File

@ -260,6 +260,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let pipeline_wire = if device let pipeline_wire = if device
@ -301,6 +302,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
Some(pipeline_wire) Some(pipeline_wire)
} else { } else {

View File

@ -110,6 +110,7 @@ async fn execute_gpu_inner(
module: &cs_module, module: &cs_module,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// Instantiates the bind group, once again specifying the binding of buffers. // Instantiates the bind group, once again specifying the binding of buffers.

View File

@ -104,6 +104,7 @@ async fn execute(
module: &shaders_module, module: &shaders_module,
entry_point: "patient_main", entry_point: "patient_main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor { let hasty_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None, label: None,
@ -111,6 +112,7 @@ async fn execute(
module: &shaders_module, module: &shaders_module,
entry_point: "hasty_main", entry_point: "hasty_main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
//---------------------------------------------------------- //----------------------------------------------------------

View File

@ -72,6 +72,7 @@ async fn run(event_loop: EventLoop<()>, window: Window) {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let mut config = surface let mut config = surface

View File

@ -111,6 +111,7 @@ async fn run() {
module: &shader, module: &shader,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
//---------------------------------------------------------- //----------------------------------------------------------

View File

@ -109,6 +109,7 @@ impl Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let bind_group_layout = pipeline.get_bind_group_layout(0); let bind_group_layout = pipeline.get_bind_group_layout(0);
@ -310,6 +311,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
// Create bind group // Create bind group

View File

@ -78,6 +78,7 @@ impl Example {
..Default::default() ..Default::default()
}, },
multiview: None, multiview: None,
cache: None,
}); });
let mut encoder = let mut encoder =
device.create_render_bundle_encoder(&wgpu::RenderBundleEncoderDescriptor { device.create_render_bundle_encoder(&wgpu::RenderBundleEncoderDescriptor {

View File

@ -72,6 +72,7 @@ async fn run(_path: Option<String>) {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
log::info!("Wgpu context set up."); log::info!("Wgpu context set up.");

View File

@ -246,6 +246,7 @@ impl WgpuContext {
module: &shader, module: &shader,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
WgpuContext { WgpuContext {

View File

@ -526,6 +526,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
Pass { Pass {
@ -660,6 +661,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
Pass { Pass {

View File

@ -221,6 +221,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let entity_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { let entity_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
label: Some("Entity"), label: Some("Entity"),
@ -254,6 +255,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let sampler = device.create_sampler(&wgpu::SamplerDescriptor { let sampler = device.create_sampler(&wgpu::SamplerDescriptor {

View File

@ -151,6 +151,7 @@ impl<const SRGB: bool> crate::framework::Example for Example<SRGB> {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
// Done // Done

View File

@ -106,6 +106,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let outer_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor { let outer_pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
@ -141,6 +142,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let stencil_buffer = device.create_texture(&wgpu::TextureDescriptor { let stencil_buffer = device.create_texture(&wgpu::TextureDescriptor {

View File

@ -101,6 +101,7 @@ async fn run(_path: Option<String>) {
module: &shader, module: &shader,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
log::info!("Wgpu context set up."); log::info!("Wgpu context set up.");

View File

@ -341,6 +341,7 @@ impl crate::framework::Example for Example {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None
}); });
Self { Self {

View File

@ -299,6 +299,7 @@ fn compute_pass(
module, module,
entry_point: "main_cs", entry_point: "main_cs",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let bind_group_layout = compute_pipeline.get_bind_group_layout(0); let bind_group_layout = compute_pipeline.get_bind_group_layout(0);
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
@ -366,8 +367,8 @@ fn render_pass(
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let render_target = device.create_texture(&wgpu::TextureDescriptor { let render_target = device.create_texture(&wgpu::TextureDescriptor {
label: Some("rendertarget"), label: Some("rendertarget"),
size: wgpu::Extent3d { size: wgpu::Extent3d {

View File

@ -192,8 +192,8 @@ impl WgpuContext {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let surface_config = surface let surface_config = surface
.get_default_config(&adapter, size.width, size.height) .get_default_config(&adapter, size.width, size.height)
.unwrap(); .unwrap();

View File

@ -574,6 +574,8 @@ impl crate::framework::Example for Example {
// No multisampling is used. // No multisampling is used.
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
// No pipeline caching is used
cache: None,
}); });
// Same idea as the water pipeline. // Same idea as the water pipeline.
@ -610,6 +612,7 @@ impl crate::framework::Example for Example {
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None
}); });
// A render bundle to draw the terrain. // A render bundle to draw the terrain.

View File

@ -302,6 +302,12 @@ impl GlobalPlay for wgc::global::Global {
Action::DestroyRenderPipeline(id) => { Action::DestroyRenderPipeline(id) => {
self.render_pipeline_drop::<A>(id); self.render_pipeline_drop::<A>(id);
} }
Action::CreatePipelineCache { id, desc } => {
let _ = unsafe { self.device_create_pipeline_cache::<A>(device, &desc, Some(id)) };
}
Action::DestroyPipelineCache(id) => {
self.pipeline_cache_drop::<A>(id);
}
Action::CreateRenderBundle { id, desc, base } => { Action::CreateRenderBundle { id, desc, base } => {
let bundle = let bundle =
wgc::command::RenderBundleEncoder::new(&desc, device, Some(base)).unwrap(); wgc::command::RenderBundleEncoder::new(&desc, device, Some(base)).unwrap();

View File

@ -370,6 +370,7 @@ fn copy_via_compute(
module: &sm, module: &sm,
entry_point: "copy_texture_to_buffer", entry_point: "copy_texture_to_buffer",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
{ {

View File

@ -98,6 +98,7 @@ static BGRA8_UNORM_STORAGE: GpuTestConfiguration = GpuTestConfiguration::new()
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
module: &module, module: &module,
cache: None,
}); });
let mut encoder = let mut encoder =

View File

@ -91,6 +91,7 @@ async fn bgl_dedupe(ctx: TestingContext) {
module: &module, module: &module,
entry_point: "no_resources", entry_point: "no_resources",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}; };
let pipeline = ctx.device.create_compute_pipeline(&desc); let pipeline = ctx.device.create_compute_pipeline(&desc);
@ -220,6 +221,7 @@ fn bgl_dedupe_with_dropped_user_handle(ctx: TestingContext) {
module: &module, module: &module,
entry_point: "no_resources", entry_point: "no_resources",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let mut encoder = ctx.device.create_command_encoder(&Default::default()); let mut encoder = ctx.device.create_command_encoder(&Default::default());
@ -266,6 +268,7 @@ fn bgl_dedupe_derived(ctx: TestingContext) {
module: &module, module: &module,
entry_point: "resources", entry_point: "resources",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// We create two bind groups, pulling the bind_group_layout from the pipeline each time. // We create two bind groups, pulling the bind_group_layout from the pipeline each time.
@ -337,6 +340,7 @@ fn separate_programs_have_incompatible_derived_bgls(ctx: TestingContext) {
module: &module, module: &module,
entry_point: "resources", entry_point: "resources",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}; };
// Create two pipelines, creating a BG from the second. // Create two pipelines, creating a BG from the second.
let pipeline1 = ctx.device.create_compute_pipeline(&desc); let pipeline1 = ctx.device.create_compute_pipeline(&desc);
@ -399,6 +403,7 @@ fn derived_bgls_incompatible_with_regular_bgls(ctx: TestingContext) {
module: &module, module: &module,
entry_point: "resources", entry_point: "resources",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// Create a matching BGL // Create a matching BGL

View File

@ -225,6 +225,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_LAYOUT: GpuTestConfiguration = GpuTestConfigu
module: &shader_module, module: &shader_module,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
}); });
}); });
@ -294,6 +295,7 @@ static MINIMUM_BUFFER_BINDING_SIZE_DISPATCH: GpuTestConfiguration = GpuTestConfi
module: &shader_module, module: &shader_module,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor { let buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {

View File

@ -161,6 +161,7 @@ fn resource_setup(ctx: &TestingContext) -> ResourceSetup {
module: &sm, module: &sm,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
ResourceSetup { ResourceSetup {

View File

@ -488,6 +488,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
fragment: None, fragment: None,
multiview: None, multiview: None,
cache: None,
}); });
}); });
@ -500,6 +501,7 @@ static DEVICE_DESTROY_THEN_MORE: GpuTestConfiguration = GpuTestConfiguration::ne
module: &shader_module, module: &shader_module,
entry_point: "", entry_point: "",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
}); });
@ -757,6 +759,7 @@ fn vs_main() -> @builtin(position) vec4<f32> {
depth_stencil: None, depth_stencil: None,
multisample: wgt::MultisampleState::default(), multisample: wgt::MultisampleState::default(),
multiview: None, multiview: None,
cache: None
}); });
// fail(&ctx.device, || { // fail(&ctx.device, || {

View File

@ -113,6 +113,7 @@ async fn draw_test_with_reports(
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let global_report = ctx.instance.generate_report().unwrap(); let global_report = ctx.instance.generate_report().unwrap();

View File

@ -41,6 +41,7 @@ static NV12_TEXTURE_CREATION_SAMPLING: GpuTestConfiguration = GpuTestConfigurati
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let tex = ctx.device.create_texture(&wgpu::TextureDescriptor { let tex = ctx.device.create_texture(&wgpu::TextureDescriptor {

View File

@ -51,6 +51,7 @@ static OCCLUSION_QUERY: GpuTestConfiguration = GpuTestConfiguration::new()
}), }),
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
// Create occlusion query set // Create occlusion query set

View File

@ -70,6 +70,7 @@ static PARTIALLY_BOUNDED_ARRAY: GpuTestConfiguration = GpuTestConfiguration::new
module: &cs_module, module: &cs_module,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {

View File

@ -29,6 +29,7 @@ static PIPELINE_DEFAULT_LAYOUT_BAD_MODULE: GpuTestConfiguration = GpuTestConfigu
module: &module, module: &module,
entry_point: "doesn't exist", entry_point: "doesn't exist",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
pipeline.get_bind_group_layout(0); pipeline.get_bind_group_layout(0);

View File

@ -0,0 +1,192 @@
use std::{fmt::Write, num::NonZeroU64};
use wgpu_test::{gpu_test, GpuTestConfiguration, TestParameters, TestingContext};
/// We want to test that using a pipeline cache doesn't cause failure
///
/// It would be nice if we could also assert that reusing a pipeline cache would make compilation
/// be faster however, some drivers use a fallback pipeline cache, which makes this inconsistent
/// (both intra- and inter-run).
#[gpu_test]
static PIPELINE_CACHE: GpuTestConfiguration = GpuTestConfiguration::new()
.parameters(
TestParameters::default()
.test_features_limits()
.features(wgpu::Features::PIPELINE_CACHE),
)
.run_async(pipeline_cache_test);
/// Set to a higher value if adding a timing based assertion. This is otherwise fast to compile
const ARRAY_SIZE: u64 = 256;
/// Create a shader which should be slow-ish to compile
fn shader() -> String {
let mut body = String::new();
for idx in 0..ARRAY_SIZE {
// "Safety": There will only be a single workgroup, and a single thread in that workgroup
writeln!(body, " output[{idx}] = {idx}u;")
.expect("`u64::fmt` and `String::write_fmt` are infallible");
}
format!(
r#"
@group(0) @binding(0)
var<storage, read_write> output: array<u32>;
@compute @workgroup_size(1)
fn main() {{
{body}
}}
"#,
)
}
async fn pipeline_cache_test(ctx: TestingContext) {
let shader = shader();
let sm = ctx
.device
.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("shader"),
source: wgpu::ShaderSource::Wgsl(shader.into()),
});
let bgl = ctx
.device
.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("bind_group_layout"),
entries: &[wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: NonZeroU64::new(ARRAY_SIZE * 4),
},
count: None,
}],
});
let gpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("gpu_buffer"),
size: ARRAY_SIZE * 4,
usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
mapped_at_creation: false,
});
let cpu_buffer = ctx.device.create_buffer(&wgpu::BufferDescriptor {
label: Some("cpu_buffer"),
size: ARRAY_SIZE * 4,
usage: wgpu::BufferUsages::COPY_DST | wgpu::BufferUsages::MAP_READ,
mapped_at_creation: false,
});
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
label: Some("bind_group"),
layout: &bgl,
entries: &[wgpu::BindGroupEntry {
binding: 0,
resource: gpu_buffer.as_entire_binding(),
}],
});
let pipeline_layout = ctx
.device
.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("pipeline_layout"),
bind_group_layouts: &[&bgl],
push_constant_ranges: &[],
});
let first_cache_data;
{
let first_cache = unsafe {
ctx.device
.create_pipeline_cache(&wgpu::PipelineCacheDescriptor {
label: Some("pipeline_cache"),
data: None,
fallback: false,
})
};
let first_pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pipeline"),
layout: Some(&pipeline_layout),
module: &sm,
entry_point: "main",
compilation_options: Default::default(),
cache: Some(&first_cache),
});
validate_pipeline(&ctx, first_pipeline, &bind_group, &gpu_buffer, &cpu_buffer).await;
first_cache_data = first_cache.get_data();
}
assert!(first_cache_data.is_some());
let second_cache = unsafe {
ctx.device
.create_pipeline_cache(&wgpu::PipelineCacheDescriptor {
label: Some("pipeline_cache"),
data: first_cache_data.as_deref(),
fallback: false,
})
};
let first_pipeline = ctx
.device
.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("pipeline"),
layout: Some(&pipeline_layout),
module: &sm,
entry_point: "main",
compilation_options: Default::default(),
cache: Some(&second_cache),
});
validate_pipeline(&ctx, first_pipeline, &bind_group, &gpu_buffer, &cpu_buffer).await;
// Ideally, we could assert here that the second compilation was faster than the first
// However, that doesn't actually work, because drivers have their own internal caches.
// This does work on my machine if I set `MESA_DISABLE_PIPELINE_CACHE=1`
// before running the test; but of course that is not a realistic scenario
}
async fn validate_pipeline(
ctx: &TestingContext,
pipeline: wgpu::ComputePipeline,
bind_group: &wgpu::BindGroup,
gpu_buffer: &wgpu::Buffer,
cpu_buffer: &wgpu::Buffer,
) {
let mut encoder = ctx
.device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("encoder"),
});
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("compute_pass"),
timestamp_writes: None,
});
cpass.set_pipeline(&pipeline);
cpass.set_bind_group(0, bind_group, &[]);
cpass.dispatch_workgroups(1, 1, 1);
}
encoder.copy_buffer_to_buffer(gpu_buffer, 0, cpu_buffer, 0, ARRAY_SIZE * 4);
ctx.queue.submit([encoder.finish()]);
cpu_buffer.slice(..).map_async(wgpu::MapMode::Read, |_| ());
ctx.async_poll(wgpu::Maintain::wait())
.await
.panic_on_timeout();
let data = cpu_buffer.slice(..).get_mapped_range();
let arrays: &[u32] = bytemuck::cast_slice(&data);
assert_eq!(arrays.len(), ARRAY_SIZE as usize);
for (idx, value) in arrays.iter().copied().enumerate() {
assert_eq!(value as usize, idx);
}
drop(data);
cpu_buffer.unmap();
}

View File

@ -104,6 +104,7 @@ async fn partial_update_test(ctx: TestingContext) {
module: &sm, module: &sm,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let mut encoder = ctx let mut encoder = ctx

View File

@ -119,6 +119,7 @@ async fn multi_stage_data_binding_test(ctx: TestingContext) {
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let texture = ctx.device.create_texture(&wgpu::TextureDescriptor { let texture = ctx.device.create_texture(&wgpu::TextureDescriptor {

View File

@ -80,6 +80,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration =
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let single_pipeline = ctx let single_pipeline = ctx
@ -111,6 +112,7 @@ static PASS_RESET_VERTEX_BUFFER: GpuTestConfiguration =
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let view = ctx let view = ctx

View File

@ -24,6 +24,7 @@ mod nv12_texture;
mod occlusion_query; mod occlusion_query;
mod partially_bounded_arrays; mod partially_bounded_arrays;
mod pipeline; mod pipeline;
mod pipeline_cache;
mod poll; mod poll;
mod push_constants; mod push_constants;
mod query_set; mod query_set;

View File

@ -61,6 +61,7 @@ async fn scissor_test_impl(
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let readback_buffer = image::ReadbackBuffers::new(&ctx.device, &texture); let readback_buffer = image::ReadbackBuffers::new(&ctx.device, &texture);

View File

@ -310,6 +310,7 @@ async fn shader_input_output_test(
module: &sm, module: &sm,
entry_point: "cs_main", entry_point: "cs_main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// -- Initializing data -- // -- Initializing data --

View File

@ -88,6 +88,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration::
module: &sm, module: &sm,
entry_point: "read", entry_point: "read",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let pipeline_write = ctx let pipeline_write = ctx
@ -98,6 +99,7 @@ static ZERO_INIT_WORKGROUP_MEMORY: GpuTestConfiguration = GpuTestConfiguration::
module: &sm, module: &sm,
entry_point: "write", entry_point: "write",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
// -- Initializing data -- // -- Initializing data --

View File

@ -147,6 +147,7 @@ async fn pulling_common(
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}); });
let width = 2; let width = 2;

View File

@ -109,6 +109,7 @@ async fn reinterpret(
depth_stencil: None, depth_stencil: None,
multisample: wgpu::MultisampleState::default(), multisample: wgpu::MultisampleState::default(),
multiview: None, multiview: None,
cache: None,
}); });
let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor { let bind_group = ctx.device.create_bind_group(&wgpu::BindGroupDescriptor {
layout: &pipeline.get_bind_group_layout(0), layout: &pipeline.get_bind_group_layout(0),

View File

@ -75,6 +75,7 @@ static SUBGROUP_OPERATIONS: GpuTestConfiguration = GpuTestConfiguration::new()
module: &cs_module, module: &cs_module,
entry_point: "main", entry_point: "main",
compilation_options: Default::default(), compilation_options: Default::default(),
cache: None,
}); });
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor { let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {

View File

@ -295,6 +295,7 @@ async fn vertex_index_common(ctx: TestingContext) {
})], })],
}), }),
multiview: None, multiview: None,
cache: None,
}; };
let builtin_pipeline = ctx.device.create_render_pipeline(&pipeline_desc); let builtin_pipeline = ctx.device.create_render_pipeline(&pipeline_desc);
pipeline_desc.vertex.entry_point = "vs_main_buffers"; pipeline_desc.vertex.entry_point = "vs_main_buffers";

View File

@ -13,8 +13,10 @@ use crate::{
instance::{self, Adapter, Surface}, instance::{self, Adapter, Surface},
lock::{rank, RwLock}, lock::{rank, RwLock},
pipeline, present, pipeline, present,
resource::{self, BufferAccessResult}, resource::{
resource::{BufferAccessError, BufferMapOperation, CreateBufferError, Resource}, self, BufferAccessError, BufferAccessResult, BufferMapOperation, CreateBufferError,
Resource,
},
validation::check_buffer_usage, validation::check_buffer_usage,
Label, LabelHelpers as _, Label, LabelHelpers as _,
}; };
@ -1823,6 +1825,66 @@ impl Global {
} }
} }
/// # Safety
/// The `data` argument of `desc` must have been returned by
/// [Self::pipeline_cache_get_data] for the same adapter
pub unsafe fn device_create_pipeline_cache<A: HalApi>(
&self,
device_id: DeviceId,
desc: &pipeline::PipelineCacheDescriptor<'_>,
id_in: Option<id::PipelineCacheId>,
) -> (
id::PipelineCacheId,
Option<pipeline::CreatePipelineCacheError>,
) {
profiling::scope!("Device::create_pipeline_cache");
let hub = A::hub(self);
let fid = hub.pipeline_caches.prepare(id_in);
let error: pipeline::CreatePipelineCacheError = 'error: {
let device = match hub.devices.get(device_id) {
Ok(device) => device,
// TODO: Handle error properly
Err(crate::storage::InvalidId) => break 'error DeviceError::Invalid.into(),
};
if !device.is_valid() {
break 'error DeviceError::Lost.into();
}
#[cfg(feature = "trace")]
if let Some(ref mut trace) = *device.trace.lock() {
trace.add(trace::Action::CreatePipelineCache {
id: fid.id(),
desc: desc.clone(),
});
}
let cache = unsafe { device.create_pipeline_cache(desc) };
match cache {
Ok(cache) => {
let (id, _) = fid.assign(Arc::new(cache));
api_log!("Device::create_pipeline_cache -> {id:?}");
return (id, None);
}
Err(e) => break 'error e,
}
};
let id = fid.assign_error(desc.label.borrow_or_default());
(id, Some(error))
}
pub fn pipeline_cache_drop<A: HalApi>(&self, pipeline_cache_id: id::PipelineCacheId) {
profiling::scope!("PipelineCache::drop");
api_log!("PipelineCache::drop {pipeline_cache_id:?}");
let hub = A::hub(self);
if let Some(cache) = hub.pipeline_caches.unregister(pipeline_cache_id) {
drop(cache)
}
}
pub fn surface_configure<A: HalApi>( pub fn surface_configure<A: HalApi>(
&self, &self,
surface_id: SurfaceId, surface_id: SurfaceId,
@ -2270,6 +2332,37 @@ impl Global {
.force_replace_with_error(device_id, "Made invalid."); .force_replace_with_error(device_id, "Made invalid.");
} }
pub fn pipeline_cache_get_data<A: HalApi>(&self, id: id::PipelineCacheId) -> Option<Vec<u8>> {
use crate::pipeline_cache;
api_log!("PipelineCache::get_data");
let hub = A::hub(self);
if let Ok(cache) = hub.pipeline_caches.get(id) {
// TODO: Is this check needed?
if !cache.device.is_valid() {
return None;
}
if let Some(raw_cache) = cache.raw.as_ref() {
let mut vec = unsafe { cache.device.raw().pipeline_cache_get_data(raw_cache) }?;
let validation_key = cache.device.raw().pipeline_cache_validation_key()?;
let mut header_contents = [0; pipeline_cache::HEADER_LENGTH];
pipeline_cache::add_cache_header(
&mut header_contents,
&vec,
&cache.device.adapter.raw.info,
validation_key,
);
let deleted = vec.splice(..0, header_contents).collect::<Vec<_>>();
debug_assert!(deleted.is_empty());
return Some(vec);
}
}
None
}
pub fn device_drop<A: HalApi>(&self, device_id: DeviceId) { pub fn device_drop<A: HalApi>(&self, device_id: DeviceId) {
profiling::scope!("Device::drop"); profiling::scope!("Device::drop");
api_log!("Device::drop {device_id:?}"); api_log!("Device::drop {device_id:?}");

View File

@ -2817,6 +2817,20 @@ impl<A: HalApi> Device<A> {
let late_sized_buffer_groups = let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout);
let cache = 'cache: {
let Some(cache) = desc.cache else {
break 'cache None;
};
let Ok(cache) = hub.pipeline_caches.get(cache) else {
break 'cache None;
};
if cache.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
Some(cache)
};
let pipeline_desc = hal::ComputePipelineDescriptor { let pipeline_desc = hal::ComputePipelineDescriptor {
label: desc.label.to_hal(self.instance_flags), label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(), layout: pipeline_layout.raw(),
@ -2826,6 +2840,7 @@ impl<A: HalApi> Device<A> {
constants: desc.stage.constants.as_ref(), constants: desc.stage.constants.as_ref(),
zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory, zero_initialize_workgroup_memory: desc.stage.zero_initialize_workgroup_memory,
}, },
cache: cache.as_ref().and_then(|it| it.raw.as_ref()),
}; };
let raw = unsafe { let raw = unsafe {
@ -3199,6 +3214,7 @@ impl<A: HalApi> Device<A> {
let vertex_shader_module; let vertex_shader_module;
let vertex_entry_point_name; let vertex_entry_point_name;
let vertex_stage = { let vertex_stage = {
let stage_desc = &desc.vertex.stage; let stage_desc = &desc.vertex.stage;
let stage = wgt::ShaderStages::VERTEX; let stage = wgt::ShaderStages::VERTEX;
@ -3393,6 +3409,20 @@ impl<A: HalApi> Device<A> {
let late_sized_buffer_groups = let late_sized_buffer_groups =
Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout); Device::make_late_sized_buffer_groups(&shader_binding_sizes, &pipeline_layout);
let pipeline_cache = 'cache: {
let Some(cache) = desc.cache else {
break 'cache None;
};
let Ok(cache) = hub.pipeline_caches.get(cache) else {
break 'cache None;
};
if cache.device.as_info().id() != self.as_info().id() {
return Err(DeviceError::WrongDevice.into());
}
Some(cache)
};
let pipeline_desc = hal::RenderPipelineDescriptor { let pipeline_desc = hal::RenderPipelineDescriptor {
label: desc.label.to_hal(self.instance_flags), label: desc.label.to_hal(self.instance_flags),
layout: pipeline_layout.raw(), layout: pipeline_layout.raw(),
@ -3404,6 +3434,7 @@ impl<A: HalApi> Device<A> {
fragment_stage, fragment_stage,
color_targets, color_targets,
multiview: desc.multiview, multiview: desc.multiview,
cache: pipeline_cache.as_ref().and_then(|it| it.raw.as_ref()),
}; };
let raw = unsafe { let raw = unsafe {
self.raw self.raw
@ -3484,6 +3515,53 @@ impl<A: HalApi> Device<A> {
Ok(pipeline) Ok(pipeline)
} }
/// # Safety
/// The `data` field on `desc` must have previously been returned from [`crate::global::Global::pipeline_cache_get_data`]
pub unsafe fn create_pipeline_cache(
self: &Arc<Self>,
desc: &pipeline::PipelineCacheDescriptor,
) -> Result<pipeline::PipelineCache<A>, pipeline::CreatePipelineCacheError> {
use crate::pipeline_cache;
self.require_features(wgt::Features::PIPELINE_CACHE)?;
let data = if let Some((data, validation_key)) = desc
.data
.as_ref()
.zip(self.raw().pipeline_cache_validation_key())
{
let data = pipeline_cache::validate_pipeline_cache(
data,
&self.adapter.raw.info,
validation_key,
);
match data {
Ok(data) => Some(data),
Err(e) if e.was_avoidable() || !desc.fallback => return Err(e.into()),
// If the error was unavoidable and we are asked to fallback, do so
Err(_) => None,
}
} else {
None
};
let cache_desc = hal::PipelineCacheDescriptor {
data,
label: desc.label.to_hal(self.instance_flags),
};
let raw = match unsafe { self.raw().create_pipeline_cache(&cache_desc) } {
Ok(raw) => raw,
Err(e) => return Err(e.into()),
};
let cache = pipeline::PipelineCache {
device: self.clone(),
info: ResourceInfo::new(
desc.label.borrow_or_default(),
Some(self.tracker_indices.pipeline_caches.clone()),
),
// This would be none in the error condition, which we don't implement yet
raw: Some(raw),
};
Ok(cache)
}
pub(crate) fn get_texture_format_features( pub(crate) fn get_texture_format_features(
&self, &self,
adapter: &Adapter<A>, adapter: &Adapter<A>,

View File

@ -98,6 +98,11 @@ pub enum Action<'a> {
implicit_context: Option<super::ImplicitPipelineContext>, implicit_context: Option<super::ImplicitPipelineContext>,
}, },
DestroyRenderPipeline(id::RenderPipelineId), DestroyRenderPipeline(id::RenderPipelineId),
CreatePipelineCache {
id: id::PipelineCacheId,
desc: crate::pipeline::PipelineCacheDescriptor<'a>,
},
DestroyPipelineCache(id::PipelineCacheId),
CreateRenderBundle { CreateRenderBundle {
id: id::RenderBundleId, id: id::RenderBundleId,
desc: crate::command::RenderBundleEncoderDescriptor<'a>, desc: crate::command::RenderBundleEncoderDescriptor<'a>,

View File

@ -110,7 +110,7 @@ use crate::{
device::{queue::Queue, Device}, device::{queue::Queue, Device},
hal_api::HalApi, hal_api::HalApi,
instance::{Adapter, Surface}, instance::{Adapter, Surface},
pipeline::{ComputePipeline, RenderPipeline, ShaderModule}, pipeline::{ComputePipeline, PipelineCache, RenderPipeline, ShaderModule},
registry::{Registry, RegistryReport}, registry::{Registry, RegistryReport},
resource::{Buffer, QuerySet, Sampler, StagingBuffer, Texture, TextureView}, resource::{Buffer, QuerySet, Sampler, StagingBuffer, Texture, TextureView},
storage::{Element, Storage}, storage::{Element, Storage},
@ -130,6 +130,7 @@ pub struct HubReport {
pub render_bundles: RegistryReport, pub render_bundles: RegistryReport,
pub render_pipelines: RegistryReport, pub render_pipelines: RegistryReport,
pub compute_pipelines: RegistryReport, pub compute_pipelines: RegistryReport,
pub pipeline_caches: RegistryReport,
pub query_sets: RegistryReport, pub query_sets: RegistryReport,
pub buffers: RegistryReport, pub buffers: RegistryReport,
pub textures: RegistryReport, pub textures: RegistryReport,
@ -180,6 +181,7 @@ pub struct Hub<A: HalApi> {
pub(crate) render_bundles: Registry<RenderBundle<A>>, pub(crate) render_bundles: Registry<RenderBundle<A>>,
pub(crate) render_pipelines: Registry<RenderPipeline<A>>, pub(crate) render_pipelines: Registry<RenderPipeline<A>>,
pub(crate) compute_pipelines: Registry<ComputePipeline<A>>, pub(crate) compute_pipelines: Registry<ComputePipeline<A>>,
pub(crate) pipeline_caches: Registry<PipelineCache<A>>,
pub(crate) query_sets: Registry<QuerySet<A>>, pub(crate) query_sets: Registry<QuerySet<A>>,
pub(crate) buffers: Registry<Buffer<A>>, pub(crate) buffers: Registry<Buffer<A>>,
pub(crate) staging_buffers: Registry<StagingBuffer<A>>, pub(crate) staging_buffers: Registry<StagingBuffer<A>>,
@ -202,6 +204,7 @@ impl<A: HalApi> Hub<A> {
render_bundles: Registry::new(A::VARIANT), render_bundles: Registry::new(A::VARIANT),
render_pipelines: Registry::new(A::VARIANT), render_pipelines: Registry::new(A::VARIANT),
compute_pipelines: Registry::new(A::VARIANT), compute_pipelines: Registry::new(A::VARIANT),
pipeline_caches: Registry::new(A::VARIANT),
query_sets: Registry::new(A::VARIANT), query_sets: Registry::new(A::VARIANT),
buffers: Registry::new(A::VARIANT), buffers: Registry::new(A::VARIANT),
staging_buffers: Registry::new(A::VARIANT), staging_buffers: Registry::new(A::VARIANT),
@ -235,6 +238,7 @@ impl<A: HalApi> Hub<A> {
self.pipeline_layouts.write().map.clear(); self.pipeline_layouts.write().map.clear();
self.compute_pipelines.write().map.clear(); self.compute_pipelines.write().map.clear();
self.render_pipelines.write().map.clear(); self.render_pipelines.write().map.clear();
self.pipeline_caches.write().map.clear();
self.query_sets.write().map.clear(); self.query_sets.write().map.clear();
for element in surface_guard.map.iter() { for element in surface_guard.map.iter() {
@ -280,6 +284,7 @@ impl<A: HalApi> Hub<A> {
render_bundles: self.render_bundles.generate_report(), render_bundles: self.render_bundles.generate_report(),
render_pipelines: self.render_pipelines.generate_report(), render_pipelines: self.render_pipelines.generate_report(),
compute_pipelines: self.compute_pipelines.generate_report(), compute_pipelines: self.compute_pipelines.generate_report(),
pipeline_caches: self.pipeline_caches.generate_report(),
query_sets: self.query_sets.generate_report(), query_sets: self.query_sets.generate_report(),
buffers: self.buffers.generate_report(), buffers: self.buffers.generate_report(),
textures: self.textures.generate_report(), textures: self.textures.generate_report(),

View File

@ -313,6 +313,7 @@ ids! {
pub type ShaderModuleId ShaderModule; pub type ShaderModuleId ShaderModule;
pub type RenderPipelineId RenderPipeline; pub type RenderPipelineId RenderPipeline;
pub type ComputePipelineId ComputePipeline; pub type ComputePipelineId ComputePipeline;
pub type PipelineCacheId PipelineCache;
pub type CommandEncoderId CommandEncoder; pub type CommandEncoderId CommandEncoder;
pub type CommandBufferId CommandBuffer; pub type CommandBufferId CommandBuffer;
pub type RenderPassEncoderId RenderPassEncoder; pub type RenderPassEncoderId RenderPassEncoder;

View File

@ -65,6 +65,7 @@ mod init_tracker;
pub mod instance; pub mod instance;
mod lock; mod lock;
pub mod pipeline; pub mod pipeline;
mod pipeline_cache;
mod pool; mod pool;
pub mod present; pub mod present;
pub mod registry; pub mod registry;

View File

@ -1,11 +1,12 @@
#[cfg(feature = "trace")] #[cfg(feature = "trace")]
use crate::device::trace; use crate::device::trace;
pub use crate::pipeline_cache::PipelineCacheValidationError;
use crate::{ use crate::{
binding_model::{CreateBindGroupLayoutError, CreatePipelineLayoutError, PipelineLayout}, binding_model::{CreateBindGroupLayoutError, CreatePipelineLayoutError, PipelineLayout},
command::ColorAttachmentError, command::ColorAttachmentError,
device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures, RenderPassContext}, device::{Device, DeviceError, MissingDownlevelFlags, MissingFeatures, RenderPassContext},
hal_api::HalApi, hal_api::HalApi,
id::{PipelineLayoutId, ShaderModuleId}, id::{PipelineCacheId, PipelineLayoutId, ShaderModuleId},
resource::{Resource, ResourceInfo, ResourceType}, resource::{Resource, ResourceInfo, ResourceType},
resource_log, validation, Label, resource_log, validation, Label,
}; };
@ -192,6 +193,8 @@ pub struct ComputePipelineDescriptor<'a> {
pub layout: Option<PipelineLayoutId>, pub layout: Option<PipelineLayoutId>,
/// The compiled compute stage and its entry point. /// The compiled compute stage and its entry point.
pub stage: ProgrammableStageDescriptor<'a>, pub stage: ProgrammableStageDescriptor<'a>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<PipelineCacheId>,
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]
@ -259,6 +262,68 @@ impl<A: HalApi> ComputePipeline<A> {
} }
} }
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum CreatePipelineCacheError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Pipeline cache validation failed")]
Validation(#[from] PipelineCacheValidationError),
#[error(transparent)]
MissingFeatures(#[from] MissingFeatures),
#[error("Internal error: {0}")]
Internal(String),
}
impl From<hal::PipelineCacheError> for CreatePipelineCacheError {
fn from(value: hal::PipelineCacheError) -> Self {
match value {
hal::PipelineCacheError::Device(device) => {
CreatePipelineCacheError::Device(device.into())
}
}
}
}
#[derive(Debug)]
pub struct PipelineCache<A: HalApi> {
pub(crate) raw: Option<A::PipelineCache>,
pub(crate) device: Arc<Device<A>>,
pub(crate) info: ResourceInfo<PipelineCache<A>>,
}
impl<A: HalApi> Drop for PipelineCache<A> {
fn drop(&mut self) {
if let Some(raw) = self.raw.take() {
resource_log!("Destroy raw PipelineCache {:?}", self.info.label());
#[cfg(feature = "trace")]
if let Some(t) = self.device.trace.lock().as_mut() {
t.add(trace::Action::DestroyPipelineCache(self.info.id()));
}
unsafe {
use hal::Device;
self.device.raw().destroy_pipeline_cache(raw);
}
}
}
}
impl<A: HalApi> Resource for PipelineCache<A> {
const TYPE: ResourceType = "PipelineCache";
type Marker = crate::id::markers::PipelineCache;
fn as_info(&self) -> &ResourceInfo<Self> {
&self.info
}
fn as_info_mut(&mut self) -> &mut ResourceInfo<Self> {
&mut self.info
}
}
/// Describes how the vertex buffer is interpreted. /// Describes how the vertex buffer is interpreted.
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
@ -315,6 +380,16 @@ pub struct RenderPipelineDescriptor<'a> {
/// If the pipeline will be used with a multiview render pass, this indicates how many array /// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have. /// layers the attachments will have.
pub multiview: Option<NonZeroU32>, pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<PipelineCacheId>,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct PipelineCacheDescriptor<'a> {
pub label: Label<'a>,
pub data: Option<Cow<'a, [u8]>>,
pub fallback: bool,
} }
#[derive(Clone, Debug, Error)] #[derive(Clone, Debug, Error)]

View File

@ -0,0 +1,530 @@
use thiserror::Error;
use wgt::AdapterInfo;
pub const HEADER_LENGTH: usize = std::mem::size_of::<PipelineCacheHeader>();
#[derive(Debug, PartialEq, Eq, Clone, Error)]
#[non_exhaustive]
pub enum PipelineCacheValidationError {
#[error("The pipeline cache data was truncated")]
Truncated,
#[error("The pipeline cache data was longer than recorded")]
// TODO: Is it plausible that this would happen
Extended,
#[error("The pipeline cache data was corrupted (e.g. the hash didn't match)")]
Corrupted,
#[error("The pipeline cacha data was out of date and so cannot be safely used")]
Outdated,
#[error("The cache data was created for a different device")]
WrongDevice,
#[error("Pipeline cacha data was created for a future version of wgpu")]
Unsupported,
}
impl PipelineCacheValidationError {
/// Could the error have been avoided?
/// That is, is there a mistake in user code interacting with the cache
pub fn was_avoidable(&self) -> bool {
match self {
PipelineCacheValidationError::WrongDevice => true,
PipelineCacheValidationError::Truncated
| PipelineCacheValidationError::Unsupported
| PipelineCacheValidationError::Extended
// It's unusual, but not implausible, to be downgrading wgpu
| PipelineCacheValidationError::Outdated
| PipelineCacheValidationError::Corrupted => false,
}
}
}
/// Validate the data in a pipeline cache
pub fn validate_pipeline_cache<'d>(
cache_data: &'d [u8],
adapter: &AdapterInfo,
validation_key: [u8; 16],
) -> Result<&'d [u8], PipelineCacheValidationError> {
let adapter_key = adapter_key(adapter)?;
let Some((header, remaining_data)) = PipelineCacheHeader::read(cache_data) else {
return Err(PipelineCacheValidationError::Truncated);
};
if header.magic != MAGIC {
return Err(PipelineCacheValidationError::Corrupted);
}
if header.header_version != HEADER_VERSION {
return Err(PipelineCacheValidationError::Outdated);
}
if header.cache_abi != ABI {
return Err(PipelineCacheValidationError::Outdated);
}
if header.backend != adapter.backend as u8 {
return Err(PipelineCacheValidationError::WrongDevice);
}
if header.adapter_key != adapter_key {
return Err(PipelineCacheValidationError::WrongDevice);
}
if header.validation_key != validation_key {
// If the validation key is wrong, that means that this device has changed
// in a way where the cache won't be compatible since the cache was made,
// so it is outdated
return Err(PipelineCacheValidationError::Outdated);
}
let data_size: usize = header
.data_size
.try_into()
// If the data was previously more than 4GiB, and we're still on a 32 bit system (ABI check, above)
// Then the data must be corrupted
.map_err(|_| PipelineCacheValidationError::Corrupted)?;
if remaining_data.len() < data_size {
return Err(PipelineCacheValidationError::Truncated);
}
if remaining_data.len() > data_size {
return Err(PipelineCacheValidationError::Extended);
}
if header.hash_space != HASH_SPACE_VALUE {
return Err(PipelineCacheValidationError::Corrupted);
}
Ok(remaining_data)
}
pub fn add_cache_header(
in_region: &mut [u8],
data: &[u8],
adapter: &AdapterInfo,
validation_key: [u8; 16],
) {
assert_eq!(in_region.len(), HEADER_LENGTH);
let header = PipelineCacheHeader {
adapter_key: adapter_key(adapter)
.expect("Called add_cache_header for an adapter which doesn't support cache data. This is a wgpu internal bug"),
backend: adapter.backend as u8,
cache_abi: ABI,
magic: MAGIC,
header_version: HEADER_VERSION,
validation_key,
hash_space: HASH_SPACE_VALUE,
data_size: data
.len()
.try_into()
.expect("Cache larger than u64::MAX bytes"),
};
header.write(in_region);
}
const MAGIC: [u8; 8] = *b"WGPUPLCH";
const HEADER_VERSION: u32 = 1;
const ABI: u32 = std::mem::size_of::<*const ()>() as u32;
/// The value used to fill [`PipelineCacheHeader::hash_space`]
///
/// If we receive reports of pipeline cache data corruption which is not otherwise caught
/// on a real device, it would be worth modifying this
///
/// Note that wgpu does not protect against malicious writes to e.g. a file used
/// to store a pipeline cache.
/// That is the resonsibility of the end application, such as by using a
/// private space.
const HASH_SPACE_VALUE: u64 = 0xFEDCBA9_876543210;
#[repr(C)]
#[derive(PartialEq, Eq)]
struct PipelineCacheHeader {
/// The magic header to ensure that we have the right file format
/// Has a value of MAGIC, as above
magic: [u8; 8],
// /// The total size of this header, in bytes
// header_size: u32,
/// The version of this wgpu header
/// Should be equal to HEADER_VERSION above
///
/// This must always be the second item, after the value above
header_version: u32,
/// The number of bytes in the pointers of this ABI, because some drivers
/// have previously not distinguished between their 32 bit and 64 bit drivers
/// leading to Vulkan data corruption
cache_abi: u32,
/// The id for the backend in use, from [wgt::Backend]
backend: u8,
/// The key which identifiers the device/adapter.
/// This is used to validate that this pipeline cache (probably) was produced for
/// the expected device.
/// On Vulkan: it is a combination of vendor ID and device ID
adapter_key: [u8; 15],
/// A key used to validate that this device is still compatible with the cache
///
/// This should e.g. contain driver version and/or intermediate compiler versions
validation_key: [u8; 16],
/// The length of the data which is sent to/recieved from the backend
data_size: u64,
/// Space reserved for a hash of the data in future
///
/// We assume that your cache storage system will be relatively robust, and so
/// do not validate this hash
///
/// Therefore, this will always have a value of [`HASH_SPACE_VALUE`]
hash_space: u64,
}
impl PipelineCacheHeader {
fn read(data: &[u8]) -> Option<(PipelineCacheHeader, &[u8])> {
let mut reader = Reader {
data,
total_read: 0,
};
let magic = reader.read_array()?;
let header_version = reader.read_u32()?;
let cache_abi = reader.read_u32()?;
let backend = reader.read_byte()?;
let adapter_key = reader.read_array()?;
let validation_key = reader.read_array()?;
let data_size = reader.read_u64()?;
let data_hash = reader.read_u64()?;
assert_eq!(
reader.total_read,
std::mem::size_of::<PipelineCacheHeader>()
);
Some((
PipelineCacheHeader {
magic,
header_version,
cache_abi,
backend,
adapter_key,
validation_key,
data_size,
hash_space: data_hash,
},
reader.data,
))
}
fn write(&self, into: &mut [u8]) -> Option<()> {
let mut writer = Writer { data: into };
writer.write_array(&self.magic)?;
writer.write_u32(self.header_version)?;
writer.write_u32(self.cache_abi)?;
writer.write_byte(self.backend)?;
writer.write_array(&self.adapter_key)?;
writer.write_array(&self.validation_key)?;
writer.write_u64(self.data_size)?;
writer.write_u64(self.hash_space)?;
assert_eq!(writer.data.len(), 0);
Some(())
}
}
fn adapter_key(adapter: &AdapterInfo) -> Result<[u8; 15], PipelineCacheValidationError> {
match adapter.backend {
wgt::Backend::Vulkan => {
// If these change size, the header format needs to change
// We set the type explicitly so this won't compile in that case
let v: [u8; 4] = adapter.vendor.to_be_bytes();
let d: [u8; 4] = adapter.device.to_be_bytes();
let adapter = [
255, 255, 255, v[0], v[1], v[2], v[3], d[0], d[1], d[2], d[3], 255, 255, 255, 255,
];
Ok(adapter)
}
_ => Err(PipelineCacheValidationError::Unsupported),
}
}
struct Reader<'a> {
data: &'a [u8],
total_read: usize,
}
impl<'a> Reader<'a> {
fn read_byte(&mut self) -> Option<u8> {
let res = *self.data.first()?;
self.total_read += 1;
self.data = &self.data[1..];
Some(res)
}
fn read_array<const N: usize>(&mut self) -> Option<[u8; N]> {
// Only greater than because we're indexing fenceposts, not items
if N > self.data.len() {
return None;
}
let (start, data) = self.data.split_at(N);
self.total_read += N;
self.data = data;
Some(start.try_into().expect("off-by-one-error in array size"))
}
// fn read_u16(&mut self) -> Option<u16> {
// self.read_array().map(u16::from_be_bytes)
// }
fn read_u32(&mut self) -> Option<u32> {
self.read_array().map(u32::from_be_bytes)
}
fn read_u64(&mut self) -> Option<u64> {
self.read_array().map(u64::from_be_bytes)
}
}
struct Writer<'a> {
data: &'a mut [u8],
}
impl<'a> Writer<'a> {
fn write_byte(&mut self, byte: u8) -> Option<()> {
self.write_array(&[byte])
}
fn write_array<const N: usize>(&mut self, array: &[u8; N]) -> Option<()> {
// Only greater than because we're indexing fenceposts, not items
if N > self.data.len() {
return None;
}
let data = std::mem::take(&mut self.data);
let (start, data) = data.split_at_mut(N);
self.data = data;
start.copy_from_slice(array);
Some(())
}
// fn write_u16(&mut self, value: u16) -> Option<()> {
// self.write_array(&value.to_be_bytes())
// }
fn write_u32(&mut self, value: u32) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
fn write_u64(&mut self, value: u64) -> Option<()> {
self.write_array(&value.to_be_bytes())
}
}
#[cfg(test)]
mod tests {
use wgt::AdapterInfo;
use crate::pipeline_cache::{PipelineCacheValidationError as E, HEADER_LENGTH};
use super::ABI;
// Assert the correct size
const _: [(); HEADER_LENGTH] = [(); 64];
const ADAPTER: AdapterInfo = AdapterInfo {
name: String::new(),
vendor: 0x0002_FEED,
device: 0xFEFE_FEFE,
device_type: wgt::DeviceType::Other,
driver: String::new(),
driver_info: String::new(),
backend: wgt::Backend::Vulkan,
};
// IMPORTANT: If these tests fail, then you MUST increment HEADER_VERSION
const VALIDATION_KEY: [u8; 16] = u128::to_be_bytes(0xFFFFFFFF_FFFFFFFF_88888888_88888888);
#[test]
fn written_header() {
let mut result = [0; HEADER_LENGTH];
super::add_cache_header(&mut result, &[], &ADAPTER, VALIDATION_KEY);
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let expected = cache.into_iter().flatten().collect::<Vec<u8>>();
assert_eq!(result.as_slice(), expected.as_slice());
}
#[test]
fn valid_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let expected: &[u8] = &[];
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Ok(expected));
}
#[test]
fn invalid_magic() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"NOT_WGPU", // (Wrong) MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Corrupted));
}
#[test]
fn wrong_version() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 2, 0, 0, 0, ABI as u8], // (wrong) Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn wrong_abi() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
// a 14 bit ABI is improbable
[0, 0, 0, 1, 0, 0, 0, 14], // Version and (wrong) ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Header
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn wrong_backend() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[2, 255, 255, 255, 0, 2, 0xFE, 0xED], // (wrong) Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::WrongDevice));
}
#[test]
fn wrong_adapter() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0x00], // Backend and (wrong) Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::WrongDevice));
}
#[test]
fn wrong_validation() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_00000000u64.to_be_bytes(), // (wrong) Validation key
0x0u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Outdated));
}
#[test]
fn too_little_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x064u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Truncated));
}
#[test]
fn not_no_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
100u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache
.into_iter()
.flatten()
.chain(std::iter::repeat(0u8).take(100))
.collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
let expected: &[u8] = &[0; 100];
assert_eq!(validation_result, Ok(expected));
}
#[test]
fn too_much_data() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x064u64.to_be_bytes(), // Data size
0xFEDCBA9_876543210u64.to_be_bytes(), // Hash
];
let cache = cache
.into_iter()
.flatten()
.chain(std::iter::repeat(0u8).take(200))
.collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Extended));
}
#[test]
fn wrong_hash() {
let cache: [[u8; 8]; HEADER_LENGTH / 8] = [
*b"WGPUPLCH", // MAGIC
[0, 0, 0, 1, 0, 0, 0, ABI as u8], // Version and ABI
[1, 255, 255, 255, 0, 2, 0xFE, 0xED], // Backend and Adapter key
[0xFE, 0xFE, 0xFE, 0xFE, 255, 255, 255, 255], // Backend and Adapter key
0xFFFFFFFF_FFFFFFFFu64.to_be_bytes(), // Validation key
0x88888888_88888888u64.to_be_bytes(), // Validation key
0x0u64.to_be_bytes(), // Data size
0x00000000_00000000u64.to_be_bytes(), // Hash
];
let cache = cache.into_iter().flatten().collect::<Vec<u8>>();
let validation_result = super::validate_pipeline_cache(&cache, &ADAPTER, VALIDATION_KEY);
assert_eq!(validation_result, Err(E::Corrupted));
}
}

View File

@ -228,6 +228,7 @@ pub(crate) struct TrackerIndexAllocators {
pub pipeline_layouts: Arc<SharedTrackerIndexAllocator>, pub pipeline_layouts: Arc<SharedTrackerIndexAllocator>,
pub bundles: Arc<SharedTrackerIndexAllocator>, pub bundles: Arc<SharedTrackerIndexAllocator>,
pub query_sets: Arc<SharedTrackerIndexAllocator>, pub query_sets: Arc<SharedTrackerIndexAllocator>,
pub pipeline_caches: Arc<SharedTrackerIndexAllocator>,
} }
impl TrackerIndexAllocators { impl TrackerIndexAllocators {
@ -245,6 +246,7 @@ impl TrackerIndexAllocators {
pipeline_layouts: Arc::new(SharedTrackerIndexAllocator::new()), pipeline_layouts: Arc::new(SharedTrackerIndexAllocator::new()),
bundles: Arc::new(SharedTrackerIndexAllocator::new()), bundles: Arc::new(SharedTrackerIndexAllocator::new()),
query_sets: Arc::new(SharedTrackerIndexAllocator::new()), query_sets: Arc::new(SharedTrackerIndexAllocator::new()),
pipeline_caches: Arc::new(SharedTrackerIndexAllocator::new()),
} }
} }
} }

View File

@ -274,6 +274,7 @@ impl<A: hal::Api> Example<A> {
write_mask: wgt::ColorWrites::default(), write_mask: wgt::ColorWrites::default(),
})], })],
multiview: None, multiview: None,
cache: None,
}; };
let pipeline = unsafe { device.create_render_pipeline(&pipeline_desc).unwrap() }; let pipeline = unsafe { device.create_render_pipeline(&pipeline_desc).unwrap() };

View File

@ -374,6 +374,7 @@ impl<A: hal::Api> Example<A> {
constants: &Default::default(), constants: &Default::default(),
zero_initialize_workgroup_memory: true, zero_initialize_workgroup_memory: true,
}, },
cache: None,
}) })
} }
.unwrap(); .unwrap();

View File

@ -1513,6 +1513,14 @@ impl crate::Device for super::Device {
} }
unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {} unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {}
unsafe fn create_pipeline_cache(
&self,
_desc: &crate::PipelineCacheDescriptor<'_>,
) -> Result<(), crate::PipelineCacheError> {
Ok(())
}
unsafe fn destroy_pipeline_cache(&self, (): ()) {}
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,
desc: &wgt::QuerySetDescriptor<crate::Label>, desc: &wgt::QuerySetDescriptor<crate::Label>,

View File

@ -82,6 +82,7 @@ impl crate::Api for Api {
type ShaderModule = ShaderModule; type ShaderModule = ShaderModule;
type RenderPipeline = RenderPipeline; type RenderPipeline = RenderPipeline;
type ComputePipeline = ComputePipeline; type ComputePipeline = ComputePipeline;
type PipelineCache = ();
type AccelerationStructure = AccelerationStructure; type AccelerationStructure = AccelerationStructure;
} }

View File

@ -30,6 +30,7 @@ impl crate::Api for Api {
type QuerySet = Resource; type QuerySet = Resource;
type Fence = Resource; type Fence = Resource;
type AccelerationStructure = Resource; type AccelerationStructure = Resource;
type PipelineCache = Resource;
type BindGroupLayout = Resource; type BindGroupLayout = Resource;
type BindGroup = Resource; type BindGroup = Resource;
@ -220,6 +221,13 @@ impl crate::Device for Context {
Ok(Resource) Ok(Resource)
} }
unsafe fn destroy_compute_pipeline(&self, pipeline: Resource) {} unsafe fn destroy_compute_pipeline(&self, pipeline: Resource) {}
unsafe fn create_pipeline_cache(
&self,
desc: &crate::PipelineCacheDescriptor<'_>,
) -> Result<Resource, crate::PipelineCacheError> {
Ok(Resource)
}
unsafe fn destroy_pipeline_cache(&self, cache: Resource) {}
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,

View File

@ -1406,6 +1406,16 @@ impl crate::Device for super::Device {
} }
} }
unsafe fn create_pipeline_cache(
&self,
_: &crate::PipelineCacheDescriptor<'_>,
) -> Result<(), crate::PipelineCacheError> {
// Even though the cache doesn't do anything, we still return something here
// as the least bad option
Ok(())
}
unsafe fn destroy_pipeline_cache(&self, (): ()) {}
#[cfg_attr(target_arch = "wasm32", allow(unused))] #[cfg_attr(target_arch = "wasm32", allow(unused))]
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,

View File

@ -154,6 +154,7 @@ impl crate::Api for Api {
type QuerySet = QuerySet; type QuerySet = QuerySet;
type Fence = Fence; type Fence = Fence;
type AccelerationStructure = (); type AccelerationStructure = ();
type PipelineCache = ();
type BindGroupLayout = BindGroupLayout; type BindGroupLayout = BindGroupLayout;
type BindGroup = BindGroup; type BindGroup = BindGroup;

View File

@ -332,6 +332,12 @@ pub enum PipelineError {
Device(#[from] DeviceError), Device(#[from] DeviceError),
} }
#[derive(Clone, Debug, Eq, PartialEq, Error)]
pub enum PipelineCacheError {
#[error(transparent)]
Device(#[from] DeviceError),
}
#[derive(Clone, Debug, Eq, PartialEq, Error)] #[derive(Clone, Debug, Eq, PartialEq, Error)]
pub enum SurfaceError { pub enum SurfaceError {
#[error("Surface is lost")] #[error("Surface is lost")]
@ -432,6 +438,7 @@ pub trait Api: Clone + fmt::Debug + Sized {
type ShaderModule: fmt::Debug + WasmNotSendSync; type ShaderModule: fmt::Debug + WasmNotSendSync;
type RenderPipeline: fmt::Debug + WasmNotSendSync; type RenderPipeline: fmt::Debug + WasmNotSendSync;
type ComputePipeline: fmt::Debug + WasmNotSendSync; type ComputePipeline: fmt::Debug + WasmNotSendSync;
type PipelineCache: fmt::Debug + WasmNotSendSync;
type AccelerationStructure: fmt::Debug + WasmNotSendSync + 'static; type AccelerationStructure: fmt::Debug + WasmNotSendSync + 'static;
} }
@ -611,6 +618,14 @@ pub trait Device: WasmNotSendSync {
desc: &ComputePipelineDescriptor<Self::A>, desc: &ComputePipelineDescriptor<Self::A>,
) -> Result<<Self::A as Api>::ComputePipeline, PipelineError>; ) -> Result<<Self::A as Api>::ComputePipeline, PipelineError>;
unsafe fn destroy_compute_pipeline(&self, pipeline: <Self::A as Api>::ComputePipeline); unsafe fn destroy_compute_pipeline(&self, pipeline: <Self::A as Api>::ComputePipeline);
unsafe fn create_pipeline_cache(
&self,
desc: &PipelineCacheDescriptor<'_>,
) -> Result<<Self::A as Api>::PipelineCache, PipelineCacheError>;
fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> {
None
}
unsafe fn destroy_pipeline_cache(&self, cache: <Self::A as Api>::PipelineCache);
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,
@ -652,6 +667,14 @@ pub trait Device: WasmNotSendSync {
unsafe fn start_capture(&self) -> bool; unsafe fn start_capture(&self) -> bool;
unsafe fn stop_capture(&self); unsafe fn stop_capture(&self);
#[allow(unused_variables)]
unsafe fn pipeline_cache_get_data(
&self,
cache: &<Self::A as Api>::PipelineCache,
) -> Option<Vec<u8>> {
None
}
unsafe fn create_acceleration_structure( unsafe fn create_acceleration_structure(
&self, &self,
desc: &AccelerationStructureDescriptor, desc: &AccelerationStructureDescriptor,
@ -1636,6 +1659,13 @@ pub struct ComputePipelineDescriptor<'a, A: Api> {
pub layout: &'a A::PipelineLayout, pub layout: &'a A::PipelineLayout,
/// The compiled compute stage and its entry point. /// The compiled compute stage and its entry point.
pub stage: ProgrammableStage<'a, A>, pub stage: ProgrammableStage<'a, A>,
/// The cache which will be used and filled when compiling this pipeline
pub cache: Option<&'a A::PipelineCache>,
}
pub struct PipelineCacheDescriptor<'a> {
pub label: Label<'a>,
pub data: Option<&'a [u8]>,
} }
/// Describes how the vertex buffer is interpreted. /// Describes how the vertex buffer is interpreted.
@ -1672,6 +1702,8 @@ pub struct RenderPipelineDescriptor<'a, A: Api> {
/// If the pipeline will be used with a multiview render pass, this indicates how many array /// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have. /// layers the attachments will have.
pub multiview: Option<NonZeroU32>, pub multiview: Option<NonZeroU32>,
/// The cache which will be used and filled when compiling this pipeline
pub cache: Option<&'a A::PipelineCache>,
} }
#[derive(Debug, Clone)] #[derive(Debug, Clone)]

View File

@ -1099,6 +1099,14 @@ impl crate::Device for super::Device {
} }
unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {} unsafe fn destroy_compute_pipeline(&self, _pipeline: super::ComputePipeline) {}
unsafe fn create_pipeline_cache(
&self,
_desc: &crate::PipelineCacheDescriptor<'_>,
) -> Result<(), crate::PipelineCacheError> {
Ok(())
}
unsafe fn destroy_pipeline_cache(&self, (): ()) {}
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,
desc: &wgt::QuerySetDescriptor<crate::Label>, desc: &wgt::QuerySetDescriptor<crate::Label>,

View File

@ -66,6 +66,7 @@ impl crate::Api for Api {
type ShaderModule = ShaderModule; type ShaderModule = ShaderModule;
type RenderPipeline = RenderPipeline; type RenderPipeline = RenderPipeline;
type ComputePipeline = ComputePipeline; type ComputePipeline = ComputePipeline;
type PipelineCache = ();
type AccelerationStructure = AccelerationStructure; type AccelerationStructure = AccelerationStructure;
} }

View File

@ -462,7 +462,8 @@ impl PhysicalDeviceFeatures {
| F::TIMESTAMP_QUERY_INSIDE_ENCODERS | F::TIMESTAMP_QUERY_INSIDE_ENCODERS
| F::TIMESTAMP_QUERY_INSIDE_PASSES | F::TIMESTAMP_QUERY_INSIDE_PASSES
| F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES | F::TEXTURE_ADAPTER_SPECIFIC_FORMAT_FEATURES
| F::CLEAR_TEXTURE; | F::CLEAR_TEXTURE
| F::PIPELINE_CACHE;
let mut dl_flags = Df::COMPUTE_SHADERS let mut dl_flags = Df::COMPUTE_SHADERS
| Df::BASE_VERTEX | Df::BASE_VERTEX
@ -1745,6 +1746,19 @@ impl super::Adapter {
unsafe { raw_device.get_device_queue(family_index, queue_index) } unsafe { raw_device.get_device_queue(family_index, queue_index) }
}; };
let driver_version = self
.phd_capabilities
.properties
.driver_version
.to_be_bytes();
#[rustfmt::skip]
let pipeline_cache_validation_key = [
driver_version[0], driver_version[1], driver_version[2], driver_version[3],
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0,
];
let shared = Arc::new(super::DeviceShared { let shared = Arc::new(super::DeviceShared {
raw: raw_device, raw: raw_device,
family_index, family_index,
@ -1760,6 +1774,7 @@ impl super::Adapter {
timeline_semaphore: timeline_semaphore_fn, timeline_semaphore: timeline_semaphore_fn,
ray_tracing: ray_tracing_fns, ray_tracing: ray_tracing_fns,
}, },
pipeline_cache_validation_key,
vendor_id: self.phd_capabilities.properties.vendor_id, vendor_id: self.phd_capabilities.properties.vendor_id,
timestamp_period: self.phd_capabilities.properties.limits.timestamp_period, timestamp_period: self.phd_capabilities.properties.limits.timestamp_period,
private_caps: self.private_caps.clone(), private_caps: self.private_caps.clone(),

View File

@ -1,4 +1,4 @@
use super::conv; use super::{conv, PipelineCache};
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
use ash::{khr, vk}; use ash::{khr, vk};
@ -1867,12 +1867,17 @@ impl crate::Device for super::Device {
.render_pass(raw_pass) .render_pass(raw_pass)
}]; }];
let pipeline_cache = desc
.cache
.map(|it| it.raw)
.unwrap_or(vk::PipelineCache::null());
let mut raw_vec = { let mut raw_vec = {
profiling::scope!("vkCreateGraphicsPipelines"); profiling::scope!("vkCreateGraphicsPipelines");
unsafe { unsafe {
self.shared self.shared
.raw .raw
.create_graphics_pipelines(vk::PipelineCache::null(), &vk_infos, None) .create_graphics_pipelines(pipeline_cache, &vk_infos, None)
.map_err(|(_, e)| crate::DeviceError::from(e)) .map_err(|(_, e)| crate::DeviceError::from(e))
}? }?
}; };
@ -1915,12 +1920,17 @@ impl crate::Device for super::Device {
.stage(compiled.create_info) .stage(compiled.create_info)
}]; }];
let pipeline_cache = desc
.cache
.map(|it| it.raw)
.unwrap_or(vk::PipelineCache::null());
let mut raw_vec = { let mut raw_vec = {
profiling::scope!("vkCreateComputePipelines"); profiling::scope!("vkCreateComputePipelines");
unsafe { unsafe {
self.shared self.shared
.raw .raw
.create_compute_pipelines(vk::PipelineCache::null(), &vk_infos, None) .create_compute_pipelines(pipeline_cache, &vk_infos, None)
.map_err(|(_, e)| crate::DeviceError::from(e)) .map_err(|(_, e)| crate::DeviceError::from(e))
}? }?
}; };
@ -1940,6 +1950,26 @@ impl crate::Device for super::Device {
unsafe { self.shared.raw.destroy_pipeline(pipeline.raw, None) }; unsafe { self.shared.raw.destroy_pipeline(pipeline.raw, None) };
} }
unsafe fn create_pipeline_cache(
&self,
desc: &crate::PipelineCacheDescriptor<'_>,
) -> Result<PipelineCache, crate::PipelineCacheError> {
let mut info = vk::PipelineCacheCreateInfo::default();
if let Some(data) = desc.data {
info = info.initial_data(data)
}
profiling::scope!("vkCreatePipelineCache");
let raw = unsafe { self.shared.raw.create_pipeline_cache(&info, None) }
.map_err(crate::DeviceError::from)?;
Ok(PipelineCache { raw })
}
fn pipeline_cache_validation_key(&self) -> Option<[u8; 16]> {
Some(self.shared.pipeline_cache_validation_key)
}
unsafe fn destroy_pipeline_cache(&self, cache: PipelineCache) {
unsafe { self.shared.raw.destroy_pipeline_cache(cache.raw, None) }
}
unsafe fn create_query_set( unsafe fn create_query_set(
&self, &self,
desc: &wgt::QuerySetDescriptor<crate::Label>, desc: &wgt::QuerySetDescriptor<crate::Label>,
@ -2105,6 +2135,11 @@ impl crate::Device for super::Device {
} }
} }
unsafe fn pipeline_cache_get_data(&self, cache: &PipelineCache) -> Option<Vec<u8>> {
let data = unsafe { self.raw_device().get_pipeline_cache_data(cache.raw) };
data.ok()
}
unsafe fn get_acceleration_structure_build_sizes<'a>( unsafe fn get_acceleration_structure_build_sizes<'a>(
&self, &self,
desc: &crate::GetAccelerationStructureBuildSizesDescriptor<'a, super::Api>, desc: &crate::GetAccelerationStructureBuildSizesDescriptor<'a, super::Api>,

View File

@ -70,6 +70,7 @@ impl crate::Api for Api {
type QuerySet = QuerySet; type QuerySet = QuerySet;
type Fence = Fence; type Fence = Fence;
type AccelerationStructure = AccelerationStructure; type AccelerationStructure = AccelerationStructure;
type PipelineCache = PipelineCache;
type BindGroupLayout = BindGroupLayout; type BindGroupLayout = BindGroupLayout;
type BindGroup = BindGroup; type BindGroup = BindGroup;
@ -338,6 +339,7 @@ struct DeviceShared {
enabled_extensions: Vec<&'static CStr>, enabled_extensions: Vec<&'static CStr>,
extension_fns: DeviceExtensionFunctions, extension_fns: DeviceExtensionFunctions,
vendor_id: u32, vendor_id: u32,
pipeline_cache_validation_key: [u8; 16],
timestamp_period: f32, timestamp_period: f32,
private_caps: PrivateCapabilities, private_caps: PrivateCapabilities,
workarounds: Workarounds, workarounds: Workarounds,
@ -549,6 +551,11 @@ pub struct ComputePipeline {
raw: vk::Pipeline, raw: vk::Pipeline,
} }
#[derive(Debug)]
pub struct PipelineCache {
raw: vk::PipelineCache,
}
#[derive(Debug)] #[derive(Debug)]
pub struct QuerySet { pub struct QuerySet {
raw: vk::QueryPool, raw: vk::QueryPool,

View File

@ -914,6 +914,15 @@ bitflags::bitflags! {
/// ///
/// This is a native only feature. /// This is a native only feature.
const SUBGROUP_BARRIER = 1 << 58; const SUBGROUP_BARRIER = 1 << 58;
/// Allows the use of pipeline cache objects
///
/// Supported platforms:
/// - Vulkan
///
/// Unimplemented Platforms:
/// - DX12
/// - Metal
const PIPELINE_CACHE = 1 << 59;
} }
} }

View File

@ -1159,6 +1159,8 @@ impl crate::context::Context for ContextWebGpu {
type SurfaceOutputDetail = SurfaceOutputDetail; type SurfaceOutputDetail = SurfaceOutputDetail;
type SubmissionIndex = Unused; type SubmissionIndex = Unused;
type SubmissionIndexData = (); type SubmissionIndexData = ();
type PipelineCacheId = Unused;
type PipelineCacheData = ();
type RequestAdapterFuture = MakeSendFuture< type RequestAdapterFuture = MakeSendFuture<
wasm_bindgen_futures::JsFuture, wasm_bindgen_futures::JsFuture,
@ -1995,6 +1997,16 @@ impl crate::context::Context for ContextWebGpu {
create_identified(device_data.0.create_compute_pipeline(&mapped_desc)) create_identified(device_data.0.create_compute_pipeline(&mapped_desc))
} }
unsafe fn device_create_pipeline_cache(
&self,
_: &Self::DeviceId,
_: &Self::DeviceData,
_: &crate::PipelineCacheDescriptor<'_>,
) -> (Self::PipelineCacheId, Self::PipelineCacheData) {
(Unused, ())
}
fn pipeline_cache_drop(&self, _: &Self::PipelineCacheId, _: &Self::PipelineCacheData) {}
fn device_create_buffer( fn device_create_buffer(
&self, &self,
_device: &Self::DeviceId, _device: &Self::DeviceId,
@ -2981,6 +2993,14 @@ impl crate::context::Context for ContextWebGpu {
fn device_start_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {} fn device_start_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {}
fn device_stop_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {} fn device_stop_capture(&self, _device: &Self::DeviceId, _device_data: &Self::DeviceData) {}
fn pipeline_cache_get_data(
&self,
_: &Self::PipelineCacheId,
_: &Self::PipelineCacheData,
) -> Option<Vec<u8>> {
None
}
fn compute_pass_set_pipeline( fn compute_pass_set_pipeline(
&self, &self,
_pass: &mut Self::ComputePassId, _pass: &mut Self::ComputePassId,

View File

@ -4,10 +4,10 @@ use crate::{
BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, CompilationMessage, BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, CompilationMessage,
CompilationMessageType, ComputePassDescriptor, ComputePipelineDescriptor, CompilationMessageType, ComputePassDescriptor, ComputePipelineDescriptor,
DownlevelCapabilities, Features, Label, Limits, LoadOp, MapMode, Operations, DownlevelCapabilities, Features, Label, Limits, LoadOp, MapMode, Operations,
PipelineLayoutDescriptor, RenderBundleEncoderDescriptor, RenderPipelineDescriptor, PipelineCacheDescriptor, PipelineLayoutDescriptor, RenderBundleEncoderDescriptor,
SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, ShaderSource, StoreOp, RenderPipelineDescriptor, SamplerDescriptor, ShaderModuleDescriptor,
SurfaceStatus, SurfaceTargetUnsafe, TextureDescriptor, TextureViewDescriptor, ShaderModuleDescriptorSpirV, ShaderSource, StoreOp, SurfaceStatus, SurfaceTargetUnsafe,
UncapturedErrorHandler, TextureDescriptor, TextureViewDescriptor, UncapturedErrorHandler,
}; };
use arrayvec::ArrayVec; use arrayvec::ArrayVec;
@ -519,6 +519,8 @@ impl crate::Context for ContextWgpuCore {
type RenderPipelineData = (); type RenderPipelineData = ();
type ComputePipelineId = wgc::id::ComputePipelineId; type ComputePipelineId = wgc::id::ComputePipelineId;
type ComputePipelineData = (); type ComputePipelineData = ();
type PipelineCacheId = wgc::id::PipelineCacheId;
type PipelineCacheData = ();
type CommandEncoderId = wgc::id::CommandEncoderId; type CommandEncoderId = wgc::id::CommandEncoderId;
type CommandEncoderData = CommandEncoder; type CommandEncoderData = CommandEncoder;
type ComputePassId = Unused; type ComputePassId = Unused;
@ -1191,6 +1193,7 @@ impl crate::Context for ContextWgpuCore {
targets: Borrowed(frag.targets), targets: Borrowed(frag.targets),
}), }),
multiview: desc.multiview, multiview: desc.multiview,
cache: desc.cache.map(|c| c.id.into()),
}; };
let (id, error) = wgc::gfx_select!(device => self.0.device_create_render_pipeline( let (id, error) = wgc::gfx_select!(device => self.0.device_create_render_pipeline(
@ -1240,6 +1243,7 @@ impl crate::Context for ContextWgpuCore {
.compilation_options .compilation_options
.zero_initialize_workgroup_memory, .zero_initialize_workgroup_memory,
}, },
cache: desc.cache.map(|c| c.id.into()),
}; };
let (id, error) = wgc::gfx_select!(device => self.0.device_create_compute_pipeline( let (id, error) = wgc::gfx_select!(device => self.0.device_create_compute_pipeline(
@ -1267,6 +1271,37 @@ impl crate::Context for ContextWgpuCore {
} }
(id, ()) (id, ())
} }
unsafe fn device_create_pipeline_cache(
&self,
device: &Self::DeviceId,
device_data: &Self::DeviceData,
desc: &PipelineCacheDescriptor<'_>,
) -> (Self::PipelineCacheId, Self::PipelineCacheData) {
use wgc::pipeline as pipe;
let descriptor = pipe::PipelineCacheDescriptor {
label: desc.label.map(Borrowed),
data: desc.data.map(Borrowed),
fallback: desc.fallback,
};
let (id, error) = wgc::gfx_select!(device => self.0.device_create_pipeline_cache(
*device,
&descriptor,
None
));
if let Some(cause) = error {
self.handle_error(
&device_data.error_sink,
cause,
LABEL,
desc.label,
"Device::device_create_pipeline_cache_init",
);
}
(id, ())
}
fn device_create_buffer( fn device_create_buffer(
&self, &self,
device: &Self::DeviceId, device: &Self::DeviceId,
@ -1726,6 +1761,14 @@ impl crate::Context for ContextWgpuCore {
wgc::gfx_select!(*pipeline => self.0.render_pipeline_drop(*pipeline)) wgc::gfx_select!(*pipeline => self.0.render_pipeline_drop(*pipeline))
} }
fn pipeline_cache_drop(
&self,
cache: &Self::PipelineCacheId,
_cache_data: &Self::PipelineCacheData,
) {
wgc::gfx_select!(*cache => self.0.pipeline_cache_drop(*cache))
}
fn compute_pipeline_get_bind_group_layout( fn compute_pipeline_get_bind_group_layout(
&self, &self,
pipeline: &Self::ComputePipelineId, pipeline: &Self::ComputePipelineId,
@ -2336,6 +2379,15 @@ impl crate::Context for ContextWgpuCore {
wgc::gfx_select!(device => self.0.device_stop_capture(*device)); wgc::gfx_select!(device => self.0.device_stop_capture(*device));
} }
fn pipeline_cache_get_data(
&self,
cache: &Self::PipelineCacheId,
// TODO: Used for error handling?
_cache_data: &Self::PipelineCacheData,
) -> Option<Vec<u8>> {
wgc::gfx_select!(cache => self.0.pipeline_cache_get_data(*cache))
}
fn compute_pass_set_pipeline( fn compute_pass_set_pipeline(
&self, &self,
_pass: &mut Self::ComputePassId, _pass: &mut Self::ComputePassId,

View File

@ -11,11 +11,12 @@ use crate::{
AnyWasmNotSendSync, BindGroupDescriptor, BindGroupLayoutDescriptor, Buffer, BufferAsyncError, AnyWasmNotSendSync, BindGroupDescriptor, BindGroupLayoutDescriptor, Buffer, BufferAsyncError,
BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, ComputePassDescriptor, BufferDescriptor, CommandEncoderDescriptor, CompilationInfo, ComputePassDescriptor,
ComputePipelineDescriptor, DeviceDescriptor, Error, ErrorFilter, ImageCopyBuffer, ComputePipelineDescriptor, DeviceDescriptor, Error, ErrorFilter, ImageCopyBuffer,
ImageCopyTexture, Maintain, MaintainResult, MapMode, PipelineLayoutDescriptor, ImageCopyTexture, Maintain, MaintainResult, MapMode, PipelineCacheDescriptor,
QuerySetDescriptor, RenderBundleDescriptor, RenderBundleEncoderDescriptor, PipelineLayoutDescriptor, QuerySetDescriptor, RenderBundleDescriptor,
RenderPassDescriptor, RenderPipelineDescriptor, RequestAdapterOptions, RequestDeviceError, RenderBundleEncoderDescriptor, RenderPassDescriptor, RenderPipelineDescriptor,
SamplerDescriptor, ShaderModuleDescriptor, ShaderModuleDescriptorSpirV, SurfaceTargetUnsafe, RequestAdapterOptions, RequestDeviceError, SamplerDescriptor, ShaderModuleDescriptor,
Texture, TextureDescriptor, TextureViewDescriptor, UncapturedErrorHandler, ShaderModuleDescriptorSpirV, SurfaceTargetUnsafe, Texture, TextureDescriptor,
TextureViewDescriptor, UncapturedErrorHandler,
}; };
/// Meta trait for an id tracked by a context. /// Meta trait for an id tracked by a context.
@ -59,6 +60,8 @@ pub trait Context: Debug + WasmNotSendSync + Sized {
type RenderPipelineData: ContextData; type RenderPipelineData: ContextData;
type ComputePipelineId: ContextId + WasmNotSendSync; type ComputePipelineId: ContextId + WasmNotSendSync;
type ComputePipelineData: ContextData; type ComputePipelineData: ContextData;
type PipelineCacheId: ContextId + WasmNotSendSync;
type PipelineCacheData: ContextData;
type CommandEncoderId: ContextId + WasmNotSendSync; type CommandEncoderId: ContextId + WasmNotSendSync;
type CommandEncoderData: ContextData; type CommandEncoderData: ContextData;
type ComputePassId: ContextId; type ComputePassId: ContextId;
@ -233,6 +236,12 @@ pub trait Context: Debug + WasmNotSendSync + Sized {
device_data: &Self::DeviceData, device_data: &Self::DeviceData,
desc: &ComputePipelineDescriptor<'_>, desc: &ComputePipelineDescriptor<'_>,
) -> (Self::ComputePipelineId, Self::ComputePipelineData); ) -> (Self::ComputePipelineId, Self::ComputePipelineData);
unsafe fn device_create_pipeline_cache(
&self,
device: &Self::DeviceId,
device_data: &Self::DeviceData,
desc: &PipelineCacheDescriptor<'_>,
) -> (Self::PipelineCacheId, Self::PipelineCacheData);
fn device_create_buffer( fn device_create_buffer(
&self, &self,
device: &Self::DeviceId, device: &Self::DeviceId,
@ -395,6 +404,11 @@ pub trait Context: Debug + WasmNotSendSync + Sized {
pipeline: &Self::RenderPipelineId, pipeline: &Self::RenderPipelineId,
pipeline_data: &Self::RenderPipelineData, pipeline_data: &Self::RenderPipelineData,
); );
fn pipeline_cache_drop(
&self,
cache: &Self::PipelineCacheId,
cache_data: &Self::PipelineCacheData,
);
fn compute_pipeline_get_bind_group_layout( fn compute_pipeline_get_bind_group_layout(
&self, &self,
@ -613,6 +627,12 @@ pub trait Context: Debug + WasmNotSendSync + Sized {
fn device_start_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData); fn device_start_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData);
fn device_stop_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData); fn device_stop_capture(&self, device: &Self::DeviceId, device_data: &Self::DeviceData);
fn pipeline_cache_get_data(
&self,
cache: &Self::PipelineCacheId,
cache_data: &Self::PipelineCacheData,
) -> Option<Vec<u8>>;
fn compute_pass_set_pipeline( fn compute_pass_set_pipeline(
&self, &self,
pass: &mut Self::ComputePassId, pass: &mut Self::ComputePassId,
@ -1271,6 +1291,12 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync {
device_data: &crate::Data, device_data: &crate::Data,
desc: &ComputePipelineDescriptor<'_>, desc: &ComputePipelineDescriptor<'_>,
) -> (ObjectId, Box<crate::Data>); ) -> (ObjectId, Box<crate::Data>);
unsafe fn device_create_pipeline_cache(
&self,
device: &ObjectId,
device_data: &crate::Data,
desc: &PipelineCacheDescriptor<'_>,
) -> (ObjectId, Box<crate::Data>);
fn device_create_buffer( fn device_create_buffer(
&self, &self,
device: &ObjectId, device: &ObjectId,
@ -1391,6 +1417,7 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync {
fn render_bundle_drop(&self, render_bundle: &ObjectId, render_bundle_data: &crate::Data); fn render_bundle_drop(&self, render_bundle: &ObjectId, render_bundle_data: &crate::Data);
fn compute_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data); fn compute_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data);
fn render_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data); fn render_pipeline_drop(&self, pipeline: &ObjectId, pipeline_data: &crate::Data);
fn pipeline_cache_drop(&self, cache: &ObjectId, _cache_data: &crate::Data);
fn compute_pipeline_get_bind_group_layout( fn compute_pipeline_get_bind_group_layout(
&self, &self,
@ -1601,6 +1628,12 @@ pub(crate) trait DynContext: Debug + WasmNotSendSync {
fn device_start_capture(&self, device: &ObjectId, data: &crate::Data); fn device_start_capture(&self, device: &ObjectId, data: &crate::Data);
fn device_stop_capture(&self, device: &ObjectId, data: &crate::Data); fn device_stop_capture(&self, device: &ObjectId, data: &crate::Data);
fn pipeline_cache_get_data(
&self,
cache: &ObjectId,
cache_data: &crate::Data,
) -> Option<Vec<u8>>;
fn compute_pass_set_pipeline( fn compute_pass_set_pipeline(
&self, &self,
pass: &mut ObjectId, pass: &mut ObjectId,
@ -2297,6 +2330,19 @@ where
(compute_pipeline.into(), Box::new(data) as _) (compute_pipeline.into(), Box::new(data) as _)
} }
unsafe fn device_create_pipeline_cache(
&self,
device: &ObjectId,
device_data: &crate::Data,
desc: &PipelineCacheDescriptor<'_>,
) -> (ObjectId, Box<crate::Data>) {
let device = <T::DeviceId>::from(*device);
let device_data = downcast_ref(device_data);
let (pipeline_cache, data) =
unsafe { Context::device_create_pipeline_cache(self, &device, device_data, desc) };
(pipeline_cache.into(), Box::new(data) as _)
}
fn device_create_buffer( fn device_create_buffer(
&self, &self,
device: &ObjectId, device: &ObjectId,
@ -2621,6 +2667,12 @@ where
Context::render_pipeline_drop(self, &pipeline, pipeline_data) Context::render_pipeline_drop(self, &pipeline, pipeline_data)
} }
fn pipeline_cache_drop(&self, cache: &ObjectId, cache_data: &crate::Data) {
let cache = <T::PipelineCacheId>::from(*cache);
let cache_data = downcast_ref(cache_data);
Context::pipeline_cache_drop(self, &cache, cache_data)
}
fn compute_pipeline_get_bind_group_layout( fn compute_pipeline_get_bind_group_layout(
&self, &self,
pipeline: &ObjectId, pipeline: &ObjectId,
@ -3083,6 +3135,16 @@ where
Context::device_stop_capture(self, &device, device_data) Context::device_stop_capture(self, &device, device_data)
} }
fn pipeline_cache_get_data(
&self,
cache: &ObjectId,
cache_data: &crate::Data,
) -> Option<Vec<u8>> {
let cache = <T::PipelineCacheId>::from(*cache);
let cache_data = downcast_ref::<T::PipelineCacheData>(cache_data);
Context::pipeline_cache_get_data(self, &cache, cache_data)
}
fn compute_pass_set_pipeline( fn compute_pass_set_pipeline(
&self, &self,
pass: &mut ObjectId, pass: &mut ObjectId,

View File

@ -1105,6 +1105,100 @@ impl ComputePipeline {
} }
} }
/// Handle to a pipeline cache, which is used to accelerate
/// creating [`RenderPipeline`]s and [`ComputePipeline`]s
/// in subsequent executions
///
/// This reuse is only applicable for the same or similar devices.
/// See [`util::pipeline_cache_key`] for some details.
///
/// # Background
///
/// In most GPU drivers, shader code must be converted into a machine code
/// which can be executed on the GPU.
/// Generating this machine code can require a lot of computation.
/// Pipeline caches allow this computation to be reused between executions
/// of the program.
/// This can be very useful for reducing program startup time.
///
/// Note that most desktop GPU drivers will manage their own caches,
/// meaning that little advantage can be gained from this on those platforms.
/// However, on some platforms, especially Android, drivers leave this to the
/// application to implement.
///
/// Unfortunately, drivers do not expose whether they manage their own caches.
/// Some reasonable policies for applications to use are:
/// - Manage their own pipeline cache on all platforms
/// - Only manage pipeline caches on Android
///
/// # Usage
///
/// It is valid to use this resource when creating multiple pipelines, in
/// which case it will likely cache each of those pipelines.
/// It is also valid to create a new cache for each pipeline.
///
/// This resource is most useful when the data produced from it (using
/// [`PipelineCache::get_data`]) is persisted.
/// Care should be taken that pipeline caches are only used for the same device,
/// as pipeline caches from compatible devices are unlikely to provide any advantage.
/// `util::pipeline_cache_key` can be used as a file/directory name to help ensure that.
///
/// It is recommended to store pipeline caches atomically. If persisting to disk,
/// this can usually be achieved by creating a temporary file, then moving/[renaming]
/// the temporary file over the existing cache
///
/// # Storage Usage
///
/// There is not currently an API available to reduce the size of a cache.
/// This is due to limitations in the underlying graphics APIs used.
/// This is especially impactful if your application is being updated, so
/// previous caches are no longer being used.
///
/// One option to work around this is to regenerate the cache.
/// That is, creating the pipelines which your program runs using
/// with the stored cached data, then recreating the *same* pipelines
/// using a new cache, which your application then store.
///
/// # Implementations
///
/// This resource currently only works on the following backends:
/// - Vulkan
///
/// This type is unique to the Rust API of `wgpu`.
///
/// [renaming]: std::fs::rename
#[derive(Debug)]
pub struct PipelineCache {
context: Arc<C>,
id: ObjectId,
data: Box<Data>,
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(PipelineCache: Send, Sync);
impl PipelineCache {
/// Get the data associated with this pipeline cache.
/// The data format is an implementation detail of `wgpu`.
/// The only defined operation on this data setting it as the `data` field
/// on [`PipelineCacheDescriptor`], then to [`Device::create_pipeline_cache`].
///
/// This function is unique to the Rust API of `wgpu`.
pub fn get_data(&self) -> Option<Vec<u8>> {
self.context
.pipeline_cache_get_data(&self.id, self.data.as_ref())
}
}
impl Drop for PipelineCache {
fn drop(&mut self) {
if !thread::panicking() {
self.context
.pipeline_cache_drop(&self.id, self.data.as_ref());
}
}
}
/// Handle to a command buffer on the GPU. /// Handle to a command buffer on the GPU.
/// ///
/// A `CommandBuffer` represents a complete sequence of commands that may be submitted to a command /// A `CommandBuffer` represents a complete sequence of commands that may be submitted to a command
@ -1832,6 +1926,8 @@ pub struct RenderPipelineDescriptor<'a> {
/// If the pipeline will be used with a multiview render pass, this indicates how many array /// If the pipeline will be used with a multiview render pass, this indicates how many array
/// layers the attachments will have. /// layers the attachments will have.
pub multiview: Option<NonZeroU32>, pub multiview: Option<NonZeroU32>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<&'a PipelineCache>,
} }
#[cfg(send_sync)] #[cfg(send_sync)]
static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync); static_assertions::assert_impl_all!(RenderPipelineDescriptor<'_>: Send, Sync);
@ -1929,10 +2025,38 @@ pub struct ComputePipelineDescriptor<'a> {
/// ///
/// This implements `Default`, and for most users can be set to `Default::default()` /// This implements `Default`, and for most users can be set to `Default::default()`
pub compilation_options: PipelineCompilationOptions<'a>, pub compilation_options: PipelineCompilationOptions<'a>,
/// The pipeline cache to use when creating this pipeline.
pub cache: Option<&'a PipelineCache>,
} }
#[cfg(send_sync)] #[cfg(send_sync)]
static_assertions::assert_impl_all!(ComputePipelineDescriptor<'_>: Send, Sync); static_assertions::assert_impl_all!(ComputePipelineDescriptor<'_>: Send, Sync);
/// Describes a pipeline cache, which allows reusing compilation work
/// between program runs.
///
/// For use with [`Device::create_pipeline_cache`]
///
/// This type is unique to the Rust API of `wgpu`.
#[derive(Clone, Debug)]
pub struct PipelineCacheDescriptor<'a> {
/// Debug label of the pipeline cache. This might show up in some logs from `wgpu`
pub label: Label<'a>,
/// The data used to initialise the cache initialise
///
/// # Safety
///
/// This data must have been provided from a previous call to
/// [`PipelineCache::get_data`], if not `None`
pub data: Option<&'a [u8]>,
/// Whether to create a cache without data when the provided data
/// is invalid.
///
/// Recommended to set to true
pub fallback: bool,
}
#[cfg(send_sync)]
static_assertions::assert_impl_all!(PipelineCacheDescriptor<'_>: Send, Sync);
pub use wgt::ImageCopyBuffer as ImageCopyBufferBase; pub use wgt::ImageCopyBuffer as ImageCopyBufferBase;
/// View of a buffer which can be used to copy to/from a texture. /// View of a buffer which can be used to copy to/from a texture.
/// ///
@ -3080,6 +3204,62 @@ impl Device {
pub fn make_invalid(&self) { pub fn make_invalid(&self) {
DynContext::device_make_invalid(&*self.context, &self.id, self.data.as_ref()) DynContext::device_make_invalid(&*self.context, &self.id, self.data.as_ref())
} }
/// Create a [`PipelineCache`] with initial data
///
/// This can be passed to [`Device::create_compute_pipeline`]
/// and [`Device::create_render_pipeline`] to either accelerate these
/// or add the cache results from those.
///
/// # Safety
///
/// If the `data` field of `desc` is set, it must have previously been returned from a call
/// to [`PipelineCache::get_data`][^saving]. This `data` will only be used if it came
/// from an adapter with the same [`util::pipeline_cache_key`].
/// This *is* compatible across wgpu versions, as any data format change will
/// be accounted for.
///
/// It is *not* supported to bring caches from previous direct uses of backend APIs
/// into this method.
///
/// # Errors
///
/// Returns an error value if:
/// * the [`PIPELINE_CACHE`](wgt::Features::PIPELINE_CACHE) feature is not enabled
/// * this device is invalid; or
/// * the device is out of memory
///
/// This method also returns an error value if:
/// * The `fallback` field on `desc` is false; and
/// * the `data` provided would not be used[^data_not_used]
///
/// If an error value is used in subsequent calls, default caching will be used.
///
/// [^saving]: We do recognise that saving this data to disk means this condition
/// is impossible to fully prove. Consider the risks for your own application in this case.
///
/// [^data_not_used]: This data may be not used if: the data was produced by a prior
/// version of wgpu; or was created for an incompatible adapter, or there was a GPU driver
/// update. In some cases, the data might not be used and a real value is returned,
/// this is left to the discretion of GPU drivers.
pub unsafe fn create_pipeline_cache(
&self,
desc: &PipelineCacheDescriptor<'_>,
) -> PipelineCache {
let (id, data) = unsafe {
DynContext::device_create_pipeline_cache(
&*self.context,
&self.id,
self.data.as_ref(),
desc,
)
};
PipelineCache {
context: Arc::clone(&self.context),
id,
data,
}
}
} }
impl Drop for Device { impl Drop for Device {

View File

@ -140,3 +140,52 @@ impl std::ops::Deref for DownloadBuffer {
self.1.slice() self.1.slice()
} }
} }
/// A recommended key for storing [`PipelineCache`]s for the adapter
/// associated with the given [`AdapterInfo`](wgt::AdapterInfo)
/// This key will define a class of adapters for which the same cache
/// might be valid.
///
/// If this returns `None`, the adapter doesn't support [`PipelineCache`].
/// This may be because the API doesn't support application managed caches
/// (such as browser WebGPU), or that `wgpu` hasn't implemented it for
/// that API yet.
///
/// This key could be used as a filename, as seen in the example below.
///
/// # Examples
///
/// ``` no_run
/// # use std::path::PathBuf;
/// # let adapter_info = todo!();
/// let cache_dir: PathBuf = PathBuf::new();
/// let filename = wgpu::util::pipeline_cache_key(&adapter_info);
/// if let Some(filename) = filename {
/// let cache_file = cache_dir.join(&filename);
/// let cache_data = std::fs::read(&cache_file);
/// let pipeline_cache: wgpu::PipelineCache = todo!("Use data (if present) to create a pipeline cache");
///
/// let data = pipeline_cache.get_data();
/// if let Some(data) = data {
/// let temp_file = cache_file.with_extension("temp");
/// std::fs::write(&temp_file, &data)?;
/// std::fs::rename(&temp_file, &cache_file)?;
/// }
/// }
/// # Ok::<(), std::io::Error>(())
/// ```
///
/// [`PipelineCache`]: super::PipelineCache
pub fn pipeline_cache_key(adapter_info: &wgt::AdapterInfo) -> Option<String> {
match adapter_info.backend {
wgt::Backend::Vulkan => Some(format!(
// The vendor/device should uniquely define a driver
// We/the driver will also later validate that the vendor/device and driver
// version match, which may lead to clearing an outdated
// cache for the same device.
"wgpu_pipeline_cache_vulkan_{}_{}",
adapter_info.vendor, adapter_info.device
)),
_ => None,
}
}