Metal encoder & pass timestamp support (#4008)

Implements timer queries via write_timestamp on Metal for encoders (whenever timer queries are available) and passes (for Intel/AMD GPUs, where we should advertise TIMESTAMP_QUERY_INSIDE_PASSES now).

Due to some bugs in Metal this was a lot harder than expected. I believe the solution is close to optimal with the current restrictions in place. For details see code comments.
This commit is contained in:
Andreas Reich 2023-09-16 22:01:46 +02:00 committed by GitHub
parent 7c575a0b40
commit 0ffdae31a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 260 additions and 61 deletions

View File

@ -27,6 +27,7 @@ allow = [
[sources]
allow-git = [
"https://github.com/grovesNL/glow",
"https://github.com/gfx-rs/metal-rs",
]
unknown-registry = "deny"
unknown-git = "deny"

View File

@ -88,6 +88,10 @@ By @Valaphee in [#3402](https://github.com/gfx-rs/wgpu/pull/3402)
### Documentation
- Use WGSL for VertexFormat example types. By @ScanMountGoat in [#4305](https://github.com/gfx-rs/wgpu/pull/4035)
#### Metal
- Support for timestamp queries on encoders and passes. By @wumpf in [#4008](https://github.com/gfx-rs/wgpu/pull/4008)
### Bug Fixes
#### General

3
Cargo.lock generated
View File

@ -1551,8 +1551,7 @@ dependencies = [
[[package]]
name = "metal"
version = "0.26.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "623b5e6cefd76e58f774bd3cc0c6f5c7615c58c03a97815245a25c3c9bdee318"
source = "git+https://github.com/gfx-rs/metal-rs/?rev=d24f1a4#d24f1a4ae92470bf87a0c65ecfe78c9299835505"
dependencies = [
"bitflags 2.4.0",
"block",

View File

@ -158,6 +158,8 @@ termcolor = "1.2.0"
#glow = { path = "../glow" }
#d3d12 = { path = "../d3d12-rs" }
#metal = { path = "../metal-rs" }
#metal = { path = "../metal-rs" }
metal = { git = "https://github.com/gfx-rs/metal-rs/", rev = "d24f1a4" } # More timer support via https://github.com/gfx-rs/metal-rs/pull/280
#web-sys = { path = "../wasm-bindgen/crates/web-sys" }
#js-sys = { path = "../wasm-bindgen/crates/js-sys" }
#wasm-bindgen = { path = "../wasm-bindgen" }

View File

@ -47,6 +47,7 @@ impl QueryResults {
// * compute end
const NUM_QUERIES: u64 = 8;
#[allow(clippy::redundant_closure)] // False positive
fn from_raw_results(timestamps: Vec<u64>, timestamps_inside_passes: bool) -> Self {
assert_eq!(timestamps.len(), Self::NUM_QUERIES as usize);
@ -60,9 +61,9 @@ impl QueryResults {
let mut encoder_timestamps = [0, 0];
encoder_timestamps[0] = get_next_slot();
let render_start_end_timestamps = [get_next_slot(), get_next_slot()];
let render_inside_timestamp = timestamps_inside_passes.then_some(get_next_slot());
let render_inside_timestamp = timestamps_inside_passes.then(|| get_next_slot());
let compute_start_end_timestamps = [get_next_slot(), get_next_slot()];
let compute_inside_timestamp = timestamps_inside_passes.then_some(get_next_slot());
let compute_inside_timestamp = timestamps_inside_passes.then(|| get_next_slot());
encoder_timestamps[1] = get_next_slot();
QueryResults {
@ -79,8 +80,8 @@ impl QueryResults {
let elapsed_us = |start, end: u64| end.wrapping_sub(start) as f64 * period as f64 / 1000.0;
println!(
"Elapsed time render + compute: {:.2} μs",
elapsed_us(self.encoder_timestamps[0], self.encoder_timestamps[1])
"Elapsed time before render until after compute: {:.2} μs",
elapsed_us(self.encoder_timestamps[0], self.encoder_timestamps[1]),
);
println!(
"Elapsed time render pass: {:.2} μs",
@ -464,13 +465,10 @@ mod tests {
render_start_end_timestamps[1].wrapping_sub(render_start_end_timestamps[0]);
let compute_delta =
compute_start_end_timestamps[1].wrapping_sub(compute_start_end_timestamps[0]);
let encoder_delta = encoder_timestamps[1].wrapping_sub(encoder_timestamps[0]);
// TODO: Metal encoder timestamps aren't implemented yet.
if ctx.adapter.get_info().backend != wgpu::Backend::Metal {
let encoder_delta = encoder_timestamps[1].wrapping_sub(encoder_timestamps[0]);
assert!(encoder_delta > 0);
assert!(encoder_delta >= render_delta + compute_delta);
}
assert!(encoder_delta > 0);
assert!(encoder_delta >= render_delta + compute_delta);
if let Some(render_inside_timestamp) = render_inside_timestamp {
assert!(render_inside_timestamp >= render_start_end_timestamps[0]);

View File

@ -5,6 +5,8 @@ use wgt::{AstcBlock, AstcChannel};
use std::{sync::Arc, thread};
use super::TimestampQuerySupport;
const MAX_COMMAND_BUFFERS: u64 = 2048;
unsafe impl Send for super::Adapter {}
@ -536,6 +538,26 @@ impl super::PrivateCapabilities {
MTLReadWriteTextureTier::TierNone
};
let mut timestamp_query_support = TimestampQuerySupport::empty();
if version.at_least((11, 0), (14, 0), os_is_mac)
&& device.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtStageBoundary)
{
// If we don't support at stage boundary, don't support anything else.
timestamp_query_support.insert(TimestampQuerySupport::STAGE_BOUNDARIES);
if device.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtDrawBoundary) {
timestamp_query_support.insert(TimestampQuerySupport::ON_RENDER_ENCODER);
}
if device.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtDispatchBoundary)
{
timestamp_query_support.insert(TimestampQuerySupport::ON_COMPUTE_ENCODER);
}
if device.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtBlitBoundary) {
timestamp_query_support.insert(TimestampQuerySupport::ON_BLIT_ENCODER);
}
// `TimestampQuerySupport::INSIDE_WGPU_PASSES` emerges from the other flags.
}
Self {
family_check,
msl_version: if os_is_xr || version.at_least((12, 0), (15, 0), os_is_mac) {
@ -773,13 +795,7 @@ impl super::PrivateCapabilities {
} else {
None
},
support_timestamp_query: version.at_least((11, 0), (14, 0), os_is_mac)
&& device
.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtStageBoundary),
support_timestamp_query_in_passes: version.at_least((11, 0), (14, 0), os_is_mac)
&& device.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtDrawBoundary)
&& device
.supports_counter_sampling(metal::MTLCounterSamplingPoint::AtDispatchBoundary),
timestamp_query_support,
}
}
@ -807,12 +823,16 @@ impl super::PrivateCapabilities {
| F::DEPTH32FLOAT_STENCIL8
| F::MULTI_DRAW_INDIRECT;
features.set(F::TIMESTAMP_QUERY, self.support_timestamp_query);
// TODO: Not yet implemented.
// features.set(
// F::TIMESTAMP_QUERY_INSIDE_PASSES,
// self.support_timestamp_query_in_passes,
// );
features.set(
F::TIMESTAMP_QUERY,
self.timestamp_query_support
.contains(TimestampQuerySupport::STAGE_BOUNDARIES),
);
features.set(
F::TIMESTAMP_QUERY_INSIDE_PASSES,
self.timestamp_query_support
.contains(TimestampQuerySupport::INSIDE_WGPU_PASSES),
);
features.set(F::TEXTURE_COMPRESSION_ASTC, self.format_astc);
features.set(F::TEXTURE_COMPRESSION_ASTC_HDR, self.format_astc_hdr);
features.set(F::TEXTURE_COMPRESSION_BC, self.format_bc);

View File

@ -1,4 +1,4 @@
use super::{conv, AsNative};
use super::{conv, AsNative, TimestampQuerySupport};
use crate::CommandEncoder as _;
use std::{borrow::Cow, mem, ops::Range};
@ -18,6 +18,7 @@ impl Default for super::CommandState {
storage_buffer_length_map: Default::default(),
work_group_memory_sizes: Vec::new(),
push_constants: Vec::new(),
pending_timer_queries: Vec::new(),
}
}
}
@ -26,10 +27,85 @@ impl super::CommandEncoder {
fn enter_blit(&mut self) -> &metal::BlitCommandEncoderRef {
if self.state.blit.is_none() {
debug_assert!(self.state.render.is_none() && self.state.compute.is_none());
let cmd_buf = self.raw_cmd_buf.as_ref().unwrap();
// Take care of pending timer queries.
// If we can't use `sample_counters_in_buffer` we have to create a dummy blit encoder!
//
// There is a known bug in Metal where blit encoders won't write timestamps if they don't have a blit operation.
// See https://github.com/gpuweb/gpuweb/issues/2046#issuecomment-1205793680 & https://source.chromium.org/chromium/chromium/src/+/006c4eb70c96229834bbaf271290f40418144cd3:third_party/dawn/src/dawn/native/metal/BackendMTL.mm;l=350
//
// To make things worse:
// * what counts as a blit operation is a bit unclear, experimenting seemed to indicate that resolve_counters doesn't count.
// * in some cases (when?) using `set_start_of_encoder_sample_index` doesn't work, so we have to use `set_end_of_encoder_sample_index` instead
//
// All this means that pretty much the only *reliable* thing as of writing is to:
// * create a dummy blit encoder using set_end_of_encoder_sample_index
// * do a dummy write that is known to be not optimized out.
// * close the encoder since we used set_end_of_encoder_sample_index and don't want to get any extra stuff in there.
// * create another encoder for whatever we actually had in mind.
let supports_sample_counters_in_buffer = self
.shared
.private_caps
.timestamp_query_support
.contains(TimestampQuerySupport::ON_BLIT_ENCODER);
if !self.state.pending_timer_queries.is_empty() && !supports_sample_counters_in_buffer {
objc::rc::autoreleasepool(|| {
let descriptor = metal::BlitPassDescriptor::new();
let mut last_query = None;
for (i, (set, index)) in self.state.pending_timer_queries.drain(..).enumerate()
{
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(i as _)
.unwrap();
sba_descriptor
.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());
// Here be dragons:
// As mentioned above, for some reasons using the start of the encoder won't yield any results sometimes!
sba_descriptor
.set_start_of_encoder_sample_index(metal::COUNTER_DONT_SAMPLE);
sba_descriptor.set_end_of_encoder_sample_index(index as _);
last_query = Some((set, index));
}
let encoder = cmd_buf.blit_command_encoder_with_descriptor(descriptor);
// As explained above, we need to do some write:
// Conveniently, we have a buffer with every query set, that we can use for this for a dummy write,
// since we know that it is going to be overwritten again on timer resolve and HAL doesn't define its state before that.
let raw_range = metal::NSRange {
location: last_query.as_ref().unwrap().1 as u64 * crate::QUERY_SIZE,
length: 1,
};
encoder.fill_buffer(
&last_query.as_ref().unwrap().0.raw_buffer,
raw_range,
255, // Don't write 0, so it's easier to identify if something went wrong.
);
encoder.end_encoding();
});
}
objc::rc::autoreleasepool(|| {
let cmd_buf = self.raw_cmd_buf.as_ref().unwrap();
self.state.blit = Some(cmd_buf.new_blit_command_encoder().to_owned());
});
let encoder = self.state.blit.as_ref().unwrap();
// UNTESTED:
// If the above described issue with empty blit encoder applies to `sample_counters_in_buffer` as well, we should use the same workaround instead!
for (set, index) in self.state.pending_timer_queries.drain(..) {
debug_assert!(supports_sample_counters_in_buffer);
encoder.sample_counters_in_buffer(
set.counter_sample_buffer.as_ref().unwrap(),
index as _,
true,
)
}
}
self.state.blit.as_ref().unwrap()
}
@ -40,7 +116,7 @@ impl super::CommandEncoder {
}
}
fn enter_any(&mut self) -> Option<&metal::CommandEncoderRef> {
fn active_encoder(&mut self) -> Option<&metal::CommandEncoderRef> {
if let Some(ref encoder) = self.state.render {
Some(encoder)
} else if let Some(ref encoder) = self.state.compute {
@ -127,9 +203,17 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}
unsafe fn end_encoding(&mut self) -> Result<super::CommandBuffer, crate::DeviceError> {
// Handle pending timer query if any.
if !self.state.pending_timer_queries.is_empty() {
self.leave_blit();
self.enter_blit();
}
self.leave_blit();
debug_assert!(self.state.render.is_none());
debug_assert!(self.state.compute.is_none());
debug_assert!(self.state.pending_timer_queries.is_empty());
Ok(super::CommandBuffer {
raw: self.raw_cmd_buf.take().unwrap(),
})
@ -322,16 +406,43 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
_ => {}
}
}
unsafe fn write_timestamp(&mut self, _set: &super::QuerySet, _index: u32) {
// TODO: If MTLCounterSamplingPoint::AtDrawBoundary/AtBlitBoundary/AtDispatchBoundary is supported,
// we don't need to insert a new encoder, but can instead use respective current one.
//let encoder = self.enter_any().unwrap_or_else(|| self.enter_blit());
unsafe fn write_timestamp(&mut self, set: &super::QuerySet, index: u32) {
let support = self.shared.private_caps.timestamp_query_support;
debug_assert!(
support.contains(TimestampQuerySupport::STAGE_BOUNDARIES),
"Timestamp queries are not supported"
);
let sample_buffer = set.counter_sample_buffer.as_ref().unwrap();
let with_barrier = true;
// TODO: Otherwise, we need to create a new blit command encoder with a descriptor that inserts the timestamps.
// Note that as of writing creating a new encoder is not exposed by the metal crate.
// https://developer.apple.com/documentation/metal/mtlcommandbuffer/3564431-makeblitcommandencoder
// Try to use an existing encoder for timestamp query if possible.
// This works only if it's supported for the active encoder.
if let (true, Some(encoder)) = (
support.contains(TimestampQuerySupport::ON_BLIT_ENCODER),
self.state.blit.as_ref(),
) {
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else if let (true, Some(encoder)) = (
support.contains(TimestampQuerySupport::ON_RENDER_ENCODER),
self.state.render.as_ref(),
) {
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else if let (true, Some(encoder)) = (
support.contains(TimestampQuerySupport::ON_COMPUTE_ENCODER),
self.state.compute.as_ref(),
) {
encoder.sample_counters_in_buffer(sample_buffer, index as _, with_barrier);
} else {
// If we're here it means we either have no encoder open, or it's not supported to sample within them.
// If this happens with render/compute open, this is an invalid usage!
debug_assert!(self.state.render.is_none() && self.state.compute.is_none());
// TODO: Enable respective test in `examples/timestamp-queries/src/tests.rs`.
// But otherwise it means we'll put defer this to the next created encoder.
self.state.pending_timer_queries.push((set.clone(), index));
// Ensure we didn't already have a blit open.
self.leave_blit();
};
}
unsafe fn reset_queries(&mut self, set: &super::QuerySet, range: Range<u32>) {
@ -342,6 +453,7 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
};
encoder.fill_buffer(&set.raw_buffer, raw_range, 0);
}
unsafe fn copy_query_results(
&mut self,
set: &super::QuerySet,
@ -454,8 +566,29 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}
}
let mut sba_index = 0;
let mut next_sba_descriptor = || {
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(sba_index)
.unwrap();
sba_descriptor.set_end_of_vertex_sample_index(metal::COUNTER_DONT_SAMPLE);
sba_descriptor.set_start_of_fragment_sample_index(metal::COUNTER_DONT_SAMPLE);
sba_index += 1;
sba_descriptor
};
for (set, index) in self.state.pending_timer_queries.drain(..) {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());
sba_descriptor.set_start_of_vertex_sample_index(index as _);
sba_descriptor.set_end_of_fragment_sample_index(metal::COUNTER_DONT_SAMPLE);
}
if let Some(ref timestamp_writes) = desc.timestamp_writes {
let sba_descriptor = descriptor.sample_buffer_attachments().object_at(0).unwrap();
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(
timestamp_writes
.query_set
@ -464,12 +597,16 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
.unwrap(),
);
if let Some(start_index) = timestamp_writes.beginning_of_pass_write_index {
sba_descriptor.set_start_of_vertex_sample_index(start_index as _);
}
if let Some(end_index) = timestamp_writes.end_of_pass_write_index {
sba_descriptor.set_end_of_fragment_sample_index(end_index as _);
}
sba_descriptor.set_start_of_vertex_sample_index(
timestamp_writes
.beginning_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
sba_descriptor.set_end_of_fragment_sample_index(
timestamp_writes
.end_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
}
if let Some(occlusion_query_set) = desc.occlusion_query_set {
@ -697,19 +834,19 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
}
unsafe fn insert_debug_marker(&mut self, label: &str) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.insert_debug_signpost(label);
}
}
unsafe fn begin_debug_marker(&mut self, group_label: &str) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.push_debug_group(group_label);
} else if let Some(ref buf) = self.raw_cmd_buf {
buf.push_debug_group(group_label);
}
}
unsafe fn end_debug_marker(&mut self) {
if let Some(encoder) = self.enter_any() {
if let Some(encoder) = self.active_encoder() {
encoder.pop_debug_group();
} else if let Some(ref buf) = self.raw_cmd_buf {
buf.pop_debug_group();
@ -969,11 +1106,25 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
objc::rc::autoreleasepool(|| {
let descriptor = metal::ComputePassDescriptor::new();
if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
let mut sba_index = 0;
let mut next_sba_descriptor = || {
let sba_descriptor = descriptor
.sample_buffer_attachments()
.object_at(0 as _)
.object_at(sba_index)
.unwrap();
sba_index += 1;
sba_descriptor
};
for (set, index) in self.state.pending_timer_queries.drain(..) {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(set.counter_sample_buffer.as_ref().unwrap());
sba_descriptor.set_start_of_encoder_sample_index(index as _);
sba_descriptor.set_end_of_encoder_sample_index(metal::COUNTER_DONT_SAMPLE);
}
if let Some(timestamp_writes) = desc.timestamp_writes.as_ref() {
let sba_descriptor = next_sba_descriptor();
sba_descriptor.set_sample_buffer(
timestamp_writes
.query_set
@ -982,12 +1133,16 @@ impl crate::CommandEncoder<super::Api> for super::CommandEncoder {
.unwrap(),
);
if let Some(start_index) = timestamp_writes.beginning_of_pass_write_index {
sba_descriptor.set_start_of_encoder_sample_index(start_index as _);
}
if let Some(end_index) = timestamp_writes.end_of_pass_write_index {
sba_descriptor.set_end_of_encoder_sample_index(end_index as _);
}
sba_descriptor.set_start_of_encoder_sample_index(
timestamp_writes
.beginning_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
sba_descriptor.set_end_of_encoder_sample_index(
timestamp_writes
.end_of_pass_write_index
.map_or(metal::COUNTER_DONT_SAMPLE, |i| i as _),
);
}
let encoder = raw.compute_command_encoder_with_descriptor(descriptor);

View File

@ -33,6 +33,7 @@ use std::{
};
use arrayvec::ArrayVec;
use bitflags::bitflags;
use metal::foreign_types::ForeignTypeRef as _;
use parking_lot::Mutex;
@ -143,6 +144,24 @@ impl crate::Instance<Api> for Instance {
}
}
bitflags!(
/// Similar to `MTLCounterSamplingPoint`, but a bit higher abstracted for our purposes.
#[derive(Debug, Copy, Clone)]
pub struct TimestampQuerySupport: u32 {
/// On creating Metal encoders.
const STAGE_BOUNDARIES = 1 << 1;
/// Within existing draw encoders.
const ON_RENDER_ENCODER = Self::STAGE_BOUNDARIES.bits() | (1 << 2);
/// Within existing dispatch encoders.
const ON_COMPUTE_ENCODER = Self::STAGE_BOUNDARIES.bits() | (1 << 3);
/// Within existing blit encoders.
const ON_BLIT_ENCODER = Self::STAGE_BOUNDARIES.bits() | (1 << 4);
/// Within any wgpu render/compute pass.
const INSIDE_WGPU_PASSES = Self::ON_RENDER_ENCODER.bits() | Self::ON_COMPUTE_ENCODER.bits();
}
);
#[allow(dead_code)]
#[derive(Clone, Debug)]
struct PrivateCapabilities {
@ -239,8 +258,7 @@ struct PrivateCapabilities {
supports_preserve_invariance: bool,
supports_shader_primitive_index: bool,
has_unified_memory: Option<bool>,
support_timestamp_query: bool,
support_timestamp_query_in_passes: bool,
timestamp_query_support: TimestampQuerySupport,
}
#[derive(Clone, Debug)]
@ -704,7 +722,7 @@ pub struct ComputePipeline {
unsafe impl Send for ComputePipeline {}
unsafe impl Sync for ComputePipeline {}
#[derive(Debug)]
#[derive(Debug, Clone)]
pub struct QuerySet {
raw_buffer: metal::Buffer,
//Metal has a custom buffer for counters.
@ -787,6 +805,9 @@ struct CommandState {
work_group_memory_sizes: Vec<u32>,
push_constants: Vec<u32>,
/// Timer query that should be executed when the next pass starts.
pending_timer_queries: Vec<(QuerySet, u32)>,
}
pub struct CommandEncoder {

View File

@ -270,7 +270,7 @@ bitflags::bitflags! {
/// Supported Platforms:
/// - Vulkan
/// - DX12
/// - Metal - TODO: Not yet supported on command encoder.
/// - Metal
///
/// This is a web and native feature.
const TIMESTAMP_QUERY = 1 << 1;
@ -458,10 +458,9 @@ bitflags::bitflags! {
/// Supported platforms:
/// - Vulkan
/// - DX12
/// - Metal (AMD & Intel, not Apple GPUs)
///
/// This is currently unimplemented on Metal.
/// When implemented, it will be supported on Metal on AMD and Intel GPUs, but not Apple GPUs.
/// (This is a common limitation of tile-based rasterization GPUs)
/// This is generally not available on tile-based rasterization GPUs.
///
/// This is a native only feature with a [proposal](https://github.com/gpuweb/gpuweb/blob/0008bd30da2366af88180b511a5d0d0c1dffbc36/proposals/timestamp-query-inside-passes.md) for the web.
const TIMESTAMP_QUERY_INSIDE_PASSES = 1 << 33;