mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2025-02-16 08:54:56 +00:00
Add basic support for struct DSTs (#504)
* Add basic support for struct DSTs * Add tests * cleanup tests * Update with entry changes, address review * Address review * Update allocate_const_scalar.stderr * Add ArrayStride decoration to OpTypeRuntimeArray
This commit is contained in:
parent
c3a3b20e3c
commit
05ce407278
@ -5,12 +5,15 @@ use crate::builder_spirv::SpirvValue;
|
||||
use crate::spirv_type::SpirvType;
|
||||
use rspirv::dr::Operand;
|
||||
use rspirv::spirv::{Decoration, ExecutionModel, FunctionControl, StorageClass, Word};
|
||||
use rustc_codegen_ssa::traits::BaseTypeMethods;
|
||||
use rustc_hir as hir;
|
||||
use rustc_middle::ty::layout::TyAndLayout;
|
||||
use rustc_middle::ty::layout::{HasParamEnv, TyAndLayout};
|
||||
use rustc_middle::ty::{Instance, Ty, TyKind};
|
||||
use rustc_span::Span;
|
||||
use rustc_target::abi::call::{FnAbi, PassMode};
|
||||
use rustc_target::abi::LayoutOf;
|
||||
use rustc_target::abi::{
|
||||
call::{ArgAbi, ArgAttribute, ArgAttributes, FnAbi, PassMode},
|
||||
LayoutOf, Size,
|
||||
};
|
||||
use std::collections::HashMap;
|
||||
|
||||
impl<'tcx> CodegenCx<'tcx> {
|
||||
@ -37,9 +40,27 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
};
|
||||
let fn_hir_id = self.tcx.hir().local_def_id_to_hir_id(local_id);
|
||||
let body = self.tcx.hir().body(self.tcx.hir().body_owned_by(fn_hir_id));
|
||||
const EMPTY: ArgAttribute = ArgAttribute::empty();
|
||||
for (abi, arg) in fn_abi.args.iter().zip(body.params) {
|
||||
match abi.mode {
|
||||
PassMode::Direct(_) | PassMode::Indirect { .. } => {}
|
||||
PassMode::Direct(_)
|
||||
| PassMode::Indirect { .. }
|
||||
// plain DST/RTA/VLA
|
||||
| PassMode::Pair(
|
||||
ArgAttributes {
|
||||
pointee_size: Size::ZERO,
|
||||
..
|
||||
},
|
||||
ArgAttributes { regular: EMPTY, .. },
|
||||
)
|
||||
// DST struct with fields before the DST member
|
||||
| PassMode::Pair(
|
||||
ArgAttributes { .. },
|
||||
ArgAttributes {
|
||||
pointee_size: Size::ZERO,
|
||||
..
|
||||
},
|
||||
) => {}
|
||||
_ => self.tcx.sess.span_err(
|
||||
arg.span,
|
||||
&format!("PassMode {:?} invalid for entry point parameter", abi.mode),
|
||||
@ -63,7 +84,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
self.shader_entry_stub(
|
||||
self.tcx.def_span(instance.def_id()),
|
||||
entry_func,
|
||||
fn_abi,
|
||||
&fn_abi.args,
|
||||
body.params,
|
||||
name,
|
||||
execution_model,
|
||||
@ -82,7 +103,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
&self,
|
||||
span: Span,
|
||||
entry_func: SpirvValue,
|
||||
entry_fn_abi: &FnAbi<'tcx, Ty<'tcx>>,
|
||||
arg_abis: &[ArgAbi<'tcx, Ty<'tcx>>],
|
||||
hir_params: &[hir::Param<'tcx>],
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
@ -94,10 +115,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
}
|
||||
.def(span, self);
|
||||
let entry_func_return_type = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments: _,
|
||||
} => return_type,
|
||||
SpirvType::Function { return_type, .. } => return_type,
|
||||
other => self.tcx.sess.fatal(&format!(
|
||||
"Invalid entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
@ -105,14 +123,14 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
};
|
||||
let mut decoration_locations = HashMap::new();
|
||||
// Create OpVariables before OpFunction so they're global instead of local vars.
|
||||
let declared_params = entry_fn_abi
|
||||
.args
|
||||
let declared_params = arg_abis
|
||||
.iter()
|
||||
.zip(hir_params)
|
||||
.map(|(entry_fn_arg, hir_param)| {
|
||||
self.declare_parameter(entry_fn_arg.layout, hir_param, &mut decoration_locations)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let len_t = self.type_isize();
|
||||
let mut emit = self.emit_global();
|
||||
let fn_id = emit
|
||||
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
|
||||
@ -121,12 +139,19 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
// Adjust any global `OpVariable`s as needed (e.g. loading from `Input`s).
|
||||
let arguments: Vec<_> = declared_params
|
||||
.iter()
|
||||
.zip(&entry_fn_abi.args)
|
||||
.zip(arg_abis)
|
||||
.zip(hir_params)
|
||||
.map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
|
||||
match entry_fn_arg.layout.ty.kind() {
|
||||
TyKind::Ref(..) => var,
|
||||
|
||||
.flat_map(|((&(var, storage_class), entry_fn_arg), hir_param)| {
|
||||
let mut dst_len_arg = None;
|
||||
let arg = match entry_fn_arg.layout.ty.kind() {
|
||||
TyKind::Ref(_, ty, _) => {
|
||||
if !ty.is_sized(self.tcx.at(span), self.param_env()) {
|
||||
dst_len_arg.replace(
|
||||
self.dst_length_argument(&mut emit, ty, hir_param, len_t, var),
|
||||
);
|
||||
}
|
||||
var
|
||||
}
|
||||
_ => match entry_fn_arg.mode {
|
||||
PassMode::Indirect { .. } => var,
|
||||
PassMode::Direct(_) => {
|
||||
@ -142,7 +167,8 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
}
|
||||
_ => unreachable!(),
|
||||
},
|
||||
}
|
||||
};
|
||||
std::iter::once(arg).chain(dst_len_arg)
|
||||
})
|
||||
.collect();
|
||||
emit.function_call(
|
||||
@ -170,6 +196,38 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
fn_id
|
||||
}
|
||||
|
||||
fn dst_length_argument(
|
||||
&self,
|
||||
emit: &mut std::cell::RefMut<'_, rspirv::dr::Builder>,
|
||||
ty: Ty<'tcx>,
|
||||
hir_param: &hir::Param<'tcx>,
|
||||
len_t: Word,
|
||||
var: Word,
|
||||
) -> Word {
|
||||
match ty.kind() {
|
||||
TyKind::Adt(adt_def, substs) => {
|
||||
let (member_idx, field_def) = adt_def.all_fields().enumerate().last().unwrap();
|
||||
let field_ty = field_def.ty(self.tcx, substs);
|
||||
if !matches!(field_ty.kind(), TyKind::Slice(..)) {
|
||||
self.tcx.sess.span_fatal(
|
||||
hir_param.ty_span,
|
||||
"DST parameters are currently restricted to a reference to a struct whose last field is a slice.",
|
||||
)
|
||||
}
|
||||
emit.array_length(len_t, None, var, member_idx as u32)
|
||||
.unwrap()
|
||||
}
|
||||
TyKind::Slice(..) | TyKind::Str => self.tcx.sess.span_fatal(
|
||||
hir_param.ty_span,
|
||||
"Straight slices are not yet supported, wrap the slice in a newtype.",
|
||||
),
|
||||
_ => self
|
||||
.tcx
|
||||
.sess
|
||||
.span_fatal(hir_param.ty_span, "Unsupported parameter type."),
|
||||
}
|
||||
}
|
||||
|
||||
fn declare_parameter(
|
||||
&self,
|
||||
layout: TyAndLayout<'tcx>,
|
||||
|
@ -188,6 +188,17 @@ impl SpirvType {
|
||||
}
|
||||
Self::RuntimeArray { element } => {
|
||||
let result = cx.emit_global().type_runtime_array(element);
|
||||
// ArrayStride decoration wants in *bytes*
|
||||
let element_size = cx
|
||||
.lookup_type(element)
|
||||
.sizeof(cx)
|
||||
.expect("Element of sized array must be sized")
|
||||
.bytes();
|
||||
cx.emit_global().decorate(
|
||||
result,
|
||||
Decoration::ArrayStride,
|
||||
iter::once(Operand::LiteralInt32(element_size as u32)),
|
||||
);
|
||||
if cx.kernel_mode {
|
||||
cx.zombie_with_span(result, def_span, "RuntimeArray in kernel mode");
|
||||
}
|
||||
|
@ -1,4 +1,4 @@
|
||||
use super::{dis_fn, dis_globals, val, val_vulkan};
|
||||
use super::{dis_entry_fn, dis_fn, dis_globals, val, val_vulkan};
|
||||
use std::ffi::OsStr;
|
||||
|
||||
struct SetEnvVar<'a> {
|
||||
@ -183,20 +183,21 @@ OpEntryPoint Fragment %1 "main"
|
||||
OpExecutionMode %1 OriginUpperLeft
|
||||
OpName %2 "test_project::add_decorate"
|
||||
OpName %3 "test_project::main"
|
||||
OpDecorate %4 DescriptorSet 0
|
||||
OpDecorate %4 Binding 0
|
||||
%5 = OpTypeVoid
|
||||
%6 = OpTypeFunction %5
|
||||
%7 = OpTypeInt 32 0
|
||||
%8 = OpTypePointer Function %7
|
||||
%9 = OpConstant %7 1
|
||||
%10 = OpTypeFloat 32
|
||||
%11 = OpTypeImage %10 2D 0 0 0 1 Unknown
|
||||
%12 = OpTypeSampledImage %11
|
||||
%13 = OpTypeRuntimeArray %12
|
||||
%14 = OpTypePointer UniformConstant %13
|
||||
%4 = OpVariable %14 UniformConstant
|
||||
%15 = OpTypePointer UniformConstant %12"#,
|
||||
OpDecorate %4 ArrayStride 4
|
||||
OpDecorate %5 DescriptorSet 0
|
||||
OpDecorate %5 Binding 0
|
||||
%6 = OpTypeVoid
|
||||
%7 = OpTypeFunction %6
|
||||
%8 = OpTypeInt 32 0
|
||||
%9 = OpTypePointer Function %8
|
||||
%10 = OpConstant %8 1
|
||||
%11 = OpTypeFloat 32
|
||||
%12 = OpTypeImage %11 2D 0 0 0 1 Unknown
|
||||
%13 = OpTypeSampledImage %12
|
||||
%4 = OpTypeRuntimeArray %13
|
||||
%14 = OpTypePointer UniformConstant %4
|
||||
%5 = OpVariable %14 UniformConstant
|
||||
%15 = OpTypePointer UniformConstant %13"#,
|
||||
);
|
||||
}
|
||||
|
||||
@ -479,3 +480,54 @@ fn ptr_copy_from_method() {
|
||||
"#
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn index_user_dst() {
|
||||
dis_entry_fn(
|
||||
r#"
|
||||
#[spirv(fragment)]
|
||||
pub fn main(
|
||||
#[spirv(uniform, descriptor_set = 0, binding = 0)] slice: &mut SliceF32,
|
||||
) {
|
||||
let float: f32 = slice.rta[0];
|
||||
let _ = float;
|
||||
}
|
||||
|
||||
pub struct SliceF32 {
|
||||
rta: [f32],
|
||||
}
|
||||
"#,
|
||||
"main",
|
||||
r#"%1 = OpFunction %2 None %3
|
||||
%4 = OpLabel
|
||||
%5 = OpArrayLength %6 %7 0
|
||||
%8 = OpCompositeInsert %9 %7 %10 0
|
||||
%11 = OpCompositeInsert %9 %5 %8 1
|
||||
%12 = OpAccessChain %13 %7 %14
|
||||
%15 = OpULessThan %16 %14 %5
|
||||
OpSelectionMerge %17 None
|
||||
OpBranchConditional %15 %18 %19
|
||||
%18 = OpLabel
|
||||
%20 = OpAccessChain %13 %7 %14
|
||||
%21 = OpInBoundsAccessChain %22 %20 %14
|
||||
%23 = OpLoad %24 %21
|
||||
OpReturn
|
||||
%19 = OpLabel
|
||||
OpBranch %25
|
||||
%25 = OpLabel
|
||||
OpBranch %26
|
||||
%26 = OpLabel
|
||||
%27 = OpPhi %16 %28 %25 %28 %29
|
||||
OpLoopMerge %30 %29 None
|
||||
OpBranchConditional %27 %31 %30
|
||||
%31 = OpLabel
|
||||
OpBranch %29
|
||||
%29 = OpLabel
|
||||
OpBranch %26
|
||||
%30 = OpLabel
|
||||
OpUnreachable
|
||||
%17 = OpLabel
|
||||
OpUnreachable
|
||||
OpFunctionEnd"#,
|
||||
)
|
||||
}
|
||||
|
@ -159,6 +159,33 @@ fn dis_fn(src: &str, func: &str, expect: &str) {
|
||||
assert_str_eq(expect, &func.disassemble())
|
||||
}
|
||||
|
||||
fn dis_entry_fn(src: &str, func: &str, expect: &str) {
|
||||
let _lock = global_lock();
|
||||
let module = read_module(&build(src)).unwrap();
|
||||
let id = module
|
||||
.entry_points
|
||||
.iter()
|
||||
.find(|inst| inst.operands.last().unwrap().unwrap_literal_string() == func)
|
||||
.unwrap_or_else(|| {
|
||||
panic!(
|
||||
"no entry point with the name `{}` found in:\n{}\n",
|
||||
func,
|
||||
module.disassemble()
|
||||
)
|
||||
})
|
||||
.operands[1]
|
||||
.unwrap_id_ref();
|
||||
let mut func = module
|
||||
.functions
|
||||
.into_iter()
|
||||
.find(|f| f.def_id().unwrap() == id)
|
||||
.unwrap();
|
||||
// Compact to make IDs more stable
|
||||
compact_ids(&mut func);
|
||||
use rspirv::binary::Disassemble;
|
||||
assert_str_eq(expect, &func.disassemble())
|
||||
}
|
||||
|
||||
fn dis_globals(src: &str, expect: &str) {
|
||||
let _lock = global_lock();
|
||||
let module = read_module(&build(src)).unwrap();
|
||||
|
@ -2,7 +2,7 @@ error: pointer has non-null integer address
|
||||
|
|
||||
= note: Stack:
|
||||
allocate_const_scalar::main
|
||||
Unnamed function ID %4
|
||||
Unnamed function ID %5
|
||||
|
||||
error: invalid binary:0:0 - No OpEntryPoint instruction was found. This is only allowed if the Linkage capability is being used.
|
||||
|
|
||||
|
Loading…
Reference in New Issue
Block a user