diff --git a/src/front/glsl/functions.rs b/src/front/glsl/functions.rs index 0bc44b101..95336a5d6 100644 --- a/src/front/glsl/functions.rs +++ b/src/front/glsl/functions.rs @@ -1058,6 +1058,82 @@ impl Parser { }); } + /// Helper function for building the input/output interface of the entry point + /// + /// Calls `f` with the data of the entry point argument, flattening composite types + /// recursively + /// + /// The passed arguments to the callback are: + /// - The name + /// - The pointer expression to the global storage + /// - The handle to the type of the entry point argument + /// - The binding of the entry point argument + /// - The expression arena + fn arg_type_walker( + &self, + name: Option, + binding: crate::Binding, + pointer: Handle, + ty: Handle, + expressions: &mut Arena, + f: &mut impl FnMut( + Option, + Handle, + Handle, + crate::Binding, + &mut Arena, + ), + ) { + match self.module.types[ty].inner { + TypeInner::Struct { ref members, .. } => { + let mut location = match binding { + crate::Binding::Location { location, .. } => location, + _ => return, + }; + + for (i, member) in members.iter().enumerate() { + let member_pointer = expressions.append( + Expression::AccessIndex { + base: pointer, + index: i as u32, + }, + crate::Span::default(), + ); + + let binding = match member.binding.clone() { + Some(binding) => binding, + None => { + let interpolation = self.module.types[member.ty] + .inner + .scalar_kind() + .map(|kind| match kind { + ScalarKind::Float => crate::Interpolation::Perspective, + _ => crate::Interpolation::Flat, + }); + let binding = crate::Binding::Location { + location, + interpolation, + sampling: None, + }; + location += 1; + binding + } + }; + + self.arg_type_walker( + member.name.clone(), + binding, + member_pointer, + member.ty, + expressions, + f, + ) + } + } + _ => f(name, pointer, ty, binding, expressions), + } + } + pub(crate) fn add_entry_point( &mut self, function: Handle, @@ -1079,20 +1155,29 @@ impl Parser { continue; } - let ty = self.module.global_variables[arg.handle].ty; - let idx = arguments.len() as u32; - - arguments.push(FunctionArgument { - name: arg.name.clone(), - ty, - binding: Some(arg.binding.clone()), - }); - let pointer = expressions.append(Expression::GlobalVariable(arg.handle), Default::default()); - let value = expressions.append(Expression::FunctionArgument(idx), Default::default()); - body.push(Statement::Store { pointer, value }, Default::default()); + self.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + self.module.global_variables[arg.handle].ty, + &mut expressions, + &mut |name, pointer, ty, binding, expressions| { + let idx = arguments.len() as u32; + + arguments.push(FunctionArgument { + name, + ty, + binding: Some(binding), + }); + + let value = + expressions.append(Expression::FunctionArgument(idx), Default::default()); + body.push(Statement::Store { pointer, value }, Default::default()); + }, + ) } body.extend_block(global_init_body); @@ -1115,26 +1200,34 @@ impl Parser { continue; } - let ty = self.module.global_variables[arg.handle].ty; - - members.push(StructMember { - name: arg.name.clone(), - ty, - binding: Some(arg.binding.clone()), - offset: span, - }); - - span += self.module.types[ty].inner.span(&self.module.constants); - let pointer = expressions.append(Expression::GlobalVariable(arg.handle), Default::default()); - let len = expressions.len(); - let load = expressions.append(Expression::Load { pointer }, Default::default()); - body.push( - Statement::Emit(expressions.range_from(len)), - Default::default(), - ); - components.push(load) + + self.arg_type_walker( + arg.name.clone(), + arg.binding.clone(), + pointer, + self.module.global_variables[arg.handle].ty, + &mut expressions, + &mut |name, pointer, ty, binding, expressions| { + members.push(StructMember { + name, + ty, + binding: Some(binding), + offset: span, + }); + + span += self.module.types[ty].inner.span(&self.module.constants); + + let len = expressions.len(); + let load = expressions.append(Expression::Load { pointer }, Default::default()); + body.push( + Statement::Emit(expressions.range_from(len)), + Default::default(), + ); + components.push(load) + }, + ) } let (ty, value) = if !components.is_empty() { diff --git a/tests/in/glsl/declarations.vert b/tests/in/glsl/declarations.vert new file mode 100644 index 000000000..e0c2ec6e8 --- /dev/null +++ b/tests/in/glsl/declarations.vert @@ -0,0 +1,14 @@ +#version 450 + +layout(location = 0) in VertexData { + vec2 position; + vec2 a; +} vert; + +layout(location = 0) out FragmentData { + vec2 position; + vec2 a; +} frag; + +void main() { +} diff --git a/tests/out/wgsl/declarations-vert.wgsl b/tests/out/wgsl/declarations-vert.wgsl new file mode 100644 index 000000000..952ca9dca --- /dev/null +++ b/tests/out/wgsl/declarations-vert.wgsl @@ -0,0 +1,30 @@ +struct VertexData { + position: vec2; + a: vec2; +}; + +struct FragmentData { + position: vec2; + a: vec2; +}; + +struct VertexOutput { + [[location(0)]] position: vec2; + [[location(1)]] a: vec2; +}; + +var vert: VertexData; +var frag: FragmentData; + +fn main_1() { +} + +[[stage(vertex)]] +fn main([[location(0)]] position: vec2, [[location(1)]] a: vec2) -> VertexOutput { + vert.position = position; + vert.a = a; + main_1(); + let _e17 = frag.position; + let _e19 = frag.a; + return VertexOutput(_e17, _e19); +}