entry: "infer" -> "deduce", anonymous pair -> dedicated struct.

This commit is contained in:
Eduard-Mihai Burtescu 2023-03-19 11:19:50 +02:00 committed by Eduard-Mihai Burtescu
parent 5fffc752a0
commit 939f00e89e
4 changed files with 85 additions and 39 deletions

View File

@ -19,6 +19,30 @@ use rustc_span::Span;
use rustc_target::abi::call::{ArgAbi, FnAbi, PassMode}; use rustc_target::abi::call::{ArgAbi, FnAbi, PassMode};
use std::assert_matches::assert_matches; use std::assert_matches::assert_matches;
/// Various information about an entry-point parameter, which can only be deduced
/// (and/or checked) in all cases by using the original reference/value Rust type
/// (e.g. `&mut T` vs `&T` vs `T`).
///
/// This is in contrast to other information about "shader interface variables",
/// that can rely on merely the SPIR-V type and/or `#[spirv(...)]` attributes.
///
/// See also `entry_param_deduce_from_rust_ref_or_value` (which computes this).
struct EntryParamDeducedFromRustRefOrValue<'tcx> {
/// The type/layout for the data to pass onto the entry-point parameter,
/// either by-value (only for `Input`) or behind some kind of reference.
///
/// That is, the original parameter type is (given `T = value_layout.ty`):
/// * `T` (iff `storage_class` is `Input`)
/// * `&T` (all shader interface storage classes other than `Input`/`Output`)
/// * `&mut T` (only writable storage classes)
value_layout: TyAndLayout<'tcx>,
/// The SPIR-V storage class to declare the shader interface variable in,
/// either deduced from the type (e.g. opaque handles use `UniformConstant`),
/// provided via `#[spirv(...)]` attributes, or an `Input`/`Output` default.
storage_class: StorageClass,
}
impl<'tcx> CodegenCx<'tcx> { impl<'tcx> CodegenCx<'tcx> {
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. // Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters.
// spir-v uses globals to declare the interface. So, we need to generate a lil stub for the // spir-v uses globals to declare the interface. So, we need to generate a lil stub for the
@ -174,20 +198,24 @@ impl<'tcx> CodegenCx<'tcx> {
stub_fn_id stub_fn_id
} }
fn infer_param_ty_and_storage_class( /// Attempt to compute `EntryParamDeducedFromRustRefOrValue` (see its docs)
/// from `ref_or_value_layout` (and potentially some of `attrs`).
///
// FIXME(eddyb) document this by itself.
fn entry_param_deduce_from_rust_ref_or_value(
&self, &self,
layout: TyAndLayout<'tcx>, ref_or_value_layout: TyAndLayout<'tcx>,
hir_param: &hir::Param<'tcx>, hir_param: &hir::Param<'tcx>,
attrs: &AggregatedSpirvAttributes, attrs: &AggregatedSpirvAttributes,
) -> (Word, StorageClass) { ) -> EntryParamDeducedFromRustRefOrValue<'tcx> {
// FIXME(eddyb) attribute validation should be done ahead of time. // FIXME(eddyb) attribute validation should be done ahead of time.
// FIXME(eddyb) also check the type for compatibility with being // FIXME(eddyb) also check the type for compatibility with being
// part of the interface, including potentially `Sync`ness etc. // part of the interface, including potentially `Sync`ness etc.
// FIXME(eddyb) really need to require `T: Sync` for references // FIXME(eddyb) really need to require `T: Sync` for references
// (especially relevant with interior mutability!). // (especially relevant with interior mutability!).
let (value_ty, explicit_mutbl, is_ref) = match *layout.ty.kind() { let (value_layout, explicit_mutbl, is_ref) = match *ref_or_value_layout.ty.kind() {
ty::Ref(_, pointee_ty, mutbl) => (pointee_ty, mutbl, true), ty::Ref(_, pointee_ty, mutbl) => (self.layout_of(pointee_ty), mutbl, true),
_ => (layout.ty, hir::Mutability::Not, false), _ => (ref_or_value_layout, hir::Mutability::Not, false),
}; };
let effective_mutbl = match explicit_mutbl { let effective_mutbl = match explicit_mutbl {
// NOTE(eddyb) `T: !Freeze` used to detect "`T` has interior mutability" // NOTE(eddyb) `T: !Freeze` used to detect "`T` has interior mutability"
@ -195,7 +223,10 @@ impl<'tcx> CodegenCx<'tcx> {
// containing `UnsafeCell` (but not behind any indirection), which // containing `UnsafeCell` (but not behind any indirection), which
// includes many safe abstractions (e.g. `Cell`, `RefCell`, `Atomic*`). // includes many safe abstractions (e.g. `Cell`, `RefCell`, `Atomic*`).
hir::Mutability::Not hir::Mutability::Not
if is_ref && !value_ty.is_freeze(self.tcx, ty::ParamEnv::reveal_all()) => if is_ref
&& !value_layout
.ty
.is_freeze(self.tcx, ty::ParamEnv::reveal_all()) =>
{ {
hir::Mutability::Mut hir::Mutability::Mut
} }
@ -213,15 +244,15 @@ impl<'tcx> CodegenCx<'tcx> {
// can be per-lane-owning `&mut T`. // can be per-lane-owning `&mut T`.
_ => hir::Mutability::Mut, _ => hir::Mutability::Mut,
}; };
let spirv_ty = self.layout_of(value_ty).spirv_type(hir_param.ty_span, self); let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
// Some types automatically specify a storage class. Compute that here. // Some types automatically specify a storage class. Compute that here.
let element_ty = match self.lookup_type(spirv_ty) { let element_ty = match self.lookup_type(value_spirv_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => { SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => {
self.lookup_type(element) self.lookup_type(element)
} }
ty => ty, ty => ty,
}; };
let inferred_storage_class_from_ty = match element_ty { let deduced_storage_class_from_ty = match element_ty {
SpirvType::Image { .. } SpirvType::Image { .. }
| SpirvType::Sampler | SpirvType::Sampler
| SpirvType::SampledImage { .. } | SpirvType::SampledImage { .. }
@ -233,7 +264,7 @@ impl<'tcx> CodegenCx<'tcx> {
hir_param.ty_span, hir_param.ty_span,
format!( format!(
"entry parameter type must be by-reference: `&{}`", "entry parameter type must be by-reference: `&{}`",
layout.ty, value_layout.ty,
), ),
); );
None None
@ -249,36 +280,35 @@ impl<'tcx> CodegenCx<'tcx> {
self.tcx.sess.span_fatal( self.tcx.sess.span_fatal(
hir_param.ty_span, hir_param.ty_span,
format!( format!(
"invalid entry param type `{}` for storage class `{:?}` \ "invalid entry param type `{}` for storage class `{storage_class:?}` \
(expected `&{}T`)", (expected `&{}T`)",
layout.ty, value_layout.ty,
storage_class,
expected_mutbl_for(storage_class).prefix_str() expected_mutbl_for(storage_class).prefix_str()
), ),
) )
} }
match inferred_storage_class_from_ty { match deduced_storage_class_from_ty {
Some(inferred) if storage_class == inferred => self.tcx.sess.span_warn( Some(deduced) if storage_class == deduced => self.tcx.sess.span_warn(
storage_class_attr.span, storage_class_attr.span,
"redundant storage class specifier, storage class is inferred from type", "redundant storage class attribute, storage class is deduced from type",
), ),
Some(inferred) => { Some(deduced) => {
self.tcx self.tcx
.sess .sess
.struct_span_err(hir_param.span, "storage class mismatch") .struct_span_err(hir_param.span, "storage class mismatch")
.span_label( .span_label(
storage_class_attr.span, storage_class_attr.span,
format!("{storage_class:?} specified in attribute"), format!("`{storage_class:?}` specified in attribute"),
) )
.span_label( .span_label(
hir_param.ty_span, hir_param.ty_span,
format!("{inferred:?} inferred from type"), format!("`{deduced:?}` deduced from type"),
) )
.span_help( .span_help(
storage_class_attr.span, storage_class_attr.span,
&format!( &format!(
"remove storage class attribute to use {inferred:?} as storage class" "remove storage class attribute to use `{deduced:?}` as storage class"
), ),
) )
.emit(); .emit();
@ -288,8 +318,8 @@ impl<'tcx> CodegenCx<'tcx> {
storage_class storage_class
}); });
// If storage class was not inferred nor specified, compute the default (i.e. input/output) // If storage class was not deduced nor specified, compute the default (i.e. input/output)
let storage_class = inferred_storage_class_from_ty let storage_class = deduced_storage_class_from_ty
.or(attr_storage_class) .or(attr_storage_class)
.unwrap_or_else(|| match (is_ref, explicit_mutbl) { .unwrap_or_else(|| match (is_ref, explicit_mutbl) {
(false, _) => StorageClass::Input, (false, _) => StorageClass::Input,
@ -298,7 +328,7 @@ impl<'tcx> CodegenCx<'tcx> {
hir_param.ty_span, hir_param.ty_span,
format!( format!(
"invalid entry param type `{}` (expected `{}` or `&mut {1}`)", "invalid entry param type `{}` (expected `{}` or `&mut {1}`)",
layout.ty, value_ty ref_or_value_layout.ty, value_layout.ty
), ),
), ),
}); });
@ -332,7 +362,7 @@ impl<'tcx> CodegenCx<'tcx> {
} else { } else {
( (
hir_param.ty_span, hir_param.ty_span,
format!("`{storage_class:?}` implied by type"), format!("`{storage_class:?}` deduced from type"),
) )
}; };
// HACK(eddyb) have to use `MultiSpan` directly for labels, // HACK(eddyb) have to use `MultiSpan` directly for labels,
@ -345,7 +375,10 @@ impl<'tcx> CodegenCx<'tcx> {
} }
} }
(spirv_ty, storage_class) EntryParamDeducedFromRustRefOrValue {
value_layout,
storage_class,
}
} }
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
@ -364,8 +397,11 @@ impl<'tcx> CodegenCx<'tcx> {
// Pre-allocate the module-scoped `OpVariable`'s *Result* ID. // Pre-allocate the module-scoped `OpVariable`'s *Result* ID.
let var = self.emit_global().id(); let var = self.emit_global().id();
let (value_spirv_type, storage_class) = let EntryParamDeducedFromRustRefOrValue {
self.infer_param_ty_and_storage_class(entry_arg_abi.layout, hir_param, &attrs); value_layout,
storage_class,
} = self.entry_param_deduce_from_rust_ref_or_value(entry_arg_abi.layout, hir_param, &attrs);
let value_spirv_type = value_layout.spirv_type(hir_param.ty_span, self);
// Certain storage classes require an `OpTypeStruct` decorated with `Block`, // Certain storage classes require an `OpTypeStruct` decorated with `Block`,
// which we represent with `SpirvType::InterfaceBlock` (see its doc comment). // which we represent with `SpirvType::InterfaceBlock` (see its doc comment).
@ -373,6 +409,16 @@ impl<'tcx> CodegenCx<'tcx> {
let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none(); let is_unsized = self.lookup_type(value_spirv_type).sizeof(self).is_none();
let is_pair = matches!(entry_arg_abi.mode, PassMode::Pair(..)); let is_pair = matches!(entry_arg_abi.mode, PassMode::Pair(..));
let is_unsized_with_len = is_pair && is_unsized; let is_unsized_with_len = is_pair && is_unsized;
// HACK(eddyb) sanity check because we get the same information in two
// very different ways, and going out of sync could cause subtle issues.
assert_eq!(
is_unsized_with_len,
value_layout.is_unsized(),
"`{}` param mismatch in call ABI (is_pair={is_pair}) + \
SPIR-V type (is_unsized={is_unsized}) \
vs layout:\n{value_layout:#?}",
entry_arg_abi.layout.ty
);
if is_pair && !is_unsized { if is_pair && !is_unsized {
// If PassMode is Pair, then we need to fill in the second part of the pair with a // If PassMode is Pair, then we need to fill in the second part of the pair with a
// value. We currently only do that with unsized types, so if a type is a pair for some // value. We currently only do that with unsized types, so if a type is a pair for some

View File

@ -1,4 +1,4 @@
// Tests that storage class inference fails correctly // Tests that storage class deduction (from entry-point signature) fails correctly
// build-fail // build-fail
use spirv_std::{spirv, Image}; use spirv_std::{spirv, Image};

View File

@ -1,26 +1,26 @@
error: storage class mismatch error: storage class mismatch
--> $DIR/bad-infer-storage-class.rs:8:5 --> $DIR/bad-deduce-storage-class.rs:8:5
| |
8 | #[spirv(uniform)] error: &Image!(2D, type=f32), 8 | #[spirv(uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^^-------^^^^^^^^^^--------------------- | ^^^^^^^^-------^^^^^^^^^^---------------------
| | | | | |
| | UniformConstant inferred from type | | `UniformConstant` deduced from type
| Uniform specified in attribute | `Uniform` specified in attribute
| |
help: remove storage class attribute to use UniformConstant as storage class help: remove storage class attribute to use `UniformConstant` as storage class
--> $DIR/bad-infer-storage-class.rs:8:13 --> $DIR/bad-deduce-storage-class.rs:8:13
| |
8 | #[spirv(uniform)] error: &Image!(2D, type=f32), 8 | #[spirv(uniform)] error: &Image!(2D, type=f32),
| ^^^^^^^ | ^^^^^^^
warning: redundant storage class specifier, storage class is inferred from type warning: redundant storage class attribute, storage class is deduced from type
--> $DIR/bad-infer-storage-class.rs:9:13 --> $DIR/bad-deduce-storage-class.rs:9:13
| |
9 | #[spirv(uniform_constant)] warning: &Image!(2D, type=f32), 9 | #[spirv(uniform_constant)] warning: &Image!(2D, type=f32),
| ^^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^^
error: entry parameter type must be by-reference: `&spirv_std::image::Image<f32, 1, 2, 0, 0, 0, 0>` error: entry parameter type must be by-reference: `&spirv_std::image::Image<f32, 1, 2, 0, 0, 0, 0>`
--> $DIR/bad-infer-storage-class.rs:15:27 --> $DIR/bad-deduce-storage-class.rs:15:27
| |
15 | pub fn issue_585(invalid: Image!(2D, type=f32)) {} 15 | pub fn issue_585(invalid: Image!(2D, type=f32)) {}
| ^^^^^^^^^^^^^^^^^^^^ | ^^^^^^^^^^^^^^^^^^^^

View File

@ -8,9 +8,9 @@ note: ...but storage class `UniformConstant` is read-only
--> $DIR/mutability-errors.rs:10:78 --> $DIR/mutability-errors.rs:10:78
| |
10 | #[spirv(descriptor_set = 0, binding = 0)] implicit_uniform_constant_mut: &mut Image2d, 10 | #[spirv(descriptor_set = 0, binding = 0)] implicit_uniform_constant_mut: &mut Image2d,
| ^^^^^^^^^^^^ `UniformConstant` implied by type | ^^^^^^^^^^^^ `UniformConstant` deduced from type
warning: redundant storage class specifier, storage class is inferred from type warning: redundant storage class attribute, storage class is deduced from type
--> $DIR/mutability-errors.rs:11:13 --> $DIR/mutability-errors.rs:11:13
| |
11 | #[spirv(uniform_constant, descriptor_set = 0, binding = 0)] uniform_constant_mut: &mut Image2d, 11 | #[spirv(uniform_constant, descriptor_set = 0, binding = 0)] uniform_constant_mut: &mut Image2d,