mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 14:56:27 +00:00
Implement function calls and some generics
This commit is contained in:
parent
687f142031
commit
75948dd70b
68
src/ctx.rs
68
src/ctx.rs
@ -4,15 +4,35 @@ use local_tracker::LocalTracker;
|
|||||||
use rspirv::binary::Assemble;
|
use rspirv::binary::Assemble;
|
||||||
use rspirv::dr::Builder;
|
use rspirv::dr::Builder;
|
||||||
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, StorageClass, Word};
|
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, StorageClass, Word};
|
||||||
use rustc_middle::mir::BasicBlock;
|
use rustc_middle::mir::{BasicBlock, Body};
|
||||||
use rustc_middle::ty::{subst::SubstsRef, TyCtxt};
|
use rustc_middle::ty::subst::SubstsRef;
|
||||||
|
use rustc_middle::ty::{fold::TypeFoldable, Instance, ParamEnv, TyCtxt};
|
||||||
|
use rustc_span::def_id::DefId;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::hash::Hash;
|
use std::hash::Hash;
|
||||||
|
|
||||||
|
#[derive(Copy, Clone)]
|
||||||
|
pub enum PointerSize {
|
||||||
|
P32,
|
||||||
|
P64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PointerSize {
|
||||||
|
pub fn val(self) -> u64 {
|
||||||
|
match self {
|
||||||
|
PointerSize::P32 => 32,
|
||||||
|
PointerSize::P64 => 64,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Context is the "bag of random variables" used for global state - i.e. variables that live the whole compilation.
|
||||||
pub struct Context<'tcx> {
|
pub struct Context<'tcx> {
|
||||||
pub tcx: TyCtxt<'tcx>,
|
pub tcx: TyCtxt<'tcx>,
|
||||||
pub spirv: Builder,
|
pub spirv: Builder,
|
||||||
spirv_helper: SpirvHelper,
|
spirv_helper: SpirvHelper,
|
||||||
|
pub pointer_size: PointerSize,
|
||||||
|
func_defs: ForwardReference<(DefId, SubstsRef<'tcx>)>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'tcx> Context<'tcx> {
|
impl<'tcx> Context<'tcx> {
|
||||||
@ -22,54 +42,92 @@ impl<'tcx> Context<'tcx> {
|
|||||||
// Temp hack: Linkage allows us to get away with no OpEntryPoint
|
// Temp hack: Linkage allows us to get away with no OpEntryPoint
|
||||||
spirv.capability(Capability::Linkage);
|
spirv.capability(Capability::Linkage);
|
||||||
spirv.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
|
spirv.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
|
||||||
|
let _ = PointerSize::P32; // just mark as used
|
||||||
Self {
|
Self {
|
||||||
tcx,
|
tcx,
|
||||||
spirv,
|
spirv,
|
||||||
spirv_helper: SpirvHelper::new(),
|
spirv_helper: SpirvHelper::new(),
|
||||||
|
pointer_size: PointerSize::P64,
|
||||||
|
func_defs: ForwardReference::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn assemble(self) -> Vec<u32> {
|
pub fn assemble(self) -> Vec<u32> {
|
||||||
self.spirv.module().assemble()
|
self.spirv.module().assemble()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn get_func_def(&mut self, id: DefId, substs: SubstsRef<'tcx>) -> Word {
|
||||||
|
self.func_defs.get(&mut self.spirv, (id, substs))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// FnCtx is the "bag of random variables" used for state when compiling a particular function - i.e. variables that are
|
||||||
|
/// specific to compiling a particular function. Note it carries a reference to the global Context, so those variables
|
||||||
|
/// can be accessed too.
|
||||||
pub struct FnCtx<'ctx, 'tcx> {
|
pub struct FnCtx<'ctx, 'tcx> {
|
||||||
|
/// The global state - note this field is `pub`, so if needed, you can go through this field, instead of the helper
|
||||||
|
/// functions, to satisfy the borrow checker (since using the helper functions borrow this whole struct).
|
||||||
pub ctx: &'ctx mut Context<'tcx>,
|
pub ctx: &'ctx mut Context<'tcx>,
|
||||||
pub substs: SubstsRef<'tcx>,
|
pub instance: Instance<'tcx>,
|
||||||
|
pub body: &'tcx Body<'tcx>,
|
||||||
pub is_void: bool,
|
pub is_void: bool,
|
||||||
basic_blocks: ForwardReference<BasicBlock>,
|
basic_blocks: ForwardReference<BasicBlock>,
|
||||||
pub locals: LocalTracker,
|
pub locals: LocalTracker,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> {
|
impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> {
|
||||||
pub fn new(ctx: &'ctx mut Context<'tcx>, substs: SubstsRef<'tcx>) -> Self {
|
pub fn new(
|
||||||
|
ctx: &'ctx mut Context<'tcx>,
|
||||||
|
instance: Instance<'tcx>,
|
||||||
|
body: &'tcx Body<'tcx>,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
ctx,
|
ctx,
|
||||||
substs,
|
instance,
|
||||||
|
body,
|
||||||
is_void: false,
|
is_void: false,
|
||||||
basic_blocks: ForwardReference::new(),
|
basic_blocks: ForwardReference::new(),
|
||||||
locals: LocalTracker::new(),
|
locals: LocalTracker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
pub fn tcx(&mut self) -> &mut TyCtxt<'tcx> {
|
pub fn tcx(&mut self) -> &mut TyCtxt<'tcx> {
|
||||||
&mut self.ctx.tcx
|
&mut self.ctx.tcx
|
||||||
}
|
}
|
||||||
|
*/
|
||||||
|
|
||||||
pub fn spirv(&mut self) -> &mut Builder {
|
pub fn spirv(&mut self) -> &mut Builder {
|
||||||
&mut self.ctx.spirv
|
&mut self.ctx.spirv
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Gets the spir-v label for a basic block, or generates one if it doesn't exist.
|
||||||
pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word {
|
pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word {
|
||||||
self.basic_blocks.get(&mut self.ctx.spirv, bb)
|
self.basic_blocks.get(&mut self.ctx.spirv, bb)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// rspirv doesn't cache type_pointer, so cache it ourselves here.
|
||||||
pub fn type_pointer(&mut self, pointee_type: Word) -> Word {
|
pub fn type_pointer(&mut self, pointee_type: Word) -> Word {
|
||||||
self.ctx
|
self.ctx
|
||||||
.spirv_helper
|
.spirv_helper
|
||||||
.type_pointer(&mut self.ctx.spirv, pointee_type)
|
.type_pointer(&mut self.ctx.spirv, pointee_type)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copied from rustc_codegen_cranelift
|
||||||
|
pub(crate) fn monomorphize<T>(&self, value: &T) -> T
|
||||||
|
where
|
||||||
|
T: TypeFoldable<'tcx> + Copy,
|
||||||
|
{
|
||||||
|
if let Some(substs) = self.instance.substs_for_mir_body() {
|
||||||
|
self.ctx
|
||||||
|
.tcx
|
||||||
|
.subst_and_normalize_erasing_regions(substs, ParamEnv::reveal_all(), value)
|
||||||
|
} else {
|
||||||
|
self.ctx
|
||||||
|
.tcx
|
||||||
|
.normalize_erasing_regions(ParamEnv::reveal_all(), *value)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ForwardReference<T: Eq + Hash> {
|
struct ForwardReference<T: Eq + Hash> {
|
||||||
|
@ -26,7 +26,8 @@ impl LocalTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn get(&self, local: Local) -> Word {
|
pub fn get(&self, local: Local) -> Word {
|
||||||
// This probably needs to be fixed, forward-references might be a thing
|
// This probably needs to be fixed, forward-references might be a thing.
|
||||||
|
// (MIR probably declares all locals up front, so use that?)
|
||||||
*self.locals.get(&local).expect("Undefined local")
|
*self.locals.get(&local).expect("Undefined local")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -114,6 +114,7 @@ impl CodegenBackend for TheBackend {
|
|||||||
_sess: &Session,
|
_sess: &Session,
|
||||||
_dep_graph: &DepGraph,
|
_dep_graph: &DepGraph,
|
||||||
) -> Result<Box<dyn Any>, ErrorReported> {
|
) -> Result<Box<dyn Any>, ErrorReported> {
|
||||||
|
// I think this function is if `codegen_crate` spawns threads, this is supposed to join those threads?
|
||||||
Ok(ongoing_codegen)
|
Ok(ongoing_codegen)
|
||||||
// let crate_name = ongoing_codegen
|
// let crate_name = ongoing_codegen
|
||||||
// .downcast::<Symbol>()
|
// .downcast::<Symbol>()
|
||||||
|
88
src/trans.rs
88
src/trans.rs
@ -1,14 +1,9 @@
|
|||||||
use crate::ctx::{Context, FnCtx};
|
use crate::ctx::{Context, FnCtx};
|
||||||
use rspirv::spirv::{FunctionControl, Word};
|
use rspirv::spirv::{FunctionControl, Word};
|
||||||
use rustc_middle::mir::{
|
use rustc_middle::mir::{Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind};
|
||||||
Body, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind,
|
use rustc_middle::ty::{Instance, ParamEnv, Ty, TyKind};
|
||||||
};
|
|
||||||
use rustc_middle::ty::{Ty, TyKind};
|
|
||||||
|
|
||||||
pub fn trans_fn<'ctx, 'tcx>(
|
pub fn trans_fn<'ctx, 'tcx>(ctx: &'ctx mut Context<'tcx>, instance: Instance<'tcx>) {
|
||||||
ctx: &'ctx mut Context<'tcx>,
|
|
||||||
instance: rustc_middle::ty::Instance<'tcx>,
|
|
||||||
) {
|
|
||||||
{
|
{
|
||||||
let mut mir = ::std::io::Cursor::new(Vec::new());
|
let mut mir = ::std::io::Cursor::new(Vec::new());
|
||||||
|
|
||||||
@ -20,13 +15,13 @@ pub fn trans_fn<'ctx, 'tcx>(
|
|||||||
println!("{}", s);
|
println!("{}", s);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mir = ctx.tcx.optimized_mir(instance.def_id());
|
let body = ctx.tcx.optimized_mir(instance.def_id());
|
||||||
|
|
||||||
let mut fnctx = FnCtx::new(ctx, instance.substs);
|
let mut fnctx = FnCtx::new(ctx, instance, body);
|
||||||
|
|
||||||
trans_fn_header(&mut fnctx, mir);
|
trans_fn_header(&mut fnctx);
|
||||||
|
|
||||||
for (bb, bb_data) in mir.basic_blocks().iter_enumerated() {
|
for (bb, bb_data) in fnctx.body.basic_blocks().iter_enumerated() {
|
||||||
let label_id = fnctx.get_basic_block(bb);
|
let label_id = fnctx.get_basic_block(bb);
|
||||||
fnctx.spirv().begin_block(Some(label_id)).unwrap();
|
fnctx.spirv().begin_block(Some(label_id)).unwrap();
|
||||||
|
|
||||||
@ -38,17 +33,16 @@ pub fn trans_fn<'ctx, 'tcx>(
|
|||||||
fnctx.spirv().end_function().unwrap();
|
fnctx.spirv().end_function().unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// return_type, fn_type
|
fn trans_fn_header<'tcx>(ctx: &mut FnCtx<'_, 'tcx>) {
|
||||||
fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) {
|
let mir_return_type = ctx.body.local_decls[0u32.into()].ty;
|
||||||
let mir_return_type = body.local_decls[0u32.into()].ty;
|
|
||||||
let return_type = trans_type(ctx, mir_return_type);
|
let return_type = trans_type(ctx, mir_return_type);
|
||||||
ctx.is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind {
|
ctx.is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind {
|
||||||
fields.len() == 0
|
fields.len() == 0
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
};
|
};
|
||||||
let params = (0..body.arg_count)
|
let params = (0..ctx.body.arg_count)
|
||||||
.map(|i| trans_type(ctx, body.local_decls[(i + 1).into()].ty))
|
.map(|i| trans_type(ctx, ctx.body.local_decls[(i + 1).into()].ty))
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
let mut params_nonzero = params.clone();
|
let mut params_nonzero = params.clone();
|
||||||
if params_nonzero.is_empty() {
|
if params_nonzero.is_empty() {
|
||||||
@ -56,12 +50,13 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) {
|
|||||||
params_nonzero.push(ctx.spirv().type_void());
|
params_nonzero.push(ctx.spirv().type_void());
|
||||||
}
|
}
|
||||||
let function_type = ctx.spirv().type_function(return_type, params_nonzero);
|
let function_type = ctx.spirv().type_function(return_type, params_nonzero);
|
||||||
let function_id = None;
|
let function_id = ctx
|
||||||
|
.ctx
|
||||||
|
.get_func_def(ctx.instance.def_id(), ctx.instance.substs);
|
||||||
let control = FunctionControl::NONE;
|
let control = FunctionControl::NONE;
|
||||||
// TODO: keep track of function IDs
|
|
||||||
let _ = ctx
|
let _ = ctx
|
||||||
.spirv()
|
.spirv()
|
||||||
.begin_function(return_type, function_id, control, function_type)
|
.begin_function(return_type, Some(function_id), control, function_type)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
for (i, ¶m_type) in params.iter().enumerate() {
|
for (i, ¶m_type) in params.iter().enumerate() {
|
||||||
@ -70,16 +65,18 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn trans_type<'tcx>(ctx: &mut FnCtx, ty: Ty<'tcx>) -> Word {
|
fn trans_type<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, ty: Ty<'tcx>) -> Word {
|
||||||
match ty.kind {
|
let mono = ctx.monomorphize(&ty);
|
||||||
|
match mono.kind {
|
||||||
|
TyKind::Param(param) => panic!("TyKind::Param after monomorphize: {:?}", param),
|
||||||
TyKind::Bool => ctx.spirv().type_bool(),
|
TyKind::Bool => ctx.spirv().type_bool(),
|
||||||
TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv().type_void(),
|
TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv().type_void(),
|
||||||
TyKind::Int(ty) => {
|
TyKind::Int(ty) => {
|
||||||
let size = ty.bit_width().expect("isize not supported yet");
|
let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val());
|
||||||
ctx.spirv().type_int(size as u32, 1)
|
ctx.spirv().type_int(size as u32, 1)
|
||||||
}
|
}
|
||||||
TyKind::Uint(ty) => {
|
TyKind::Uint(ty) => {
|
||||||
let size = ty.bit_width().expect("isize not supported yet");
|
let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val());
|
||||||
ctx.spirv().type_int(size as u32, 0)
|
ctx.spirv().type_int(size as u32, 0)
|
||||||
}
|
}
|
||||||
TyKind::Float(ty) => ctx.spirv().type_float(ty.bit_width() as u32),
|
TyKind::Float(ty) => ctx.spirv().type_float(ty.bit_width() as u32),
|
||||||
@ -116,7 +113,7 @@ fn trans_stmt<'tcx>(ctx: &mut FnCtx, stmt: &Statement<'tcx>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) {
|
fn trans_terminator<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, term: &Terminator<'tcx>) {
|
||||||
match term.kind {
|
match term.kind {
|
||||||
TerminatorKind::Return => {
|
TerminatorKind::Return => {
|
||||||
if ctx.is_void {
|
if ctx.is_void {
|
||||||
@ -138,6 +135,47 @@ fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) {
|
|||||||
let inst_id = ctx.get_basic_block(target);
|
let inst_id = ctx.get_basic_block(target);
|
||||||
ctx.spirv().branch(inst_id).unwrap();
|
ctx.spirv().branch(inst_id).unwrap();
|
||||||
}
|
}
|
||||||
|
TerminatorKind::Call {
|
||||||
|
ref func,
|
||||||
|
ref args,
|
||||||
|
ref destination,
|
||||||
|
..
|
||||||
|
} => {
|
||||||
|
let destination = destination.expect("Divergent function calls not supported yet");
|
||||||
|
let fn_ty = ctx.monomorphize(&func.ty(ctx.body, ctx.ctx.tcx));
|
||||||
|
let fn_sig = ctx.ctx.tcx.normalize_erasing_late_bound_regions(
|
||||||
|
ParamEnv::reveal_all(),
|
||||||
|
&fn_ty.fn_sig(ctx.ctx.tcx),
|
||||||
|
);
|
||||||
|
let fn_return_type = ParamEnv::reveal_all().and(fn_sig.output()).value;
|
||||||
|
let result_type = trans_type(ctx, fn_return_type);
|
||||||
|
let function = match func {
|
||||||
|
// TODO: Can constant.literal.val not be a ZST?
|
||||||
|
Operand::Constant(constant) => match constant.literal.ty.kind {
|
||||||
|
TyKind::FnDef(id, substs) => ctx.ctx.get_func_def(id, substs),
|
||||||
|
ref thing => panic!("Unknown type in fn call: {:?}", thing),
|
||||||
|
},
|
||||||
|
thing => panic!("Dynamic calls not supported yet: {:?}", thing),
|
||||||
|
};
|
||||||
|
let arguments = args
|
||||||
|
.iter()
|
||||||
|
.map(|arg| trans_operand(ctx, arg))
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
let dest_local = destination.0.local;
|
||||||
|
if destination.0.projection.len() != 0 {
|
||||||
|
panic!(
|
||||||
|
"Place projections aren't supported yet (fn call): {:?}",
|
||||||
|
destination.0.projection
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let result = ctx
|
||||||
|
.spirv()
|
||||||
|
.function_call(result_type, None, function, arguments)
|
||||||
|
.unwrap();
|
||||||
|
ctx.locals.def(dest_local, result);
|
||||||
|
let destination_bb = ctx.get_basic_block(destination.1);
|
||||||
|
ctx.spirv().branch(destination_bb).unwrap();
|
||||||
|
}
|
||||||
ref thing => panic!("Unknown terminator: {:?}", thing),
|
ref thing => panic!("Unknown terminator: {:?}", thing),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user