mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 06:45:13 +00:00
Interfaces and kernel entry points
This commit is contained in:
parent
d212f489a1
commit
78f7d8f91c
@ -3,12 +3,12 @@
|
||||
#![register_attr(spirv)]
|
||||
|
||||
use core::panic::PanicInfo;
|
||||
use spirv_std::Workgroup;
|
||||
use spirv_std::{CrossWorkgroup, UniformConstant};
|
||||
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(entry = "kernel")]
|
||||
pub fn screaming_bananans(mut x: Workgroup<u32>) {
|
||||
x.store(x.load() + 1);
|
||||
pub fn add_two_ints(x: UniformConstant<u32>, y: UniformConstant<u32>, mut z: CrossWorkgroup<u32>) {
|
||||
z.store(x.load() + y.load())
|
||||
}
|
||||
|
||||
#[panic_handler]
|
||||
|
@ -1318,11 +1318,15 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
|
||||
let (result_type, argument_types) = loop {
|
||||
match self.lookup_type(llfn.ty) {
|
||||
SpirvType::Pointer { pointee, .. } => {
|
||||
llfn = self
|
||||
.emit()
|
||||
.load(pointee, None, llfn.def, None, empty())
|
||||
.unwrap()
|
||||
.with_type(pointee)
|
||||
if let Some(func) = self.cx.function_pointers.borrow().get(&llfn) {
|
||||
llfn = *func;
|
||||
} else {
|
||||
llfn = self
|
||||
.emit()
|
||||
.load(pointee, None, llfn.def, None, empty())
|
||||
.unwrap()
|
||||
.with_type(pointee)
|
||||
}
|
||||
}
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
|
@ -6,7 +6,7 @@ use std::collections::hash_map::Entry;
|
||||
use std::collections::HashMap;
|
||||
use std::{fs::File, io::Write, path::Path};
|
||||
|
||||
#[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq)]
|
||||
#[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq, Hash)]
|
||||
pub struct SpirvValue {
|
||||
pub def: Word,
|
||||
pub ty: Word,
|
||||
|
@ -3,7 +3,7 @@ use crate::abi::ConvSpirvType;
|
||||
use crate::builder_spirv::{SpirvConst, SpirvValue, SpirvValueExt};
|
||||
use crate::spirv_type::SpirvType;
|
||||
use crate::symbols::{parse_attr, SpirvAttribute};
|
||||
use rspirv::spirv::{FunctionControl, LinkageType, StorageClass};
|
||||
use rspirv::spirv::{ExecutionModel, FunctionControl, LinkageType, StorageClass};
|
||||
use rustc_attr::InlineAttr;
|
||||
use rustc_codegen_ssa::traits::{DeclareMethods, MiscMethods, PreDefineMethods, StaticMethods};
|
||||
use rustc_middle::middle::codegen_fn_attrs::{CodegenFnAttrFlags, CodegenFnAttrs};
|
||||
@ -147,6 +147,102 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
self.instances.borrow_mut().insert(instance, g);
|
||||
g
|
||||
}
|
||||
|
||||
// Entry points declare their "interface" (all uniforms, inputs, outputs, etc.) as parameters. spir-v uses globals
|
||||
// to declare the interface. So, we need to generate a lil stub for the "real" main that collects all those global
|
||||
// variables and calls the user-defined main function.
|
||||
fn entry_stub(&self, entry_func: SpirvValue, name: String, execution_model: ExecutionModel) {
|
||||
let void = SpirvType::Void.def(self);
|
||||
let fn_void_void = SpirvType::Function {
|
||||
return_type: void,
|
||||
arguments: vec![],
|
||||
}
|
||||
.def(self);
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => panic!(
|
||||
"Invalid entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
),
|
||||
};
|
||||
let mut emit = self.emit_global();
|
||||
// Create OpVariables before OpFunction so they're global instead of local vars.
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.map(|&arg| {
|
||||
let storage_class = match self.lookup_type(arg) {
|
||||
SpirvType::Pointer { storage_class, .. } => storage_class,
|
||||
other => panic!("Invalid entry arg type {}", other.debug(arg, self)),
|
||||
};
|
||||
// Note: this *declares* the variable too.
|
||||
emit.variable(arg, None, storage_class, None)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let fn_id = emit
|
||||
.begin_function(void, None, FunctionControl::NONE, fn_void_void)
|
||||
.unwrap();
|
||||
emit.begin_block(None).unwrap();
|
||||
emit.function_call(
|
||||
entry_func_return,
|
||||
None,
|
||||
entry_func.def,
|
||||
arguments.iter().copied(),
|
||||
)
|
||||
.unwrap();
|
||||
emit.ret().unwrap();
|
||||
emit.end_function().unwrap();
|
||||
|
||||
let interface = arguments;
|
||||
emit.entry_point(execution_model, fn_id, name, interface);
|
||||
}
|
||||
|
||||
// Kernel mode takes its interface as function parameters(??)
|
||||
// OpEntryPoints cannot be OpLinkage, so write out a stub to call through.
|
||||
fn kernel_entry_stub(
|
||||
&self,
|
||||
entry_func: SpirvValue,
|
||||
name: String,
|
||||
execution_model: ExecutionModel,
|
||||
) {
|
||||
let (entry_func_return, entry_func_args) = match self.lookup_type(entry_func.ty) {
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
arguments,
|
||||
} => (return_type, arguments),
|
||||
other => panic!(
|
||||
"Invalid kernel_entry_stub type: {}",
|
||||
other.debug(entry_func.ty, self)
|
||||
),
|
||||
};
|
||||
let mut emit = self.emit_global();
|
||||
let fn_id = emit
|
||||
.begin_function(
|
||||
entry_func_return,
|
||||
None,
|
||||
FunctionControl::NONE,
|
||||
entry_func.ty,
|
||||
)
|
||||
.unwrap();
|
||||
let arguments = entry_func_args
|
||||
.iter()
|
||||
.map(|&ty| emit.function_parameter(ty).unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
emit.begin_block(None).unwrap();
|
||||
let call_result = emit
|
||||
.function_call(entry_func_return, None, entry_func.def, arguments)
|
||||
.unwrap();
|
||||
if self.lookup_type(entry_func_return) == SpirvType::Void {
|
||||
emit.ret().unwrap();
|
||||
} else {
|
||||
emit.ret_value(call_result).unwrap();
|
||||
}
|
||||
emit.end_function().unwrap();
|
||||
|
||||
emit.entry_point(execution_model, fn_id, name, &[]);
|
||||
}
|
||||
}
|
||||
|
||||
impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
|
||||
@ -198,13 +294,11 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
|
||||
|
||||
for attr in self.tcx.get_attrs(instance.def_id()) {
|
||||
if let Some(SpirvAttribute::Entry(execution_model)) = parse_attr(self, attr) {
|
||||
let interface = &[];
|
||||
self.emit_global().entry_point(
|
||||
execution_model,
|
||||
declared.def,
|
||||
human_name.clone(),
|
||||
interface,
|
||||
);
|
||||
if execution_model == ExecutionModel::Kernel {
|
||||
self.kernel_entry_stub(declared, human_name.clone(), execution_model);
|
||||
} else {
|
||||
self.entry_stub(declared, human_name.clone(), execution_model);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3,15 +3,15 @@ mod declare;
|
||||
mod type_;
|
||||
|
||||
use crate::builder::ExtInst;
|
||||
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue};
|
||||
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt};
|
||||
use crate::finalizing_passes::{block_ordering_pass, delete_dead_blocks, zombie_pass};
|
||||
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
|
||||
use crate::symbols::Symbols;
|
||||
use rspirv::dr::{Module, Operand};
|
||||
use rspirv::spirv::{Decoration, LinkageType, Word};
|
||||
use rspirv::spirv::{Decoration, LinkageType, StorageClass, Word};
|
||||
use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind};
|
||||
use rustc_codegen_ssa::traits::{
|
||||
AsmMethods, BackendTypes, CoverageInfoMethods, DebugInfoMethods, MiscMethods, StaticMethods,
|
||||
AsmMethods, BackendTypes, CoverageInfoMethods, DebugInfoMethods, MiscMethods,
|
||||
};
|
||||
use rustc_data_structures::fx::FxHashMap;
|
||||
use rustc_hir::GlobalAsm;
|
||||
@ -25,7 +25,7 @@ use rustc_span::source_map::Span;
|
||||
use rustc_span::symbol::Symbol;
|
||||
use rustc_span::SourceFile;
|
||||
use rustc_target::abi::call::FnAbi;
|
||||
use rustc_target::abi::{Align, HasDataLayout, TargetDataLayout};
|
||||
use rustc_target::abi::{HasDataLayout, TargetDataLayout};
|
||||
use rustc_target::spec::{HasTargetSpec, Target};
|
||||
use std::cell::RefCell;
|
||||
use std::collections::HashMap;
|
||||
@ -51,6 +51,8 @@ pub struct CodegenCx<'tcx> {
|
||||
pub kernel_mode: bool,
|
||||
/// Cache of all the builtin symbols we need
|
||||
pub sym: Box<Symbols>,
|
||||
/// Functions created in get_fn_addr
|
||||
pub function_pointers: RefCell<HashMap<SpirvValue, SpirvValue>>,
|
||||
}
|
||||
|
||||
impl<'tcx> CodegenCx<'tcx> {
|
||||
@ -68,6 +70,7 @@ impl<'tcx> CodegenCx<'tcx> {
|
||||
zombie_values: Default::default(),
|
||||
kernel_mode: true,
|
||||
sym: Box::new(Symbols::new()),
|
||||
function_pointers: Default::default(),
|
||||
}
|
||||
}
|
||||
|
||||
@ -186,7 +189,16 @@ impl<'tcx> MiscMethods<'tcx> for CodegenCx<'tcx> {
|
||||
|
||||
fn get_fn_addr(&self, instance: Instance<'tcx>) -> Self::Value {
|
||||
let function = self.get_fn_ext(instance);
|
||||
self.static_addr_of(function, Align::from_bytes(0).unwrap(), None)
|
||||
let ty = SpirvType::Pointer {
|
||||
storage_class: StorageClass::Function,
|
||||
pointee: function.ty,
|
||||
}
|
||||
.def(self);
|
||||
// We want a unique ID for these undefs, so don't use the caching system.
|
||||
let result = self.emit_global().undef(ty, None).with_type(ty);
|
||||
self.zombie(result.def, "get_fn_addr");
|
||||
self.function_pointers.borrow_mut().insert(result, function);
|
||||
result
|
||||
}
|
||||
|
||||
fn eh_personality(&self) -> Self::Value {
|
||||
|
@ -15,6 +15,7 @@ pub struct Symbols {
|
||||
|
||||
fn make_storage_classes() -> HashMap<Symbol, StorageClass> {
|
||||
use StorageClass::*;
|
||||
// make sure these strings stay synced with spirv-std's pointer types
|
||||
[
|
||||
("uniform_constant", UniformConstant),
|
||||
("input", Input),
|
||||
|
@ -2,38 +2,47 @@
|
||||
#![feature(register_attr)]
|
||||
#![register_attr(spirv)]
|
||||
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(storage_class = "private")]
|
||||
pub struct Private<'a, T> {
|
||||
x: &'a mut T,
|
||||
macro_rules! pointer_addrspace {
|
||||
($storage_class:literal, $type_name:ident) => {
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(storage_class = $storage_class)]
|
||||
pub struct $type_name<'a, T> {
|
||||
x: &'a mut T,
|
||||
}
|
||||
|
||||
impl<'a, T: Copy> $type_name<'a, T> {
|
||||
#[inline]
|
||||
pub fn load(&self) -> T {
|
||||
*self.x
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn store(&mut self, v: T) {
|
||||
*self.x = v
|
||||
}
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
impl<'a, T: Copy> Private<'a, T> {
|
||||
#[inline]
|
||||
pub fn load(&self) -> T {
|
||||
*self.x
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn store(&mut self, v: T) {
|
||||
*self.x = v
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(unused_attributes)]
|
||||
#[spirv(storage_class = "workgroup")]
|
||||
pub struct Workgroup<'a, T> {
|
||||
x: &'a mut T,
|
||||
}
|
||||
|
||||
impl<'a, T: Copy> Workgroup<'a, T> {
|
||||
#[inline]
|
||||
pub fn load(&self) -> T {
|
||||
*self.x
|
||||
}
|
||||
|
||||
#[inline]
|
||||
pub fn store(&mut self, v: T) {
|
||||
*self.x = v
|
||||
}
|
||||
}
|
||||
// Make sure these strings stay synced with symbols.rs
|
||||
// Note the type names don't have to match anything, they can be renamed (only the string must match)
|
||||
pointer_addrspace!("uniform_constant", UniformConstant);
|
||||
pointer_addrspace!("input", Input);
|
||||
pointer_addrspace!("uniform", Uniform);
|
||||
pointer_addrspace!("output", Output);
|
||||
pointer_addrspace!("workgroup", Workgroup);
|
||||
pointer_addrspace!("cross_workgroup", CrossWorkgroup);
|
||||
pointer_addrspace!("private", Private);
|
||||
pointer_addrspace!("function", Function);
|
||||
pointer_addrspace!("generic", Generic);
|
||||
pointer_addrspace!("push_constant", PushConstant);
|
||||
pointer_addrspace!("atomic_counter", AtomicCounter);
|
||||
pointer_addrspace!("image", Image);
|
||||
pointer_addrspace!("storage_buffer", StorageBuffer);
|
||||
pointer_addrspace!("callable_data_nv", CallableDataNV);
|
||||
pointer_addrspace!("incoming_callable_data_nv", IncomingCallableDataNV);
|
||||
pointer_addrspace!("ray_payload_nv", RayPayloadNV);
|
||||
pointer_addrspace!("hit_attribute_nv", HitAttributeNV);
|
||||
pointer_addrspace!("incoming_ray_payload_nv", IncomingRayPayloadNV);
|
||||
pointer_addrspace!("shader_record_buffer_nv", ShaderRecordBufferNV);
|
||||
pointer_addrspace!("physical_storage_buffer", PhysicalStorageBuffer);
|
||||
|
Loading…
Reference in New Issue
Block a user