[glsl-in] parse all math functions

This commit is contained in:
Frizi 2021-06-17 21:24:50 +02:00 committed by Dzmitry Malyshau
parent a07310536f
commit efd416d964
3 changed files with 260 additions and 17 deletions

View File

@ -205,7 +205,9 @@ impl Program<'_> {
}
"ceil" | "round" | "floor" | "fract" | "trunc" | "sin" | "abs" | "sqrt"
| "inversesqrt" | "exp" | "exp2" | "sign" | "transpose" | "inverse"
| "normalize" => {
| "normalize" | "sinh" | "cos" | "cosh" | "tan" | "tanh" | "acos" | "asin"
| "log" | "log2" | "length" | "determinant" | "bitCount"
| "bitfieldReverse" => {
if args.len() != 1 {
return Err(ErrorKind::wrong_function_args(name, 1, args.len(), meta));
}
@ -227,6 +229,19 @@ impl Program<'_> {
"transpose" => MathFunction::Transpose,
"inverse" => MathFunction::Inverse,
"normalize" => MathFunction::Normalize,
"sinh" => MathFunction::Sinh,
"cos" => MathFunction::Cos,
"cosh" => MathFunction::Cosh,
"tan" => MathFunction::Tan,
"tanh" => MathFunction::Tanh,
"acos" => MathFunction::Acos,
"asin" => MathFunction::Asin,
"log" => MathFunction::Log,
"log2" => MathFunction::Log2,
"length" => MathFunction::Length,
"determinant" => MathFunction::Determinant,
"bitCount" => MathFunction::CountOneBits,
"bitfieldReverse" => MathFunction::ReverseBits,
_ => unreachable!(),
},
arg: args[0].0,
@ -236,6 +251,31 @@ impl Program<'_> {
body,
)))
}
"atan" => {
let expr = match args.len() {
1 => Expression::Math {
fun: MathFunction::Atan,
arg: args[0].0,
arg1: None,
arg2: None,
},
2 => Expression::Math {
fun: MathFunction::Atan2,
arg: args[0].0,
arg1: Some(args[1].0),
arg2: None,
},
_ => {
return Err(ErrorKind::wrong_function_args(
name,
2,
args.len(),
meta,
))
}
};
Ok(Some(ctx.add_expression(expr, body)))
}
"mod" => {
if args.len() != 2 {
return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta));
@ -248,26 +288,17 @@ impl Program<'_> {
self, &mut left, left_meta, &mut right, right_meta,
)?;
let expr = if let Some(ScalarKind::Float) =
self.resolve_type(ctx, args[0].0, args[1].1)?.scalar_kind()
{
Expression::Math {
fun: MathFunction::Modf,
arg: left,
arg1: Some(right),
arg2: None,
}
} else {
Ok(Some(ctx.add_expression(
Expression::Binary {
op: BinaryOperator::Modulo,
left,
right,
}
};
Ok(Some(ctx.add_expression(expr, body)))
},
body,
)))
}
"pow" | "dot" | "max" | "min" | "reflect" | "cross" => {
"pow" | "dot" | "max" | "min" | "reflect" | "cross" | "outerProduct"
| "distance" | "step" | "modf" | "frexp" | "ldexp" => {
if args.len() != 2 {
return Err(ErrorKind::wrong_function_args(name, 2, args.len(), meta));
}
@ -280,6 +311,12 @@ impl Program<'_> {
"min" => MathFunction::Min,
"reflect" => MathFunction::Reflect,
"cross" => MathFunction::Cross,
"outerProduct" => MathFunction::Outer,
"distance" => MathFunction::Distance,
"step" => MathFunction::Step,
"modf" => MathFunction::Modf,
"frexp" => MathFunction::Frexp,
"ldexp" => MathFunction::Ldexp,
_ => unreachable!(),
},
arg: args[0].0,
@ -289,7 +326,7 @@ impl Program<'_> {
body,
)))
}
"mix" | "clamp" => {
"mix" | "clamp" | "faceforward" | "refract" | "fma" | "smoothstep" => {
if args.len() != 3 {
return Err(ErrorKind::wrong_function_args(name, 3, args.len(), meta));
}
@ -298,6 +335,10 @@ impl Program<'_> {
fun: match name.as_str() {
"mix" => MathFunction::Mix,
"clamp" => MathFunction::Clamp,
"faceforward" => MathFunction::FaceForward,
"refract" => MathFunction::Refract,
"fma" => MathFunction::Fma,
"smoothstep" => MathFunction::SmoothStep,
_ => unreachable!(),
},
arg: args[0].0,

View File

@ -0,0 +1,55 @@
#version 450
void main() {
vec4 a = vec4(1.0);
vec4 b = vec4(2.0);
mat4 m = mat4(a, b, a, b);
int i = 5;
vec4 ceilOut = ceil(a);
vec4 roundOut = round(a);
vec4 floorOut = floor(a);
vec4 fractOut = fract(a);
vec4 truncOut = trunc(a);
vec4 sinOut = sin(a);
vec4 absOut = abs(a);
vec4 sqrtOut = sqrt(a);
vec4 inversesqrtOut = inversesqrt(a);
vec4 expOut = exp(a);
vec4 exp2Out = exp2(a);
vec4 signOut = sign(a);
mat4 transposeOut = transpose(m);
// TODO: support inverse function in wgsl output
// mat4 inverseOut = inverse(m);
vec4 normalizeOut = normalize(a);
vec4 sinhOut = sinh(a);
vec4 cosOut = cos(a);
vec4 coshOut = cosh(a);
vec4 tanOut = tan(a);
vec4 tanhOut = tanh(a);
vec4 acosOut = acos(a);
vec4 asinOut = asin(a);
vec4 logOut = log(a);
vec4 log2Out = log2(a);
float lengthOut = length(a);
float determinantOut = determinant(m);
int bitCountOut = bitCount(i);
int bitfieldReverseOut = bitfieldReverse(i);
float atanOut = atan(a.x);
float atan2Out = atan(a.x, a.y);
float modOut = mod(a.x, b.x);
vec4 powOut = pow(a, b);
float dotOut = dot(a, b);
vec4 maxOut = max(a, b);
vec4 minOut = min(a, b);
vec4 reflectOut = reflect(a, b);
vec3 crossOut = cross(a.xyz, b.xyz);
mat4 outerProductOut = outerProduct(a, b);
float distanceOut = distance(a, b);
vec4 stepOut = step(a, b);
// TODO: support out params in wgsl output
// vec4 modfOut = modf(a, b);
// vec4 frexpOut = frexp(a, b);
// float ldexpOut = ldexp(a.x, i);
}

View File

@ -0,0 +1,147 @@
fn main() {
var a: vec4<f32> = vec4<f32>(1.0, 1.0, 1.0, 1.0);
var b: vec4<f32> = vec4<f32>(2.0, 2.0, 2.0, 2.0);
var m: mat4x4<f32>;
var i: i32 = 5;
var ceilOut: vec4<f32>;
var roundOut: vec4<f32>;
var floorOut: vec4<f32>;
var fractOut: vec4<f32>;
var truncOut: vec4<f32>;
var sinOut: vec4<f32>;
var absOut: vec4<f32>;
var sqrtOut: vec4<f32>;
var inversesqrtOut: vec4<f32>;
var expOut: vec4<f32>;
var exp2Out: vec4<f32>;
var signOut: vec4<f32>;
var transposeOut: mat4x4<f32>;
var normalizeOut: vec4<f32>;
var sinhOut: vec4<f32>;
var cosOut: vec4<f32>;
var coshOut: vec4<f32>;
var tanOut: vec4<f32>;
var tanhOut: vec4<f32>;
var acosOut: vec4<f32>;
var asinOut: vec4<f32>;
var logOut: vec4<f32>;
var log2Out: vec4<f32>;
var lengthOut: f32;
var determinantOut: f32;
var bitCountOut: i32;
var bitfieldReverseOut: i32;
var atanOut: f32;
var atan2Out: f32;
var modOut: f32;
var powOut: vec4<f32>;
var dotOut: f32;
var maxOut: vec4<f32>;
var minOut: vec4<f32>;
var reflectOut: vec4<f32>;
var crossOut: vec3<f32>;
var outerProductOut: mat4x4<f32>;
var distanceOut: f32;
var stepOut: vec4<f32>;
let _e6: vec4<f32> = a;
let _e7: vec4<f32> = b;
let _e8: vec4<f32> = a;
let _e9: vec4<f32> = b;
m = mat4x4<f32>(_e6, _e7, _e8, _e9);
let _e14: vec4<f32> = a;
ceilOut = ceil(_e14);
let _e17: vec4<f32> = a;
roundOut = round(_e17);
let _e20: vec4<f32> = a;
floorOut = floor(_e20);
let _e23: vec4<f32> = a;
fractOut = fract(_e23);
let _e26: vec4<f32> = a;
truncOut = trunc(_e26);
let _e29: vec4<f32> = a;
sinOut = sin(_e29);
let _e32: vec4<f32> = a;
absOut = abs(_e32);
let _e35: vec4<f32> = a;
sqrtOut = sqrt(_e35);
let _e38: vec4<f32> = a;
inversesqrtOut = inverseSqrt(_e38);
let _e41: vec4<f32> = a;
expOut = exp(_e41);
let _e44: vec4<f32> = a;
exp2Out = exp2(_e44);
let _e47: vec4<f32> = a;
signOut = sign(_e47);
let _e50: mat4x4<f32> = m;
transposeOut = transpose(_e50);
let _e53: vec4<f32> = a;
normalizeOut = normalize(_e53);
let _e56: vec4<f32> = a;
sinhOut = sinh(_e56);
let _e59: vec4<f32> = a;
cosOut = cos(_e59);
let _e62: vec4<f32> = a;
coshOut = cosh(_e62);
let _e65: vec4<f32> = a;
tanOut = tan(_e65);
let _e68: vec4<f32> = a;
tanhOut = tanh(_e68);
let _e71: vec4<f32> = a;
acosOut = acos(_e71);
let _e74: vec4<f32> = a;
asinOut = asin(_e74);
let _e77: vec4<f32> = a;
logOut = log(_e77);
let _e80: vec4<f32> = a;
log2Out = log2(_e80);
let _e83: vec4<f32> = a;
lengthOut = length(_e83);
let _e86: mat4x4<f32> = m;
determinantOut = determinant(_e86);
let _e89: i32 = i;
bitCountOut = countOneBits(_e89);
let _e92: i32 = i;
bitfieldReverseOut = reverseBits(_e92);
let _e95: vec4<f32> = a;
atanOut = atan(_e95.x);
let _e99: vec4<f32> = a;
let _e101: vec4<f32> = a;
atan2Out = atan2(_e99.x, _e101.y);
let _e105: vec4<f32> = a;
let _e107: vec4<f32> = b;
modOut = (_e105.x % _e107.x);
let _e111: vec4<f32> = a;
let _e112: vec4<f32> = b;
powOut = pow(_e111, _e112);
let _e115: vec4<f32> = a;
let _e116: vec4<f32> = b;
dotOut = dot(_e115, _e116);
let _e119: vec4<f32> = a;
let _e120: vec4<f32> = b;
maxOut = max(_e119, _e120);
let _e123: vec4<f32> = a;
let _e124: vec4<f32> = b;
minOut = min(_e123, _e124);
let _e127: vec4<f32> = a;
let _e128: vec4<f32> = b;
reflectOut = reflect(_e127, _e128);
let _e131: vec4<f32> = a;
let _e133: vec4<f32> = b;
crossOut = cross(_e131.xyz, _e133.xyz);
let _e137: vec4<f32> = a;
let _e138: vec4<f32> = b;
outerProductOut = outerProduct(_e137, _e138);
let _e141: vec4<f32> = a;
let _e142: vec4<f32> = b;
distanceOut = distance(_e141, _e142);
let _e145: vec4<f32> = a;
let _e146: vec4<f32> = b;
stepOut = step(_e145, _e146);
return;
}
[[stage(vertex)]]
fn main1() {
main();
return;
}