Various printing helpers and ergonomic refactorings

This commit is contained in:
khyperia 2020-08-31 11:05:08 +02:00
parent 65e80668d7
commit c64f433135
6 changed files with 115 additions and 44 deletions

View File

@ -65,7 +65,7 @@ fn memset_dynamic_scalar<'a, 'spv, 'tcx>(
) -> Word {
let composite_type = SpirvType::Vector {
element: SpirvType::Integer(8, false).def(builder),
count: builder.constant_u32(byte_width as u32),
count: builder.constant_u32(byte_width as u32).def,
}
.def(builder);
let composite = builder
@ -296,7 +296,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
} => {
let fields = field_types
.iter()
.map(|&f| self.cx.lookup_type(f).debug(self.cx))
.map(|&f| self.cx.debug_type(f))
.collect::<Vec<_>>();
f.debug_struct("Adt")
.field("field_types", &fields)
@ -305,7 +305,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
}
SpirvType::Vector { element, count } => f
.debug_struct("Vector")
.field("element", &self.cx.lookup_type(element).debug(self.cx))
.field("element", &self.cx.debug_type(element))
.field(
"count",
&self
@ -317,7 +317,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
.finish(),
SpirvType::Array { element, count } => f
.debug_struct("Array")
.field("element", &self.cx.lookup_type(element).debug(self.cx))
.field("element", &self.cx.debug_type(element))
.field(
"count",
&self
@ -333,7 +333,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
} => f
.debug_struct("Pointer")
.field("storage_class", &storage_class)
.field("pointee", &self.cx.lookup_type(pointee).debug(self.cx))
.field("pointee", &self.cx.debug_type(pointee))
.finish(),
SpirvType::Function {
return_type,
@ -341,7 +341,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
} => {
let args = arguments
.iter()
.map(|&a| self.cx.lookup_type(a).debug(self.cx))
.map(|&a| self.cx.debug_type(a))
.collect::<Vec<_>>();
f.debug_struct("Function")
.field("return_type", &self.cx.lookup_type(return_type))
@ -352,6 +352,70 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
}
}
impl fmt::Display for SpirvTypePrinter<'_, '_, '_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.ty {
SpirvType::Void => f.write_str("void"),
SpirvType::Bool => f.write_str("bool"),
SpirvType::Integer(width, signedness) => {
let prefix = if signedness { "i" } else { "u" };
write!(f, "{}{}", prefix, width)
}
SpirvType::Float(width) => write!(f, "f{}", width),
SpirvType::Adt {
ref field_types,
field_offsets: _,
} => {
f.write_str("struct { ")?;
for (index, &field) in field_types.iter().enumerate() {
let suffix = if index + 1 == field_types.len() {
""
} else {
", "
};
write!(f, "{}{}", self.cx.debug_type(field), suffix)?;
}
f.write_str(" }")
}
SpirvType::Vector { element, count } => {
let elem = self.cx.debug_type(element);
let len = self.cx.builder.lookup_const_u64(count);
let len = len.expect("Vector type has invalid count value");
write!(f, "vec<{}, {}>", elem, len)
}
SpirvType::Array { element, count } => {
let elem = self.cx.debug_type(element);
let len = self.cx.builder.lookup_const_u64(count);
let len = len.expect("Array type has invalid count value");
write!(f, "[{}; {}]", elem, len)
}
SpirvType::Pointer {
storage_class,
pointee,
} => {
let pointee = self.cx.debug_type(pointee);
write!(f, "*{{{:?}}} {}", storage_class, pointee)
}
SpirvType::Function {
return_type,
ref arguments,
} => {
f.write_str("fn(")?;
for (index, &arg) in arguments.iter().enumerate() {
let suffix = if index + 1 == arguments.len() {
""
} else {
", "
};
write!(f, "{}{}", self.cx.debug_type(arg), suffix)?;
}
let ret_type = self.cx.debug_type(return_type);
write!(f, ") -> {}", ret_type)
}
}
}
}
// returns (function_type, return_type, argument_types)
pub fn trans_fnabi<'spv, 'tcx>(
cx: &CodegenCx<'spv, 'tcx>,
@ -545,7 +609,7 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>
FieldsShape::Union(_field_count) => {
assert_ne!(ty.size.bytes(), 0);
let byte = SpirvType::Integer(8, false).def(cx);
let count = cx.constant_u32(ty.size.bytes() as u32);
let count = cx.constant_u32(ty.size.bytes() as u32).def;
SpirvType::Array {
element: byte,
count,
@ -558,7 +622,7 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>
let nonzero_count = if count == 0 { 1 } else { count };
// TODO: Assert stride is same as spirv's stride?
let element_type = trans_type(cx, ty.field(cx, 0));
let count_const = cx.constant_u32(nonzero_count as u32);
let count_const = cx.constant_u32(nonzero_count as u32).def;
SpirvType::Array {
element: element_type,
count: count_const,

View File

@ -22,7 +22,7 @@ macro_rules! assert_ty_eq {
assert_eq!(
$left,
$right,
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
"Expected types to be equal:\n{}\n==\n{}",
$codegen_cx.debug_type($left),
$codegen_cx.debug_type($right)
)
@ -332,6 +332,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
} => pointee,
ty => panic!("store called on variable that wasn't a pointer: {:?}", ty),
};
println!("ptr={} val={}", ptr.def, val.def);
assert_ty_eq!(self, ptr_elem_ty, val.ty);
self.emit().store(ptr.def, val.def, None, empty()).unwrap();
val
@ -384,8 +385,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
pointee: result_pointee_type,
}
.def(self);
let u64 = SpirvType::Integer(64, false).def(self);
let index_const = self.builder.constant_u64(u64, idx);
let index_const = self.constant_u64(idx).def;
self.emit()
.access_chain(result_type, None, ptr.def, [index_const].iter().cloned())
.unwrap()
@ -622,7 +622,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
let elem_ty = match self.lookup_type(ptr.ty) {
SpirvType::Pointer { pointee, .. } => pointee,
_ => panic!(
"memset called on non-pointer type: {:?}",
"memset called on non-pointer type: {}",
self.debug_type(ptr.ty)
),
};
@ -648,9 +648,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
self.store(pat, ptr, Align::from_bytes(0).unwrap());
} else {
for index in 0..size {
let u32 = SpirvType::Integer(32, false).def(self);
let const_index =
self.builder.constant_u32(u32, index as u32).with_type(u32);
let const_index = self.constant_u32(index as u32);
let gep_ptr = self.gep(ptr, &[const_index]);
self.store(pat, gep_ptr, Align::from_bytes(0).unwrap());
}
@ -672,9 +670,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
self.store(pat, ptr, Align::from_bytes(0).unwrap());
} else {
for index in 0..size {
let u32 = SpirvType::Integer(32, false).def(self);
let const_index =
self.builder.constant_u32(u32, index as u32).with_type(u32);
let const_index = self.constant_u32(index as u32);
let gep_ptr = self.gep(ptr, &[const_index]);
self.store(pat, gep_ptr, Align::from_bytes(0).unwrap());
}
@ -723,8 +719,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
}
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
let u32 = SpirvType::Integer(32, false).def(self);
let count = self.builder.constant_u32(u32, num_elts as u32);
let count = self.constant_u32(num_elts as u32).def;
let result_type = SpirvType::Vector {
element: elt.ty,
count,

View File

@ -61,7 +61,7 @@ impl<'a, 'spv, 'tcx> Builder<'a, 'spv, 'tcx> {
result_pointee_type = match self.lookup_type(result_pointee_type) {
SpirvType::Array { element, count: _ } => element,
_ => panic!(
"GEP not implemented for type {:?}",
"GEP not implemented for type {}",
self.debug_type(result_pointee_type)
),
};

View File

@ -72,7 +72,6 @@ impl BuilderSpirv {
}
/// Helper function useful to place right before a crash, to debug the module state.
#[allow(dead_code)]
pub fn dump_module(&self, path: impl AsRef<Path>) {
let mut module = self.builder.borrow().module_ref().clone();
let mut header = rspirv::dr::ModuleHeader::new(0);

View File

@ -35,7 +35,7 @@ macro_rules! assert_ty_eq {
assert_eq!(
$left,
$right,
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
"Expected types to be equal:\n{}\n==\n{}",
$codegen_cx.debug_type($left),
$codegen_cx.debug_type($right)
)
@ -123,7 +123,6 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
}
// Useful for printing out types when debugging
#[allow(dead_code)]
pub fn debug_type<'cx>(&'cx self, ty: Word) -> SpirvTypePrinter<'cx, 'spv, 'tcx> {
self.lookup_type(ty).debug(self)
}
@ -139,38 +138,37 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
// Presumably these methods will get used eventually, so allow(dead_code) to not have to rewrite when needed.
#[allow(dead_code)]
pub fn constant_u8(&self, val: u32) -> Word {
pub fn constant_u8(&self, val: u32) -> SpirvValue {
let ty = SpirvType::Integer(8, false).def(self);
self.builder.constant_u32(ty, val)
self.builder.constant_u32(ty, val).with_type(ty)
}
#[allow(dead_code)]
pub fn constant_u16(&self, val: u32) -> Word {
pub fn constant_u16(&self, val: u32) -> SpirvValue {
let ty = SpirvType::Integer(16, false).def(self);
self.builder.constant_u32(ty, val)
self.builder.constant_u32(ty, val).with_type(ty)
}
pub fn constant_u32(&self, val: u32) -> Word {
pub fn constant_u32(&self, val: u32) -> SpirvValue {
let ty = SpirvType::Integer(32, false).def(self);
self.builder.constant_u32(ty, val)
self.builder.constant_u32(ty, val).with_type(ty)
}
#[allow(dead_code)]
pub fn constant_u64(&self, val: u64) -> Word {
pub fn constant_u64(&self, val: u64) -> SpirvValue {
let ty = SpirvType::Integer(64, false).def(self);
self.builder.constant_u64(ty, val)
self.builder.constant_u64(ty, val).with_type(ty)
}
#[allow(dead_code)]
pub fn constant_f32(&self, val: f32) -> Word {
pub fn constant_f32(&self, val: f32) -> SpirvValue {
let ty = SpirvType::Float(32).def(self);
self.builder.constant_f32(ty, val)
self.builder.constant_f32(ty, val).with_type(ty)
}
#[allow(dead_code)]
pub fn constant_f64(&self, val: f64) -> Word {
pub fn constant_f64(&self, val: f64) -> SpirvValue {
let ty = SpirvType::Float(64).def(self);
self.builder.constant_f64(ty, val)
self.builder.constant_f64(ty, val).with_type(ty)
}
}
@ -651,7 +649,17 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> {
if u > u64::MAX as u128 {
panic!("u128 literals not supported yet: {}", u);
}
self.builder.constant_u64(t, u as u64).with_type(t)
match self.lookup_type(t) {
SpirvType::Integer(width, _) => {
if width > 32 {
self.builder.constant_u64(t, u as u64).with_type(t)
} else {
assert!(u <= u32::MAX as u128);
self.builder.constant_u32(t, u as u32).with_type(t)
}
}
other => panic!("const_uint_big invalid on type {}", other.debug(self)),
}
}
fn const_bool(&self, val: bool) -> Self::Value {
let bool = SpirvType::Bool.def(self);

View File

@ -339,8 +339,8 @@ impl ExtraBackendMethods for SsaBackend {
// attributes::sanitize(&cx, SanitizerSet::empty(), entry);
}
};
if option_env!("DUMP_MODULE_ON_PANIC").is_some() {
let module_dumper = DumpModuleOnPanic { cx: &cx };
if let Some(path) = option_env!("DUMP_MODULE_ON_PANIC") {
let module_dumper = DumpModuleOnPanic { cx: &cx, path };
do_codegen();
drop(module_dumper)
} else {
@ -371,15 +371,20 @@ impl ExtraBackendMethods for SsaBackend {
}
}
struct DumpModuleOnPanic<'cx, 'spv, 'tcx> {
struct DumpModuleOnPanic<'a, 'cx, 'spv, 'tcx> {
cx: &'cx CodegenCx<'spv, 'tcx>,
path: &'a str,
}
impl<'cx, 'spv, 'tcx> Drop for DumpModuleOnPanic<'cx, 'spv, 'tcx> {
impl Drop for DumpModuleOnPanic<'_, '_, '_, '_> {
fn drop(&mut self) {
if std::thread::panicking() {
// can also use dump_module with a path here to write it to disk
println!("{}", self.cx.builder.dump_module_str());
let path: &std::path::Path = self.path.as_ref();
if path.has_root() {
self.cx.builder.dump_module(path);
} else {
println!("{}", self.cx.builder.dump_module_str());
}
}
}
}