hal/mtl: array length support

This commit is contained in:
Dzmitry Malyshau 2021-06-10 14:22:32 -04:00
parent 006f8abbba
commit 38e13a178b
5 changed files with 211 additions and 35 deletions

1
Cargo.lock generated
View File

@ -1918,6 +1918,7 @@ dependencies = [
"bitflags",
"block",
"foreign-types",
"fxhash",
"log",
"metal",
"naga",

View File

@ -19,6 +19,7 @@ metal = ["block", "foreign-types", "mtl", "objc", "parking_lot", "naga/msl-out"]
[dependencies]
arrayvec = "0.5"
bitflags = "1.0"
fxhash = "0.2.1"
log = "0.4"
parking_lot = { version = "0.11", optional = true }
raw-window-handle = "0.3"

View File

@ -1,6 +1,8 @@
use super::{conv, AsNative};
use std::{mem, ops::Range};
const WORD_SIZE: usize = 4;
impl super::CommandBuffer {
fn enter_blit(&mut self) -> &mtl::BlitCommandEncoderRef {
if self.blit.is_none() {
@ -25,6 +27,34 @@ impl super::CommandBuffer {
self.enter_blit()
}
}
fn begin_pass(&mut self) {
self.state.storage_buffer_length_map.clear();
self.state.stage_infos.vs.clear();
self.state.stage_infos.fs.clear();
self.state.stage_infos.cs.clear();
self.leave_blit();
}
}
impl super::CommandState {
fn make_sizes_buffer_update<'a>(
&self,
stage: naga::ShaderStage,
result_sizes: &'a mut Vec<wgt::BufferSize>,
) -> Option<(u32, &'a [wgt::BufferSize])> {
let stage_info = &self.stage_infos[stage];
let slot = stage_info.sizes_slot?;
result_sizes.clear();
for br in stage_info.sized_bindings.iter() {
// If it's None, this isn't the right time to update the sizes
let size = self
.storage_buffer_length_map
.get(&(br.group, br.binding))?;
result_sizes.push(*size);
}
Some((slot as _, result_sizes))
}
}
impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
@ -228,7 +258,9 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
// render
unsafe fn begin_render_pass(&mut self, desc: &crate::RenderPassDescriptor<super::Api>) {
self.leave_blit();
self.begin_pass();
self.state.index = None;
let descriptor = mtl::RenderPassDescriptor::new();
//TODO: set visibility results buffer
@ -311,13 +343,14 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
unsafe fn set_bind_group(
&mut self,
layout: &super::PipelineLayout,
index: u32,
group_index: u32,
group: &super::BindGroup,
dynamic_offsets: &[wgt::DynamicOffset],
) {
let bg_info = &layout.bind_group_infos[index as usize];
let bg_info = &layout.bind_group_infos[group_index as usize];
if let Some(ref encoder) = self.render {
let mut changes_sizes_buffer = false;
for index in 0..group.counters.vs.buffers {
let buf = &group.buffers[index as usize];
let mut offset = buf.offset;
@ -329,7 +362,27 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
self.state
.storage_buffer_length_map
.insert((group_index, buf.binding_location), size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Vertex,
&mut self.temp.binding_sizes,
) {
encoder.set_vertex_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
changes_sizes_buffer = false;
for index in 0..group.counters.fs.buffers {
let buf = &group.buffers[(group.counters.vs.buffers + index) as usize];
let mut offset = buf.offset;
@ -341,6 +394,24 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
self.state
.storage_buffer_length_map
.insert((group_index, buf.binding_location), size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Fragment,
&mut self.temp.binding_sizes,
) {
encoder.set_fragment_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
for index in 0..group.counters.vs.samplers {
@ -380,6 +451,8 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
samplers: group.counters.vs.samplers + group.counters.fs.samplers,
textures: group.counters.vs.textures + group.counters.fs.textures,
};
let mut changes_sizes_buffer = false;
for index in 0..group.counters.cs.buffers {
let buf = &group.buffers[(index_base.buffers + index) as usize];
let mut offset = buf.offset;
@ -391,7 +464,26 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
Some(buf.ptr.as_native()),
offset,
);
if let Some(size) = buf.binding_size {
self.state
.storage_buffer_length_map
.insert((group_index, buf.binding_location), size);
changes_sizes_buffer = true;
}
}
if changes_sizes_buffer {
if let Some((index, sizes)) = self.state.make_sizes_buffer_update(
naga::ShaderStage::Compute,
&mut self.temp.binding_sizes,
) {
encoder.set_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
for index in 0..group.counters.cs.samplers {
let res = group.samplers[(index_base.samplers + index) as usize];
encoder.set_sampler_state(
@ -430,7 +522,10 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
}
unsafe fn set_render_pipeline(&mut self, pipeline: &super::RenderPipeline) {
self.raw_primitive_type = pipeline.raw_primitive_type;
self.state.raw_primitive_type = pipeline.raw_primitive_type;
self.state.stage_infos.vs.assign_from(&pipeline.vs_info);
self.state.stage_infos.fs.assign_from(&pipeline.fs_info);
let encoder = self.render.as_ref().unwrap();
encoder.set_render_pipeline_state(&pipeline.raw);
encoder.set_front_facing_winding(pipeline.raw_front_winding);
@ -442,6 +537,31 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
encoder.set_depth_stencil_state(state);
encoder.set_depth_bias(bias.constant as f32, bias.slope_scale, bias.clamp);
}
{
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Vertex, &mut self.temp.binding_sizes)
{
encoder.set_vertex_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
if pipeline.fs_lib.is_some() {
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Fragment, &mut self.temp.binding_sizes)
{
encoder.set_fragment_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
}
unsafe fn set_index_buffer<'a>(
@ -453,7 +573,7 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
wgt::IndexFormat::Uint16 => (2, mtl::MTLIndexType::UInt16),
wgt::IndexFormat::Uint32 => (4, mtl::MTLIndexType::UInt32),
};
self.index_state = Some(super::IndexState {
self.state.index = Some(super::IndexState {
buffer_ptr: AsNative::from(binding.buffer.raw.as_ref()),
offset: binding.offset,
stride,
@ -522,7 +642,7 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
let encoder = self.render.as_ref().unwrap();
if start_instance != 0 {
encoder.draw_primitives_instanced_base_instance(
self.raw_primitive_type,
self.state.raw_primitive_type,
start_vertex as _,
vertex_count as _,
instance_count as _,
@ -530,14 +650,14 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
);
} else if instance_count != 1 {
encoder.draw_primitives_instanced(
self.raw_primitive_type,
self.state.raw_primitive_type,
start_vertex as _,
vertex_count as _,
instance_count as _,
);
} else {
encoder.draw_primitives(
self.raw_primitive_type,
self.state.raw_primitive_type,
start_vertex as _,
vertex_count as _,
);
@ -553,11 +673,11 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
instance_count: u32,
) {
let encoder = self.render.as_ref().unwrap();
let index = self.index_state.as_ref().unwrap();
let index = self.state.index.as_ref().unwrap();
let offset = index.offset + index.stride * start_index as wgt::BufferAddress;
if base_vertex != 0 || start_instance != 0 {
encoder.draw_indexed_primitives_instanced_base_instance(
self.raw_primitive_type,
self.state.raw_primitive_type,
index_count as _,
index.raw_type,
index.buffer_ptr.as_native(),
@ -568,7 +688,7 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
);
} else if instance_count != 1 {
encoder.draw_indexed_primitives_instanced(
self.raw_primitive_type,
self.state.raw_primitive_type,
index_count as _,
index.raw_type,
index.buffer_ptr.as_native(),
@ -577,7 +697,7 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
);
} else {
encoder.draw_indexed_primitives(
self.raw_primitive_type,
self.state.raw_primitive_type,
index_count as _,
index.raw_type,
index.buffer_ptr.as_native(),
@ -594,7 +714,7 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
) {
let encoder = self.render.as_ref().unwrap();
for _ in 0..draw_count {
encoder.draw_primitives_indirect(self.raw_primitive_type, &buffer.raw, offset);
encoder.draw_primitives_indirect(self.state.raw_primitive_type, &buffer.raw, offset);
offset += mem::size_of::<wgt::DrawIndirectArgs>() as wgt::BufferAddress;
}
}
@ -606,10 +726,10 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
draw_count: u32,
) {
let encoder = self.render.as_ref().unwrap();
let index = self.index_state.as_ref().unwrap();
let index = self.state.index.as_ref().unwrap();
for _ in 0..draw_count {
encoder.draw_indexed_primitives_indirect(
self.raw_primitive_type,
self.state.raw_primitive_type,
index.raw_type,
index.buffer_ptr.as_native(),
index.offset,
@ -644,7 +764,8 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
// compute
unsafe fn begin_compute_pass(&mut self, desc: &crate::ComputePassDescriptor) {
self.leave_blit();
self.begin_pass();
let encoder = self.raw.new_compute_command_encoder();
if let Some(label) = desc.label {
encoder.set_label(label);
@ -656,9 +777,22 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
}
unsafe fn set_compute_pipeline(&mut self, pipeline: &super::ComputePipeline) {
self.raw_wg_size = pipeline.work_group_size;
self.state.raw_wg_size = pipeline.work_group_size;
self.state.stage_infos.cs.assign_from(&pipeline.cs_info);
let encoder = self.compute.as_ref().unwrap();
encoder.set_compute_pipeline_state(&pipeline.raw);
if let Some((index, sizes)) = self
.state
.make_sizes_buffer_update(naga::ShaderStage::Compute, &mut self.temp.binding_sizes)
{
encoder.set_bytes(
index as _,
(sizes.len() * WORD_SIZE) as u64,
sizes.as_ptr() as _,
);
}
}
unsafe fn dispatch(&mut self, count: [u32; 3]) {
@ -668,11 +802,11 @@ impl crate::CommandBuffer<super::Api> for super::CommandBuffer {
height: count[1] as u64,
depth: count[2] as u64,
};
encoder.dispatch_thread_groups(raw_count, self.raw_wg_size);
encoder.dispatch_thread_groups(raw_count, self.state.raw_wg_size);
}
unsafe fn dispatch_indirect(&mut self, buffer: &super::Buffer, offset: wgt::BufferAddress) {
let encoder = self.compute.as_ref().unwrap();
encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.raw_wg_size);
encoder.dispatch_thread_groups_indirect(&buffer.raw, offset, self.state.raw_wg_size);
}
}

View File

@ -354,11 +354,16 @@ impl crate::Device<super::Api> for super::Device {
blit: None,
render: None,
compute: None,
raw_primitive_type: mtl::MTLPrimitiveType::Point,
index_state: None,
raw_wg_size: mtl::MTLSize::new(0, 0, 0),
max_buffers_per_stage: self.shared.private_caps.max_buffers_per_stage,
disabilities: self.shared.disabilities.clone(),
max_buffers_per_stage: self.shared.private_caps.max_buffers_per_stage,
state: super::CommandState {
raw_primitive_type: mtl::MTLPrimitiveType::Point,
index: None,
raw_wg_size: mtl::MTLSize::new(0, 0, 0),
stage_infos: Default::default(),
storage_buffer_length_map: Default::default(),
},
temp: super::Temp::default(),
})
}
unsafe fn destroy_command_buffer(&self, mut cmd_buf: super::CommandBuffer) {
@ -431,7 +436,6 @@ impl crate::Device<super::Api> for super::Device {
for (group_index, &bgl) in desc.bind_group_layouts.iter().enumerate() {
// remember where the resources for this set start at each shader stage
let mut dynamic_buffers = Vec::new();
let mut sized_buffer_bindings = Vec::new();
let base_resource_indices = stage_data.map(|info| info.counters.clone());
for entry in bgl.entries.iter() {
@ -451,7 +455,6 @@ impl crate::Device<super::Api> for super::Device {
}));
}
if let wgt::BufferBindingType::Storage { .. } = ty {
sized_buffer_bindings.push((entry.binding, entry.visibility));
for info in stage_data.iter_mut() {
if entry.visibility.contains(map_naga_stage(info.stage)) {
info.sizes_count += 1;
@ -506,8 +509,6 @@ impl crate::Device<super::Api> for super::Device {
bind_group_infos.push(super::BindGroupLayoutInfo {
base_resource_indices,
//dynamic_buffers,
sized_buffer_bindings,
});
}
@ -592,9 +593,17 @@ impl crate::Device<super::Api> for super::Device {
}
match layout.ty {
wgt::BindingType::Buffer {
has_dynamic_offset, ..
ty,
has_dynamic_offset,
..
} => {
let source = &desc.buffers[entry.resource_index as usize];
let binding_size = match ty {
wgt::BufferBindingType::Storage { .. } => source
.size
.or(wgt::BufferSize::new(source.buffer.size - source.offset)),
_ => None,
};
bg.buffers.push(super::BufferResource {
ptr: source.buffer.as_raw(),
offset: source.offset,
@ -603,6 +612,8 @@ impl crate::Device<super::Api> for super::Device {
} else {
None
},
binding_size,
binding_location: layout.binding,
});
counter.buffers += 1;
}

View File

@ -463,8 +463,6 @@ type MultiStageResourceCounters = MultiStageData<ResourceData<ResourceIndex>>;
#[derive(Debug)]
struct BindGroupLayoutInfo {
base_resource_indices: MultiStageResourceCounters,
//dynamic_buffers: Vec<MultiStageData<ResourceIndex>>,
sized_buffer_bindings: Vec<(u32, wgt::ShaderStage)>,
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
@ -532,6 +530,8 @@ struct BufferResource {
ptr: BufferPtr,
offset: wgt::BufferAddress,
dynamic_index: Option<u32>,
binding_size: Option<wgt::BufferSize>,
binding_location: u32,
}
#[derive(Debug, Default)]
@ -557,10 +557,26 @@ struct PipelineStageInfo {
sized_bindings: Vec<naga::ResourceBinding>,
}
#[allow(dead_code)] // silence xx_lib and xx_info warnings
impl PipelineStageInfo {
fn clear(&mut self) {
self.push_constants = None;
self.sizes_slot = None;
self.sized_bindings.clear();
}
fn assign_from(&mut self, other: &Self) {
self.push_constants = other.push_constants;
self.sizes_slot = other.sizes_slot;
self.sized_bindings.clear();
self.sized_bindings.extend_from_slice(&other.sized_bindings);
}
}
pub struct RenderPipeline {
raw: mtl::RenderPipelineState,
#[allow(dead_code)]
vs_lib: mtl::Library,
#[allow(dead_code)]
fs_lib: Option<mtl::Library>,
vs_info: PipelineStageInfo,
fs_info: PipelineStageInfo,
@ -574,9 +590,9 @@ pub struct RenderPipeline {
unsafe impl Send for RenderPipeline {}
unsafe impl Sync for RenderPipeline {}
#[allow(dead_code)] // silence xx_lib and xx_info warnings
pub struct ComputePipeline {
raw: mtl::ComputePipelineState,
#[allow(dead_code)]
cs_lib: mtl::Library,
cs_info: PipelineStageInfo,
work_group_size: mtl::MTLSize,
@ -628,16 +644,29 @@ struct IndexState {
raw_type: mtl::MTLIndexType,
}
#[derive(Default)]
struct Temp {
binding_sizes: Vec<wgt::BufferSize>,
}
struct CommandState {
raw_primitive_type: mtl::MTLPrimitiveType,
index: Option<IndexState>,
raw_wg_size: mtl::MTLSize,
stage_infos: MultiStageData<PipelineStageInfo>,
//TODO: use `naga::ResourceBinding` for keys
storage_buffer_length_map: fxhash::FxHashMap<(u32, u32), wgt::BufferSize>,
}
pub struct CommandBuffer {
raw: mtl::CommandBuffer,
blit: Option<mtl::BlitCommandEncoder>,
render: Option<mtl::RenderCommandEncoder>,
compute: Option<mtl::ComputeCommandEncoder>,
raw_primitive_type: mtl::MTLPrimitiveType,
index_state: Option<IndexState>,
raw_wg_size: mtl::MTLSize,
max_buffers_per_stage: u32,
disabilities: PrivateDisabilities,
state: CommandState,
temp: Temp,
}
unsafe impl Send for CommandBuffer {}