Implement function calls and some generics

This commit is contained in:
khyperia 2020-08-12 13:55:56 +02:00
parent 687f142031
commit 75948dd70b
4 changed files with 129 additions and 31 deletions

View File

@ -4,15 +4,35 @@ use local_tracker::LocalTracker;
use rspirv::binary::Assemble; use rspirv::binary::Assemble;
use rspirv::dr::Builder; use rspirv::dr::Builder;
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, StorageClass, Word}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, StorageClass, Word};
use rustc_middle::mir::BasicBlock; use rustc_middle::mir::{BasicBlock, Body};
use rustc_middle::ty::{subst::SubstsRef, TyCtxt}; use rustc_middle::ty::subst::SubstsRef;
use rustc_middle::ty::{fold::TypeFoldable, Instance, ParamEnv, TyCtxt};
use rustc_span::def_id::DefId;
use std::collections::HashMap; use std::collections::HashMap;
use std::hash::Hash; use std::hash::Hash;
#[derive(Copy, Clone)]
pub enum PointerSize {
P32,
P64,
}
impl PointerSize {
pub fn val(self) -> u64 {
match self {
PointerSize::P32 => 32,
PointerSize::P64 => 64,
}
}
}
/// Context is the "bag of random variables" used for global state - i.e. variables that live the whole compilation.
pub struct Context<'tcx> { pub struct Context<'tcx> {
pub tcx: TyCtxt<'tcx>, pub tcx: TyCtxt<'tcx>,
pub spirv: Builder, pub spirv: Builder,
spirv_helper: SpirvHelper, spirv_helper: SpirvHelper,
pub pointer_size: PointerSize,
func_defs: ForwardReference<(DefId, SubstsRef<'tcx>)>,
} }
impl<'tcx> Context<'tcx> { impl<'tcx> Context<'tcx> {
@ -22,54 +42,92 @@ impl<'tcx> Context<'tcx> {
// Temp hack: Linkage allows us to get away with no OpEntryPoint // Temp hack: Linkage allows us to get away with no OpEntryPoint
spirv.capability(Capability::Linkage); spirv.capability(Capability::Linkage);
spirv.memory_model(AddressingModel::Logical, MemoryModel::GLSL450); spirv.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
let _ = PointerSize::P32; // just mark as used
Self { Self {
tcx, tcx,
spirv, spirv,
spirv_helper: SpirvHelper::new(), spirv_helper: SpirvHelper::new(),
pointer_size: PointerSize::P64,
func_defs: ForwardReference::new(),
} }
} }
pub fn assemble(self) -> Vec<u32> { pub fn assemble(self) -> Vec<u32> {
self.spirv.module().assemble() self.spirv.module().assemble()
} }
pub fn get_func_def(&mut self, id: DefId, substs: SubstsRef<'tcx>) -> Word {
self.func_defs.get(&mut self.spirv, (id, substs))
}
} }
/// FnCtx is the "bag of random variables" used for state when compiling a particular function - i.e. variables that are
/// specific to compiling a particular function. Note it carries a reference to the global Context, so those variables
/// can be accessed too.
pub struct FnCtx<'ctx, 'tcx> { pub struct FnCtx<'ctx, 'tcx> {
/// The global state - note this field is `pub`, so if needed, you can go through this field, instead of the helper
/// functions, to satisfy the borrow checker (since using the helper functions borrow this whole struct).
pub ctx: &'ctx mut Context<'tcx>, pub ctx: &'ctx mut Context<'tcx>,
pub substs: SubstsRef<'tcx>, pub instance: Instance<'tcx>,
pub body: &'tcx Body<'tcx>,
pub is_void: bool, pub is_void: bool,
basic_blocks: ForwardReference<BasicBlock>, basic_blocks: ForwardReference<BasicBlock>,
pub locals: LocalTracker, pub locals: LocalTracker,
} }
impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> { impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> {
pub fn new(ctx: &'ctx mut Context<'tcx>, substs: SubstsRef<'tcx>) -> Self { pub fn new(
ctx: &'ctx mut Context<'tcx>,
instance: Instance<'tcx>,
body: &'tcx Body<'tcx>,
) -> Self {
Self { Self {
ctx, ctx,
substs, instance,
body,
is_void: false, is_void: false,
basic_blocks: ForwardReference::new(), basic_blocks: ForwardReference::new(),
locals: LocalTracker::new(), locals: LocalTracker::new(),
} }
} }
/*
pub fn tcx(&mut self) -> &mut TyCtxt<'tcx> { pub fn tcx(&mut self) -> &mut TyCtxt<'tcx> {
&mut self.ctx.tcx &mut self.ctx.tcx
} }
*/
pub fn spirv(&mut self) -> &mut Builder { pub fn spirv(&mut self) -> &mut Builder {
&mut self.ctx.spirv &mut self.ctx.spirv
} }
/// Gets the spir-v label for a basic block, or generates one if it doesn't exist.
pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word { pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word {
self.basic_blocks.get(&mut self.ctx.spirv, bb) self.basic_blocks.get(&mut self.ctx.spirv, bb)
} }
/// rspirv doesn't cache type_pointer, so cache it ourselves here.
pub fn type_pointer(&mut self, pointee_type: Word) -> Word { pub fn type_pointer(&mut self, pointee_type: Word) -> Word {
self.ctx self.ctx
.spirv_helper .spirv_helper
.type_pointer(&mut self.ctx.spirv, pointee_type) .type_pointer(&mut self.ctx.spirv, pointee_type)
} }
// copied from rustc_codegen_cranelift
pub(crate) fn monomorphize<T>(&self, value: &T) -> T
where
T: TypeFoldable<'tcx> + Copy,
{
if let Some(substs) = self.instance.substs_for_mir_body() {
self.ctx
.tcx
.subst_and_normalize_erasing_regions(substs, ParamEnv::reveal_all(), value)
} else {
self.ctx
.tcx
.normalize_erasing_regions(ParamEnv::reveal_all(), *value)
}
}
} }
struct ForwardReference<T: Eq + Hash> { struct ForwardReference<T: Eq + Hash> {

View File

@ -26,7 +26,8 @@ impl LocalTracker {
} }
pub fn get(&self, local: Local) -> Word { pub fn get(&self, local: Local) -> Word {
// This probably needs to be fixed, forward-references might be a thing // This probably needs to be fixed, forward-references might be a thing.
// (MIR probably declares all locals up front, so use that?)
*self.locals.get(&local).expect("Undefined local") *self.locals.get(&local).expect("Undefined local")
} }
} }

View File

@ -114,6 +114,7 @@ impl CodegenBackend for TheBackend {
_sess: &Session, _sess: &Session,
_dep_graph: &DepGraph, _dep_graph: &DepGraph,
) -> Result<Box<dyn Any>, ErrorReported> { ) -> Result<Box<dyn Any>, ErrorReported> {
// I think this function is if `codegen_crate` spawns threads, this is supposed to join those threads?
Ok(ongoing_codegen) Ok(ongoing_codegen)
// let crate_name = ongoing_codegen // let crate_name = ongoing_codegen
// .downcast::<Symbol>() // .downcast::<Symbol>()

View File

@ -1,14 +1,9 @@
use crate::ctx::{Context, FnCtx}; use crate::ctx::{Context, FnCtx};
use rspirv::spirv::{FunctionControl, Word}; use rspirv::spirv::{FunctionControl, Word};
use rustc_middle::mir::{ use rustc_middle::mir::{Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind};
Body, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, use rustc_middle::ty::{Instance, ParamEnv, Ty, TyKind};
};
use rustc_middle::ty::{Ty, TyKind};
pub fn trans_fn<'ctx, 'tcx>( pub fn trans_fn<'ctx, 'tcx>(ctx: &'ctx mut Context<'tcx>, instance: Instance<'tcx>) {
ctx: &'ctx mut Context<'tcx>,
instance: rustc_middle::ty::Instance<'tcx>,
) {
{ {
let mut mir = ::std::io::Cursor::new(Vec::new()); let mut mir = ::std::io::Cursor::new(Vec::new());
@ -20,13 +15,13 @@ pub fn trans_fn<'ctx, 'tcx>(
println!("{}", s); println!("{}", s);
} }
let mir = ctx.tcx.optimized_mir(instance.def_id()); let body = ctx.tcx.optimized_mir(instance.def_id());
let mut fnctx = FnCtx::new(ctx, instance.substs); let mut fnctx = FnCtx::new(ctx, instance, body);
trans_fn_header(&mut fnctx, mir); trans_fn_header(&mut fnctx);
for (bb, bb_data) in mir.basic_blocks().iter_enumerated() { for (bb, bb_data) in fnctx.body.basic_blocks().iter_enumerated() {
let label_id = fnctx.get_basic_block(bb); let label_id = fnctx.get_basic_block(bb);
fnctx.spirv().begin_block(Some(label_id)).unwrap(); fnctx.spirv().begin_block(Some(label_id)).unwrap();
@ -38,17 +33,16 @@ pub fn trans_fn<'ctx, 'tcx>(
fnctx.spirv().end_function().unwrap(); fnctx.spirv().end_function().unwrap();
} }
// return_type, fn_type fn trans_fn_header<'tcx>(ctx: &mut FnCtx<'_, 'tcx>) {
fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) { let mir_return_type = ctx.body.local_decls[0u32.into()].ty;
let mir_return_type = body.local_decls[0u32.into()].ty;
let return_type = trans_type(ctx, mir_return_type); let return_type = trans_type(ctx, mir_return_type);
ctx.is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind { ctx.is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind {
fields.len() == 0 fields.len() == 0
} else { } else {
false false
}; };
let params = (0..body.arg_count) let params = (0..ctx.body.arg_count)
.map(|i| trans_type(ctx, body.local_decls[(i + 1).into()].ty)) .map(|i| trans_type(ctx, ctx.body.local_decls[(i + 1).into()].ty))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
let mut params_nonzero = params.clone(); let mut params_nonzero = params.clone();
if params_nonzero.is_empty() { if params_nonzero.is_empty() {
@ -56,12 +50,13 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) {
params_nonzero.push(ctx.spirv().type_void()); params_nonzero.push(ctx.spirv().type_void());
} }
let function_type = ctx.spirv().type_function(return_type, params_nonzero); let function_type = ctx.spirv().type_function(return_type, params_nonzero);
let function_id = None; let function_id = ctx
.ctx
.get_func_def(ctx.instance.def_id(), ctx.instance.substs);
let control = FunctionControl::NONE; let control = FunctionControl::NONE;
// TODO: keep track of function IDs
let _ = ctx let _ = ctx
.spirv() .spirv()
.begin_function(return_type, function_id, control, function_type) .begin_function(return_type, Some(function_id), control, function_type)
.unwrap(); .unwrap();
for (i, &param_type) in params.iter().enumerate() { for (i, &param_type) in params.iter().enumerate() {
@ -70,16 +65,18 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) {
} }
} }
fn trans_type<'tcx>(ctx: &mut FnCtx, ty: Ty<'tcx>) -> Word { fn trans_type<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, ty: Ty<'tcx>) -> Word {
match ty.kind { let mono = ctx.monomorphize(&ty);
match mono.kind {
TyKind::Param(param) => panic!("TyKind::Param after monomorphize: {:?}", param),
TyKind::Bool => ctx.spirv().type_bool(), TyKind::Bool => ctx.spirv().type_bool(),
TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv().type_void(), TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv().type_void(),
TyKind::Int(ty) => { TyKind::Int(ty) => {
let size = ty.bit_width().expect("isize not supported yet"); let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val());
ctx.spirv().type_int(size as u32, 1) ctx.spirv().type_int(size as u32, 1)
} }
TyKind::Uint(ty) => { TyKind::Uint(ty) => {
let size = ty.bit_width().expect("isize not supported yet"); let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val());
ctx.spirv().type_int(size as u32, 0) ctx.spirv().type_int(size as u32, 0)
} }
TyKind::Float(ty) => ctx.spirv().type_float(ty.bit_width() as u32), TyKind::Float(ty) => ctx.spirv().type_float(ty.bit_width() as u32),
@ -116,7 +113,7 @@ fn trans_stmt<'tcx>(ctx: &mut FnCtx, stmt: &Statement<'tcx>) {
} }
} }
fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) { fn trans_terminator<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, term: &Terminator<'tcx>) {
match term.kind { match term.kind {
TerminatorKind::Return => { TerminatorKind::Return => {
if ctx.is_void { if ctx.is_void {
@ -138,6 +135,47 @@ fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) {
let inst_id = ctx.get_basic_block(target); let inst_id = ctx.get_basic_block(target);
ctx.spirv().branch(inst_id).unwrap(); ctx.spirv().branch(inst_id).unwrap();
} }
TerminatorKind::Call {
ref func,
ref args,
ref destination,
..
} => {
let destination = destination.expect("Divergent function calls not supported yet");
let fn_ty = ctx.monomorphize(&func.ty(ctx.body, ctx.ctx.tcx));
let fn_sig = ctx.ctx.tcx.normalize_erasing_late_bound_regions(
ParamEnv::reveal_all(),
&fn_ty.fn_sig(ctx.ctx.tcx),
);
let fn_return_type = ParamEnv::reveal_all().and(fn_sig.output()).value;
let result_type = trans_type(ctx, fn_return_type);
let function = match func {
// TODO: Can constant.literal.val not be a ZST?
Operand::Constant(constant) => match constant.literal.ty.kind {
TyKind::FnDef(id, substs) => ctx.ctx.get_func_def(id, substs),
ref thing => panic!("Unknown type in fn call: {:?}", thing),
},
thing => panic!("Dynamic calls not supported yet: {:?}", thing),
};
let arguments = args
.iter()
.map(|arg| trans_operand(ctx, arg))
.collect::<Vec<_>>();
let dest_local = destination.0.local;
if destination.0.projection.len() != 0 {
panic!(
"Place projections aren't supported yet (fn call): {:?}",
destination.0.projection
);
}
let result = ctx
.spirv()
.function_call(result_type, None, function, arguments)
.unwrap();
ctx.locals.def(dest_local, result);
let destination_bb = ctx.get_basic_block(destination.1);
ctx.spirv().branch(destination_bb).unwrap();
}
ref thing => panic!("Unknown terminator: {:?}", thing), ref thing => panic!("Unknown terminator: {:?}", thing),
} }
} }