[wgsl-out] Implement all math functions. Rename entry point function based on stage. Rename entry point output struct.

This commit is contained in:
Gordon-F 2021-05-12 21:09:03 +03:00 committed by Dzmitry Malyshau
parent 5958f98489
commit 6efe347e90
3 changed files with 95 additions and 20 deletions

View File

@ -13,6 +13,8 @@ pub enum Error {
Custom(String),
#[error("{0}")]
Unimplemented(String), // TODO: Error used only during development
#[error("Unsupported math function: {0:?}")]
UnsupportedMathFunction(crate::MathFunction),
}
pub fn write_string(

View File

@ -61,6 +61,7 @@ pub struct Writer<W> {
names: FastHashMap<NameKey, String>,
namer: Namer,
named_expressions: BitSet,
ep_results: Vec<(ShaderStage, Handle<Type>)>,
}
impl<W: Write> Writer<W> {
@ -70,6 +71,7 @@ impl<W: Write> Writer<W> {
names: FastHashMap::default(),
namer: Namer::default(),
named_expressions: BitSet::new(),
ep_results: vec![],
}
}
@ -82,6 +84,13 @@ impl<W: Write> Writer<W> {
pub fn write(&mut self, module: &Module, info: &ModuleInfo) -> BackendResult {
self.reset(module);
// Save all ep result types
for (_, ep) in module.entry_points.iter().enumerate() {
if let Some(ref result) = ep.function.result {
self.ep_results.push((ep.stage, result.ty));
}
}
// Write all structs
for (handle, ty) in module.types.iter() {
if let TypeInner::Struct {
@ -173,6 +182,29 @@ impl<W: Write> Writer<W> {
Ok(())
}
/// Helper method used to write stuct name
///
/// # Notes
/// Adds no trailing or leading whitespace
fn write_struct_name(&mut self, module: &Module, handle: Handle<Type>) -> BackendResult {
if module.types[handle].name.is_none() {
if let Some(&(stage, _)) = self.ep_results.iter().find(|&&(_, ty)| ty == handle) {
let name = match stage {
ShaderStage::Compute => "ComputeOutput",
ShaderStage::Fragment => "FragmentOutput",
ShaderStage::Vertex => "VertexOutput",
};
write!(self.out, "{}", name)?;
return Ok(());
}
}
write!(self.out, "{}", self.names[&NameKey::Type(handle)])?;
Ok(())
}
/// Helper method used to write structs
/// https://gpuweb.github.io/gpuweb/wgsl/#functions
///
@ -222,14 +254,11 @@ impl<W: Write> Writer<W> {
// Write function return type
if let Some(ref result) = func.result {
write!(self.out, " -> ")?;
if let Some(ref binding) = result.binding {
write!(self.out, " -> ")?;
self.write_attributes(&map_binding_to_attribute(binding), true)?;
self.write_type(module, result.ty)?;
} else {
let struct_name = &self.names[&NameKey::Type(result.ty)].clone();
write!(self.out, " -> {}", struct_name)?;
}
self.write_type(module, result.ty)?;
}
write!(self.out, " {{")?;
@ -383,8 +412,10 @@ impl<W: Write> Writer<W> {
self.write_attributes(&[Attribute::Block], false)?;
writeln!(self.out)?;
}
let name = &self.names[&NameKey::Type(handle)].clone();
write!(self.out, "struct {} {{", name)?;
write!(self.out, "struct ")?;
self.write_struct_name(module, handle)?;
write!(self.out, " {{")?;
writeln!(self.out)?;
for (index, member) in members.iter().enumerate() {
// Skip struct member with unsupported built in
@ -431,12 +462,7 @@ impl<W: Write> Writer<W> {
fn write_type(&mut self, module: &Module, ty: Handle<Type>) -> BackendResult {
let inner = &module.types[ty].inner;
match *inner {
TypeInner::Struct { .. } => {
// Get the struct name
let name = &self.names[&NameKey::Type(ty)];
write!(self.out, "{}", name)?;
return Ok(());
}
TypeInner::Struct { .. } => self.write_struct_name(module, ty)?,
ref other => self.write_value_type(module, other)?,
}
@ -1010,13 +1036,60 @@ impl<W: Write> Writer<W> {
use crate::MathFunction as Mf;
let fun_name = match fun {
Mf::Abs => "abs",
Mf::Min => "min",
Mf::Max => "max",
Mf::Clamp => "clamp",
// trigonometry
Mf::Cos => "cos",
Mf::Cosh => "cosh",
Mf::Sin => "sin",
Mf::Sinh => "sinh",
Mf::Tan => "tan",
Mf::Tanh => "tanh",
Mf::Acos => "acos",
Mf::Asin => "asin",
Mf::Atan => "atan",
Mf::Atan2 => "atan2",
// decomposition
Mf::Ceil => "ceil",
Mf::Floor => "floor",
Mf::Round => "round",
Mf::Fract => "fract",
Mf::Trunc => "trunc",
Mf::Modf => "modf",
Mf::Frexp => "frexp",
Mf::Ldexp => "ldexp",
// exponent
Mf::Exp => "exp",
Mf::Exp2 => "exp2",
Mf::Log => "log",
Mf::Log2 => "log2",
Mf::Pow => "pow",
// geometry
Mf::Dot => "dot",
Mf::Outer => "outerProduct",
Mf::Cross => "cross",
Mf::Distance => "distance",
Mf::Length => "length",
Mf::Normalize => "normalize",
Mf::FaceForward => "faceForward",
Mf::Reflect => "reflect",
// computational
Mf::Sign => "sign",
Mf::Fma => "fma",
Mf::Mix => "mix",
Mf::Step => "step",
Mf::SmoothStep => "smoothStep",
Mf::Sqrt => "sqrt",
Mf::InverseSqrt => "inverseSqrt",
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountOneBits => "countOneBits",
Mf::ReverseBits => "reverseBits",
_ => {
return Err(Error::Unimplemented(format!(
"write_expr Math func {:?}",
fun
)));
return Err(Error::UnsupportedMathFunction(fun));
}
};

View File

@ -3,7 +3,7 @@ struct gl_PerVertex {
[[builtin(position)]] gl_Position: vec4<f32>;
};
struct type10 {
struct VertexOutput {
[[location(0), interpolate(perspective)]] member: vec2<f32>;
[[builtin(position)]] gl_Position1: vec4<f32>;
};
@ -21,9 +21,9 @@ fn main() {
}
[[stage(vertex)]]
fn main1([[location(1)]] a_uv1: vec2<f32>, [[location(0)]] a_pos1: vec2<f32>) -> type10 {
fn main1([[location(1)]] a_uv1: vec2<f32>, [[location(0)]] a_pos1: vec2<f32>) -> VertexOutput {
a_uv = a_uv1;
a_pos = a_pos1;
main();
return type10(v_uv, perVertexStruct.gl_Position);
return VertexOutput(v_uv, perVertexStruct.gl_Position);
}