[msl] make sizes buffer local to each function

This commit is contained in:
Dzmitry Malyshau 2021-05-13 01:26:08 -04:00 committed by Dzmitry Malyshau
parent 0250ffe2fb
commit a8df6a2b34
4 changed files with 77 additions and 60 deletions

View File

@ -342,6 +342,21 @@ fn should_pack_struct_member(
}
}
fn needs_array_length(ty: Handle<crate::Type>, arena: &Arena<crate::Type>) -> bool {
if let crate::TypeInner::Struct { ref members, .. } = arena[ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = arena[member.ty].inner
{
return true;
}
}
}
false
}
impl crate::StorageClass {
/// Returns true for storage classes, for which the global
/// variables are passed in function arguments.
@ -1488,8 +1503,12 @@ impl<W: Write> Writer<W> {
// follow-up with any global resources used
let mut separate = !arguments.is_empty();
let fun_info = &context.mod_info[function];
let mut supports_array_length = false;
for (handle, var) in context.expression.module.global_variables.iter() {
if !fun_info[handle].is_empty() && var.class.needs_pass_through() {
if fun_info[handle].is_empty() {
continue;
}
if var.class.needs_pass_through() {
let name = &self.names[&NameKey::GlobalVariable(handle)];
if separate {
write!(self.out, ", ")?;
@ -1498,12 +1517,13 @@ impl<W: Write> Writer<W> {
}
write!(self.out, "{}", name)?;
}
supports_array_length |=
needs_array_length(var.ty, &context.expression.module.types);
}
if !self.runtime_sized_buffers.is_empty() {
if supports_array_length {
if separate {
write!(self.out, ", ")?;
}
write!(self.out, "_buffer_sizes")?;
}
@ -1542,19 +1562,11 @@ impl<W: Write> Writer<W> {
{
let mut indices = vec![];
for (handle, gv) in module.global_variables.iter() {
if let crate::TypeInner::Struct { ref members, .. } = module.types[gv.ty].inner {
if let Some(member) = members.last() {
if let crate::TypeInner::Array {
size: crate::ArraySize::Dynamic,
..
} = module.types[member.ty].inner
{
let idx = handle.index();
self.runtime_sized_buffers.insert(handle, idx);
indices.push(idx);
}
}
for (handle, var) in module.global_variables.iter() {
if needs_array_length(var.ty, &module.types) {
let idx = handle.index();
self.runtime_sized_buffers.insert(handle, idx);
indices.push(idx);
}
}
@ -1845,9 +1857,13 @@ impl<W: Write> Writer<W> {
for (fun_handle, fun) in module.functions.iter() {
let fun_info = &mod_info[fun_handle];
pass_through_globals.clear();
let mut supports_array_length = false;
for (handle, var) in module.global_variables.iter() {
if !fun_info[handle].is_empty() && var.class.needs_pass_through() {
pass_through_globals.push(handle);
if !fun_info[handle].is_empty() {
if var.class.needs_pass_through() {
pass_through_globals.push(handle);
}
supports_array_length |= needs_array_length(var.ty, &module.types);
}
}
@ -1882,7 +1898,7 @@ impl<W: Write> Writer<W> {
let separator = separate(
!pass_through_globals.is_empty()
|| index + 1 != fun.arguments.len()
|| !self.runtime_sized_buffers.is_empty(),
|| supports_array_length,
);
writeln!(
self.out,
@ -1898,16 +1914,14 @@ impl<W: Write> Writer<W> {
usage: fun_info[handle],
reference: true,
};
let separator = separate(
index + 1 != pass_through_globals.len()
|| !self.runtime_sized_buffers.is_empty(),
);
let separator =
separate(index + 1 != pass_through_globals.len() || supports_array_length);
write!(self.out, "{}", INDENT)?;
tyvar.try_fmt(&mut self.out)?;
writeln!(self.out, "{}", separator)?;
}
if !self.runtime_sized_buffers.is_empty() {
if supports_array_length {
writeln!(
self.out,
"{}constant _mslBufferSizes& _buffer_sizes",
@ -1961,42 +1975,45 @@ impl<W: Write> Writer<W> {
for (ep_index, ep) in module.entry_points.iter().enumerate() {
let fun = &ep.function;
let fun_info = mod_info.get_entry_point(ep_index);
let mut ep_error = None;
let mut supports_array_length = false;
// skip this entry point if any global bindings are missing
if !options.fake_missing_bindings {
if let Some(err) = module
.global_variables
.iter()
.find_map(|(var_handle, var)| {
if !fun_info[var_handle].is_empty() {
if let Some(ref br) = var.binding {
if let Err(e) = options.resolve_resource_binding(ep.stage, br) {
return Some(e);
}
}
if var.class == crate::StorageClass::PushConstant {
if let Err(e) = options.resolve_push_constants(ep.stage) {
return Some(e);
}
}
}
None
})
{
info.entry_point_names.push(Err(err));
continue;
}
if !self.runtime_sized_buffers.is_empty() {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
info.entry_point_names.push(Err(err));
for (var_handle, var) in module.global_variables.iter() {
if fun_info[var_handle].is_empty() {
continue;
}
if let Some(ref br) = var.binding {
if let Err(e) = options.resolve_resource_binding(ep.stage, br) {
ep_error = Some(e);
break;
}
}
if var.class == crate::StorageClass::PushConstant {
if let Err(e) = options.resolve_push_constants(ep.stage) {
ep_error = Some(e);
break;
}
}
supports_array_length |= needs_array_length(var.ty, &module.types);
}
if supports_array_length {
if let Err(err) = options.resolve_sizes_buffer(ep.stage) {
ep_error = Some(err);
}
}
}
writeln!(self.out)?;
if let Some(err) = ep_error {
info.entry_point_names.push(Err(err));
continue;
}
let fun_name = &self.names[&NameKey::EntryPoint(ep_index as _)];
info.entry_point_names.push(Ok(fun_name.clone()));
writeln!(self.out)?;
let stage_out_name = format!("{}Output", fun_name);
let stage_in_name = format!("{}Input", fun_name);
@ -2212,7 +2229,7 @@ impl<W: Write> Writer<W> {
writeln!(self.out)?;
}
if !self.runtime_sized_buffers.is_empty() {
if supports_array_length {
// this is checked earlier
let resolved = options.resolve_sizes_buffer(ep.stage).unwrap();
let separator = if module.global_variables.is_empty() {

View File

@ -11,8 +11,7 @@ struct PrimeIndices {
};
metal::uint collatz_iterations(
metal::uint n_base,
constant _mslBufferSizes& _buffer_sizes
metal::uint n_base
) {
metal::uint n;
metal::uint i = 0u;
@ -36,9 +35,8 @@ struct main1Input {
kernel void main1(
metal::uint3 global_id [[thread_position_in_grid]]
, device PrimeIndices& v_indices [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x], _buffer_sizes);
metal::uint _e9 = collatz_iterations(v_indices.data[global_id.x]);
v_indices.data[global_id.x] = _e9;
return;
}

View File

@ -24,8 +24,7 @@ float fetch_shadow(
metal::uint light_id,
metal::float4 homogeneous_coords,
metal::depth2d_array<float, metal::access::sample> t_shadow,
metal::sampler sampler_shadow,
constant _mslBufferSizes& _buffer_sizes
metal::sampler sampler_shadow
) {
if (homogeneous_coords.w <= 0.0) {
return 1.0;
@ -47,7 +46,6 @@ fragment fs_mainOutput fs_main(
, constant Lights& s_lights [[user(fake0)]]
, metal::depth2d_array<float, metal::access::sample> t_shadow [[user(fake0)]]
, metal::sampler sampler_shadow [[user(fake0)]]
, constant _mslBufferSizes& _buffer_sizes [[user(fake0)]]
) {
const auto raw_normal = varyings.raw_normal;
const auto position = varyings.position;
@ -63,7 +61,7 @@ fragment fs_mainOutput fs_main(
break;
}
Light _e21 = s_lights.data[i];
float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow, _buffer_sizes);
float _e25 = fetch_shadow(i, _e21.proj * position, t_shadow, sampler_shadow);
color1 = color1 + ((_e25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(_e21.pos.xyz - position.xyz)))) * _e21.color.xyz);
}
return fs_mainOutput { metal::float4(color1, 1.0) };

View File

@ -303,7 +303,11 @@ fn convert_spv(name: &str, adjust_coordinate_space: bool, targets: Targets) {
#[cfg(feature = "spv-in")]
#[test]
fn convert_spv_quad_vert() {
convert_spv("quad-vert", false, Targets::METAL | Targets::GLSL | Targets::WGSL);
convert_spv(
"quad-vert",
false,
Targets::METAL | Targets::GLSL | Targets::WGSL,
);
}
#[cfg(feature = "spv-in")]