diff --git a/src/ctx.rs b/src/ctx.rs index e40669c14f..35e0e72798 100644 --- a/src/ctx.rs +++ b/src/ctx.rs @@ -4,15 +4,35 @@ use local_tracker::LocalTracker; use rspirv::binary::Assemble; use rspirv::dr::Builder; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, StorageClass, Word}; -use rustc_middle::mir::BasicBlock; -use rustc_middle::ty::{subst::SubstsRef, TyCtxt}; +use rustc_middle::mir::{BasicBlock, Body}; +use rustc_middle::ty::subst::SubstsRef; +use rustc_middle::ty::{fold::TypeFoldable, Instance, ParamEnv, TyCtxt}; +use rustc_span::def_id::DefId; use std::collections::HashMap; use std::hash::Hash; +#[derive(Copy, Clone)] +pub enum PointerSize { + P32, + P64, +} + +impl PointerSize { + pub fn val(self) -> u64 { + match self { + PointerSize::P32 => 32, + PointerSize::P64 => 64, + } + } +} + +/// Context is the "bag of random variables" used for global state - i.e. variables that live the whole compilation. pub struct Context<'tcx> { pub tcx: TyCtxt<'tcx>, pub spirv: Builder, spirv_helper: SpirvHelper, + pub pointer_size: PointerSize, + func_defs: ForwardReference<(DefId, SubstsRef<'tcx>)>, } impl<'tcx> Context<'tcx> { @@ -22,54 +42,92 @@ impl<'tcx> Context<'tcx> { // Temp hack: Linkage allows us to get away with no OpEntryPoint spirv.capability(Capability::Linkage); spirv.memory_model(AddressingModel::Logical, MemoryModel::GLSL450); + let _ = PointerSize::P32; // just mark as used Self { tcx, spirv, spirv_helper: SpirvHelper::new(), + pointer_size: PointerSize::P64, + func_defs: ForwardReference::new(), } } pub fn assemble(self) -> Vec { self.spirv.module().assemble() } + + pub fn get_func_def(&mut self, id: DefId, substs: SubstsRef<'tcx>) -> Word { + self.func_defs.get(&mut self.spirv, (id, substs)) + } } +/// FnCtx is the "bag of random variables" used for state when compiling a particular function - i.e. variables that are +/// specific to compiling a particular function. Note it carries a reference to the global Context, so those variables +/// can be accessed too. pub struct FnCtx<'ctx, 'tcx> { + /// The global state - note this field is `pub`, so if needed, you can go through this field, instead of the helper + /// functions, to satisfy the borrow checker (since using the helper functions borrow this whole struct). pub ctx: &'ctx mut Context<'tcx>, - pub substs: SubstsRef<'tcx>, + pub instance: Instance<'tcx>, + pub body: &'tcx Body<'tcx>, pub is_void: bool, basic_blocks: ForwardReference, pub locals: LocalTracker, } impl<'ctx, 'tcx> FnCtx<'ctx, 'tcx> { - pub fn new(ctx: &'ctx mut Context<'tcx>, substs: SubstsRef<'tcx>) -> Self { + pub fn new( + ctx: &'ctx mut Context<'tcx>, + instance: Instance<'tcx>, + body: &'tcx Body<'tcx>, + ) -> Self { Self { ctx, - substs, + instance, + body, is_void: false, basic_blocks: ForwardReference::new(), locals: LocalTracker::new(), } } + /* pub fn tcx(&mut self) -> &mut TyCtxt<'tcx> { &mut self.ctx.tcx } + */ pub fn spirv(&mut self) -> &mut Builder { &mut self.ctx.spirv } + /// Gets the spir-v label for a basic block, or generates one if it doesn't exist. pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word { self.basic_blocks.get(&mut self.ctx.spirv, bb) } + /// rspirv doesn't cache type_pointer, so cache it ourselves here. pub fn type_pointer(&mut self, pointee_type: Word) -> Word { self.ctx .spirv_helper .type_pointer(&mut self.ctx.spirv, pointee_type) } + + // copied from rustc_codegen_cranelift + pub(crate) fn monomorphize(&self, value: &T) -> T + where + T: TypeFoldable<'tcx> + Copy, + { + if let Some(substs) = self.instance.substs_for_mir_body() { + self.ctx + .tcx + .subst_and_normalize_erasing_regions(substs, ParamEnv::reveal_all(), value) + } else { + self.ctx + .tcx + .normalize_erasing_regions(ParamEnv::reveal_all(), *value) + } + } } struct ForwardReference { diff --git a/src/ctx/local_tracker.rs b/src/ctx/local_tracker.rs index 5989f311f7..7fcf5b0706 100644 --- a/src/ctx/local_tracker.rs +++ b/src/ctx/local_tracker.rs @@ -26,7 +26,8 @@ impl LocalTracker { } pub fn get(&self, local: Local) -> Word { - // This probably needs to be fixed, forward-references might be a thing + // This probably needs to be fixed, forward-references might be a thing. + // (MIR probably declares all locals up front, so use that?) *self.locals.get(&local).expect("Undefined local") } } diff --git a/src/lib.rs b/src/lib.rs index e6ad87f558..148fea7266 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -114,6 +114,7 @@ impl CodegenBackend for TheBackend { _sess: &Session, _dep_graph: &DepGraph, ) -> Result, ErrorReported> { + // I think this function is if `codegen_crate` spawns threads, this is supposed to join those threads? Ok(ongoing_codegen) // let crate_name = ongoing_codegen // .downcast::() diff --git a/src/trans.rs b/src/trans.rs index 9a0a2ba6e2..0b147d66cc 100644 --- a/src/trans.rs +++ b/src/trans.rs @@ -1,14 +1,9 @@ use crate::ctx::{Context, FnCtx}; use rspirv::spirv::{FunctionControl, Word}; -use rustc_middle::mir::{ - Body, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, -}; -use rustc_middle::ty::{Ty, TyKind}; +use rustc_middle::mir::{Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind}; +use rustc_middle::ty::{Instance, ParamEnv, Ty, TyKind}; -pub fn trans_fn<'ctx, 'tcx>( - ctx: &'ctx mut Context<'tcx>, - instance: rustc_middle::ty::Instance<'tcx>, -) { +pub fn trans_fn<'ctx, 'tcx>(ctx: &'ctx mut Context<'tcx>, instance: Instance<'tcx>) { { let mut mir = ::std::io::Cursor::new(Vec::new()); @@ -20,13 +15,13 @@ pub fn trans_fn<'ctx, 'tcx>( println!("{}", s); } - let mir = ctx.tcx.optimized_mir(instance.def_id()); + let body = ctx.tcx.optimized_mir(instance.def_id()); - let mut fnctx = FnCtx::new(ctx, instance.substs); + let mut fnctx = FnCtx::new(ctx, instance, body); - trans_fn_header(&mut fnctx, mir); + trans_fn_header(&mut fnctx); - for (bb, bb_data) in mir.basic_blocks().iter_enumerated() { + for (bb, bb_data) in fnctx.body.basic_blocks().iter_enumerated() { let label_id = fnctx.get_basic_block(bb); fnctx.spirv().begin_block(Some(label_id)).unwrap(); @@ -38,17 +33,16 @@ pub fn trans_fn<'ctx, 'tcx>( fnctx.spirv().end_function().unwrap(); } -// return_type, fn_type -fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) { - let mir_return_type = body.local_decls[0u32.into()].ty; +fn trans_fn_header<'tcx>(ctx: &mut FnCtx<'_, 'tcx>) { + let mir_return_type = ctx.body.local_decls[0u32.into()].ty; let return_type = trans_type(ctx, mir_return_type); ctx.is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind { fields.len() == 0 } else { false }; - let params = (0..body.arg_count) - .map(|i| trans_type(ctx, body.local_decls[(i + 1).into()].ty)) + let params = (0..ctx.body.arg_count) + .map(|i| trans_type(ctx, ctx.body.local_decls[(i + 1).into()].ty)) .collect::>(); let mut params_nonzero = params.clone(); if params_nonzero.is_empty() { @@ -56,12 +50,13 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) { params_nonzero.push(ctx.spirv().type_void()); } let function_type = ctx.spirv().type_function(return_type, params_nonzero); - let function_id = None; + let function_id = ctx + .ctx + .get_func_def(ctx.instance.def_id(), ctx.instance.substs); let control = FunctionControl::NONE; - // TODO: keep track of function IDs let _ = ctx .spirv() - .begin_function(return_type, function_id, control, function_type) + .begin_function(return_type, Some(function_id), control, function_type) .unwrap(); for (i, ¶m_type) in params.iter().enumerate() { @@ -70,16 +65,18 @@ fn trans_fn_header<'tcx>(ctx: &mut FnCtx, body: &Body<'tcx>) { } } -fn trans_type<'tcx>(ctx: &mut FnCtx, ty: Ty<'tcx>) -> Word { - match ty.kind { +fn trans_type<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, ty: Ty<'tcx>) -> Word { + let mono = ctx.monomorphize(&ty); + match mono.kind { + TyKind::Param(param) => panic!("TyKind::Param after monomorphize: {:?}", param), TyKind::Bool => ctx.spirv().type_bool(), TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv().type_void(), TyKind::Int(ty) => { - let size = ty.bit_width().expect("isize not supported yet"); + let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val()); ctx.spirv().type_int(size as u32, 1) } TyKind::Uint(ty) => { - let size = ty.bit_width().expect("isize not supported yet"); + let size = ty.bit_width().unwrap_or_else(|| ctx.ctx.pointer_size.val()); ctx.spirv().type_int(size as u32, 0) } TyKind::Float(ty) => ctx.spirv().type_float(ty.bit_width() as u32), @@ -116,7 +113,7 @@ fn trans_stmt<'tcx>(ctx: &mut FnCtx, stmt: &Statement<'tcx>) { } } -fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) { +fn trans_terminator<'tcx>(ctx: &mut FnCtx<'_, 'tcx>, term: &Terminator<'tcx>) { match term.kind { TerminatorKind::Return => { if ctx.is_void { @@ -138,6 +135,47 @@ fn trans_terminator<'tcx>(ctx: &mut FnCtx, term: &Terminator<'tcx>) { let inst_id = ctx.get_basic_block(target); ctx.spirv().branch(inst_id).unwrap(); } + TerminatorKind::Call { + ref func, + ref args, + ref destination, + .. + } => { + let destination = destination.expect("Divergent function calls not supported yet"); + let fn_ty = ctx.monomorphize(&func.ty(ctx.body, ctx.ctx.tcx)); + let fn_sig = ctx.ctx.tcx.normalize_erasing_late_bound_regions( + ParamEnv::reveal_all(), + &fn_ty.fn_sig(ctx.ctx.tcx), + ); + let fn_return_type = ParamEnv::reveal_all().and(fn_sig.output()).value; + let result_type = trans_type(ctx, fn_return_type); + let function = match func { + // TODO: Can constant.literal.val not be a ZST? + Operand::Constant(constant) => match constant.literal.ty.kind { + TyKind::FnDef(id, substs) => ctx.ctx.get_func_def(id, substs), + ref thing => panic!("Unknown type in fn call: {:?}", thing), + }, + thing => panic!("Dynamic calls not supported yet: {:?}", thing), + }; + let arguments = args + .iter() + .map(|arg| trans_operand(ctx, arg)) + .collect::>(); + let dest_local = destination.0.local; + if destination.0.projection.len() != 0 { + panic!( + "Place projections aren't supported yet (fn call): {:?}", + destination.0.projection + ); + } + let result = ctx + .spirv() + .function_call(result_type, None, function, arguments) + .unwrap(); + ctx.locals.def(dest_local, result); + let destination_bb = ctx.get_basic_block(destination.1); + ctx.spirv().branch(destination_bb).unwrap(); + } ref thing => panic!("Unknown terminator: {:?}", thing), } }