From fc27b08dcacb9dbbeb4408a282c78946b00bb17c Mon Sep 17 00:00:00 2001 From: Erich Gubler Date: Wed, 10 Jan 2024 17:30:24 -0500 Subject: [PATCH] feat(const_eval): impl. `abs` with new `component_wise_scalar` --- CHANGELOG.md | 8 ++ naga/src/proc/constant_evaluator.rs | 27 ++++ .../glsl/math-functions.main.Fragment.glsl | 6 +- naga/tests/out/hlsl/math-functions.hlsl | 6 +- naga/tests/out/msl/math-functions.msl | 11 +- naga/tests/out/spv/math-functions.spvasm | 115 +++++++++--------- naga/tests/out/wgsl/math-functions.wgsl | 2 +- 7 files changed, 104 insertions(+), 71 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index c4361ea48..604399bdc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -37,6 +37,14 @@ Bottom level categories: - Hal --> +## Unreleased + +### New features + +- Many numeric built-ins have had a constant evaluation implementation added for them, which allows them to be used in a `const` context: + - [#4879](https://github.com/gfx-rs/wgpu/pull/4879) by @ErichDonGubler: + - `abs` + ## v0.19.0 (2024-01-17) This release includes: diff --git a/naga/src/proc/constant_evaluator.rs b/naga/src/proc/constant_evaluator.rs index e020c1692..13b4dd7ce 100644 --- a/naga/src/proc/constant_evaluator.rs +++ b/naga/src/proc/constant_evaluator.rs @@ -192,6 +192,24 @@ macro_rules! gen_component_wise_extractor { }; } +gen_component_wise_extractor! { + component_wise_scalar -> Scalar, + literals: [ + AbstractFloat => AbstractFloat: f64, + F32 => F32: f32, + AbstractInt => AbstractInt: i64, + U32 => U32: u32, + I32 => I32: i32, + ], + scalar_kinds: [ + Float, + AbstractFloat, + Sint, + Uint, + AbstractInt, + ], +} + gen_component_wise_extractor! { component_wise_float -> Float, literals: [ @@ -792,6 +810,15 @@ impl<'a> ConstantEvaluator<'a> { } match fun { + crate::MathFunction::Abs => { + component_wise_scalar(self, span, [arg], |args| match args { + Scalar::AbstractFloat([e]) => Ok(Scalar::AbstractFloat([e.abs()])), + Scalar::F32([e]) => Ok(Scalar::F32([e.abs()])), + Scalar::AbstractInt([e]) => Ok(Scalar::AbstractInt([e.abs()])), + Scalar::I32([e]) => Ok(Scalar::I32([e.wrapping_abs()])), + Scalar::U32([e]) => Ok(Scalar::U32([e])), // TODO: just re-use the expression, ezpz + }) + } crate::MathFunction::Pow => self.math_pow(arg, arg1.unwrap(), span), crate::MathFunction::Clamp => self.math_clamp(arg, arg1.unwrap(), arg2.unwrap(), span), fun => Err(ConstantEvaluatorError::NotImplemented(format!( diff --git a/naga/tests/out/glsl/math-functions.main.Fragment.glsl b/naga/tests/out/glsl/math-functions.main.Fragment.glsl index ed81535ab..bf0561f12 100644 --- a/naga/tests/out/glsl/math-functions.main.Fragment.glsl +++ b/naga/tests/out/glsl/math-functions.main.Fragment.glsl @@ -67,7 +67,7 @@ void main() { float sign_c = sign(-1.0); vec4 sign_d = sign(vec4(-1.0)); int const_dot = ( + ivec2(0).x * ivec2(0).x + ivec2(0).y * ivec2(0).y); - uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); + uint first_leading_bit_abs = uint(findMSB(0u)); int flb_a = findMSB(-1); ivec2 flb_b = findMSB(ivec2(-1)); uvec2 flb_c = uvec2(findMSB(uvec2(1u))); @@ -85,8 +85,8 @@ void main() { ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u))); int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1)); uint clz_b = uint(31 - findMSB(1u)); - ivec2 _e68 = ivec2(-1); - ivec2 clz_c = mix(ivec2(31) - findMSB(_e68), ivec2(0), lessThan(_e68, ivec2(0))); + ivec2 _e67 = ivec2(-1); + ivec2 clz_c = mix(ivec2(31) - findMSB(_e67), ivec2(0), lessThan(_e67, ivec2(0))); uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u))); float lde_a = ldexp(1.0, 2); vec2 lde_b = ldexp(vec2(1.0, 2.0), ivec2(3, 4)); diff --git a/naga/tests/out/hlsl/math-functions.hlsl b/naga/tests/out/hlsl/math-functions.hlsl index 53d3acf0c..5da3461da 100644 --- a/naga/tests/out/hlsl/math-functions.hlsl +++ b/naga/tests/out/hlsl/math-functions.hlsl @@ -77,7 +77,7 @@ void main() float sign_c = sign(-1.0); float4 sign_d = sign((-1.0).xxxx); int const_dot = dot((int2)0, (int2)0); - uint first_leading_bit_abs = firstbithigh(abs(0u)); + uint first_leading_bit_abs = firstbithigh(0u); int flb_a = asint(firstbithigh(-1)); int2 flb_b = asint(firstbithigh((-1).xx)); uint2 flb_c = firstbithigh((1u).xx); @@ -95,8 +95,8 @@ void main() int2 ctz_h = asint(min((32u).xx, firstbitlow((1).xx))); int clz_a = (-1 < 0 ? 0 : 31 - asint(firstbithigh(-1))); uint clz_b = (31u - firstbithigh(1u)); - int2 _expr68 = (-1).xx; - int2 clz_c = (_expr68 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr68))); + int2 _expr67 = (-1).xx; + int2 clz_c = (_expr67 < (0).xx ? (0).xx : (31).xx - asint(firstbithigh(_expr67))); uint2 clz_d = ((31u).xx - firstbithigh((1u).xx)); float lde_a = ldexp(1.0, 2); float2 lde_b = ldexp(float2(1.0, 2.0), int2(3, 4)); diff --git a/naga/tests/out/msl/math-functions.msl b/naga/tests/out/msl/math-functions.msl index d93e502dc..45fbcd00a 100644 --- a/naga/tests/out/msl/math-functions.msl +++ b/naga/tests/out/msl/math-functions.msl @@ -70,13 +70,12 @@ fragment void main_( float sign_c = metal::sign(-1.0); metal::float4 sign_d = metal::sign(metal::float4(-1.0)); int const_dot = ( + metal::int2 {}.x * metal::int2 {}.x + metal::int2 {}.y * metal::int2 {}.y); - uint _e23 = metal::abs(0u); - uint first_leading_bit_abs = metal::select(31 - metal::clz(_e23), uint(-1), _e23 == 0 || _e23 == -1); + uint first_leading_bit_abs = metal::select(31 - metal::clz(0u), uint(-1), 0u == 0 || 0u == -1); int flb_a = metal::select(31 - metal::clz(metal::select(-1, ~-1, -1 < 0)), int(-1), -1 == 0 || -1 == -1); - metal::int2 _e28 = metal::int2(-1); - metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e28, ~_e28, _e28 < 0)), int2(-1), _e28 == 0 || _e28 == -1); - metal::uint2 _e31 = metal::uint2(1u); - metal::uint2 flb_c = metal::select(31 - metal::clz(_e31), uint2(-1), _e31 == 0 || _e31 == -1); + metal::int2 _e27 = metal::int2(-1); + metal::int2 flb_b = metal::select(31 - metal::clz(metal::select(_e27, ~_e27, _e27 < 0)), int2(-1), _e27 == 0 || _e27 == -1); + metal::uint2 _e30 = metal::uint2(1u); + metal::uint2 flb_c = metal::select(31 - metal::clz(_e30), uint2(-1), _e30 == 0 || _e30 == -1); int ftb_a = (((metal::ctz(-1) + 1) % 33) - 1); uint ftb_b = (((metal::ctz(1u) + 1) % 33) - 1); metal::int2 ftb_c = (((metal::ctz(metal::int2(-1)) + 1) % 33) - 1); diff --git a/naga/tests/out/spv/math-functions.spvasm b/naga/tests/out/spv/math-functions.spvasm index ba3e7cffb..bbbf37097 100644 --- a/naga/tests/out/spv/math-functions.spvasm +++ b/naga/tests/out/spv/math-functions.spvasm @@ -1,7 +1,7 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 126 +; Bound: 125 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 @@ -61,10 +61,10 @@ OpMemberDecorate %13 1 Offset 16 %45 = OpConstantComposite %3 %43 %43 %43 %43 %52 = OpConstantComposite %3 %17 %17 %17 %17 %59 = OpConstantNull %6 -%77 = OpConstant %25 32 -%86 = OpConstantComposite %29 %77 %77 -%95 = OpConstant %6 31 -%100 = OpConstantComposite %5 %95 %95 +%76 = OpConstant %25 32 +%85 = OpConstantComposite %29 %76 %76 +%94 = OpConstant %6 31 +%99 = OpConstantComposite %5 %94 %94 %15 = OpFunction %2 None %16 %14 = OpLabel OpBranch %46 @@ -87,60 +87,59 @@ OpBranch %46 %65 = OpCompositeExtract %6 %24 1 %66 = OpIMul %6 %64 %65 %58 = OpIAdd %6 %63 %66 -%67 = OpCopyObject %25 %26 -%68 = OpExtInst %25 %1 FindUMsb %67 -%69 = OpExtInst %6 %1 FindSMsb %20 -%70 = OpExtInst %5 %1 FindSMsb %27 -%71 = OpExtInst %29 %1 FindUMsb %30 -%72 = OpExtInst %6 %1 FindILsb %20 -%73 = OpExtInst %25 %1 FindILsb %28 -%74 = OpExtInst %5 %1 FindILsb %27 -%75 = OpExtInst %29 %1 FindILsb %30 -%78 = OpExtInst %25 %1 FindILsb %26 -%76 = OpExtInst %25 %1 UMin %77 %78 -%80 = OpExtInst %6 %1 FindILsb %31 -%79 = OpExtInst %6 %1 UMin %77 %80 -%82 = OpExtInst %25 %1 FindILsb %32 -%81 = OpExtInst %25 %1 UMin %77 %82 -%84 = OpExtInst %6 %1 FindILsb %20 -%83 = OpExtInst %6 %1 UMin %77 %84 -%87 = OpExtInst %29 %1 FindILsb %33 -%85 = OpExtInst %29 %1 UMin %86 %87 -%89 = OpExtInst %5 %1 FindILsb %34 -%88 = OpExtInst %5 %1 UMin %86 %89 -%91 = OpExtInst %29 %1 FindILsb %30 -%90 = OpExtInst %29 %1 UMin %86 %91 -%93 = OpExtInst %5 %1 FindILsb %36 -%92 = OpExtInst %5 %1 UMin %86 %93 -%96 = OpExtInst %6 %1 FindUMsb %20 -%94 = OpISub %6 %95 %96 -%98 = OpExtInst %6 %1 FindUMsb %28 -%97 = OpISub %25 %95 %98 -%101 = OpExtInst %5 %1 FindUMsb %27 -%99 = OpISub %5 %100 %101 -%103 = OpExtInst %5 %1 FindUMsb %30 -%102 = OpISub %29 %100 %103 -%104 = OpExtInst %4 %1 Ldexp %17 %37 -%105 = OpExtInst %7 %1 Ldexp %39 %42 +%67 = OpExtInst %25 %1 FindUMsb %26 +%68 = OpExtInst %6 %1 FindSMsb %20 +%69 = OpExtInst %5 %1 FindSMsb %27 +%70 = OpExtInst %29 %1 FindUMsb %30 +%71 = OpExtInst %6 %1 FindILsb %20 +%72 = OpExtInst %25 %1 FindILsb %28 +%73 = OpExtInst %5 %1 FindILsb %27 +%74 = OpExtInst %29 %1 FindILsb %30 +%77 = OpExtInst %25 %1 FindILsb %26 +%75 = OpExtInst %25 %1 UMin %76 %77 +%79 = OpExtInst %6 %1 FindILsb %31 +%78 = OpExtInst %6 %1 UMin %76 %79 +%81 = OpExtInst %25 %1 FindILsb %32 +%80 = OpExtInst %25 %1 UMin %76 %81 +%83 = OpExtInst %6 %1 FindILsb %20 +%82 = OpExtInst %6 %1 UMin %76 %83 +%86 = OpExtInst %29 %1 FindILsb %33 +%84 = OpExtInst %29 %1 UMin %85 %86 +%88 = OpExtInst %5 %1 FindILsb %34 +%87 = OpExtInst %5 %1 UMin %85 %88 +%90 = OpExtInst %29 %1 FindILsb %30 +%89 = OpExtInst %29 %1 UMin %85 %90 +%92 = OpExtInst %5 %1 FindILsb %36 +%91 = OpExtInst %5 %1 UMin %85 %92 +%95 = OpExtInst %6 %1 FindUMsb %20 +%93 = OpISub %6 %94 %95 +%97 = OpExtInst %6 %1 FindUMsb %28 +%96 = OpISub %25 %94 %97 +%100 = OpExtInst %5 %1 FindUMsb %27 +%98 = OpISub %5 %99 %100 +%102 = OpExtInst %5 %1 FindUMsb %30 +%101 = OpISub %29 %99 %102 +%103 = OpExtInst %4 %1 Ldexp %17 %37 +%104 = OpExtInst %7 %1 Ldexp %39 %42 +%105 = OpExtInst %8 %1 ModfStruct %43 %106 = OpExtInst %8 %1 ModfStruct %43 -%107 = OpExtInst %8 %1 ModfStruct %43 -%108 = OpCompositeExtract %4 %107 0 -%109 = OpExtInst %8 %1 ModfStruct %43 -%110 = OpCompositeExtract %4 %109 1 -%111 = OpExtInst %9 %1 ModfStruct %44 -%112 = OpExtInst %10 %1 ModfStruct %45 -%113 = OpCompositeExtract %3 %112 1 -%114 = OpCompositeExtract %4 %113 0 -%115 = OpExtInst %9 %1 ModfStruct %44 -%116 = OpCompositeExtract %7 %115 0 -%117 = OpCompositeExtract %4 %116 1 +%107 = OpCompositeExtract %4 %106 0 +%108 = OpExtInst %8 %1 ModfStruct %43 +%109 = OpCompositeExtract %4 %108 1 +%110 = OpExtInst %9 %1 ModfStruct %44 +%111 = OpExtInst %10 %1 ModfStruct %45 +%112 = OpCompositeExtract %3 %111 1 +%113 = OpCompositeExtract %4 %112 0 +%114 = OpExtInst %9 %1 ModfStruct %44 +%115 = OpCompositeExtract %7 %114 0 +%116 = OpCompositeExtract %4 %115 1 +%117 = OpExtInst %11 %1 FrexpStruct %43 %118 = OpExtInst %11 %1 FrexpStruct %43 -%119 = OpExtInst %11 %1 FrexpStruct %43 -%120 = OpCompositeExtract %4 %119 0 -%121 = OpExtInst %11 %1 FrexpStruct %43 -%122 = OpCompositeExtract %6 %121 1 -%123 = OpExtInst %13 %1 FrexpStruct %45 -%124 = OpCompositeExtract %12 %123 1 -%125 = OpCompositeExtract %6 %124 0 +%119 = OpCompositeExtract %4 %118 0 +%120 = OpExtInst %11 %1 FrexpStruct %43 +%121 = OpCompositeExtract %6 %120 1 +%122 = OpExtInst %13 %1 FrexpStruct %45 +%123 = OpCompositeExtract %12 %122 1 +%124 = OpCompositeExtract %6 %123 0 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/naga/tests/out/wgsl/math-functions.wgsl b/naga/tests/out/wgsl/math-functions.wgsl index 92f391038..ce38fee98 100644 --- a/naga/tests/out/wgsl/math-functions.wgsl +++ b/naga/tests/out/wgsl/math-functions.wgsl @@ -12,7 +12,7 @@ fn main() { let sign_c = sign(-1f); let sign_d = sign(vec4(-1f)); let const_dot = dot(vec2(), vec2()); - let first_leading_bit_abs = firstLeadingBit(abs(0u)); + let first_leading_bit_abs = firstLeadingBit(0u); let flb_a = firstLeadingBit(-1i); let flb_b = firstLeadingBit(vec2(-1i)); let flb_c = firstLeadingBit(vec2(1u));