[wgsl-in] Replace TypedExpression with a generic enum, Typed.

Replace the `TypedExpression` struct, used to distinguish between WGSL
pointers and references since Naga has only `Pointer`, with an enum,
`Typed`, with variants for references and plain types. This cleans up
a bunch of code, since the struct's `is_reference` field basically
served as a detached enum discriminant. This also prepares the code
for adding abstract types.
This commit is contained in:
Jim Blandy 2023-10-26 16:51:43 -07:00 committed by Teodor Tanasoaia
parent e78e9dd400
commit e6063c5255

View File

@ -138,7 +138,7 @@ pub struct StatementContext<'source, 'temp, 'out> {
///
/// [`LocalVariable`]: crate::Expression::LocalVariable
/// [`FunctionArgument`]: crate::Expression::FunctionArgument
local_table: &'temp mut FastHashMap<Handle<ast::Local>, TypedExpression>,
local_table: &'temp mut FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
const_typifier: &'temp mut Typifier,
typifier: &'temp mut Typifier,
@ -215,7 +215,7 @@ pub struct RuntimeExpressionContext<'temp, 'out> {
///
/// This is always [`StatementContext::local_table`] for the
/// enclosing statement; see that documentation for details.
local_table: &'temp FastHashMap<Handle<ast::Local>, TypedExpression>,
local_table: &'temp FastHashMap<Handle<ast::Local>, Typed<Handle<crate::Expression>>>,
function: &'out mut crate::Function,
block: &'temp mut crate::Block,
@ -620,16 +620,15 @@ impl<'source, 'temp, 'out> ExpressionContext<'source, 'temp, 'out> {
/// `T`. Otherwise, return `expr` unchanged.
fn apply_load_rule(
&mut self,
expr: TypedExpression,
expr: Typed<Handle<crate::Expression>>,
) -> Result<Handle<crate::Expression>, Error<'source>> {
if expr.is_reference {
let load = crate::Expression::Load {
pointer: expr.handle,
};
let span = self.get_expression_span(expr.handle);
match expr {
Typed::Reference(pointer) => {
let load = crate::Expression::Load { pointer };
let span = self.get_expression_span(pointer);
self.append_expression(load, span)
} else {
Ok(expr.handle)
}
Typed::Plain(handle) => Ok(handle),
}
}
@ -693,31 +692,50 @@ impl<'source> ArgumentContext<'_, 'source> {
}
}
/// A Naga [`Expression`] handle, with WGSL type information.
/// WGSL type annotations on expressions, types, values, etc.
///
/// Naga and WGSL types are very close, but Naga lacks WGSL's 'reference' types,
/// which we need to know to apply the Load Rule. This struct carries a Naga
/// `Handle<Expression>` along with enough information to determine its WGSL type.
/// Naga and WGSL types are very close, but Naga lacks WGSL's `ref` types, which
/// we need to know to apply the Load Rule. This enum carries some WGSL or Naga
/// datum along with enough information to determine its corresponding WGSL
/// type.
///
/// The `T` type parameter can be any expression-like thing:
///
/// - `Typed<Handle<crate::Type>>` can represent a full WGSL type. For example,
/// given some Naga `Pointer` type `ptr`, a WGSL reference type is a
/// `Typed::Reference(ptr)` whereas a WGSL pointer type is a
/// `Typed::Plain(ptr)`.
///
/// - `Typed<crate::Expression>` or `Typed<Handle<crate::Expression>>` can
/// represent references similarly.
///
/// Use the `map` and `try_map` methods to convert from one expression
/// representation to another.
///
/// [`Expression`]: crate::Expression
#[derive(Debug, Copy, Clone)]
struct TypedExpression {
/// The handle of the Naga expression.
handle: Handle<crate::Expression>,
enum Typed<T> {
/// A WGSL reference.
Reference(T),
/// True if this expression's WGSL type is a reference.
///
/// When this is true, `handle` must be a pointer.
is_reference: bool,
/// A WGSL plain type.
Plain(T),
}
impl TypedExpression {
const fn non_reference(handle: Handle<crate::Expression>) -> TypedExpression {
TypedExpression {
handle,
is_reference: false,
impl<T> Typed<T> {
fn map<U>(self, mut f: impl FnMut(T) -> U) -> Typed<U> {
match self {
Self::Reference(v) => Typed::Reference(f(v)),
Self::Plain(v) => Typed::Plain(f(v)),
}
}
fn try_map<U, E>(self, mut f: impl FnMut(T) -> Result<U, E>) -> Result<Typed<U>, E> {
Ok(match self {
Self::Reference(expr) => Typed::Reference(f(expr)?),
Self::Plain(expr) => Typed::Plain(f(expr)?),
})
}
}
/// A single vector component or swizzle.
@ -974,7 +992,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let ty = self.resolve_ast_type(arg.ty, ctx)?;
let expr = expressions
.append(crate::Expression::FunctionArgument(i as u32), arg.name.span);
local_table.insert(arg.handle, TypedExpression::non_reference(expr));
local_table.insert(arg.handle, Typed::Plain(expr));
named_expressions.insert(expr, (arg.name.name.to_string(), arg.name.span));
Ok(crate::FunctionArgument {
@ -1119,8 +1137,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
block.extend(emitter.finish(&ctx.function.expressions));
ctx.local_table
.insert(l.handle, TypedExpression::non_reference(value));
ctx.local_table.insert(l.handle, Typed::Plain(value));
ctx.named_expressions
.insert(value, (l.name.name.to_string(), l.name.span));
@ -1200,13 +1217,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Span::UNDEFINED,
)?;
block.extend(emitter.finish(&ctx.function.expressions));
ctx.local_table.insert(
v.handle,
TypedExpression {
handle,
is_reference: true,
},
);
ctx.local_table.insert(v.handle, Typed::Reference(handle));
match initializer {
Some(initializer) => crate::Statement::Store {
@ -1334,30 +1345,36 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
block.extend(emitter.finish(&ctx.function.expressions));
return Ok(());
}
ast::StatementKind::Assign { target, op, value } => {
ast::StatementKind::Assign {
target: ast_target,
op,
value,
} => {
let mut emitter = Emitter::default();
emitter.start(&ctx.function.expressions);
let expr = self.expression_for_reference(
target,
let target = self.expression_for_reference(
ast_target,
&mut ctx.as_expression(block, &mut emitter),
)?;
let mut value =
self.expression(value, &mut ctx.as_expression(block, &mut emitter))?;
if !expr.is_reference {
let ty = ctx.invalid_assignment_type(expr.handle);
let target_handle = match target {
Typed::Reference(handle) => handle,
Typed::Plain(handle) => {
let ty = ctx.invalid_assignment_type(handle);
return Err(Error::InvalidAssignment {
span: ctx.ast_expressions.get_span(target),
span: ctx.ast_expressions.get_span(ast_target),
ty,
});
}
};
let value = match op {
Some(op) => {
let mut ctx = ctx.as_expression(block, &mut emitter);
let mut left = ctx.apply_load_rule(expr)?;
let mut left = ctx.apply_load_rule(target)?;
ctx.binary_op_splat(op, &mut left, &mut value)?;
ctx.append_expression(
crate::Expression::Binary {
@ -1373,7 +1390,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
pointer: expr.handle,
pointer: target_handle,
value,
}
}
@ -1388,11 +1405,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
};
let value_span = ctx.ast_expressions.get_span(value);
let reference = self
let target = self
.expression_for_reference(value, &mut ctx.as_expression(block, &mut emitter))?;
let mut ectx = ctx.as_expression(block, &mut emitter);
let target_handle = match target {
Typed::Reference(handle) => handle,
Typed::Plain(_) => return Err(Error::BadIncrDecrReferenceType(value_span)),
};
let (kind, width) = match *resolve_inner!(ectx, reference.handle) {
let mut ectx = ctx.as_expression(block, &mut emitter);
let (kind, width) = match *resolve_inner!(ectx, target_handle) {
crate::TypeInner::ValuePointer {
size: None,
kind,
@ -1418,7 +1439,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let rctx = ectx.runtime_expression_ctx(stmt.span)?;
let left = rctx.function.expressions.append(
crate::Expression::Load {
pointer: reference.handle,
pointer: target_handle,
},
value_span,
);
@ -1429,7 +1450,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
block.extend(emitter.finish(&ctx.function.expressions));
crate::Statement::Store {
pointer: reference.handle,
pointer: target_handle,
value,
}
}
@ -1448,6 +1469,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
Ok(())
}
/// Lower `expr` and apply the Load Rule if possible.
fn expression(
&mut self,
expr: Handle<ast::Expression<'source>>,
@ -1461,11 +1483,11 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
&mut self,
expr: Handle<ast::Expression<'source>>,
ctx: &mut ExpressionContext<'source, '_, '_>,
) -> Result<TypedExpression, Error<'source>> {
) -> Result<Typed<Handle<crate::Expression>>, Error<'source>> {
let span = ctx.ast_expressions.get_span(expr);
let expr = &ctx.ast_expressions[expr];
let (expr, is_reference) = match *expr {
let expr: Typed<crate::Expression> = match *expr {
ast::Expression::Literal(literal) => {
let literal = match literal {
ast::Literal::Number(Number::F32(f)) => crate::Literal::F32(f),
@ -1477,36 +1499,34 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ast::Literal::Bool(b) => crate::Literal::Bool(b),
};
let handle = ctx.interrupt_emitter(crate::Expression::Literal(literal), span)?;
return Ok(TypedExpression::non_reference(handle));
return Ok(Typed::Plain(handle));
}
ast::Expression::Ident(ast::IdentExpr::Local(local)) => {
let rctx = ctx.runtime_expression_ctx(span)?;
return Ok(rctx.local_table[&local]);
}
ast::Expression::Ident(ast::IdentExpr::Unresolved(name)) => {
return if let Some(global) = ctx.globals.get(name) {
let (expr, is_reference) = match *global {
LoweredGlobalDecl::Var(handle) => (
crate::Expression::GlobalVariable(handle),
ctx.module.global_variables[handle].space
!= crate::AddressSpace::Handle,
),
let global = ctx
.globals
.get(name)
.ok_or(Error::UnknownIdent(span, name))?;
let expr = match *global {
LoweredGlobalDecl::Var(handle) => {
let expr = crate::Expression::GlobalVariable(handle);
match ctx.module.global_variables[handle].space {
crate::AddressSpace::Handle => Typed::Plain(expr),
_ => Typed::Reference(expr),
}
}
LoweredGlobalDecl::Const(handle) => {
(crate::Expression::Constant(handle), false)
Typed::Plain(crate::Expression::Constant(handle))
}
_ => {
return Err(Error::Unexpected(span, ExpectedToken::Variable));
}
};
let handle = ctx.interrupt_emitter(expr, span)?;
Ok(TypedExpression {
handle,
is_reference,
})
} else {
Err(Error::UnknownIdent(span, name))
}
return expr.try_map(|handle| ctx.interrupt_emitter(handle, span));
}
ast::Expression::Construct {
ref ty,
@ -1514,25 +1534,24 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
ref components,
} => {
let handle = self.construct(span, ty, ty_span, components, ctx)?;
return Ok(TypedExpression::non_reference(handle));
return Ok(Typed::Plain(handle));
}
ast::Expression::Unary { op, expr } => {
let expr = self.expression(expr, ctx)?;
(crate::Expression::Unary { op, expr }, false)
Typed::Plain(crate::Expression::Unary { op, expr })
}
ast::Expression::AddrOf(expr) => {
// The `&` operator simply converts a reference to a pointer. And since a
// reference is required, the Load Rule is not applied.
let expr = self.expression_for_reference(expr, ctx)?;
if !expr.is_reference {
match self.expression_for_reference(expr, ctx)? {
Typed::Reference(handle) => {
// No code is generated. We just declare the reference a pointer now.
return Ok(Typed::Plain(handle));
}
Typed::Plain(_) => {
return Err(Error::NotReference("the operand of the `&` operator", span));
}
// No code is generated. We just declare the pointer a reference now.
return Ok(TypedExpression {
is_reference: false,
..expr
});
}
}
ast::Expression::Deref(expr) => {
// The pointer we dereference must be loaded.
@ -1542,17 +1561,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
return Err(Error::NotPointer(span));
}
return Ok(TypedExpression {
handle: pointer,
is_reference: true,
});
// No code is generated. We just declare the pointer a reference now.
return Ok(Typed::Reference(pointer));
}
ast::Expression::Binary { op, left, right } => {
// Load both operands.
let mut left = self.expression(left, ctx)?;
let mut right = self.expression(right, ctx)?;
ctx.binary_op_splat(op, &mut left, &mut right)?;
(crate::Expression::Binary { op, left, right }, false)
Typed::Plain(crate::Expression::Binary { op, left, right })
}
ast::Expression::Call {
ref function,
@ -1561,51 +1578,35 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
let handle = self
.call(span, function, arguments, ctx)?
.ok_or(Error::FunctionReturnsVoid(function.span))?;
return Ok(TypedExpression::non_reference(handle));
return Ok(Typed::Plain(handle));
}
ast::Expression::Index { base, index } => {
let expr = self.expression_for_reference(base, ctx)?;
let lowered_base = self.expression_for_reference(base, ctx)?;
let index = self.expression(index, ctx)?;
let wgsl_pointer = resolve_inner!(ctx, expr.handle).pointer_space().is_some()
&& !expr.is_reference;
if wgsl_pointer {
if let Typed::Plain(handle) = lowered_base {
if resolve_inner!(ctx, handle).pointer_space().is_some() {
return Err(Error::Pointer(
"the value indexed by a `[]` subscripting expression",
ctx.ast_expressions.get_span(base),
));
}
if let Some(index) = ctx.const_access(index) {
(
crate::Expression::AccessIndex {
base: expr.handle,
index,
},
expr.is_reference,
)
} else {
(
crate::Expression::Access {
base: expr.handle,
index,
},
expr.is_reference,
)
}
lowered_base.map(|base| match ctx.const_access(index) {
Some(index) => crate::Expression::AccessIndex { base, index },
None => crate::Expression::Access { base, index },
})
}
ast::Expression::Member { base, ref field } => {
let TypedExpression {
handle,
is_reference,
} = self.expression_for_reference(base, ctx)?;
let lowered_base = self.expression_for_reference(base, ctx)?;
let temp_inner;
let (composite, wgsl_pointer) = match *resolve_inner!(ctx, handle) {
crate::TypeInner::Pointer { base, .. } => {
(&ctx.module.types[base].inner, !is_reference)
}
let composite_type: &crate::TypeInner = match lowered_base {
Typed::Reference(handle) => {
let inner = resolve_inner!(ctx, handle);
match *inner {
crate::TypeInner::Pointer { base, .. } => &ctx.module.types[base].inner,
crate::TypeInner::ValuePointer {
size: None,
kind,
@ -1613,7 +1614,7 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
..
} => {
temp_inner = crate::TypeInner::Scalar { kind, width };
(&temp_inner, !is_reference)
&temp_inner
}
crate::TypeInner::ValuePointer {
size: Some(size),
@ -1622,19 +1623,29 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
..
} => {
temp_inner = crate::TypeInner::Vector { size, kind, width };
(&temp_inner, !is_reference)
&temp_inner
}
_ => unreachable!(
"In Typed::Reference(handle), handle must be a Naga pointer"
),
}
}
ref other => (other, false),
};
if wgsl_pointer {
Typed::Plain(handle) => {
let inner = resolve_inner!(ctx, handle);
if let crate::TypeInner::Pointer { .. }
| crate::TypeInner::ValuePointer { .. } = *inner
{
return Err(Error::Pointer(
"the value accessed by a `.member` expression",
ctx.ast_expressions.get_span(base),
));
}
inner
}
};
let access = match *composite {
let access = match *composite_type {
crate::TypeInner::Struct { ref members, .. } => {
let index = members
.iter()
@ -1642,38 +1653,21 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
.ok_or(Error::BadAccessor(field.span))?
as u32;
(
crate::Expression::AccessIndex {
base: handle,
index,
},
is_reference,
)
lowered_base.map(|base| crate::Expression::AccessIndex { base, index })
}
crate::TypeInner::Vector { .. } | crate::TypeInner::Matrix { .. } => {
match Components::new(field.name, field.span)? {
Components::Swizzle { size, pattern } => {
let vector = ctx.apply_load_rule(TypedExpression {
handle,
is_reference,
})?;
(
crate::Expression::Swizzle {
// Swizzles aren't allowed on matrices, but
// validation will catch that.
Typed::Plain(crate::Expression::Swizzle {
size,
vector,
vector: ctx.apply_load_rule(lowered_base)?,
pattern,
},
false,
)
})
}
Components::Single(index) => (
crate::Expression::AccessIndex {
base: handle,
index,
},
is_reference,
),
Components::Single(index) => lowered_base
.map(|base| crate::Expression::AccessIndex { base, index }),
}
}
_ => return Err(Error::BadAccessor(field.span)),
@ -1698,22 +1692,15 @@ impl<'source, 'temp> Lowerer<'source, 'temp> {
}
};
(
crate::Expression::As {
Typed::Plain(crate::Expression::As {
expr,
kind,
convert: None,
},
false,
)
})
}
};
let handle = ctx.append_expression(expr, span)?;
Ok(TypedExpression {
handle,
is_reference,
})
expr.try_map(|handle| ctx.append_expression(handle, span))
}
/// Generate Naga IR for call expressions and statements, and type