From efd416d96448d0c749d332b857b67bed3b7fc66e Mon Sep 17 00:00:00 2001 From: Frizi Date: Thu, 17 Jun 2021 21:24:50 +0200 Subject: [PATCH] [glsl-in] parse all math functions --- src/front/glsl/functions.rs | 75 +++++++++++---- tests/in/glsl/math-functions.vert | 55 +++++++++++ tests/out/math-functions-vert.wgsl | 147 +++++++++++++++++++++++++++++ 3 files changed, 260 insertions(+), 17 deletions(-) create mode 100644 tests/in/glsl/math-functions.vert create mode 100644 tests/out/math-functions-vert.wgsl diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 8965fd08a..3e25a0d84 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -205,7 +205,9 @@ impl Program<'_> { } "ceil" | "round" | "floor" | "fract" | "trunc" | "sin" | "abs" | "sqrt" | "inversesqrt" | "exp" | "exp2" | "sign" | "transpose" | "inverse" - | "normalize" => { + | "normalize" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin" + | "log" | "log2" | "length" | "determinant" | "bitCount" + | "bitfieldReverse" => { if args.len() != 1 { return Err(ErrorKind::wrong_function_args(name, 1, args.len(), meta)); } @@ -227,6 +229,19 @@ impl Program<'_> { "transpose" => MathFunction::Transpose, "inverse" => MathFunction::Inverse, "normalize" => MathFunction::Normalize, + "sinh" => MathFunction::Sinh, + "cos" => MathFunction::Cos, + "cosh" => MathFunction::Cosh, + "tan" => MathFunction::Tan, + "tanh" => MathFunction::Tanh, + "acos" => MathFunction::Acos, + "asin" => MathFunction::Asin, + "log" => MathFunction::Log, + "log2" => MathFunction::Log2, + "length" => MathFunction::Length, + "determinant" => MathFunction::Determinant, + "bitCount" => MathFunction::CountOneBits, + "bitfieldReverse" => MathFunction::ReverseBits, _ => unreachable!(), }, arg: args[0].0, @@ -236,6 +251,31 @@ impl Program<'_> { body, ))) } + "atan" => { + let expr = match args.len() { + 1 => Expression::Math { + fun: MathFunction::Atan, + arg: args[0].0, + arg1: None, + arg2: None, + }, + 2 => Expression::Math { + fun: MathFunction::Atan2, + arg: args[0].0, + arg1: Some(args[1].0), + arg2: None, + }, + _ => { + return Err(ErrorKind::wrong_function_args( + name, + 2, + args.len(), + meta, + )) + } + }; + Ok(Some(ctx.add_expression(expr, body))) + } "mod" => { if args.len() != 2 { return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta)); @@ -248,26 +288,17 @@ impl Program<'_> { self, &mut left, left_meta, &mut right, right_meta, )?; - let expr = if let Some(ScalarKind::Float) = - self.resolve_type(ctx, args[0].0, args[1].1)?.scalar_kind() - { - Expression::Math { - fun: MathFunction::Modf, - arg: left, - arg1: Some(right), - arg2: None, - } - } else { + Ok(Some(ctx.add_expression( Expression::Binary { op: BinaryOperator::Modulo, left, right, - } - }; - - Ok(Some(ctx.add_expression(expr, body))) + }, + body, + ))) } - "pow" | "dot" | "max" | "min" | "reflect" | "cross" => { + "pow" | "dot" | "max" | "min" | "reflect" | "cross" | "outerProduct" + | "distance" | "step" | "modf" | "frexp" | "ldexp" => { if args.len() != 2 { return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta)); } @@ -280,6 +311,12 @@ impl Program<'_> { "min" => MathFunction::Min, "reflect" => MathFunction::Reflect, "cross" => MathFunction::Cross, + "outerProduct" => MathFunction::Outer, + "distance" => MathFunction::Distance, + "step" => MathFunction::Step, + "modf" => MathFunction::Modf, + "frexp" => MathFunction::Frexp, + "ldexp" => MathFunction::Ldexp, _ => unreachable!(), }, arg: args[0].0, @@ -289,7 +326,7 @@ impl Program<'_> { body, ))) } - "mix" | "clamp" => { + "mix" | "clamp" | "faceforward" | "refract" | "fma" | "smoothstep" => { if args.len() != 3 { return Err(ErrorKind::wrong_function_args(name, 3, args.len(), meta)); } @@ -298,6 +335,10 @@ impl Program<'_> { fun: match name.as_str() { "mix" => MathFunction::Mix, "clamp" => MathFunction::Clamp, + "faceforward" => MathFunction::FaceForward, + "refract" => MathFunction::Refract, + "fma" => MathFunction::Fma, + "smoothstep" => MathFunction::SmoothStep, _ => unreachable!(), }, arg: args[0].0, diff --git a/tests/in/glsl/math-functions.vert b/tests/in/glsl/math-functions.vert new file mode 100644 index 000000000..ff1851a69 --- /dev/null +++ b/tests/in/glsl/math-functions.vert @@ -0,0 +1,55 @@ +#version 450 + +void main() { + vec4 a = vec4(1.0); + vec4 b = vec4(2.0); + mat4 m = mat4(a, b, a, b); + int i = 5; + + vec4 ceilOut = ceil(a); + vec4 roundOut = round(a); + vec4 floorOut = floor(a); + vec4 fractOut = fract(a); + vec4 truncOut = trunc(a); + vec4 sinOut = sin(a); + vec4 absOut = abs(a); + vec4 sqrtOut = sqrt(a); + vec4 inversesqrtOut = inversesqrt(a); + vec4 expOut = exp(a); + vec4 exp2Out = exp2(a); + vec4 signOut = sign(a); + mat4 transposeOut = transpose(m); + // TODO: support inverse function in wgsl output + // mat4 inverseOut = inverse(m); + vec4 normalizeOut = normalize(a); + vec4 sinhOut = sinh(a); + vec4 cosOut = cos(a); + vec4 coshOut = cosh(a); + vec4 tanOut = tan(a); + vec4 tanhOut = tanh(a); + vec4 acosOut = acos(a); + vec4 asinOut = asin(a); + vec4 logOut = log(a); + vec4 log2Out = log2(a); + float lengthOut = length(a); + float determinantOut = determinant(m); + int bitCountOut = bitCount(i); + int bitfieldReverseOut = bitfieldReverse(i); + float atanOut = atan(a.x); + float atan2Out = atan(a.x, a.y); + float modOut = mod(a.x, b.x); + vec4 powOut = pow(a, b); + float dotOut = dot(a, b); + vec4 maxOut = max(a, b); + vec4 minOut = min(a, b); + vec4 reflectOut = reflect(a, b); + vec3 crossOut = cross(a.xyz, b.xyz); + mat4 outerProductOut = outerProduct(a, b); + float distanceOut = distance(a, b); + vec4 stepOut = step(a, b); + // TODO: support out params in wgsl output + // vec4 modfOut = modf(a, b); + // vec4 frexpOut = frexp(a, b); + // float ldexpOut = ldexp(a.x, i); + +} \ No newline at end of file diff --git a/tests/out/math-functions-vert.wgsl b/tests/out/math-functions-vert.wgsl new file mode 100644 index 000000000..8cb0af35e --- /dev/null +++ b/tests/out/math-functions-vert.wgsl @@ -0,0 +1,147 @@ +fn main() { + var a: vec4 = vec4(1.0, 1.0, 1.0, 1.0); + var b: vec4 = vec4(2.0, 2.0, 2.0, 2.0); + var m: mat4x4; + var i: i32 = 5; + var ceilOut: vec4; + var roundOut: vec4; + var floorOut: vec4; + var fractOut: vec4; + var truncOut: vec4; + var sinOut: vec4; + var absOut: vec4; + var sqrtOut: vec4; + var inversesqrtOut: vec4; + var expOut: vec4; + var exp2Out: vec4; + var signOut: vec4; + var transposeOut: mat4x4; + var normalizeOut: vec4; + var sinhOut: vec4; + var cosOut: vec4; + var coshOut: vec4; + var tanOut: vec4; + var tanhOut: vec4; + var acosOut: vec4; + var asinOut: vec4; + var logOut: vec4; + var log2Out: vec4; + var lengthOut: f32; + var determinantOut: f32; + var bitCountOut: i32; + var bitfieldReverseOut: i32; + var atanOut: f32; + var atan2Out: f32; + var modOut: f32; + var powOut: vec4; + var dotOut: f32; + var maxOut: vec4; + var minOut: vec4; + var reflectOut: vec4; + var crossOut: vec3; + var outerProductOut: mat4x4; + var distanceOut: f32; + var stepOut: vec4; + + let _e6: vec4 = a; + let _e7: vec4 = b; + let _e8: vec4 = a; + let _e9: vec4 = b; + m = mat4x4(_e6, _e7, _e8, _e9); + let _e14: vec4 = a; + ceilOut = ceil(_e14); + let _e17: vec4 = a; + roundOut = round(_e17); + let _e20: vec4 = a; + floorOut = floor(_e20); + let _e23: vec4 = a; + fractOut = fract(_e23); + let _e26: vec4 = a; + truncOut = trunc(_e26); + let _e29: vec4 = a; + sinOut = sin(_e29); + let _e32: vec4 = a; + absOut = abs(_e32); + let _e35: vec4 = a; + sqrtOut = sqrt(_e35); + let _e38: vec4 = a; + inversesqrtOut = inverseSqrt(_e38); + let _e41: vec4 = a; + expOut = exp(_e41); + let _e44: vec4 = a; + exp2Out = exp2(_e44); + let _e47: vec4 = a; + signOut = sign(_e47); + let _e50: mat4x4 = m; + transposeOut = transpose(_e50); + let _e53: vec4 = a; + normalizeOut = normalize(_e53); + let _e56: vec4 = a; + sinhOut = sinh(_e56); + let _e59: vec4 = a; + cosOut = cos(_e59); + let _e62: vec4 = a; + coshOut = cosh(_e62); + let _e65: vec4 = a; + tanOut = tan(_e65); + let _e68: vec4 = a; + tanhOut = tanh(_e68); + let _e71: vec4 = a; + acosOut = acos(_e71); + let _e74: vec4 = a; + asinOut = asin(_e74); + let _e77: vec4 = a; + logOut = log(_e77); + let _e80: vec4 = a; + log2Out = log2(_e80); + let _e83: vec4 = a; + lengthOut = length(_e83); + let _e86: mat4x4 = m; + determinantOut = determinant(_e86); + let _e89: i32 = i; + bitCountOut = countOneBits(_e89); + let _e92: i32 = i; + bitfieldReverseOut = reverseBits(_e92); + let _e95: vec4 = a; + atanOut = atan(_e95.x); + let _e99: vec4 = a; + let _e101: vec4 = a; + atan2Out = atan2(_e99.x, _e101.y); + let _e105: vec4 = a; + let _e107: vec4 = b; + modOut = (_e105.x % _e107.x); + let _e111: vec4 = a; + let _e112: vec4 = b; + powOut = pow(_e111, _e112); + let _e115: vec4 = a; + let _e116: vec4 = b; + dotOut = dot(_e115, _e116); + let _e119: vec4 = a; + let _e120: vec4 = b; + maxOut = max(_e119, _e120); + let _e123: vec4 = a; + let _e124: vec4 = b; + minOut = min(_e123, _e124); + let _e127: vec4 = a; + let _e128: vec4 = b; + reflectOut = reflect(_e127, _e128); + let _e131: vec4 = a; + let _e133: vec4 = b; + crossOut = cross(_e131.xyz, _e133.xyz); + let _e137: vec4 = a; + let _e138: vec4 = b; + outerProductOut = outerProduct(_e137, _e138); + let _e141: vec4 = a; + let _e142: vec4 = b; + distanceOut = distance(_e141, _e142); + let _e145: vec4 = a; + let _e146: vec4 = b; + stepOut = step(_e145, _e146); + return; +} + +[[stage(vertex)]] +fn main1() { + main(); + return; +}