[naga spv-out] Clean up write_expression_pointer type adjustment.

Replace the `return_type_override` argument of
`BlockContext::write_expression_pointer` with an enum that says how to
derive the return type from `expr_handle`'s type.

Introduce a new type, `AccessTypeAdjustment`, that covers possible
derivation rules.

This simplifies callers and the callee, in part by making the possible
alternatives less general, and by giving them explicit names (the
variants of the `AccessTypeAdjustment` enum).
This commit is contained in:
Jim Blandy 2024-10-10 12:05:28 -07:00
parent d034c4b428
commit b9f1e4a266

View File

@ -3,7 +3,7 @@ Implementations for `BlockContext` methods.
*/
use super::{
helpers, index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
index::BoundsCheckResult, selection::Selection, Block, BlockContext, Dimension, Error,
Instruction, LocalType, LookupType, NumericType, ResultMember, Writer, WriterFlags,
};
use crate::{arena::Handle, proc::TypeResolution, Statement};
@ -18,6 +18,54 @@ fn get_dimension(type_inner: &crate::TypeInner) -> Dimension {
}
}
/// How to derive the type of `OpAccessChain` instructions from Naga IR.
///
/// Most of the time, we compile Naga IR to SPIR-V instructions whose result
/// types are simply the direct SPIR-V analog of the Naga IR's. But in some
/// cases, the Naga IR and SPIR-V types need to diverge.
///
/// This enum specifies how [`BlockContext::write_expression_pointer`] should
/// choose a SPIR-V result type for the `OpAccessChain` it generates, based on
/// the type of the given Naga IR [`Expression`] it's generating code for.
///
/// [`Expression`]: crate::Expression
enum AccessTypeAdjustment {
/// No adjustment needed: the SPIR-V type should be the direct
/// analog of the Naga IR expression type.
///
/// For most access chains, this is the right thing: the Naga IR access
/// expression produces a [`Pointer`] to the element / component, and the
/// SPIR-V `OpAccessChain` instruction does the same.
///
/// [`Pointer`]: crate::TypeInner::Pointer
None,
/// The SPIR-V type should be an `OpPointer` to the direct analog of the
/// Naga IR expression's type.
///
/// This is necessary for indexing binding arrays in the [`Handle`] address
/// space:
///
/// - In Naga IR, referencing a binding array [`GlobalVariable`] in the
/// [`Handle`] address space produces a value of type [`BindingArray`],
/// not a pointer to such. And [`Access`] and [`AccessIndex`] expressions
/// operate on handle binding arrays by value, and produce handle values,
/// not pointers.
///
/// - In SPIR-V, a binding array `OpVariable` produces a pointer to an
/// array, and `OpAccessChain` instructions operate on pointers,
/// regardless of whether the elements are opaque types or not.
///
/// See also the documentation for [`BindingArray`].
///
/// [`Handle`]: crate::AddressSpace::Handle
/// [`GlobalVariable`]: crate::GlobalVariable
/// [`BindingArray`]: crate::TypeInner::BindingArray
/// [`Access`]: crate::Expression::Access
/// [`AccessIndex`]: crate::Expression::AccessIndex
IntroducePointer(spirv::StorageClass),
}
/// The results of emitting code for a left-hand-side expression.
///
/// On success, `write_expression_pointer` returns one of these.
@ -301,21 +349,12 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: helpers::map_storage_class(space),
});
let result_id = match self.write_expression_pointer(
expr_handle,
block,
Some(binding_array_false_pointer),
AccessTypeAdjustment::IntroducePointer(
spirv::StorageClass::UniformConstant,
),
)? {
ExpressionPointer::Ready { pointer_id } => pointer_id,
ExpressionPointer::Conditional { .. } => {
@ -414,21 +453,12 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::BindingArray {
base: binding_type, ..
} => {
let space = match self.ir_function.expressions[base] {
crate::Expression::GlobalVariable(gvar) => {
self.ir_module.global_variables[gvar].space
}
_ => unreachable!(),
};
let binding_array_false_pointer = LookupType::Local(LocalType::Pointer {
base: binding_type,
class: helpers::map_storage_class(space),
});
let result_id = match self.write_expression_pointer(
expr_handle,
block,
Some(binding_array_false_pointer),
AccessTypeAdjustment::IntroducePointer(
spirv::StorageClass::UniformConstant,
),
)? {
ExpressionPointer::Ready { pointer_id } => pointer_id,
ExpressionPointer::Conditional { .. } => {
@ -1670,9 +1700,9 @@ impl<'w> BlockContext<'w> {
///
/// Emit any needed bounds-checking expressions to `block`.
///
/// Some cases we need to generate a different return type than what the IR gives us.
/// This is because pointers to binding arrays of handles (such as images or samplers)
/// don't exist in the IR, but we need to create them to create an access chain in SPIRV.
/// Give the `OpAccessChain` a result type based on `expr_handle`, adjusted
/// according to `type_adjustment`; see the documentation for
/// [`AccessTypeAdjustment`] for details.
///
/// On success, the return value is an [`ExpressionPointer`] value; see the
/// documentation for that type.
@ -1680,21 +1710,22 @@ impl<'w> BlockContext<'w> {
&mut self,
mut expr_handle: Handle<crate::Expression>,
block: &mut Block,
return_type_override: Option<LookupType>,
type_adjustment: AccessTypeAdjustment,
) -> Result<ExpressionPointer, Error> {
let result_lookup_ty = match self.fun_info[expr_handle].ty {
TypeResolution::Handle(ty_handle) => match return_type_override {
// We use the return type override as a special case for handle binding arrays as the OpAccessChain
// needs to return a pointer, but indexing into a handle binding array just gives you the type of
// the binding in the IR.
Some(ty) => ty,
None => LookupType::Handle(ty_handle),
},
TypeResolution::Value(ref inner) => {
LookupType::Local(LocalType::from_inner(inner).unwrap())
let result_type_id = {
let resolution = &self.fun_info[expr_handle].ty;
match type_adjustment {
AccessTypeAdjustment::None => self.writer.get_expression_type_id(resolution),
AccessTypeAdjustment::IntroducePointer(class) => match *resolution {
TypeResolution::Handle(handle) => self.writer.get_pointer_id(handle, class),
TypeResolution::Value(_) => {
unreachable!(
"IntroducePointer should only be used with images and samplers"
);
}
},
}
};
let result_type_id = self.get_type_id(result_lookup_ty);
// The id of the boolean `and` of all dynamic bounds checks up to this point.
//
@ -1902,7 +1933,7 @@ impl<'w> BlockContext<'w> {
block: &mut Block,
result_type_id: Word,
) -> Result<Word, Error> {
match self.write_expression_pointer(pointer, block, None)? {
match self.write_expression_pointer(pointer, block, AccessTypeAdjustment::None)? {
ExpressionPointer::Ready { pointer_id } => {
let id = self.gen_id();
let atomic_space =
@ -2458,7 +2489,11 @@ impl<'w> BlockContext<'w> {
}
Statement::Store { pointer, value } => {
let value_id = self.cached[value];
match self.write_expression_pointer(pointer, &mut block, None)? {
match self.write_expression_pointer(
pointer,
&mut block,
AccessTypeAdjustment::None,
)? {
ExpressionPointer::Ready { pointer_id } => {
let atomic_space = match *self.fun_info[pointer]
.ty
@ -2554,15 +2589,18 @@ impl<'w> BlockContext<'w> {
self.cached[result] = id;
}
let pointer_id =
match self.write_expression_pointer(pointer, &mut block, None)? {
ExpressionPointer::Ready { pointer_id } => pointer_id,
ExpressionPointer::Conditional { .. } => {
return Err(Error::FeatureNotImplemented(
"Atomics out-of-bounds handling",
));
}
};
let pointer_id = match self.write_expression_pointer(
pointer,
&mut block,
AccessTypeAdjustment::None,
)? {
ExpressionPointer::Ready { pointer_id } => pointer_id,
ExpressionPointer::Conditional { .. } => {
return Err(Error::FeatureNotImplemented(
"Atomics out-of-bounds handling",
));
}
};
let space = self.fun_info[pointer]
.ty
@ -2723,7 +2761,11 @@ impl<'w> BlockContext<'w> {
.write_barrier(crate::Barrier::WORK_GROUP, &mut block);
let result_type_id = self.get_expression_type_id(&self.fun_info[result].ty);
// Embed the body of
match self.write_expression_pointer(pointer, &mut block, None)? {
match self.write_expression_pointer(
pointer,
&mut block,
AccessTypeAdjustment::None,
)? {
ExpressionPointer::Ready { pointer_id } => {
let id = self.gen_id();
block.body.push(Instruction::load(