diff --git a/wgpu-native/src/track.rs b/wgpu-native/src/track/mod.rs similarity index 73% rename from wgpu-native/src/track.rs rename to wgpu-native/src/track/mod.rs index c6d31f0f2..9acbcca85 100644 --- a/wgpu-native/src/track.rs +++ b/wgpu-native/src/track/mod.rs @@ -1,3 +1,5 @@ +mod range; + use crate::{ conv, device::MAX_MIP_LEVELS, @@ -18,15 +20,14 @@ use hal::backend::FastHashMap; use std::{ borrow::Borrow, - cmp::Ordering, collections::hash_map::Entry, - iter::Peekable, marker::PhantomData, ops::Range, - slice, vec::Drain, }; +use range::RangedStates; + /// A single unit of state tracking. #[derive(Clone, Copy, Debug, PartialEq)] @@ -362,163 +363,6 @@ impl ResourceState for BufferState { } -#[derive(Clone, Debug)] -pub struct RangedStates { - ranges: Vec<(Range, T)>, -} - -impl Default for RangedStates { - fn default() -> Self { - RangedStates { - ranges: Vec::new(), - } - } -} - -impl RangedStates { - fn _check_sanity(&self) { - for a in self.ranges.iter() { - assert!(a.0.start < a.0.end); - } - for (a, b) in self.ranges.iter().zip(self.ranges[1..].iter()) { - assert!(a.0.end <= b.0.start); - } - } - - fn _coalesce(&mut self) { - let mut num_removed = 0; - let mut iter = self.ranges.iter_mut(); - let mut cur = match iter.next() { - Some(elem) => elem, - None => return, - }; - while let Some(next) = iter.next() { - if cur.0.end == next.0.start && cur.1 == next.1 { - num_removed += 1; - cur.0.end = next.0.end; - next.0.end = next.0.start; - } else { - cur = next; - } - } - if num_removed != 0 { - self.ranges.retain(|pair| pair.0.start != pair.0.end); - } - } - - fn isolate(&mut self, index: &Range, default: T) -> &mut [(Range, T)] { - let start_pos = match self.ranges - .iter() - .position(|pair| pair.0.end > index.start) - { - Some(pos) => pos, - None => { - let pos = self.ranges.len(); - self.ranges.push((index.clone(), default)); - return &mut self.ranges[pos ..]; - } - }; - - let mut pos = start_pos; - let mut range_pos = index.start; - loop { - let (range, unit) = self.ranges[pos].clone(); - if range.start >= index.end { - self.ranges.insert(pos, (range_pos .. index.end, default)); - pos += 1; - break; - } - if range.start > range_pos { - self.ranges.insert(pos, (range_pos .. range.start, default)); - pos += 1; - range_pos = range.start; - } - if range.end >= index.end { - self.ranges[pos].0.start = index.end; - self.ranges.insert(pos, (range_pos .. index.end, unit)); - pos += 1; - break; - } - pos += 1; - range_pos = range.end; - if pos == self.ranges.len() { - self.ranges.push((range_pos .. index.end, default)); - pos += 1; - break; - } - } - - &mut self.ranges[start_pos .. pos] - } -} - -struct Merge<'a, I, T> { - base: I, - sa: Peekable, T)>>, - sb: Peekable, T)>>, -} - -impl<'a, I: Copy + Ord, T: Copy> Iterator for Merge<'a, I, T> { - type Item = (Range, Range); - fn next(&mut self) -> Option { - match (self.sa.peek(), self.sb.peek()) { - // we have both streams - (Some(&(ref ra, va)), Some(&(ref rb, vb))) => { - let (range, usage) = if ra.start < self.base { // in the middle of the left stream - if self.base == rb.start { // right stream is starting - debug_assert!(self.base < ra.end); - (self.base .. ra.end.min(rb.end), *va .. *vb) - } else { // right hasn't started yet - debug_assert!(self.base < rb.start); - (self.base .. rb.start, *va .. *va) - } - } else if rb.start < self.base { // in the middle of the right stream - if self.base == ra.start { // left stream is starting - debug_assert!(self.base < rb.end); - (self.base .. ra.end.min(rb.end), *va .. *vb) - } else { // left hasn't started yet - debug_assert!(self.base < ra.start); - (self.base .. ra.start, *vb .. *vb) - } - } else { // no active streams - match ra.start.cmp(&rb.start) { - // both are starting - Ordering::Equal => (ra.start .. ra.end.min(rb.end), *va .. *vb), - // only left is starting - Ordering::Less => (ra.start .. rb.start, *va .. *va), - // only right is starting - Ordering::Greater => (rb.start .. ra.start, *vb .. *vb), - } - }; - self.base = range.end; - if ra.end == range.end { - let _ = self.sa.next(); - } - if rb.end == range.end { - let _ = self.sb.next(); - } - Some((range, usage)) - } - // only right stream - (None, Some(&(ref rb, vb))) => { - let range = self.base.max(rb.start) .. rb.end; - self.base = rb.end; - let _ = self.sb.next(); - Some((range, *vb .. *vb)) - } - // only left stream - (Some(&(ref ra, va)), None) => { - let range = self.base.max(ra.start) .. ra.end; - self.base = ra.end; - let _ = self.sa.next(); - Some((range, *va .. *va)) - } - // done - (None, None) => None, - } - } -} - type PlaneStates = RangedStates; //TODO: store `hal::image::State` here to avoid extra conversions @@ -556,7 +400,7 @@ impl ResourceState for TextureStates { let layer_start = num_levels.min(selector.levels.start as usize); let layer_end = num_levels.min(selector.levels.end as usize); for layer in self.color_mips[layer_start .. layer_end].iter() { - for &(ref range, ref unit) in layer.ranges.iter() { + for &(ref range, ref unit) in layer.iter() { if range.end > selector.layers.start && range.start < selector.layers.end { let old = usage.replace(unit.last); if old.is_some() && old != usage { @@ -567,7 +411,7 @@ impl ResourceState for TextureStates { } } if selector.aspects.intersects(hal::format::Aspects::DEPTH | hal::format::Aspects::STENCIL) { - for &(ref range, ref ds) in self.depth_stencil.ranges.iter() { + for &(ref range, ref ds) in self.depth_stencil.iter() { if range.end > selector.layers.start && range.start < selector.layers.end { if selector.aspects.contains(hal::format::Aspects::DEPTH) { let old = usage.replace(ds.depth.last); @@ -712,12 +556,8 @@ impl ResourceState for TextureStates { .zip(&other.color_mips) .enumerate() { - temp_color.extend(Merge { - base: 0, - sa: mip_self.ranges.iter().peekable(), - sb: mip_other.ranges.iter().peekable(), - }); - mip_self.ranges.clear(); + temp_color.extend(mip_self.merge(mip_other, 0)); + mip_self.clear(); for (layers, states) in temp_color.drain(..) { let color_usage = states.start.last .. states.end.select(stitch); if let Some(out) = output.as_mut() { @@ -734,20 +574,16 @@ impl ResourceState for TextureStates { }); } } - mip_self.ranges.push((layers, Unit { + mip_self.append(layers, Unit { init: states.start.init, last: color_usage.end, - })); + }); } } let mut temp_ds = Vec::new(); - temp_ds.extend(Merge { - base: 0, - sa: self.depth_stencil.ranges.iter().peekable(), - sb: other.depth_stencil.ranges.iter().peekable(), - }); - self.depth_stencil.ranges.clear(); + temp_ds.extend(self.depth_stencil.merge(&other.depth_stencil, 0)); + self.depth_stencil.clear(); for (layers, states) in temp_ds.drain(..) { let usage_depth = states.start.depth.last .. states.end.depth.select(stitch); let usage_stencil = states.start.stencil.last .. states.end.stencil.select(stitch); @@ -775,7 +611,7 @@ impl ResourceState for TextureStates { }); } } - self.depth_stencil.ranges.push((layers, DepthStencilState { + self.depth_stencil.append(layers, DepthStencilState { depth: Unit { init: states.start.depth.init, last: usage_depth.end, @@ -784,7 +620,7 @@ impl ResourceState for TextureStates { init: states.start.stencil.init, last: usage_stencil.end, }, - })); + }); } Ok(()) @@ -857,79 +693,3 @@ impl TrackerSet { self.bind_groups.merge_extend(&other.bind_groups).unwrap(); } } - -#[cfg(test)] -mod test_range { - use super::RangedStates; - - #[test] - fn test_sane0() { - let rs = RangedStates { ranges: vec![ - (1..4, 9u8), - (4..5, 9), - ]}; - rs._check_sanity(); - } - - #[test(must_fail)] - fn test_sane1() { - let rs = RangedStates { ranges: vec![ - (1..4, 9u8), - (5..5, 9), - ]}; - rs._check_sanity(); - } - - #[test(must_fail)] - fn test_sane2() { - let rs = RangedStates { ranges: vec![ - (1..4, 9u8), - (3..5, 9), - ]}; - rs._check_sanity(); - } - - #[test] - fn test_coalesce() { - let mut rs = RangedStates { ranges: vec![ - (1..4, 9u8), - (4..5, 9), - (5..7, 1), - (8..9, 1), - ]}; - rs._coalesce(); - assert_eq!(rs.ranges, vec![ - (1..5, 9), - (5..7, 1), - (8..9, 1), - ]); - } - - #[test] - fn test_isolate() { - let rs = RangedStates { ranges: vec![ - (1..4, 9u8), - (4..5, 9), - (5..7, 1), - (8..9, 1), - ]}; - assert_eq!(rs.clone().isolate(&(4..5), 0), [ - (4..5, 9u8), - ]); - assert_eq!(rs.clone().isolate(&(0..6), 0), [ - (0..1, 0), - (1..4, 9u8), - (4..5, 9), - (5..6, 1), - ]); - assert_eq!(rs.clone().isolate(&(8..10), 1), [ - (8..9, 1), - (9..10, 1), - ]); - assert_eq!(rs.clone().isolate(&(6..8), 0), [ - (6..7, 1), - (7..8, 0), - (8..9, 1), - ]); - } -} diff --git a/wgpu-native/src/track/range.rs b/wgpu-native/src/track/range.rs new file mode 100644 index 000000000..6ae346f87 --- /dev/null +++ b/wgpu-native/src/track/range.rs @@ -0,0 +1,278 @@ +use std::{ + cmp::Ordering, + iter::Peekable, + ops::Range, + slice::Iter, +}; + +#[derive(Clone, Debug)] +pub struct RangedStates { + ranges: Vec<(Range, T)>, +} + +impl Default for RangedStates { + fn default() -> Self { + RangedStates { + ranges: Vec::new(), + } + } +} + +impl RangedStates { + pub fn clear(&mut self) { + self.ranges.clear(); + } + + pub fn append(&mut self, index: Range, value: T) { + self.ranges.push((index, value)); + } + + pub fn iter(&self) -> Iter<(Range, T)> { + self.ranges.iter() + } + + #[cfg(test)] + fn check_sanity(&self) { + for a in self.ranges.iter() { + assert!(a.0.start < a.0.end); + } + for (a, b) in self.ranges.iter().zip(self.ranges[1..].iter()) { + assert!(a.0.end <= b.0.start); + } + } + + #[cfg(test)] + fn coalesce(&mut self) { + let mut num_removed = 0; + let mut iter = self.ranges.iter_mut(); + let mut cur = match iter.next() { + Some(elem) => elem, + None => return, + }; + while let Some(next) = iter.next() { + if cur.0.end == next.0.start && cur.1 == next.1 { + num_removed += 1; + cur.0.end = next.0.end; + next.0.end = next.0.start; + } else { + cur = next; + } + } + if num_removed != 0 { + self.ranges.retain(|pair| pair.0.start != pair.0.end); + } + } + + pub fn isolate(&mut self, index: &Range, default: T) -> &mut [(Range, T)] { + //TODO: implement this in 2 passes: + // 1. scan the ranges to figure out how many extra ones need to be inserted + // 2. go through the ranges by moving them them to the right and inserting the missing ones + + let mut start_pos = match self.ranges + .iter() + .position(|pair| pair.0.end > index.start) + { + Some(pos) => pos, + None => { + let pos = self.ranges.len(); + self.ranges.push((index.clone(), default)); + return &mut self.ranges[pos ..]; + } + }; + + { + let (range, value) = self.ranges[start_pos].clone(); + if range.start < index.start { + self.ranges[start_pos].0.start = index.start; + self.ranges.insert(start_pos, (range.start .. index.start, value)); + start_pos += 1; + } + } + let mut pos = start_pos; + let mut range_pos = index.start; + loop { + let (range, value) = self.ranges[pos].clone(); + if range.start >= index.end { + self.ranges.insert(pos, (range_pos .. index.end, default)); + pos += 1; + break; + } + if range.start > range_pos { + self.ranges.insert(pos, (range_pos .. range.start, default)); + pos += 1; + range_pos = range.start; + } + if range.end >= index.end { + if range.end != index.end { + self.ranges[pos].0.start = index.end; + self.ranges.insert(pos, (range_pos .. index.end, value)); + } + pos += 1; + break; + } + pos += 1; + range_pos = range.end; + if pos == self.ranges.len() { + self.ranges.push((range_pos .. index.end, default)); + pos += 1; + break; + } + } + + &mut self.ranges[start_pos .. pos] + } + + pub fn merge<'a>(&'a self, other: &'a Self, base: I) -> Merge<'a, I, T> { + Merge { + base, + sa: self.ranges.iter().peekable(), + sb: other.ranges.iter().peekable(), + } + } +} + +pub struct Merge<'a, I, T> { + base: I, + sa: Peekable, T)>>, + sb: Peekable, T)>>, +} + +impl<'a, I: Copy + Ord, T: Copy> Iterator for Merge<'a, I, T> { + type Item = (Range, Range); + fn next(&mut self) -> Option { + match (self.sa.peek(), self.sb.peek()) { + // we have both streams + (Some(&(ref ra, va)), Some(&(ref rb, vb))) => { + let (range, usage) = if ra.start < self.base { // in the middle of the left stream + if self.base == rb.start { // right stream is starting + debug_assert!(self.base < ra.end); + (self.base .. ra.end.min(rb.end), *va .. *vb) + } else { // right hasn't started yet + debug_assert!(self.base < rb.start); + (self.base .. rb.start, *va .. *va) + } + } else if rb.start < self.base { // in the middle of the right stream + if self.base == ra.start { // left stream is starting + debug_assert!(self.base < rb.end); + (self.base .. ra.end.min(rb.end), *va .. *vb) + } else { // left hasn't started yet + debug_assert!(self.base < ra.start); + (self.base .. ra.start, *vb .. *vb) + } + } else { // no active streams + match ra.start.cmp(&rb.start) { + // both are starting + Ordering::Equal => (ra.start .. ra.end.min(rb.end), *va .. *vb), + // only left is starting + Ordering::Less => (ra.start .. rb.start, *va .. *va), + // only right is starting + Ordering::Greater => (rb.start .. ra.start, *vb .. *vb), + } + }; + self.base = range.end; + if ra.end == range.end { + let _ = self.sa.next(); + } + if rb.end == range.end { + let _ = self.sb.next(); + } + Some((range, usage)) + } + // only right stream + (None, Some(&(ref rb, vb))) => { + let range = self.base.max(rb.start) .. rb.end; + self.base = rb.end; + let _ = self.sb.next(); + Some((range, *vb .. *vb)) + } + // only left stream + (Some(&(ref ra, va)), None) => { + let range = self.base.max(ra.start) .. ra.end; + self.base = ra.end; + let _ = self.sa.next(); + Some((range, *va .. *va)) + } + // done + (None, None) => None, + } + } +} + +#[cfg(test)] +mod test { + use super::RangedStates; + + #[test] + fn test_sane0() { + let rs = RangedStates { ranges: vec![ + (1..4, 9u8), + (4..5, 9), + ]}; + rs.check_sanity(); + } + + #[test] + #[should_panic] + fn test_sane1() { + let rs = RangedStates { ranges: vec![ + (1..4, 9u8), + (5..5, 9), + ]}; + rs.check_sanity(); + } + + #[test] + #[should_panic] + fn test_sane2() { + let rs = RangedStates { ranges: vec![ + (1..4, 9u8), + (3..5, 9), + ]}; + rs.check_sanity(); + } + + #[test] + fn test_coalesce() { + let mut rs = RangedStates { ranges: vec![ + (1..4, 9u8), + (4..5, 9), + (5..7, 1), + (8..9, 1), + ]}; + rs.coalesce(); + rs.check_sanity(); + assert_eq!(rs.ranges, vec![ + (1..5, 9), + (5..7, 1), + (8..9, 1), + ]); + } + + #[test] + fn test_isolate() { + let rs = RangedStates { ranges: vec![ + (1..4, 9u8), + (4..5, 9), + (5..7, 1), + (8..9, 1), + ]}; + assert_eq!(rs.clone().isolate(&(4..5), 0), [ + (4..5, 9u8), + ]); + assert_eq!(rs.clone().isolate(&(0..6), 0), [ + (0..1, 0), + (1..4, 9u8), + (4..5, 9), + (5..6, 1), + ]); + assert_eq!(rs.clone().isolate(&(8..10), 1), [ + (8..9, 1), + (9..10, 1), + ]); + assert_eq!(rs.clone().isolate(&(6..9), 0), [ + (6..7, 1), + (7..8, 0), + (8..9, 1), + ]); + } +}