diff --git a/Cargo.toml b/Cargo.toml index 4569b7bf9b..e417082ef3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,6 @@ crate-type = ["dylib"] [dependencies] rspirv = { git = "https://github.com/gfx-rs/rspirv" } -spirv_headers = { git = "https://github.com/gfx-rs/rspirv" } [dev-dependencies] pretty_assertions = "0.6" diff --git a/src/ctx.rs b/src/ctx.rs new file mode 100644 index 0000000000..2a893b39ad --- /dev/null +++ b/src/ctx.rs @@ -0,0 +1,65 @@ +mod local_tracker; + +use crate::spirv_ctx::SpirvContext; +use local_tracker::LocalTracker; +use rspirv::binary::Assemble; +use rspirv::dr::Builder; +use rspirv::spirv::Word; +use rustc_middle::mir::BasicBlock; +use rustc_middle::ty::TyCtxt; +use std::collections::HashMap; +use std::hash::Hash; + +pub struct Context<'tcx> { + pub spirv: SpirvContext, + pub tcx: TyCtxt<'tcx>, + pub current_function_is_void: bool, + basic_blocks: ForwardReference, + pub locals: LocalTracker, +} + +impl<'tcx> Context<'tcx> { + pub fn new(tcx: TyCtxt<'tcx>) -> Self { + let spirv = SpirvContext::new(); + Self { + spirv, + tcx, + current_function_is_void: false, + basic_blocks: ForwardReference::new(), + locals: LocalTracker::new(), + } + } + + pub fn assemble(self) -> Vec { + self.spirv.builder.module().assemble() + } + + pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word { + self.basic_blocks.get(&mut self.spirv.builder, bb) + } + + pub fn clear_after_fn(&mut self) { + self.basic_blocks.clear(); + self.locals.clear(); + } +} + +struct ForwardReference { + values: HashMap, +} + +impl ForwardReference { + fn new() -> Self { + Self { + values: HashMap::new(), + } + } + + fn get(&mut self, builder: &mut Builder, key: T) -> Word { + *self.values.entry(key).or_insert_with(|| builder.id()) + } + + fn clear(&mut self) { + self.values.clear(); + } +} diff --git a/src/ctx/local_tracker.rs b/src/ctx/local_tracker.rs new file mode 100644 index 0000000000..d2c9ccc225 --- /dev/null +++ b/src/ctx/local_tracker.rs @@ -0,0 +1,36 @@ +use rspirv::spirv::Word; +use rustc_middle::mir::Local; +use std::collections::HashMap; + +// TODO: Expand this to SSA-conversion +pub struct LocalTracker { + locals: HashMap, +} + +impl LocalTracker { + pub fn new() -> Self { + Self { + locals: HashMap::new(), + } + } + + pub fn def(&mut self, local: Local, expr: Word) { + match self.locals.entry(local) { + std::collections::hash_map::Entry::Occupied(_) => { + println!("Non-SSA code not supported yet") + } + std::collections::hash_map::Entry::Vacant(entry) => { + entry.insert(expr); + } + } + } + + pub fn get(&self, local: Local) -> Word { + // This probably needs to be fixed, forward-references might be a thing + *self.locals.get(&local).expect("Undefined local") + } + + pub fn clear(&mut self) { + self.locals.clear(); + } +} diff --git a/src/lib.rs b/src/lib.rs index 035d3b5186..d8cbb0aac7 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -29,6 +29,7 @@ use rustc_target::spec::Target; use std::any::Any; use std::path::Path; +mod ctx; mod spirv_ctx; mod trans; @@ -87,7 +88,7 @@ impl CodegenBackend for TheBackend { &[] }; - let mut translator = trans::Translator::new(tcx); + let mut context = ctx::Context::new(tcx); cgus.iter().for_each(|cgu| { let cgu = tcx.codegen_unit(cgu.name()); @@ -95,13 +96,13 @@ impl CodegenBackend for TheBackend { for (mono_item, (_linkage, _visibility)) in mono_items { match mono_item { - MonoItem::Fn(instance) => translator.trans_fn(instance), + MonoItem::Fn(instance) => trans::trans_fn(&mut context, instance), thing => println!("Unknown MonoItem: {:?}", thing), } } }); - let module = translator.assemble(); + let module = context.assemble(); Box::new(module) } diff --git a/src/spirv_ctx.rs b/src/spirv_ctx.rs index b08d174e04..32e41b0aac 100644 --- a/src/spirv_ctx.rs +++ b/src/spirv_ctx.rs @@ -1,5 +1,5 @@ use rspirv::dr::Builder; -use spirv_headers::{AddressingModel, MemoryModel, Word}; +use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Word}; use std::collections::HashMap; macro_rules! impl_cache { @@ -34,6 +34,9 @@ struct Cache { impl SpirvContext { pub fn new() -> Self { let mut builder = Builder::new(); + builder.capability(Capability::Shader); + // Temp hack: Linkage allows us to get away with no OpEntryPoint + builder.capability(Capability::Linkage); builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450); SpirvContext { builder, diff --git a/src/trans.rs b/src/trans.rs index f72d6653c8..2bf77eba23 100644 --- a/src/trans.rs +++ b/src/trans.rs @@ -1,207 +1,177 @@ -use crate::spirv_ctx::SpirvContext; -use rspirv::binary::Assemble; +use crate::ctx::Context; +use rspirv::spirv::{FunctionControl, Word}; use rustc_middle::mir::{ - BasicBlock, Body, Local, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, + Body, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind, }; -use rustc_middle::ty::{Ty, TyCtxt, TyKind}; -use spirv_headers::{FunctionControl, Word}; -use std::collections::HashMap; +use rustc_middle::ty::{Ty, TyKind}; -pub struct Translator<'tcx> { - spirv: SpirvContext, - tcx: TyCtxt<'tcx>, - basic_blocks: HashMap, - locals: HashMap, -} +pub fn trans_fn<'tcx>(ctx: &mut Context<'tcx>, instance: rustc_middle::ty::Instance<'tcx>) { + { + let mut mir = ::std::io::Cursor::new(Vec::new()); -impl<'tcx> Translator<'tcx> { - pub fn new(tcx: TyCtxt<'tcx>) -> Self { - let spirv = SpirvContext::new(); - Self { - spirv, - tcx, - basic_blocks: HashMap::new(), - locals: HashMap::new(), - } - } - - pub fn assemble(self) -> Vec { - self.spirv.builder.module().assemble() - } - - pub fn get_bb_label(&self, bb: BasicBlock) -> Option { - self.basic_blocks.get(&bb).cloned() - } - - pub fn get_or_gen_bb_label(&mut self, bb: BasicBlock) -> Word { - let builder = &mut self.spirv.builder; - *self.basic_blocks.entry(bb).or_insert_with(|| builder.id()) - } - - pub fn def_local(&mut self, local: Local, expr: Word) { - match self.locals.entry(local) { - std::collections::hash_map::Entry::Occupied(_) => { - println!("Non-SSA code not supported yet") - } - std::collections::hash_map::Entry::Vacant(entry) => { - entry.insert(expr); - } - } - } - - pub fn trans_fn(&mut self, instance: rustc_middle::ty::Instance<'tcx>) { - { - let mut mir = ::std::io::Cursor::new(Vec::new()); - - crate::rustc_mir::util::write_mir_pretty(self.tcx, Some(instance.def_id()), &mut mir) - .unwrap(); - - let s = String::from_utf8(mir.into_inner()).unwrap(); - - println!("{}", s); - } - - let mir = self.tcx.instance_mir(instance.def); - - self.trans_fn_header(mir); - - for (bb, bb_data) in mir.basic_blocks().iter_enumerated() { - let label_id = self.get_bb_label(bb); - let result = self.spirv.builder.begin_block(label_id).unwrap(); - if label_id.is_none() { - self.basic_blocks.insert(bb, result); - } - - for stmt in &bb_data.statements { - self.trans_stmt(stmt); - } - self.trans_terminator(bb_data.terminator()); - } - self.spirv.builder.end_function().unwrap(); - - self.basic_blocks.clear(); - self.locals.clear(); - } - - // return_type, fn_type - fn trans_fn_header(&mut self, body: &Body<'tcx>) { - let return_type = self.trans_type(body.local_decls[0u32.into()].ty); - let params = (0..body.arg_count) - .map(|i| self.trans_type(body.local_decls[(i + 1).into()].ty)) - .collect::>(); - // TODO: this clone is gross - let function_type = self.spirv.type_function(return_type, params.clone()); - let function_id = None; - let control = FunctionControl::NONE; - // TODO: keep track of function IDs - let _ = self - .spirv - .builder - .begin_function(return_type, function_id, control, function_type) + crate::rustc_mir::util::write_mir_pretty(ctx.tcx, Some(instance.def_id()), &mut mir) .unwrap(); - for (i, ¶m_type) in params.iter().enumerate() { - let param_value = self.spirv.builder.function_parameter(param_type).unwrap(); - self.def_local((i + 1).into(), param_value); - } + let s = String::from_utf8(mir.into_inner()).unwrap(); + + println!("{}", s); } - fn trans_type(&mut self, ty: Ty<'tcx>) -> Word { - match ty.kind { - TyKind::Bool => self.spirv.type_bool(), - TyKind::Tuple(fields) if fields.len() == 0 => self.spirv.type_void(), - TyKind::Int(ty) => { - let size = ty.bit_width().expect("isize not supported yet"); - self.spirv.type_int(size as u32, 1) - } - TyKind::Uint(ty) => { - let size = ty.bit_width().expect("isize not supported yet"); - self.spirv.type_int(size as u32, 0) - } - TyKind::Float(ty) => self.spirv.type_float(ty.bit_width() as u32), - ref thing => { - println!("Unknown type: {:?}", thing); - self.spirv.builder.id() - } + let mir = ctx.tcx.instance_mir(instance.def); + + trans_fn_header(ctx, mir); + + for (bb, bb_data) in mir.basic_blocks().iter_enumerated() { + let label_id = ctx.get_basic_block(bb); + ctx.spirv.builder.begin_block(Some(label_id)).unwrap(); + + for stmt in &bb_data.statements { + trans_stmt(ctx, stmt); } + trans_terminator(ctx, bb_data.terminator()); } + ctx.spirv.builder.end_function().unwrap(); - fn trans_stmt(&mut self, stmt: &Statement<'tcx>) { - match &stmt.kind { - StatementKind::Assign(place_and_rval) => { - // can't destructure this since it's a Box<(place, rvalue)> - let place = place_and_rval.0; - let rvalue = &place_and_rval.1; + ctx.clear_after_fn(); +} - if place.projection.len() != 0 { - println!( - "Place projections aren't supported yet (assignment): {:?}", - place.projection - ); - } +// return_type, fn_type +fn trans_fn_header<'tcx>(ctx: &mut Context<'tcx>, body: &Body<'tcx>) { + let mir_return_type = body.local_decls[0u32.into()].ty; + let return_type = trans_type(ctx, mir_return_type); + ctx.current_function_is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind { + fields.len() == 0 + } else { + false + }; + let params = (0..body.arg_count) + .map(|i| trans_type(ctx, body.local_decls[(i + 1).into()].ty)) + .collect::>(); + let mut params_nonzero = params.clone(); + if params_nonzero.is_empty() { + // spirv says take 1 argument of type void if no arguments + params_nonzero.push(ctx.spirv.type_void()); + } + let function_type = ctx.spirv.type_function(return_type, params_nonzero); + let function_id = None; + let control = FunctionControl::NONE; + // TODO: keep track of function IDs + let _ = ctx + .spirv + .builder + .begin_function(return_type, function_id, control, function_type) + .unwrap(); - let expr = self.trans_rvalue(rvalue); - self.def_local(place.local, expr); - } - // ignore StorageLive/Dead for now - StatementKind::StorageLive(_local) => (), - StatementKind::StorageDead(_local) => (), - thing => println!("Unknown statement: {:?}", thing), + for (i, ¶m_type) in params.iter().enumerate() { + let param_value = ctx.spirv.builder.function_parameter(param_type).unwrap(); + ctx.locals.def((i + 1).into(), param_value); + } +} + +fn trans_type<'tcx>(ctx: &mut Context<'tcx>, ty: Ty<'tcx>) -> Word { + match ty.kind { + TyKind::Bool => ctx.spirv.type_bool(), + TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv.type_void(), + TyKind::Int(ty) => { + let size = ty.bit_width().expect("isize not supported yet"); + ctx.spirv.type_int(size as u32, 1) } - } - - fn trans_terminator(&mut self, term: &Terminator<'tcx>) { - match term.kind { - TerminatorKind::Return => self.spirv.builder.ret().unwrap(), - TerminatorKind::Assert { target, .. } => { - // ignore asserts for now, just do direct goto - let inst_id = self.get_or_gen_bb_label(target); - self.spirv.builder.branch(inst_id).unwrap(); - } - TerminatorKind::Goto { target } => { - let inst_id = self.get_or_gen_bb_label(target); - self.spirv.builder.branch(inst_id).unwrap(); - } - ref thing => println!("Unknown terminator: {:?}", thing), + TyKind::Uint(ty) => { + let size = ty.bit_width().expect("isize not supported yet"); + ctx.spirv.type_int(size as u32, 0) } - } - - fn trans_rvalue(&mut self, expr: &Rvalue<'tcx>) -> Word { - match expr { - Rvalue::Use(operand) => self.trans_operand(operand), - Rvalue::BinaryOp(_op, left, right) => { - // TODO: properly implement - let left = self.trans_operand(left); - let right = self.trans_operand(right); - let result_type = self.spirv.type_int(32, 1); - self.spirv - .builder - .i_add(result_type, None, left, right) - .unwrap() - } - thing => { - println!("Unknown rvalue: {:?}", thing); - self.spirv.builder.id() - } - } - } - - fn trans_operand(&mut self, operand: &Operand<'tcx>) -> Word { - match operand { - Operand::Copy(place) | Operand::Move(place) => { - if place.projection.len() != 0 { - println!( - "Place projections aren't supported yet (operand): {:?}", - place.projection - ); - } - // This probably needs to be fixed, forward-references might be a thing - *self.locals.get(&place.local).expect("Undefined local") - } - Operand::Constant(constant) => { - println!("Unimplemented Operand::Constant: {:?}", constant); - self.spirv.builder.id() - } + TyKind::Float(ty) => ctx.spirv.type_float(ty.bit_width() as u32), + ref thing => { + println!("Unknown type: {:?}", thing); + ctx.spirv.builder.id() + } + } +} + +fn trans_stmt<'tcx>(ctx: &mut Context<'tcx>, stmt: &Statement<'tcx>) { + match &stmt.kind { + StatementKind::Assign(place_and_rval) => { + // can't destructure this since it's a Box<(place, rvalue)> + let place = place_and_rval.0; + let rvalue = &place_and_rval.1; + + if place.projection.len() != 0 { + println!( + "Place projections aren't supported yet (assignment): {:?}", + place.projection + ); + } + + let expr = trans_rvalue(ctx, rvalue); + ctx.locals.def(place.local, expr); + } + // ignore StorageLive/Dead for now + StatementKind::StorageLive(_local) => (), + StatementKind::StorageDead(_local) => (), + thing => println!("Unknown statement: {:?}", thing), + } +} + +fn trans_terminator<'tcx>(ctx: &mut Context<'tcx>, term: &Terminator<'tcx>) { + match term.kind { + TerminatorKind::Return => { + if ctx.current_function_is_void { + ctx.spirv.builder.ret().unwrap(); + } else { + // local 0 is return value + ctx.spirv + .builder + .ret_value(ctx.locals.get(0u32.into())) + .unwrap(); + } + } + TerminatorKind::Assert { target, .. } => { + // ignore asserts for now, just do direct goto + let inst_id = ctx.get_basic_block(target); + ctx.spirv.builder.branch(inst_id).unwrap(); + } + TerminatorKind::Goto { target } => { + let inst_id = ctx.get_basic_block(target); + ctx.spirv.builder.branch(inst_id).unwrap(); + } + ref thing => println!("Unknown terminator: {:?}", thing), + } +} + +fn trans_rvalue<'tcx>(ctx: &mut Context<'tcx>, expr: &Rvalue<'tcx>) -> Word { + match expr { + Rvalue::Use(operand) => trans_operand(ctx, operand), + Rvalue::BinaryOp(_op, left, right) => { + // TODO: properly implement + let left = trans_operand(ctx, left); + let right = trans_operand(ctx, right); + let result_type = ctx.spirv.type_int(32, 0); + ctx.spirv + .builder + .i_add(result_type, None, left, right) + .unwrap() + } + thing => { + println!("Unknown rvalue: {:?}", thing); + ctx.spirv.builder.id() + } + } +} + +fn trans_operand<'tcx>(ctx: &mut Context<'tcx>, operand: &Operand<'tcx>) -> Word { + match operand { + Operand::Copy(place) | Operand::Move(place) => { + if place.projection.len() != 0 { + println!( + "Place projections aren't supported yet (operand): {:?}", + place.projection + ); + } + ctx.locals.get(place.local) + } + Operand::Constant(constant) => { + println!("Unimplemented Operand::Constant: {:?}", constant); + ctx.spirv.builder.id() } } } diff --git a/tests/lib.rs b/tests/lib.rs index dc8f251bf7..9bcbd305c3 100644 --- a/tests/lib.rs +++ b/tests/lib.rs @@ -102,10 +102,15 @@ fn go(code: &str, expected: &str) { .expect("failed to parse spirv") .disassemble(); + assert_eq!(PrettyString(&output_disas), PrettyString(expected)); + + match Command::new("spirv-val").arg(&output).status() { + Ok(status) => assert!(status.success()), + Err(err) => eprint!("spirv-val tool not found, ignoring test: {}", err), + } + remove_file(input).expect("Failed to delete input file"); remove_file(output).expect("Failed to delete output file"); - - assert_eq!(PrettyString(&output_disas), PrettyString(expected)); } #[test] @@ -118,24 +123,25 @@ pub fn add_numbers(x: u32, y: u32) -> u32 { r"; SPIR-V ; Version: 1.5 ; Generator: rspirv -; Bound: 14 +; Bound: 13 +OpCapability Shader +OpCapability Linkage OpMemoryModel Logical GLSL450 %1 = OpTypeInt 32 0 %2 = OpTypeFunction %1 %1 %1 -%7 = OpTypeInt 32 1 %3 = OpFunction %1 None %2 %4 = OpFunctionParameter %1 %5 = OpFunctionParameter %1 %6 = OpLabel -%8 = OpIAdd %7 %4 %5 -OpReturn +%7 = OpIAdd %1 %4 %5 +OpReturnValue %7 OpFunctionEnd -%9 = OpFunction %1 None %2 +%8 = OpFunction %1 None %2 +%9 = OpFunctionParameter %1 %10 = OpFunctionParameter %1 -%11 = OpFunctionParameter %1 -%12 = OpLabel -%13 = OpIAdd %7 %10 %11 -OpReturn +%11 = OpLabel +%12 = OpIAdd %1 %9 %10 +OpReturnValue %12 OpFunctionEnd", ); }