[msl] inline some of the constants

This commit is contained in:
Dzmitry Malyshau 2021-04-08 12:29:26 -04:00 committed by Dzmitry Malyshau
parent 2774dcb403
commit 6166b95b6a
8 changed files with 187 additions and 120 deletions

View File

@ -229,6 +229,42 @@ impl<'a> TypedGlobalVariable<'a> {
}
}
struct ConstantContext<'a> {
handle: Handle<crate::Constant>,
arena: &'a Arena<crate::Constant>,
names: &'a FastHashMap<NameKey, String>,
first_time: bool,
}
impl<'a> Display for ConstantContext<'a> {
fn fmt(&self, out: &mut Formatter<'_>) -> Result<(), FmtError> {
let con = &self.arena[self.handle];
if con.needs_alias() && !self.first_time {
let name = &self.names[&NameKey::Constant(self.handle)];
return write!(out, "{}", name);
}
match con.inner {
crate::ConstantInner::Scalar { value, width: _ } => match value {
crate::ScalarValue::Sint(value) => {
write!(out, "{}", value)
}
crate::ScalarValue::Uint(value) => {
write!(out, "{}u", value)
}
crate::ScalarValue::Float(value) => {
let suffix = if value.fract() == 0.0 { ".0" } else { "" };
write!(out, "{}{}", value, suffix)
}
crate::ScalarValue::Bool(value) => {
write!(out, "{}", value)
}
},
crate::ConstantInner::Composite { .. } => unreachable!("should be aliased"),
}
}
}
pub struct Writer<W> {
out: W,
names: FastHashMap<NameKey, String>,
@ -299,7 +335,7 @@ impl crate::StorageClass {
}
impl crate::Type {
// Returns `true` if we need to emit a type alias for this type.
// Returns `true` if we need to emit an alias for this type.
fn needs_alias(&self) -> bool {
use crate::TypeInner as Ti;
match self.inner {
@ -317,6 +353,16 @@ impl crate::Type {
}
}
impl crate::Constant {
// Returns `true` if we need to emit an alias for this constant.
fn needs_alias(&self) -> bool {
match self.inner {
crate::ConstantInner::Scalar { .. } => self.name.is_some(),
crate::ConstantInner::Composite { .. } => true,
}
}
}
enum FunctionOrigin {
Handle(Handle<crate::Function>),
EntryPoint(EntryPointIndex),
@ -545,8 +591,13 @@ impl<W: Write> Writer<W> {
}
}
crate::Expression::Constant(handle) => {
let handle_name = &self.names[&NameKey::Constant(handle)];
write!(self.out, "{}", handle_name)?;
let coco = ConstantContext {
handle,
arena: &context.module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, "{}", coco)?;
}
crate::Expression::Compose { ty, ref components } => {
let inner = &context.module.types[ty].inner;
@ -670,8 +721,13 @@ impl<W: Write> Writer<W> {
}
}
if let Some(constant) = offset {
let offset_str = &self.names[&NameKey::Constant(constant)];
write!(self.out, ", {}", offset_str)?;
let coco = ConstantContext {
handle: constant,
arena: &context.module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, ", {}", coco)?;
}
write!(self.out, ")")?;
}
@ -914,8 +970,13 @@ impl<W: Write> Writer<W> {
size: crate::ArraySize::Constant(const_handle),
..
} => {
let size_str = &self.names[&NameKey::Constant(const_handle)];
write!(self.out, "{}", size_str)?;
let coco = ConstantContext {
handle: const_handle,
arena: &context.module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, "{}", coco)?;
}
crate::TypeInner::Array { .. } => {
return Err(Error::FeatureNotImplemented(
@ -1296,13 +1357,21 @@ impl<W: Write> Writer<W> {
access: crate::StorageAccess::empty(),
first_time: false,
};
let size_str = match size {
write!(self.out, "typedef {} {}", base_name, name)?;
match size {
crate::ArraySize::Constant(const_handle) => {
&self.names[&NameKey::Constant(const_handle)]
let coco = ConstantContext {
handle: const_handle,
arena: &module.constants,
names: &self.names,
first_time: false,
};
writeln!(self.out, "[{}];", coco)?;
}
crate::ArraySize::Dynamic => "1",
};
writeln!(self.out, "typedef {} {}[{}];", base_name, name, size_str)?;
crate::ArraySize::Dynamic => {
writeln!(self.out, "[1];")?;
}
}
}
crate::TypeInner::Struct {
block: _,
@ -1345,29 +1414,33 @@ impl<W: Write> Writer<W> {
crate::ConstantInner::Scalar {
width: _,
ref value,
} => {
let name = &self.names[&NameKey::Constant(handle)];
} if constant.name.is_some() => {
debug_assert!(constant.needs_alias());
write!(self.out, "constexpr constant ")?;
match *value {
crate::ScalarValue::Sint(value) => {
write!(self.out, "int {} = {}", name, value)?;
crate::ScalarValue::Sint(_) => {
write!(self.out, "int")?;
}
crate::ScalarValue::Uint(value) => {
write!(self.out, "unsigned {} = {}u", name, value)?;
crate::ScalarValue::Uint(_) => {
write!(self.out, "unsigned")?;
}
crate::ScalarValue::Float(value) => {
write!(self.out, "float {} = {}", name, value)?;
if value.fract() == 0.0 {
write!(self.out, ".0")?;
}
crate::ScalarValue::Float(_) => {
write!(self.out, "float")?;
}
crate::ScalarValue::Bool(value) => {
write!(self.out, "bool {} = {}", name, value)?;
crate::ScalarValue::Bool(_) => {
write!(self.out, "bool")?;
}
}
writeln!(self.out, ";")?;
let name = &self.names[&NameKey::Constant(handle)];
let coco = ConstantContext {
handle,
arena: &module.constants,
names: &self.names,
first_time: true,
};
writeln!(self.out, " {} = {};", name, coco)?;
}
crate::ConstantInner::Composite { .. } => {}
_ => {}
}
}
Ok(())
@ -1378,6 +1451,7 @@ impl<W: Write> Writer<W> {
match constant.inner {
crate::ConstantInner::Scalar { .. } => {}
crate::ConstantInner::Composite { ty, ref components } => {
debug_assert!(constant.needs_alias());
let name = &self.names[&NameKey::Constant(handle)];
let ty_name = TypeContext {
handle: ty,
@ -1390,8 +1464,13 @@ impl<W: Write> Writer<W> {
write!(self.out, "constexpr constant {} {} = {{", ty_name, name,)?;
for (i, &sub_handle) in components.iter().enumerate() {
let separator = if i != 0 { ", " } else { "" };
let sub_name = &self.names[&NameKey::Constant(sub_handle)];
write!(self.out, "{}{}", separator, sub_name)?;
let coco = ConstantContext {
handle: sub_handle,
arena: &module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, "{}{}", separator, coco)?;
}
writeln!(self.out, "}};")?;
}
@ -1561,8 +1640,13 @@ impl<W: Write> Writer<W> {
let local_name = &self.names[&NameKey::FunctionLocal(fun_handle, local_handle)];
write!(self.out, "{}{} {}", INDENT, ty_name, local_name)?;
if let Some(value) = local.init {
let value_str = &self.names[&NameKey::Constant(value)];
write!(self.out, " = {}", value_str)?;
let coco = ConstantContext {
handle: value,
arena: &module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, " = {}", coco)?;
}
writeln!(self.out, ";")?;
}
@ -1801,8 +1885,13 @@ impl<W: Write> Writer<W> {
resolved.try_fmt_decorated(&mut self.out, "")?;
}
if let Some(value) = var.init {
let value_str = &self.names[&NameKey::Constant(value)];
write!(self.out, " = {}", value_str)?;
let coco = ConstantContext {
handle: value,
arena: &module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, " = {}", coco)?;
}
writeln!(self.out)?;
}
@ -1827,11 +1916,20 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{}", INDENT)?;
tyvar.try_fmt(&mut self.out)?;
let value_str = match var.init {
Some(value) => &self.names[&NameKey::Constant(value)],
None => "{}",
match var.init {
Some(value) => {
let coco = ConstantContext {
handle: value,
arena: &module.constants,
names: &self.names,
first_time: false,
};
writeln!(self.out, " = {};", coco)?;
}
None => {
writeln!(self.out, " = {{}};")?;
}
};
writeln!(self.out, " = {};", value_str)?;
} else if let Some(ref binding) = var.binding {
// write an inline sampler
let resolved = options.resolve_global_binding(ep.stage, binding).unwrap();
@ -1902,8 +2000,13 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{}{} {}", INDENT, ty_name, name)?;
if let Some(value) = local.init {
let value_str = &self.names[&NameKey::Constant(value)];
write!(self.out, " = {}", value_str)?;
let coco = ConstantContext {
handle: value,
arena: &module.constants,
names: &self.names,
first_time: false,
};
write!(self.out, " = {}", coco)?;
}
writeln!(self.out, ";")?;
}

View File

@ -6,14 +6,6 @@ expression: msl
#include <simd/simd.h>
constexpr constant unsigned NUM_PARTICLES = 1500u;
constexpr constant float const_0f = 0.0;
constexpr constant int const_0i = 0;
constexpr constant unsigned const_0u = 0u;
constexpr constant int const_1i = 1;
constexpr constant unsigned const_1u = 1u;
constexpr constant float const_1f = 1.0;
constexpr constant float const_0_10f = 0.1;
constexpr constant float const_n1f = -1.0;
struct Particle {
metal::float2 pos;
metal::float2 vel;
@ -45,23 +37,23 @@ kernel void main1(
metal::float2 cMass;
metal::float2 cVel;
metal::float2 colVel;
int cMassCount = const_0i;
int cVelCount = const_0i;
int cMassCount = 0;
int cVelCount = 0;
metal::float2 pos1;
metal::float2 vel1;
metal::uint i = const_0u;
metal::uint i = 0u;
if (global_invocation_id.x >= NUM_PARTICLES) {
return;
}
vPos = particlesSrc.particles[global_invocation_id.x].pos;
vVel = particlesSrc.particles[global_invocation_id.x].vel;
cMass = metal::float2(const_0f, const_0f);
cVel = metal::float2(const_0f, const_0f);
colVel = metal::float2(const_0f, const_0f);
cMass = metal::float2(0.0, 0.0);
cVel = metal::float2(0.0, 0.0);
colVel = metal::float2(0.0, 0.0);
bool loop_init = true;
while(true) {
if (!loop_init) {
i = i + const_1u;
i = i + 1u;
}
loop_init = false;
if (i >= NUM_PARTICLES) {
@ -74,36 +66,36 @@ kernel void main1(
vel1 = particlesSrc.particles[i].vel;
if (metal::distance(pos1, vPos) < params.rule1Distance) {
cMass = cMass + pos1;
cMassCount = cMassCount + const_1i;
cMassCount = cMassCount + 1;
}
if (metal::distance(pos1, vPos) < params.rule2Distance) {
colVel = colVel - (pos1 - vPos);
}
if (metal::distance(pos1, vPos) < params.rule3Distance) {
cVel = cVel + vel1;
cVelCount = cVelCount + const_1i;
cVelCount = cVelCount + 1;
}
}
if (cMassCount > const_0i) {
cMass = (cMass * (const_1f / static_cast<float>(cMassCount))) - vPos;
if (cMassCount > 0) {
cMass = (cMass * (1.0 / static_cast<float>(cMassCount))) - vPos;
}
if (cVelCount > const_0i) {
cVel = cVel * (const_1f / static_cast<float>(cVelCount));
if (cVelCount > 0) {
cVel = cVel * (1.0 / static_cast<float>(cVelCount));
}
vVel = ((vVel + (cMass * params.rule1Scale)) + (colVel * params.rule2Scale)) + (cVel * params.rule3Scale);
vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), const_0f, const_0_10f);
vVel = metal::normalize(vVel) * metal::clamp(metal::length(vVel), 0.0, 0.1);
vPos = vPos + (vVel * params.deltaT);
if (vPos.x < const_n1f) {
vPos.x = const_1f;
if (vPos.x < -1.0) {
vPos.x = 1.0;
}
if (vPos.x > const_1f) {
vPos.x = const_n1f;
if (vPos.x > 1.0) {
vPos.x = -1.0;
}
if (vPos.y < const_n1f) {
vPos.y = const_1f;
if (vPos.y < -1.0) {
vPos.y = 1.0;
}
if (vPos.y > const_1f) {
vPos.y = const_n1f;
if (vPos.y > 1.0) {
vPos.y = -1.0;
}
particlesDst.particles[global_invocation_id.x].pos = vPos;
particlesDst.particles[global_invocation_id.x].vel = vVel;

View File

@ -5,10 +5,6 @@ expression: msl
#include <metal_stdlib>
#include <simd/simd.h>
constexpr constant unsigned const_0u = 0u;
constexpr constant unsigned const_1u = 1u;
constexpr constant unsigned const_2u = 2u;
constexpr constant unsigned const_3u = 3u;
typedef metal::uint type1[1];
struct PrimeIndices {
type1 data;
@ -18,18 +14,18 @@ metal::uint collatz_iterations(
metal::uint n_base
) {
metal::uint n;
metal::uint i = const_0u;
metal::uint i = 0u;
n = n_base;
while(true) {
if (n <= const_1u) {
if (n <= 1u) {
break;
}
if ((n % const_2u) == const_0u) {
n = n / const_2u;
if ((n % 2u) == 0u) {
n = n / 2u;
} else {
n = (const_3u * n) + const_1u;
n = (3u * n) + 1u;
}
i = i + const_1u;
i = i + 1u;
}
return i;
}

View File

@ -5,8 +5,6 @@ expression: msl
#include <metal_stdlib>
#include <simd/simd.h>
constexpr constant int const_10i = 10;
constexpr constant int const_20i = 20;
struct main1Input {
};
@ -15,7 +13,7 @@ kernel void main1(
, metal::texture2d<uint, metal::access::read> image_src [[user(fake0)]]
, metal::texture1d<uint, metal::access::write> image_dst [[user(fake0)]]
) {
metal::int2 _expr12 = (int2(image_src.get_width(), image_src.get_height()) * static_cast<int2>(metal::uint2(local_id.x, local_id.y))) % metal::int2(const_10i, const_20i);
metal::int2 _expr12 = (int2(image_src.get_width(), image_src.get_height()) * static_cast<int2>(metal::uint2(local_id.x, local_id.y))) % metal::int2(10, 20);
metal::uint4 _expr13 = image_src.read(metal::uint2(_expr12));
image_dst.write(_expr13, metal::uint(_expr12.x));
return;

View File

@ -5,15 +5,7 @@ expression: msl
#include <metal_stdlib>
#include <simd/simd.h>
constexpr constant int const_0i = 0;
constexpr constant int const_1i = 1;
constexpr constant int const_2i = 2;
constexpr constant int const_3i = 3;
constexpr constant unsigned const_1u = 1u;
constexpr constant int const_0i1 = 0;
constexpr constant float const_0f = 0.0;
constexpr constant float const_1f = 1.0;
typedef float type6[const_1u];
typedef float type6[1u];
struct gl_PerVertex {
metal::float4 gl_Position;
float gl_PointSize;
@ -35,7 +27,7 @@ void main1(
) {
v_uv = a_uv;
metal::float2 _expr13 = a_pos;
_.gl_Position = metal::float4(_expr13.x, _expr13.y, const_0f, const_1f);
_.gl_Position = metal::float4(_expr13.x, _expr13.y, 0.0, 1.0);
return;
}

View File

@ -6,8 +6,6 @@ expression: msl
#include <simd/simd.h>
constexpr constant float c_scale = 1.2;
constexpr constant float const_0f = 0.0;
constexpr constant float const_1f = 1.0;
struct VertexOutput {
metal::float2 uv;
metal::float4 position;
@ -28,7 +26,7 @@ vertex main1Output main1(
const auto uv1 = varyings.uv1;
VertexOutput out;
out.uv = uv1;
out.position = metal::float4(c_scale * pos, const_0f, const_1f);
out.position = metal::float4(c_scale * pos, 0.0, 1.0);
const auto _tmp = out;
return main1Output { _tmp.uv, _tmp.position };
}
@ -47,7 +45,7 @@ fragment main2Output main2(
) {
const auto uv2 = varyings1.uv2;
metal::float4 _expr4 = u_texture.sample(u_sampler, uv2);
if (_expr4.w == const_0f) {
if (_expr4.w == 0.0) {
metal::discard_fragment();
}
return main2Output { _expr4.w * _expr4 };

View File

@ -5,14 +5,7 @@ expression: msl
#include <metal_stdlib>
#include <simd/simd.h>
constexpr constant float const_0f = 0.0;
constexpr constant float const_1f = 1.0;
constexpr constant float const_0_50f = 0.5;
constexpr constant float const_n0_50f = -0.5;
constexpr constant float const_0_05f = 0.05;
constexpr constant unsigned c_max_lights = 10u;
constexpr constant unsigned const_0u = 0u;
constexpr constant unsigned const_1u = 1u;
struct Globals {
metal::uint4 num_lights;
};
@ -25,7 +18,7 @@ typedef Light type3[1];
struct Lights {
type3 data;
};
constexpr constant metal::float3 c_ambient = {const_0_05f, const_0_05f, const_0_05f};
constexpr constant metal::float3 c_ambient = {0.05, 0.05, 0.05};
float fetch_shadow(
metal::uint light_id,
@ -33,11 +26,11 @@ float fetch_shadow(
metal::depth2d_array<float, metal::access::sample> t_shadow,
metal::sampler sampler_shadow
) {
if (homogeneous_coords.w <= const_0f) {
return const_1f;
if (homogeneous_coords.w <= 0.0) {
return 1.0;
}
float _expr15 = const_1f / homogeneous_coords.w;
float _expr28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(const_0_50f, const_n0_50f)) * _expr15) + metal::float2(const_0_50f, const_0_50f), static_cast<int>(light_id), homogeneous_coords.z * _expr15);
float _expr15 = 1.0 / homogeneous_coords.w;
float _expr28 = t_shadow.sample_compare(sampler_shadow, ((metal::float2(homogeneous_coords.x, homogeneous_coords.y) * metal::float2(0.5, -0.5)) * _expr15) + metal::float2(0.5, 0.5), static_cast<int>(light_id), homogeneous_coords.z * _expr15);
return _expr28;
}
@ -58,11 +51,11 @@ fragment fs_mainOutput fs_main(
const auto raw_normal = varyings.raw_normal;
const auto position = varyings.position;
metal::float3 color1 = c_ambient;
metal::uint i = const_0u;
metal::uint i = 0u;
bool loop_init = true;
while(true) {
if (!loop_init) {
i = i + const_1u;
i = i + 1u;
}
loop_init = false;
if (i >= metal::min(u_globals.num_lights.x, c_max_lights)) {
@ -70,8 +63,8 @@ fragment fs_mainOutput fs_main(
}
Light _expr21 = s_lights.data[i];
float _expr25 = fetch_shadow(i, _expr21.proj * position, t_shadow, sampler_shadow);
color1 = color1 + ((_expr25 * metal::max(const_0f, metal::dot(metal::normalize(raw_normal), metal::normalize(metal::float3(_expr21.pos.x, _expr21.pos.y, _expr21.pos.z) - metal::float3(position.x, position.y, position.z))))) * metal::float3(_expr21.color.x, _expr21.color.y, _expr21.color.z));
color1 = color1 + ((_expr25 * metal::max(0.0, metal::dot(metal::normalize(raw_normal), metal::normalize(metal::float3(_expr21.pos.x, _expr21.pos.y, _expr21.pos.z) - metal::float3(position.x, position.y, position.z))))) * metal::float3(_expr21.color.x, _expr21.color.y, _expr21.color.z));
}
return fs_mainOutput { metal::float4(color1, const_1f) };
return fs_mainOutput { metal::float4(color1, 1.0) };
}

View File

@ -5,11 +5,6 @@ expression: msl
#include <metal_stdlib>
#include <simd/simd.h>
constexpr constant int const_2i = 2;
constexpr constant int const_1i = 1;
constexpr constant float const_4f = 4.0;
constexpr constant float const_1f = 1.0;
constexpr constant float const_0f = 0.0;
struct VertexOutput {
metal::float4 position;
metal::float3 uv;
@ -32,9 +27,9 @@ vertex vs_mainOutput vs_main(
int tmp1_;
int tmp2_;
VertexOutput out;
tmp1_ = static_cast<int>(vertex_index) / const_2i;
tmp2_ = static_cast<int>(vertex_index) & const_1i;
metal::float4 _expr24 = metal::float4((static_cast<float>(tmp1_) * const_4f) - const_1f, (static_cast<float>(tmp2_) * const_4f) - const_1f, const_0f, const_1f);
tmp1_ = static_cast<int>(vertex_index) / 2;
tmp2_ = static_cast<int>(vertex_index) & 1;
metal::float4 _expr24 = metal::float4((static_cast<float>(tmp1_) * 4.0) - 1.0, (static_cast<float>(tmp2_) * 4.0) - 1.0, 0.0, 1.0);
metal::float4 _expr50 = r_data.proj_inv * _expr24;
out.uv = metal::transpose(metal::float3x3(metal::float3(r_data.view[0].x, r_data.view[0].y, r_data.view[0].z), metal::float3(r_data.view[1].x, r_data.view[1].y, r_data.view[1].z), metal::float3(r_data.view[2].x, r_data.view[2].y, r_data.view[2].z))) * metal::float3(_expr50.x, _expr50.y, _expr50.z);
out.position = _expr24;