Split up some functions and structs

This commit is contained in:
khyperia 2020-08-11 15:06:25 +02:00
parent e0de7815ce
commit 3d43764fe1
7 changed files with 285 additions and 205 deletions

View File

@ -11,7 +11,6 @@ crate-type = ["dylib"]
[dependencies]
rspirv = { git = "https://github.com/gfx-rs/rspirv" }
spirv_headers = { git = "https://github.com/gfx-rs/rspirv" }
[dev-dependencies]
pretty_assertions = "0.6"

65
src/ctx.rs Normal file
View File

@ -0,0 +1,65 @@
mod local_tracker;
use crate::spirv_ctx::SpirvContext;
use local_tracker::LocalTracker;
use rspirv::binary::Assemble;
use rspirv::dr::Builder;
use rspirv::spirv::Word;
use rustc_middle::mir::BasicBlock;
use rustc_middle::ty::TyCtxt;
use std::collections::HashMap;
use std::hash::Hash;
pub struct Context<'tcx> {
pub spirv: SpirvContext,
pub tcx: TyCtxt<'tcx>,
pub current_function_is_void: bool,
basic_blocks: ForwardReference<BasicBlock>,
pub locals: LocalTracker,
}
impl<'tcx> Context<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
let spirv = SpirvContext::new();
Self {
spirv,
tcx,
current_function_is_void: false,
basic_blocks: ForwardReference::new(),
locals: LocalTracker::new(),
}
}
pub fn assemble(self) -> Vec<u32> {
self.spirv.builder.module().assemble()
}
pub fn get_basic_block(&mut self, bb: BasicBlock) -> Word {
self.basic_blocks.get(&mut self.spirv.builder, bb)
}
pub fn clear_after_fn(&mut self) {
self.basic_blocks.clear();
self.locals.clear();
}
}
struct ForwardReference<T: Eq + Hash> {
values: HashMap<T, Word>,
}
impl<T: Eq + Hash> ForwardReference<T> {
fn new() -> Self {
Self {
values: HashMap::new(),
}
}
fn get(&mut self, builder: &mut Builder, key: T) -> Word {
*self.values.entry(key).or_insert_with(|| builder.id())
}
fn clear(&mut self) {
self.values.clear();
}
}

36
src/ctx/local_tracker.rs Normal file
View File

@ -0,0 +1,36 @@
use rspirv::spirv::Word;
use rustc_middle::mir::Local;
use std::collections::HashMap;
// TODO: Expand this to SSA-conversion
pub struct LocalTracker {
locals: HashMap<Local, Word>,
}
impl LocalTracker {
pub fn new() -> Self {
Self {
locals: HashMap::new(),
}
}
pub fn def(&mut self, local: Local, expr: Word) {
match self.locals.entry(local) {
std::collections::hash_map::Entry::Occupied(_) => {
println!("Non-SSA code not supported yet")
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(expr);
}
}
}
pub fn get(&self, local: Local) -> Word {
// This probably needs to be fixed, forward-references might be a thing
*self.locals.get(&local).expect("Undefined local")
}
pub fn clear(&mut self) {
self.locals.clear();
}
}

View File

@ -29,6 +29,7 @@ use rustc_target::spec::Target;
use std::any::Any;
use std::path::Path;
mod ctx;
mod spirv_ctx;
mod trans;
@ -87,7 +88,7 @@ impl CodegenBackend for TheBackend {
&[]
};
let mut translator = trans::Translator::new(tcx);
let mut context = ctx::Context::new(tcx);
cgus.iter().for_each(|cgu| {
let cgu = tcx.codegen_unit(cgu.name());
@ -95,13 +96,13 @@ impl CodegenBackend for TheBackend {
for (mono_item, (_linkage, _visibility)) in mono_items {
match mono_item {
MonoItem::Fn(instance) => translator.trans_fn(instance),
MonoItem::Fn(instance) => trans::trans_fn(&mut context, instance),
thing => println!("Unknown MonoItem: {:?}", thing),
}
}
});
let module = translator.assemble();
let module = context.assemble();
Box::new(module)
}

View File

@ -1,5 +1,5 @@
use rspirv::dr::Builder;
use spirv_headers::{AddressingModel, MemoryModel, Word};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Word};
use std::collections::HashMap;
macro_rules! impl_cache {
@ -34,6 +34,9 @@ struct Cache {
impl SpirvContext {
pub fn new() -> Self {
let mut builder = Builder::new();
builder.capability(Capability::Shader);
// Temp hack: Linkage allows us to get away with no OpEntryPoint
builder.capability(Capability::Linkage);
builder.memory_model(AddressingModel::Logical, MemoryModel::GLSL450);
SpirvContext {
builder,

View File

@ -1,207 +1,177 @@
use crate::spirv_ctx::SpirvContext;
use rspirv::binary::Assemble;
use crate::ctx::Context;
use rspirv::spirv::{FunctionControl, Word};
use rustc_middle::mir::{
BasicBlock, Body, Local, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind,
Body, Operand, Rvalue, Statement, StatementKind, Terminator, TerminatorKind,
};
use rustc_middle::ty::{Ty, TyCtxt, TyKind};
use spirv_headers::{FunctionControl, Word};
use std::collections::HashMap;
use rustc_middle::ty::{Ty, TyKind};
pub struct Translator<'tcx> {
spirv: SpirvContext,
tcx: TyCtxt<'tcx>,
basic_blocks: HashMap<BasicBlock, Word>,
locals: HashMap<Local, Word>,
}
pub fn trans_fn<'tcx>(ctx: &mut Context<'tcx>, instance: rustc_middle::ty::Instance<'tcx>) {
{
let mut mir = ::std::io::Cursor::new(Vec::new());
impl<'tcx> Translator<'tcx> {
pub fn new(tcx: TyCtxt<'tcx>) -> Self {
let spirv = SpirvContext::new();
Self {
spirv,
tcx,
basic_blocks: HashMap::new(),
locals: HashMap::new(),
}
}
pub fn assemble(self) -> Vec<u32> {
self.spirv.builder.module().assemble()
}
pub fn get_bb_label(&self, bb: BasicBlock) -> Option<Word> {
self.basic_blocks.get(&bb).cloned()
}
pub fn get_or_gen_bb_label(&mut self, bb: BasicBlock) -> Word {
let builder = &mut self.spirv.builder;
*self.basic_blocks.entry(bb).or_insert_with(|| builder.id())
}
pub fn def_local(&mut self, local: Local, expr: Word) {
match self.locals.entry(local) {
std::collections::hash_map::Entry::Occupied(_) => {
println!("Non-SSA code not supported yet")
}
std::collections::hash_map::Entry::Vacant(entry) => {
entry.insert(expr);
}
}
}
pub fn trans_fn(&mut self, instance: rustc_middle::ty::Instance<'tcx>) {
{
let mut mir = ::std::io::Cursor::new(Vec::new());
crate::rustc_mir::util::write_mir_pretty(self.tcx, Some(instance.def_id()), &mut mir)
.unwrap();
let s = String::from_utf8(mir.into_inner()).unwrap();
println!("{}", s);
}
let mir = self.tcx.instance_mir(instance.def);
self.trans_fn_header(mir);
for (bb, bb_data) in mir.basic_blocks().iter_enumerated() {
let label_id = self.get_bb_label(bb);
let result = self.spirv.builder.begin_block(label_id).unwrap();
if label_id.is_none() {
self.basic_blocks.insert(bb, result);
}
for stmt in &bb_data.statements {
self.trans_stmt(stmt);
}
self.trans_terminator(bb_data.terminator());
}
self.spirv.builder.end_function().unwrap();
self.basic_blocks.clear();
self.locals.clear();
}
// return_type, fn_type
fn trans_fn_header(&mut self, body: &Body<'tcx>) {
let return_type = self.trans_type(body.local_decls[0u32.into()].ty);
let params = (0..body.arg_count)
.map(|i| self.trans_type(body.local_decls[(i + 1).into()].ty))
.collect::<Vec<_>>();
// TODO: this clone is gross
let function_type = self.spirv.type_function(return_type, params.clone());
let function_id = None;
let control = FunctionControl::NONE;
// TODO: keep track of function IDs
let _ = self
.spirv
.builder
.begin_function(return_type, function_id, control, function_type)
crate::rustc_mir::util::write_mir_pretty(ctx.tcx, Some(instance.def_id()), &mut mir)
.unwrap();
for (i, &param_type) in params.iter().enumerate() {
let param_value = self.spirv.builder.function_parameter(param_type).unwrap();
self.def_local((i + 1).into(), param_value);
}
let s = String::from_utf8(mir.into_inner()).unwrap();
println!("{}", s);
}
fn trans_type(&mut self, ty: Ty<'tcx>) -> Word {
match ty.kind {
TyKind::Bool => self.spirv.type_bool(),
TyKind::Tuple(fields) if fields.len() == 0 => self.spirv.type_void(),
TyKind::Int(ty) => {
let size = ty.bit_width().expect("isize not supported yet");
self.spirv.type_int(size as u32, 1)
}
TyKind::Uint(ty) => {
let size = ty.bit_width().expect("isize not supported yet");
self.spirv.type_int(size as u32, 0)
}
TyKind::Float(ty) => self.spirv.type_float(ty.bit_width() as u32),
ref thing => {
println!("Unknown type: {:?}", thing);
self.spirv.builder.id()
}
let mir = ctx.tcx.instance_mir(instance.def);
trans_fn_header(ctx, mir);
for (bb, bb_data) in mir.basic_blocks().iter_enumerated() {
let label_id = ctx.get_basic_block(bb);
ctx.spirv.builder.begin_block(Some(label_id)).unwrap();
for stmt in &bb_data.statements {
trans_stmt(ctx, stmt);
}
trans_terminator(ctx, bb_data.terminator());
}
ctx.spirv.builder.end_function().unwrap();
fn trans_stmt(&mut self, stmt: &Statement<'tcx>) {
match &stmt.kind {
StatementKind::Assign(place_and_rval) => {
// can't destructure this since it's a Box<(place, rvalue)>
let place = place_and_rval.0;
let rvalue = &place_and_rval.1;
ctx.clear_after_fn();
}
if place.projection.len() != 0 {
println!(
"Place projections aren't supported yet (assignment): {:?}",
place.projection
);
}
// return_type, fn_type
fn trans_fn_header<'tcx>(ctx: &mut Context<'tcx>, body: &Body<'tcx>) {
let mir_return_type = body.local_decls[0u32.into()].ty;
let return_type = trans_type(ctx, mir_return_type);
ctx.current_function_is_void = if let TyKind::Tuple(fields) = &mir_return_type.kind {
fields.len() == 0
} else {
false
};
let params = (0..body.arg_count)
.map(|i| trans_type(ctx, body.local_decls[(i + 1).into()].ty))
.collect::<Vec<_>>();
let mut params_nonzero = params.clone();
if params_nonzero.is_empty() {
// spirv says take 1 argument of type void if no arguments
params_nonzero.push(ctx.spirv.type_void());
}
let function_type = ctx.spirv.type_function(return_type, params_nonzero);
let function_id = None;
let control = FunctionControl::NONE;
// TODO: keep track of function IDs
let _ = ctx
.spirv
.builder
.begin_function(return_type, function_id, control, function_type)
.unwrap();
let expr = self.trans_rvalue(rvalue);
self.def_local(place.local, expr);
}
// ignore StorageLive/Dead for now
StatementKind::StorageLive(_local) => (),
StatementKind::StorageDead(_local) => (),
thing => println!("Unknown statement: {:?}", thing),
for (i, &param_type) in params.iter().enumerate() {
let param_value = ctx.spirv.builder.function_parameter(param_type).unwrap();
ctx.locals.def((i + 1).into(), param_value);
}
}
fn trans_type<'tcx>(ctx: &mut Context<'tcx>, ty: Ty<'tcx>) -> Word {
match ty.kind {
TyKind::Bool => ctx.spirv.type_bool(),
TyKind::Tuple(fields) if fields.len() == 0 => ctx.spirv.type_void(),
TyKind::Int(ty) => {
let size = ty.bit_width().expect("isize not supported yet");
ctx.spirv.type_int(size as u32, 1)
}
}
fn trans_terminator(&mut self, term: &Terminator<'tcx>) {
match term.kind {
TerminatorKind::Return => self.spirv.builder.ret().unwrap(),
TerminatorKind::Assert { target, .. } => {
// ignore asserts for now, just do direct goto
let inst_id = self.get_or_gen_bb_label(target);
self.spirv.builder.branch(inst_id).unwrap();
}
TerminatorKind::Goto { target } => {
let inst_id = self.get_or_gen_bb_label(target);
self.spirv.builder.branch(inst_id).unwrap();
}
ref thing => println!("Unknown terminator: {:?}", thing),
TyKind::Uint(ty) => {
let size = ty.bit_width().expect("isize not supported yet");
ctx.spirv.type_int(size as u32, 0)
}
}
fn trans_rvalue(&mut self, expr: &Rvalue<'tcx>) -> Word {
match expr {
Rvalue::Use(operand) => self.trans_operand(operand),
Rvalue::BinaryOp(_op, left, right) => {
// TODO: properly implement
let left = self.trans_operand(left);
let right = self.trans_operand(right);
let result_type = self.spirv.type_int(32, 1);
self.spirv
.builder
.i_add(result_type, None, left, right)
.unwrap()
}
thing => {
println!("Unknown rvalue: {:?}", thing);
self.spirv.builder.id()
}
}
}
fn trans_operand(&mut self, operand: &Operand<'tcx>) -> Word {
match operand {
Operand::Copy(place) | Operand::Move(place) => {
if place.projection.len() != 0 {
println!(
"Place projections aren't supported yet (operand): {:?}",
place.projection
);
}
// This probably needs to be fixed, forward-references might be a thing
*self.locals.get(&place.local).expect("Undefined local")
}
Operand::Constant(constant) => {
println!("Unimplemented Operand::Constant: {:?}", constant);
self.spirv.builder.id()
}
TyKind::Float(ty) => ctx.spirv.type_float(ty.bit_width() as u32),
ref thing => {
println!("Unknown type: {:?}", thing);
ctx.spirv.builder.id()
}
}
}
fn trans_stmt<'tcx>(ctx: &mut Context<'tcx>, stmt: &Statement<'tcx>) {
match &stmt.kind {
StatementKind::Assign(place_and_rval) => {
// can't destructure this since it's a Box<(place, rvalue)>
let place = place_and_rval.0;
let rvalue = &place_and_rval.1;
if place.projection.len() != 0 {
println!(
"Place projections aren't supported yet (assignment): {:?}",
place.projection
);
}
let expr = trans_rvalue(ctx, rvalue);
ctx.locals.def(place.local, expr);
}
// ignore StorageLive/Dead for now
StatementKind::StorageLive(_local) => (),
StatementKind::StorageDead(_local) => (),
thing => println!("Unknown statement: {:?}", thing),
}
}
fn trans_terminator<'tcx>(ctx: &mut Context<'tcx>, term: &Terminator<'tcx>) {
match term.kind {
TerminatorKind::Return => {
if ctx.current_function_is_void {
ctx.spirv.builder.ret().unwrap();
} else {
// local 0 is return value
ctx.spirv
.builder
.ret_value(ctx.locals.get(0u32.into()))
.unwrap();
}
}
TerminatorKind::Assert { target, .. } => {
// ignore asserts for now, just do direct goto
let inst_id = ctx.get_basic_block(target);
ctx.spirv.builder.branch(inst_id).unwrap();
}
TerminatorKind::Goto { target } => {
let inst_id = ctx.get_basic_block(target);
ctx.spirv.builder.branch(inst_id).unwrap();
}
ref thing => println!("Unknown terminator: {:?}", thing),
}
}
fn trans_rvalue<'tcx>(ctx: &mut Context<'tcx>, expr: &Rvalue<'tcx>) -> Word {
match expr {
Rvalue::Use(operand) => trans_operand(ctx, operand),
Rvalue::BinaryOp(_op, left, right) => {
// TODO: properly implement
let left = trans_operand(ctx, left);
let right = trans_operand(ctx, right);
let result_type = ctx.spirv.type_int(32, 0);
ctx.spirv
.builder
.i_add(result_type, None, left, right)
.unwrap()
}
thing => {
println!("Unknown rvalue: {:?}", thing);
ctx.spirv.builder.id()
}
}
}
fn trans_operand<'tcx>(ctx: &mut Context<'tcx>, operand: &Operand<'tcx>) -> Word {
match operand {
Operand::Copy(place) | Operand::Move(place) => {
if place.projection.len() != 0 {
println!(
"Place projections aren't supported yet (operand): {:?}",
place.projection
);
}
ctx.locals.get(place.local)
}
Operand::Constant(constant) => {
println!("Unimplemented Operand::Constant: {:?}", constant);
ctx.spirv.builder.id()
}
}
}

View File

@ -102,10 +102,15 @@ fn go(code: &str, expected: &str) {
.expect("failed to parse spirv")
.disassemble();
assert_eq!(PrettyString(&output_disas), PrettyString(expected));
match Command::new("spirv-val").arg(&output).status() {
Ok(status) => assert!(status.success()),
Err(err) => eprint!("spirv-val tool not found, ignoring test: {}", err),
}
remove_file(input).expect("Failed to delete input file");
remove_file(output).expect("Failed to delete output file");
assert_eq!(PrettyString(&output_disas), PrettyString(expected));
}
#[test]
@ -118,24 +123,25 @@ pub fn add_numbers(x: u32, y: u32) -> u32 {
r"; SPIR-V
; Version: 1.5
; Generator: rspirv
; Bound: 14
; Bound: 13
OpCapability Shader
OpCapability Linkage
OpMemoryModel Logical GLSL450
%1 = OpTypeInt 32 0
%2 = OpTypeFunction %1 %1 %1
%7 = OpTypeInt 32 1
%3 = OpFunction %1 None %2
%4 = OpFunctionParameter %1
%5 = OpFunctionParameter %1
%6 = OpLabel
%8 = OpIAdd %7 %4 %5
OpReturn
%7 = OpIAdd %1 %4 %5
OpReturnValue %7
OpFunctionEnd
%9 = OpFunction %1 None %2
%8 = OpFunction %1 None %2
%9 = OpFunctionParameter %1
%10 = OpFunctionParameter %1
%11 = OpFunctionParameter %1
%12 = OpLabel
%13 = OpIAdd %7 %10 %11
OpReturn
%11 = OpLabel
%12 = OpIAdd %1 %9 %10
OpReturnValue %12
OpFunctionEnd",
);
}