mirror of
https://github.com/EmbarkStudios/rust-gpu.git
synced 2024-11-22 14:56:27 +00:00
Various printing helpers and ergonomic refactorings
This commit is contained in:
parent
65e80668d7
commit
c64f433135
@ -65,7 +65,7 @@ fn memset_dynamic_scalar<'a, 'spv, 'tcx>(
|
||||
) -> Word {
|
||||
let composite_type = SpirvType::Vector {
|
||||
element: SpirvType::Integer(8, false).def(builder),
|
||||
count: builder.constant_u32(byte_width as u32),
|
||||
count: builder.constant_u32(byte_width as u32).def,
|
||||
}
|
||||
.def(builder);
|
||||
let composite = builder
|
||||
@ -296,7 +296,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
} => {
|
||||
let fields = field_types
|
||||
.iter()
|
||||
.map(|&f| self.cx.lookup_type(f).debug(self.cx))
|
||||
.map(|&f| self.cx.debug_type(f))
|
||||
.collect::<Vec<_>>();
|
||||
f.debug_struct("Adt")
|
||||
.field("field_types", &fields)
|
||||
@ -305,7 +305,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
}
|
||||
SpirvType::Vector { element, count } => f
|
||||
.debug_struct("Vector")
|
||||
.field("element", &self.cx.lookup_type(element).debug(self.cx))
|
||||
.field("element", &self.cx.debug_type(element))
|
||||
.field(
|
||||
"count",
|
||||
&self
|
||||
@ -317,7 +317,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
.finish(),
|
||||
SpirvType::Array { element, count } => f
|
||||
.debug_struct("Array")
|
||||
.field("element", &self.cx.lookup_type(element).debug(self.cx))
|
||||
.field("element", &self.cx.debug_type(element))
|
||||
.field(
|
||||
"count",
|
||||
&self
|
||||
@ -333,7 +333,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
} => f
|
||||
.debug_struct("Pointer")
|
||||
.field("storage_class", &storage_class)
|
||||
.field("pointee", &self.cx.lookup_type(pointee).debug(self.cx))
|
||||
.field("pointee", &self.cx.debug_type(pointee))
|
||||
.finish(),
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
@ -341,7 +341,7 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
} => {
|
||||
let args = arguments
|
||||
.iter()
|
||||
.map(|&a| self.cx.lookup_type(a).debug(self.cx))
|
||||
.map(|&a| self.cx.debug_type(a))
|
||||
.collect::<Vec<_>>();
|
||||
f.debug_struct("Function")
|
||||
.field("return_type", &self.cx.lookup_type(return_type))
|
||||
@ -352,6 +352,70 @@ impl fmt::Debug for SpirvTypePrinter<'_, '_, '_> {
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SpirvTypePrinter<'_, '_, '_> {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self.ty {
|
||||
SpirvType::Void => f.write_str("void"),
|
||||
SpirvType::Bool => f.write_str("bool"),
|
||||
SpirvType::Integer(width, signedness) => {
|
||||
let prefix = if signedness { "i" } else { "u" };
|
||||
write!(f, "{}{}", prefix, width)
|
||||
}
|
||||
SpirvType::Float(width) => write!(f, "f{}", width),
|
||||
SpirvType::Adt {
|
||||
ref field_types,
|
||||
field_offsets: _,
|
||||
} => {
|
||||
f.write_str("struct { ")?;
|
||||
for (index, &field) in field_types.iter().enumerate() {
|
||||
let suffix = if index + 1 == field_types.len() {
|
||||
""
|
||||
} else {
|
||||
", "
|
||||
};
|
||||
write!(f, "{}{}", self.cx.debug_type(field), suffix)?;
|
||||
}
|
||||
f.write_str(" }")
|
||||
}
|
||||
SpirvType::Vector { element, count } => {
|
||||
let elem = self.cx.debug_type(element);
|
||||
let len = self.cx.builder.lookup_const_u64(count);
|
||||
let len = len.expect("Vector type has invalid count value");
|
||||
write!(f, "vec<{}, {}>", elem, len)
|
||||
}
|
||||
SpirvType::Array { element, count } => {
|
||||
let elem = self.cx.debug_type(element);
|
||||
let len = self.cx.builder.lookup_const_u64(count);
|
||||
let len = len.expect("Array type has invalid count value");
|
||||
write!(f, "[{}; {}]", elem, len)
|
||||
}
|
||||
SpirvType::Pointer {
|
||||
storage_class,
|
||||
pointee,
|
||||
} => {
|
||||
let pointee = self.cx.debug_type(pointee);
|
||||
write!(f, "*{{{:?}}} {}", storage_class, pointee)
|
||||
}
|
||||
SpirvType::Function {
|
||||
return_type,
|
||||
ref arguments,
|
||||
} => {
|
||||
f.write_str("fn(")?;
|
||||
for (index, &arg) in arguments.iter().enumerate() {
|
||||
let suffix = if index + 1 == arguments.len() {
|
||||
""
|
||||
} else {
|
||||
", "
|
||||
};
|
||||
write!(f, "{}{}", self.cx.debug_type(arg), suffix)?;
|
||||
}
|
||||
let ret_type = self.cx.debug_type(return_type);
|
||||
write!(f, ") -> {}", ret_type)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// returns (function_type, return_type, argument_types)
|
||||
pub fn trans_fnabi<'spv, 'tcx>(
|
||||
cx: &CodegenCx<'spv, 'tcx>,
|
||||
@ -545,7 +609,7 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>
|
||||
FieldsShape::Union(_field_count) => {
|
||||
assert_ne!(ty.size.bytes(), 0);
|
||||
let byte = SpirvType::Integer(8, false).def(cx);
|
||||
let count = cx.constant_u32(ty.size.bytes() as u32);
|
||||
let count = cx.constant_u32(ty.size.bytes() as u32).def;
|
||||
SpirvType::Array {
|
||||
element: byte,
|
||||
count,
|
||||
@ -558,7 +622,7 @@ fn trans_aggregate<'spv, 'tcx>(cx: &CodegenCx<'spv, 'tcx>, ty: TyAndLayout<'tcx>
|
||||
let nonzero_count = if count == 0 { 1 } else { count };
|
||||
// TODO: Assert stride is same as spirv's stride?
|
||||
let element_type = trans_type(cx, ty.field(cx, 0));
|
||||
let count_const = cx.constant_u32(nonzero_count as u32);
|
||||
let count_const = cx.constant_u32(nonzero_count as u32).def;
|
||||
SpirvType::Array {
|
||||
element: element_type,
|
||||
count: count_const,
|
||||
|
@ -22,7 +22,7 @@ macro_rules! assert_ty_eq {
|
||||
assert_eq!(
|
||||
$left,
|
||||
$right,
|
||||
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
|
||||
"Expected types to be equal:\n{}\n==\n{}",
|
||||
$codegen_cx.debug_type($left),
|
||||
$codegen_cx.debug_type($right)
|
||||
)
|
||||
@ -332,6 +332,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
} => pointee,
|
||||
ty => panic!("store called on variable that wasn't a pointer: {:?}", ty),
|
||||
};
|
||||
println!("ptr={} val={}", ptr.def, val.def);
|
||||
assert_ty_eq!(self, ptr_elem_ty, val.ty);
|
||||
self.emit().store(ptr.def, val.def, None, empty()).unwrap();
|
||||
val
|
||||
@ -384,8 +385,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
pointee: result_pointee_type,
|
||||
}
|
||||
.def(self);
|
||||
let u64 = SpirvType::Integer(64, false).def(self);
|
||||
let index_const = self.builder.constant_u64(u64, idx);
|
||||
let index_const = self.constant_u64(idx).def;
|
||||
self.emit()
|
||||
.access_chain(result_type, None, ptr.def, [index_const].iter().cloned())
|
||||
.unwrap()
|
||||
@ -622,7 +622,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
let elem_ty = match self.lookup_type(ptr.ty) {
|
||||
SpirvType::Pointer { pointee, .. } => pointee,
|
||||
_ => panic!(
|
||||
"memset called on non-pointer type: {:?}",
|
||||
"memset called on non-pointer type: {}",
|
||||
self.debug_type(ptr.ty)
|
||||
),
|
||||
};
|
||||
@ -648,9 +648,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
self.store(pat, ptr, Align::from_bytes(0).unwrap());
|
||||
} else {
|
||||
for index in 0..size {
|
||||
let u32 = SpirvType::Integer(32, false).def(self);
|
||||
let const_index =
|
||||
self.builder.constant_u32(u32, index as u32).with_type(u32);
|
||||
let const_index = self.constant_u32(index as u32);
|
||||
let gep_ptr = self.gep(ptr, &[const_index]);
|
||||
self.store(pat, gep_ptr, Align::from_bytes(0).unwrap());
|
||||
}
|
||||
@ -672,9 +670,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
self.store(pat, ptr, Align::from_bytes(0).unwrap());
|
||||
} else {
|
||||
for index in 0..size {
|
||||
let u32 = SpirvType::Integer(32, false).def(self);
|
||||
let const_index =
|
||||
self.builder.constant_u32(u32, index as u32).with_type(u32);
|
||||
let const_index = self.constant_u32(index as u32);
|
||||
let gep_ptr = self.gep(ptr, &[const_index]);
|
||||
self.store(pat, gep_ptr, Align::from_bytes(0).unwrap());
|
||||
}
|
||||
@ -723,8 +719,7 @@ impl<'a, 'spv, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'spv, 'tcx> {
|
||||
}
|
||||
|
||||
fn vector_splat(&mut self, num_elts: usize, elt: Self::Value) -> Self::Value {
|
||||
let u32 = SpirvType::Integer(32, false).def(self);
|
||||
let count = self.builder.constant_u32(u32, num_elts as u32);
|
||||
let count = self.constant_u32(num_elts as u32).def;
|
||||
let result_type = SpirvType::Vector {
|
||||
element: elt.ty,
|
||||
count,
|
||||
|
@ -61,7 +61,7 @@ impl<'a, 'spv, 'tcx> Builder<'a, 'spv, 'tcx> {
|
||||
result_pointee_type = match self.lookup_type(result_pointee_type) {
|
||||
SpirvType::Array { element, count: _ } => element,
|
||||
_ => panic!(
|
||||
"GEP not implemented for type {:?}",
|
||||
"GEP not implemented for type {}",
|
||||
self.debug_type(result_pointee_type)
|
||||
),
|
||||
};
|
||||
|
@ -72,7 +72,6 @@ impl BuilderSpirv {
|
||||
}
|
||||
|
||||
/// Helper function useful to place right before a crash, to debug the module state.
|
||||
#[allow(dead_code)]
|
||||
pub fn dump_module(&self, path: impl AsRef<Path>) {
|
||||
let mut module = self.builder.borrow().module_ref().clone();
|
||||
let mut header = rspirv::dr::ModuleHeader::new(0);
|
||||
|
@ -35,7 +35,7 @@ macro_rules! assert_ty_eq {
|
||||
assert_eq!(
|
||||
$left,
|
||||
$right,
|
||||
"Expected types to be equal:\n{:#?}\n==\n{:#?}",
|
||||
"Expected types to be equal:\n{}\n==\n{}",
|
||||
$codegen_cx.debug_type($left),
|
||||
$codegen_cx.debug_type($right)
|
||||
)
|
||||
@ -123,7 +123,6 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
|
||||
}
|
||||
|
||||
// Useful for printing out types when debugging
|
||||
#[allow(dead_code)]
|
||||
pub fn debug_type<'cx>(&'cx self, ty: Word) -> SpirvTypePrinter<'cx, 'spv, 'tcx> {
|
||||
self.lookup_type(ty).debug(self)
|
||||
}
|
||||
@ -139,38 +138,37 @@ impl<'spv, 'tcx> CodegenCx<'spv, 'tcx> {
|
||||
|
||||
// Presumably these methods will get used eventually, so allow(dead_code) to not have to rewrite when needed.
|
||||
#[allow(dead_code)]
|
||||
pub fn constant_u8(&self, val: u32) -> Word {
|
||||
pub fn constant_u8(&self, val: u32) -> SpirvValue {
|
||||
let ty = SpirvType::Integer(8, false).def(self);
|
||||
self.builder.constant_u32(ty, val)
|
||||
self.builder.constant_u32(ty, val).with_type(ty)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn constant_u16(&self, val: u32) -> Word {
|
||||
pub fn constant_u16(&self, val: u32) -> SpirvValue {
|
||||
let ty = SpirvType::Integer(16, false).def(self);
|
||||
self.builder.constant_u32(ty, val)
|
||||
self.builder.constant_u32(ty, val).with_type(ty)
|
||||
}
|
||||
|
||||
pub fn constant_u32(&self, val: u32) -> Word {
|
||||
pub fn constant_u32(&self, val: u32) -> SpirvValue {
|
||||
let ty = SpirvType::Integer(32, false).def(self);
|
||||
self.builder.constant_u32(ty, val)
|
||||
self.builder.constant_u32(ty, val).with_type(ty)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn constant_u64(&self, val: u64) -> Word {
|
||||
pub fn constant_u64(&self, val: u64) -> SpirvValue {
|
||||
let ty = SpirvType::Integer(64, false).def(self);
|
||||
self.builder.constant_u64(ty, val)
|
||||
self.builder.constant_u64(ty, val).with_type(ty)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn constant_f32(&self, val: f32) -> Word {
|
||||
pub fn constant_f32(&self, val: f32) -> SpirvValue {
|
||||
let ty = SpirvType::Float(32).def(self);
|
||||
self.builder.constant_f32(ty, val)
|
||||
self.builder.constant_f32(ty, val).with_type(ty)
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn constant_f64(&self, val: f64) -> Word {
|
||||
pub fn constant_f64(&self, val: f64) -> SpirvValue {
|
||||
let ty = SpirvType::Float(64).def(self);
|
||||
self.builder.constant_f64(ty, val)
|
||||
self.builder.constant_f64(ty, val).with_type(ty)
|
||||
}
|
||||
}
|
||||
|
||||
@ -651,7 +649,17 @@ impl<'spv, 'tcx> ConstMethods<'tcx> for CodegenCx<'spv, 'tcx> {
|
||||
if u > u64::MAX as u128 {
|
||||
panic!("u128 literals not supported yet: {}", u);
|
||||
}
|
||||
self.builder.constant_u64(t, u as u64).with_type(t)
|
||||
match self.lookup_type(t) {
|
||||
SpirvType::Integer(width, _) => {
|
||||
if width > 32 {
|
||||
self.builder.constant_u64(t, u as u64).with_type(t)
|
||||
} else {
|
||||
assert!(u <= u32::MAX as u128);
|
||||
self.builder.constant_u32(t, u as u32).with_type(t)
|
||||
}
|
||||
}
|
||||
other => panic!("const_uint_big invalid on type {}", other.debug(self)),
|
||||
}
|
||||
}
|
||||
fn const_bool(&self, val: bool) -> Self::Value {
|
||||
let bool = SpirvType::Bool.def(self);
|
||||
|
@ -339,8 +339,8 @@ impl ExtraBackendMethods for SsaBackend {
|
||||
// attributes::sanitize(&cx, SanitizerSet::empty(), entry);
|
||||
}
|
||||
};
|
||||
if option_env!("DUMP_MODULE_ON_PANIC").is_some() {
|
||||
let module_dumper = DumpModuleOnPanic { cx: &cx };
|
||||
if let Some(path) = option_env!("DUMP_MODULE_ON_PANIC") {
|
||||
let module_dumper = DumpModuleOnPanic { cx: &cx, path };
|
||||
do_codegen();
|
||||
drop(module_dumper)
|
||||
} else {
|
||||
@ -371,15 +371,20 @@ impl ExtraBackendMethods for SsaBackend {
|
||||
}
|
||||
}
|
||||
|
||||
struct DumpModuleOnPanic<'cx, 'spv, 'tcx> {
|
||||
struct DumpModuleOnPanic<'a, 'cx, 'spv, 'tcx> {
|
||||
cx: &'cx CodegenCx<'spv, 'tcx>,
|
||||
path: &'a str,
|
||||
}
|
||||
|
||||
impl<'cx, 'spv, 'tcx> Drop for DumpModuleOnPanic<'cx, 'spv, 'tcx> {
|
||||
impl Drop for DumpModuleOnPanic<'_, '_, '_, '_> {
|
||||
fn drop(&mut self) {
|
||||
if std::thread::panicking() {
|
||||
// can also use dump_module with a path here to write it to disk
|
||||
println!("{}", self.cx.builder.dump_module_str());
|
||||
let path: &std::path::Path = self.path.as_ref();
|
||||
if path.has_root() {
|
||||
self.cx.builder.dump_module(path);
|
||||
} else {
|
||||
println!("{}", self.cx.builder.dump_module_str());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user