Make SpirvValue contain an enum for its value (#219)

This commit is contained in:
Ashley Hauck 2020-11-11 09:23:07 +01:00 committed by GitHub
parent c2ccdbe6ef
commit 9c8cec3639
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 530 additions and 236 deletions

View File

@ -24,7 +24,7 @@ macro_rules! simple_op {
assert_ty_eq!(self, lhs.ty, rhs.ty); assert_ty_eq!(self, lhs.ty, rhs.ty);
let result_type = lhs.ty; let result_type = lhs.ty;
self.emit() self.emit()
.$inst_name(result_type, None, lhs.def, rhs.def) .$inst_name(result_type, None, lhs.def(self), rhs.def(self))
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -36,7 +36,7 @@ macro_rules! simple_op_unchecked_type {
($func_name:ident, $inst_name:ident) => { ($func_name:ident, $inst_name:ident) => {
fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value { fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value {
self.emit() self.emit()
.$inst_name(lhs.ty, None, lhs.def, rhs.def) .$inst_name(lhs.ty, None, lhs.def(self), rhs.def(self))
.unwrap() .unwrap()
.with_type(lhs.ty) .with_type(lhs.ty)
} }
@ -47,7 +47,7 @@ macro_rules! simple_uni_op {
($func_name:ident, $inst_name:ident) => { ($func_name:ident, $inst_name:ident) => {
fn $func_name(&mut self, val: Self::Value) -> Self::Value { fn $func_name(&mut self, val: Self::Value) -> Self::Value {
self.emit() self.emit()
.$inst_name(val.ty, None, val.def) .$inst_name(val.ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(val.ty) .with_type(val.ty)
} }
@ -133,7 +133,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
let semantics = self.constant_u32(semantics.bits()); let semantics = self.constant_u32(semantics.bits());
if invalid_seq_cst { if invalid_seq_cst {
self.zombie( self.zombie(
semantics.def, semantics.def(self),
"Cannot use AtomicOrdering=SequentiallyConsistent on Vulkan memory model. Check if AcquireRelease fits your needs.", "Cannot use AtomicOrdering=SequentiallyConsistent on Vulkan memory model. Check if AcquireRelease fits your needs.",
); );
} }
@ -145,24 +145,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
SpirvType::Void => self.fatal("memset invalid on void pattern"), SpirvType::Void => self.fatal("memset invalid on void pattern"),
SpirvType::Bool => self.fatal("memset invalid on bool pattern"), SpirvType::Bool => self.fatal("memset invalid on bool pattern"),
SpirvType::Integer(width, _signedness) => match width { SpirvType::Integer(width, _signedness) => match width {
8 => self.constant_u8(fill_byte).def, 8 => self.constant_u8(fill_byte).def(self),
16 => self.constant_u16(memset_fill_u16(fill_byte)).def, 16 => self.constant_u16(memset_fill_u16(fill_byte)).def(self),
32 => self.constant_u32(memset_fill_u32(fill_byte)).def, 32 => self.constant_u32(memset_fill_u32(fill_byte)).def(self),
64 => self.constant_u64(memset_fill_u64(fill_byte)).def, 64 => self.constant_u64(memset_fill_u64(fill_byte)).def(self),
_ => self.fatal(&format!( _ => self.fatal(&format!(
"memset on integer width {} not implemented yet", "memset on integer width {} not implemented yet",
width width
)), )),
}, },
SpirvType::Float(width) => match width { SpirvType::Float(width) => match width {
32 => { 32 => self
self.constant_f32(f32::from_bits(memset_fill_u32(fill_byte))) .constant_f32(f32::from_bits(memset_fill_u32(fill_byte)))
.def .def(self),
} 64 => self
64 => { .constant_f64(f64::from_bits(memset_fill_u64(fill_byte)))
self.constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) .def(self),
.def
}
_ => self.fatal(&format!( _ => self.fatal(&format!(
"memset on float width {} not implemented yet", "memset on float width {} not implemented yet",
width width
@ -173,13 +171,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
SpirvType::Vector { element, count } => { SpirvType::Vector { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
self.constant_composite(ty.clone().def(self), vec![elem_pat; count as usize]) self.constant_composite(ty.clone().def(self), vec![elem_pat; count as usize])
.def .def(self)
} }
SpirvType::Array { element, count } => { SpirvType::Array { element, count } => {
let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte);
let count = self.builder.lookup_const_u64(count).unwrap() as usize; let count = self.builder.lookup_const_u64(count).unwrap() as usize;
self.constant_composite(ty.clone().def(self), vec![elem_pat; count]) self.constant_composite(ty.clone().def(self), vec![elem_pat; count])
.def .def(self)
} }
SpirvType::RuntimeArray { .. } => { SpirvType::RuntimeArray { .. } => {
self.fatal("memset on runtime arrays not implemented yet") self.fatal("memset on runtime arrays not implemented yet")
@ -371,7 +369,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn new_block<'b>(cx: &'a Self::CodegenCx, llfn: Self::Function, _name: &'b str) -> Self { fn new_block<'b>(cx: &'a Self::CodegenCx, llfn: Self::Function, _name: &'b str) -> Self {
let cursor_fn = cx.builder.select_function_by_id(llfn.def); let cursor_fn = cx.builder.select_function_by_id(llfn.def_cx(cx));
let label = cx.emit_with_cursor(cursor_fn).begin_block(None).unwrap(); let label = cx.emit_with_cursor(cursor_fn).begin_block(None).unwrap();
let cursor = cx.builder.select_block_by_id(label); let cursor = cx.builder.select_block_by_id(label);
Self { Self {
@ -388,7 +386,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
Self { Self {
cx, cx,
cursor: Default::default(), cursor: Default::default(),
current_fn: Default::default(), current_fn: 0.with_type(0),
basic_block: Default::default(), basic_block: Default::default(),
current_span: Default::default(), current_span: Default::default(),
} }
@ -446,7 +444,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} }
fn ret(&mut self, value: Self::Value) { fn ret(&mut self, value: Self::Value) {
self.emit().ret_value(value.def).unwrap(); self.emit().ret_value(value.def(self)).unwrap();
} }
fn br(&mut self, dest: Self::BasicBlock) { fn br(&mut self, dest: Self::BasicBlock) {
@ -464,7 +462,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
else_llbb: Self::BasicBlock, else_llbb: Self::BasicBlock,
) { ) {
self.emit() self.emit()
.branch_conditional(cond.def, then_llbb, else_llbb, empty()) .branch_conditional(cond.def(self), then_llbb, else_llbb, empty())
.unwrap() .unwrap()
} }
@ -547,7 +545,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let cases = cases let cases = cases
.map(|(i, b)| (construct_case(self, signed, i), b)) .map(|(i, b)| (construct_case(self, signed, i), b))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.emit().switch(v.def, else_llbb, cases).unwrap() self.emit().switch(v.def(self), else_llbb, cases).unwrap()
} }
fn invoke( fn invoke(
@ -606,8 +604,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
assert_ty_eq!(self, lhs.ty, rhs.ty); assert_ty_eq!(self, lhs.ty, rhs.ty);
let ty = lhs.ty; let ty = lhs.ty;
match self.lookup_type(ty) { match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_and(ty, None, lhs.def, rhs.def), SpirvType::Integer(_, _) => {
SpirvType::Bool => self.emit().logical_and(ty, None, lhs.def, rhs.def), self.emit()
.bitwise_and(ty, None, lhs.def(self), rhs.def(self))
}
SpirvType::Bool => self
.emit()
.logical_and(ty, None, lhs.def(self), rhs.def(self)),
o => self.fatal(&format!( o => self.fatal(&format!(
"and() not implemented for type {}", "and() not implemented for type {}",
o.debug(ty, self) o.debug(ty, self)
@ -620,8 +623,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
assert_ty_eq!(self, lhs.ty, rhs.ty); assert_ty_eq!(self, lhs.ty, rhs.ty);
let ty = lhs.ty; let ty = lhs.ty;
match self.lookup_type(ty) { match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_or(ty, None, lhs.def, rhs.def), SpirvType::Integer(_, _) => {
SpirvType::Bool => self.emit().logical_or(ty, None, lhs.def, rhs.def), self.emit()
.bitwise_or(ty, None, lhs.def(self), rhs.def(self))
}
SpirvType::Bool => self
.emit()
.logical_or(ty, None, lhs.def(self), rhs.def(self)),
o => self.fatal(&format!( o => self.fatal(&format!(
"or() not implemented for type {}", "or() not implemented for type {}",
o.debug(ty, self) o.debug(ty, self)
@ -634,8 +642,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
assert_ty_eq!(self, lhs.ty, rhs.ty); assert_ty_eq!(self, lhs.ty, rhs.ty);
let ty = lhs.ty; let ty = lhs.ty;
match self.lookup_type(ty) { match self.lookup_type(ty) {
SpirvType::Integer(_, _) => self.emit().bitwise_xor(ty, None, lhs.def, rhs.def), SpirvType::Integer(_, _) => {
SpirvType::Bool => self.emit().logical_not_equal(ty, None, lhs.def, rhs.def), self.emit()
.bitwise_xor(ty, None, lhs.def(self), rhs.def(self))
}
SpirvType::Bool => {
self.emit()
.logical_not_equal(ty, None, lhs.def(self), rhs.def(self))
}
o => self.fatal(&format!( o => self.fatal(&format!(
"xor() not implemented for type {}", "xor() not implemented for type {}",
o.debug(ty, self) o.debug(ty, self)
@ -646,12 +660,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} }
fn not(&mut self, val: Self::Value) -> Self::Value { fn not(&mut self, val: Self::Value) -> Self::Value {
match self.lookup_type(val.ty) { match self.lookup_type(val.ty) {
SpirvType::Integer(_, _) => self.emit().not(val.ty, None, val.def), SpirvType::Integer(_, _) => self.emit().not(val.ty, None, val.def(self)),
SpirvType::Bool => { SpirvType::Bool => {
let true_ = self.constant_bool(true); let true_ = self.constant_bool(true);
// intel-compute-runtime doesn't like OpLogicalNot // intel-compute-runtime doesn't like OpLogicalNot
self.emit() self.emit()
.logical_not_equal(val.ty, None, val.def, true_.def) .logical_not_equal(val.ty, None, val.def(self), true_.def(self))
} }
o => self.fatal(&format!( o => self.fatal(&format!(
"not() not implemented for type {}", "not() not implemented for type {}",
@ -676,7 +690,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
OverflowOp::Mul => (self.mul(lhs, rhs), fals), OverflowOp::Mul => (self.mul(lhs, rhs), fals),
}; };
self.zombie( self.zombie(
result.1.def, result.1.def(self),
match oop { match oop {
OverflowOp::Add => "checked add is not supported yet", OverflowOp::Add => "checked add is not supported yet",
OverflowOp::Sub => "checked sub is not supported yet", OverflowOp::Sub => "checked sub is not supported yet",
@ -751,8 +765,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} }
fn load(&mut self, ptr: Self::Value, _align: Align) -> Self::Value { fn load(&mut self, ptr: Self::Value, _align: Align) -> Self::Value {
// See comment on `register_constant_pointer` // See comment on `SpirvValueKind::ConstantPointer`
if let Some(value) = self.lookup_constant_pointer(ptr) { if let Some(value) = ptr.const_ptr_val(self) {
return value; return value;
} }
let ty = match self.lookup_type(ptr.ty) { let ty = match self.lookup_type(ptr.ty) {
@ -766,7 +780,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
self.emit() self.emit()
.load(ty, None, ptr.def, None, empty()) .load(ty, None, ptr.def(self), None, empty())
.unwrap() .unwrap()
.with_type(ty) .with_type(ty)
} }
@ -774,7 +788,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn volatile_load(&mut self, ptr: Self::Value) -> Self::Value { fn volatile_load(&mut self, ptr: Self::Value) -> Self::Value {
// TODO: Implement this // TODO: Implement this
let result = self.load(ptr, Align::from_bytes(0).unwrap()); let result = self.load(ptr, Align::from_bytes(0).unwrap());
self.zombie(result.def, "volatile load is not supported yet"); self.zombie(result.def(self), "volatile load is not supported yet");
result result
} }
@ -794,10 +808,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let semantics = self.ordering_to_semantics_def(order); let semantics = self.ordering_to_semantics_def(order);
let result = self let result = self
.emit() .emit()
.atomic_load(ty, None, ptr.def, memory.def, semantics.def) .atomic_load(
ty,
None,
ptr.def(self),
memory.def(self),
semantics.def(self),
)
.unwrap() .unwrap()
.with_type(ty); .with_type(ty);
self.validate_atomic(ty, result.def); self.validate_atomic(ty, result.def(self));
result result
} }
@ -887,7 +907,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
assert_ty_eq!(self, ptr_elem_ty, val.ty); assert_ty_eq!(self, ptr_elem_ty, val.ty);
self.emit().store(ptr.def, val.def, None, empty()).unwrap(); self.emit()
.store(ptr.def(self), val.def(self), None, empty())
.unwrap();
val val
} }
@ -928,9 +950,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// TODO: Default to device scope // TODO: Default to device scope
let memory = self.constant_u32(Scope::Device as u32); let memory = self.constant_u32(Scope::Device as u32);
let semantics = self.ordering_to_semantics_def(order); let semantics = self.ordering_to_semantics_def(order);
self.validate_atomic(val.ty, ptr.def); self.validate_atomic(val.ty, ptr.def(self));
self.emit() self.emit()
.atomic_store(ptr.def, memory.def, semantics.def, val.def) .atomic_store(
ptr.def(self),
memory.def(self),
semantics.def(self),
val.def(self),
)
.unwrap(); .unwrap();
} }
@ -972,9 +999,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
if idx > u32::MAX as u64 { if idx > u32::MAX as u64 {
self.fatal("struct_gep bigger than u32::MAX"); self.fatal("struct_gep bigger than u32::MAX");
} }
let index_const = self.constant_u32(idx as u32).def; let index_const = self.constant_u32(idx as u32).def(self);
self.emit() self.emit()
.access_chain(result_type, None, ptr.def, [index_const].iter().cloned()) .access_chain(
result_type,
None,
ptr.def(self),
[index_const].iter().cloned(),
)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -1003,7 +1035,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.convert_f_to_u(dest_ty, None, val.def) .convert_f_to_u(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1014,7 +1046,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.convert_f_to_s(dest_ty, None, val.def) .convert_f_to_s(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1025,7 +1057,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.convert_u_to_f(dest_ty, None, val.def) .convert_u_to_f(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1036,7 +1068,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.convert_s_to_f(dest_ty, None, val.def) .convert_s_to_f(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1047,7 +1079,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.f_convert(dest_ty, None, val.def) .f_convert(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1058,7 +1090,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
val val
} else { } else {
self.emit() self.emit()
.f_convert(dest_ty, None, val.def) .f_convert(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1077,10 +1109,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} else { } else {
let result = self let result = self
.emit() .emit()
.convert_ptr_to_u(dest_ty, None, val.def) .convert_ptr_to_u(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty); .with_type(dest_ty);
self.zombie_convert_ptr_to_u(result.def); self.zombie_convert_ptr_to_u(result.def(self));
result result
} }
} }
@ -1098,10 +1130,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} else { } else {
let result = self let result = self
.emit() .emit()
.convert_u_to_ptr(dest_ty, None, val.def) .convert_u_to_ptr(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty); .with_type(dest_ty);
self.zombie_convert_u_to_ptr(result.def); self.zombie_convert_u_to_ptr(result.def(self));
result result
} }
} }
@ -1112,13 +1144,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} else { } else {
let result = self let result = self
.emit() .emit()
.bitcast(dest_ty, None, val.def) .bitcast(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty); .with_type(dest_ty);
let val_is_ptr = matches!(self.lookup_type(val.ty), SpirvType::Pointer{..}); let val_is_ptr = matches!(self.lookup_type(val.ty), SpirvType::Pointer{..});
let dest_is_ptr = matches!(self.lookup_type(dest_ty), SpirvType::Pointer{..}); let dest_is_ptr = matches!(self.lookup_type(dest_ty), SpirvType::Pointer{..});
if val_is_ptr || dest_is_ptr { if val_is_ptr || dest_is_ptr {
self.zombie_bitcast_ptr(result.def, val.ty, dest_ty); self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty);
} }
result result
} }
@ -1136,7 +1168,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
SpirvType::Integer(dest_width, dest_signedness), SpirvType::Integer(dest_width, dest_signedness),
) if val_width == dest_width && val_signedness != dest_signedness => self ) if val_width == dest_width && val_signedness != dest_signedness => self
.emit() .emit()
.bitcast(dest_ty, None, val.def) .bitcast(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty), .with_type(dest_ty),
// width change, and optional sign change // width change, and optional sign change
@ -1144,9 +1176,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// spir-v spec doesn't seem to say that signedness needs to match the operands, only that the signedness // spir-v spec doesn't seem to say that signedness needs to match the operands, only that the signedness
// of the destination type must match the instruction's signedness. // of the destination type must match the instruction's signedness.
if dest_signedness { if dest_signedness {
self.emit().s_convert(dest_ty, None, val.def) self.emit().s_convert(dest_ty, None, val.def(self))
} else { } else {
self.emit().u_convert(dest_ty, None, val.def) self.emit().u_convert(dest_ty, None, val.def(self))
} }
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
@ -1157,7 +1189,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let if_true = self.constant_int(dest_ty, 1); let if_true = self.constant_int(dest_ty, 1);
let if_false = self.constant_int(dest_ty, 0); let if_false = self.constant_int(dest_ty, 0);
self.emit() self.emit()
.select(dest_ty, None, val.def, if_true.def, if_false.def) .select(
dest_ty,
None,
val.def(self),
if_true.def(self),
if_false.def(self),
)
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1165,7 +1203,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
// spir-v doesn't have a direct conversion instruction, glslang emits OpINotEqual // spir-v doesn't have a direct conversion instruction, glslang emits OpINotEqual
let zero = self.constant_int(val.ty, 0); let zero = self.constant_int(val.ty, 0);
self.emit() self.emit()
.i_not_equal(dest_ty, None, val.def, zero.def) .i_not_equal(dest_ty, None, val.def(self), zero.def(self))
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} }
@ -1196,10 +1234,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) { } else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) {
let indices = indices let indices = indices
.into_iter() .into_iter()
.map(|idx| self.constant_u32(idx).def) .map(|idx| self.constant_u32(idx).def(self))
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.emit() self.emit()
.access_chain(dest_ty, None, val.def, indices) .access_chain(dest_ty, None, val.def(self), indices)
.unwrap() .unwrap()
.with_type(dest_ty) .with_type(dest_ty)
} else if self } else if self
@ -1211,10 +1249,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} else { } else {
let result = self let result = self
.emit() .emit()
.bitcast(dest_ty, None, val.def) .bitcast(dest_ty, None, val.def(self))
.unwrap() .unwrap()
.with_type(dest_ty); .with_type(dest_ty);
self.zombie_bitcast_ptr(result.def, val.ty, dest_ty); self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty);
result result
} }
} }
@ -1226,71 +1264,126 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let b = SpirvType::Bool.def(self); let b = SpirvType::Bool.def(self);
match self.lookup_type(lhs.ty) { match self.lookup_type(lhs.ty) {
SpirvType::Integer(_, _) => match op { SpirvType::Integer(_, _) => match op {
IntEQ => self.emit().i_equal(b, None, lhs.def, rhs.def), IntEQ => self.emit().i_equal(b, None, lhs.def(self), rhs.def(self)),
IntNE => self.emit().i_not_equal(b, None, lhs.def, rhs.def), IntNE => self
IntUGT => self.emit().u_greater_than(b, None, lhs.def, rhs.def), .emit()
IntUGE => self.emit().u_greater_than_equal(b, None, lhs.def, rhs.def), .i_not_equal(b, None, lhs.def(self), rhs.def(self)),
IntULT => self.emit().u_less_than(b, None, lhs.def, rhs.def), IntUGT => self
IntULE => self.emit().u_less_than_equal(b, None, lhs.def, rhs.def), .emit()
IntSGT => self.emit().s_greater_than(b, None, lhs.def, rhs.def), .u_greater_than(b, None, lhs.def(self), rhs.def(self)),
IntSGE => self.emit().s_greater_than_equal(b, None, lhs.def, rhs.def), IntUGE => self
IntSLT => self.emit().s_less_than(b, None, lhs.def, rhs.def), .emit()
IntSLE => self.emit().s_less_than_equal(b, None, lhs.def, rhs.def), .u_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntULT => self
.emit()
.u_less_than(b, None, lhs.def(self), rhs.def(self)),
IntULE => self
.emit()
.u_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntSGT => self
.emit()
.s_greater_than(b, None, lhs.def(self), rhs.def(self)),
IntSGE => self
.emit()
.s_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
IntSLT => self
.emit()
.s_less_than(b, None, lhs.def(self), rhs.def(self)),
IntSLE => self
.emit()
.s_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
}, },
SpirvType::Pointer { .. } => match op { SpirvType::Pointer { .. } => match op {
IntEQ => { IntEQ => {
if self.emit().version().unwrap() > (1, 3) { if self.emit().version().unwrap() > (1, 3) {
self.emit().ptr_equal(b, None, lhs.def, rhs.def) self.emit().ptr_equal(b, None, lhs.def(self), rhs.def(self))
} else { } else {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().i_not_equal(b, None, lhs, rhs) self.emit().i_not_equal(b, None, lhs, rhs)
} }
} }
IntNE => { IntNE => {
if self.emit().version().unwrap() > (1, 3) { if self.emit().version().unwrap() > (1, 3) {
self.emit().ptr_not_equal(b, None, lhs.def, rhs.def) self.emit()
.ptr_not_equal(b, None, lhs.def(self), rhs.def(self))
} else { } else {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().i_not_equal(b, None, lhs, rhs) self.emit().i_not_equal(b, None, lhs, rhs)
} }
} }
IntUGT => { IntUGT => {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().u_greater_than(b, None, lhs, rhs) self.emit().u_greater_than(b, None, lhs, rhs)
} }
IntUGE => { IntUGE => {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().u_greater_than_equal(b, None, lhs, rhs) self.emit().u_greater_than_equal(b, None, lhs, rhs)
} }
IntULT => { IntULT => {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().u_less_than(b, None, lhs, rhs) self.emit().u_less_than(b, None, lhs, rhs)
} }
IntULE => { IntULE => {
let int_ty = self.type_usize(); let int_ty = self.type_usize();
let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); let lhs = self
.emit()
.convert_ptr_to_u(int_ty, None, lhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(lhs); self.zombie_convert_ptr_to_u(lhs);
let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); let rhs = self
.emit()
.convert_ptr_to_u(int_ty, None, rhs.def(self))
.unwrap();
self.zombie_convert_ptr_to_u(rhs); self.zombie_convert_ptr_to_u(rhs);
self.emit().u_less_than_equal(b, None, lhs, rhs) self.emit().u_less_than_equal(b, None, lhs, rhs)
} }
@ -1300,44 +1393,48 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
IntSLE => self.fatal("TODO: pointer operator IntSLE not implemented yet"), IntSLE => self.fatal("TODO: pointer operator IntSLE not implemented yet"),
}, },
SpirvType::Bool => match op { SpirvType::Bool => match op {
IntEQ => self.emit().logical_equal(b, None, lhs.def, rhs.def), IntEQ => self
IntNE => self.emit().logical_not_equal(b, None, lhs.def, rhs.def), .emit()
.logical_equal(b, None, lhs.def(self), rhs.def(self)),
IntNE => self
.emit()
.logical_not_equal(b, None, lhs.def(self), rhs.def(self)),
// x > y => x && !y // x > y => x && !y
IntUGT => { IntUGT => {
// intel-compute-runtime doesn't like OpLogicalNot // intel-compute-runtime doesn't like OpLogicalNot
let true_ = self.constant_bool(true); let true_ = self.constant_bool(true);
let rhs = self let rhs = self
.emit() .emit()
.logical_not_equal(b, None, rhs.def, true_.def) .logical_not_equal(b, None, rhs.def(self), true_.def(self))
.unwrap(); .unwrap();
self.emit().logical_and(b, None, lhs.def, rhs) self.emit().logical_and(b, None, lhs.def(self), rhs)
} }
// x >= y => x || !y // x >= y => x || !y
IntUGE => { IntUGE => {
let true_ = self.constant_bool(true); let true_ = self.constant_bool(true);
let rhs = self let rhs = self
.emit() .emit()
.logical_not_equal(b, None, rhs.def, true_.def) .logical_not_equal(b, None, rhs.def(self), true_.def(self))
.unwrap(); .unwrap();
self.emit().logical_or(b, None, lhs.def, rhs) self.emit().logical_or(b, None, lhs.def(self), rhs)
} }
// x < y => !x && y // x < y => !x && y
IntULE => { IntULE => {
let true_ = self.constant_bool(true); let true_ = self.constant_bool(true);
let lhs = self let lhs = self
.emit() .emit()
.logical_not_equal(b, None, lhs.def, true_.def) .logical_not_equal(b, None, lhs.def(self), true_.def(self))
.unwrap(); .unwrap();
self.emit().logical_and(b, None, lhs, rhs.def) self.emit().logical_and(b, None, lhs, rhs.def(self))
} }
// x <= y => !x || y // x <= y => !x || y
IntULT => { IntULT => {
let true_ = self.constant_bool(true); let true_ = self.constant_bool(true);
let lhs = self let lhs = self
.emit() .emit()
.logical_not_equal(b, None, lhs.def, true_.def) .logical_not_equal(b, None, lhs.def(self), true_.def(self))
.unwrap(); .unwrap();
self.emit().logical_or(b, None, lhs, rhs.def) self.emit().logical_or(b, None, lhs, rhs.def(self))
} }
IntSGT => self.fatal("TODO: boolean operator IntSGT not implemented yet"), IntSGT => self.fatal("TODO: boolean operator IntSGT not implemented yet"),
IntSGE => self.fatal("TODO: boolean operator IntSGE not implemented yet"), IntSGE => self.fatal("TODO: boolean operator IntSGE not implemented yet"),
@ -1360,26 +1457,45 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
match op { match op {
RealPredicateFalse => return self.cx.constant_bool(false), RealPredicateFalse => return self.cx.constant_bool(false),
RealPredicateTrue => return self.cx.constant_bool(true), RealPredicateTrue => return self.cx.constant_bool(true),
RealOEQ => self.emit().f_ord_equal(b, None, lhs.def, rhs.def), RealOEQ => self
RealOGT => self.emit().f_ord_greater_than(b, None, lhs.def, rhs.def), .emit()
.f_ord_equal(b, None, lhs.def(self), rhs.def(self)),
RealOGT => self
.emit()
.f_ord_greater_than(b, None, lhs.def(self), rhs.def(self)),
RealOGE => self RealOGE => self
.emit() .emit()
.f_ord_greater_than_equal(b, None, lhs.def, rhs.def), .f_ord_greater_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealOLT => self.emit().f_ord_less_than(b, None, lhs.def, rhs.def), RealOLT => self
RealOLE => self.emit().f_ord_less_than_equal(b, None, lhs.def, rhs.def),
RealONE => self.emit().f_ord_not_equal(b, None, lhs.def, rhs.def),
RealORD => self.emit().ordered(b, None, lhs.def, rhs.def),
RealUNO => self.emit().unordered(b, None, lhs.def, rhs.def),
RealUEQ => self.emit().f_unord_equal(b, None, lhs.def, rhs.def),
RealUGT => self.emit().f_unord_greater_than(b, None, lhs.def, rhs.def),
RealUGE => self
.emit() .emit()
.f_unord_greater_than_equal(b, None, lhs.def, rhs.def), .f_ord_less_than(b, None, lhs.def(self), rhs.def(self)),
RealULT => self.emit().f_unord_less_than(b, None, lhs.def, rhs.def), RealOLE => self
.emit()
.f_ord_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealONE => self
.emit()
.f_ord_not_equal(b, None, lhs.def(self), rhs.def(self)),
RealORD => self.emit().ordered(b, None, lhs.def(self), rhs.def(self)),
RealUNO => self.emit().unordered(b, None, lhs.def(self), rhs.def(self)),
RealUEQ => self
.emit()
.f_unord_equal(b, None, lhs.def(self), rhs.def(self)),
RealUGT => self
.emit()
.f_unord_greater_than(b, None, lhs.def(self), rhs.def(self)),
RealUGE => {
self.emit()
.f_unord_greater_than_equal(b, None, lhs.def(self), rhs.def(self))
}
RealULT => self
.emit()
.f_unord_less_than(b, None, lhs.def(self), rhs.def(self)),
RealULE => self RealULE => self
.emit() .emit()
.f_unord_less_than_equal(b, None, lhs.def, rhs.def), .f_unord_less_than_equal(b, None, lhs.def(self), rhs.def(self)),
RealUNE => self.emit().f_unord_not_equal(b, None, lhs.def, rhs.def), RealUNE => self
.emit()
.f_unord_not_equal(b, None, lhs.def(self), rhs.def(self)),
} }
.unwrap() .unwrap()
.with_type(b) .with_type(b)
@ -1411,19 +1527,31 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}; };
let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self)); let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self));
if src_element_size.is_some() && src_element_size == const_size.map(Size::from_bytes) { if src_element_size.is_some() && src_element_size == const_size.map(Size::from_bytes) {
if let Some(const_value) = self.lookup_constant_pointer(src) { // See comment on `SpirvValueKind::ConstantPointer`
if let Some(const_value) = src.const_ptr_val(self) {
self.store(const_value, dst, Align::from_bytes(0).unwrap()); self.store(const_value, dst, Align::from_bytes(0).unwrap());
} else { } else {
self.emit() self.emit()
.copy_memory(dst.def, src.def, None, None, empty()) .copy_memory(dst.def(self), src.def(self), None, None, empty())
.unwrap(); .unwrap();
} }
} else { } else {
self.emit() self.emit()
.copy_memory_sized(dst.def, src.def, size.def, None, None, empty()) .copy_memory_sized(
dst.def(self),
src.def(self),
size.def(self),
None,
None,
empty(),
)
.unwrap(); .unwrap();
if !self.builder.has_capability(Capability::Addresses) { if !self.builder.has_capability(Capability::Addresses) {
self.zombie(dst.def, "OpCopyMemorySized without OpCapability Addresses") self.zombie(
dst.def(self),
"OpCopyMemorySized without OpCapability Addresses",
)
} }
} }
} }
@ -1464,7 +1592,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
let elem_ty_spv = self.lookup_type(elem_ty); let elem_ty_spv = self.lookup_type(elem_ty);
let pat = match self.builder.lookup_const_u64(fill_byte) { let pat = match self.builder.lookup_const_u64(fill_byte) {
Some(fill_byte) => self.memset_const_pattern(&elem_ty_spv, fill_byte as u8), Some(fill_byte) => self.memset_const_pattern(&elem_ty_spv, fill_byte as u8),
None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def), None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def(self)),
} }
.with_type(elem_ty); .with_type(elem_ty);
match self.builder.lookup_const_u64(size) { match self.builder.lookup_const_u64(size) {
@ -1482,7 +1610,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
assert_ty_eq!(self, then_val.ty, else_val.ty); assert_ty_eq!(self, then_val.ty, else_val.ty);
let result_type = then_val.ty; let result_type = then_val.ty;
self.emit() self.emit()
.select(result_type, None, cond.def, then_val.def, else_val.def) .select(
result_type,
None,
cond.def(self),
then_val.def(self),
else_val.def(self),
)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -1503,12 +1637,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
Some(const_index) => self.emit().composite_extract( Some(const_index) => self.emit().composite_extract(
result_type, result_type,
None, None,
vec.def, vec.def(self),
[const_index as u32].iter().cloned(), [const_index as u32].iter().cloned(),
), ),
None => self None => {
.emit() self.emit()
.vector_extract_dynamic(result_type, None, vec.def, idx.def), .vector_extract_dynamic(result_type, None, vec.def(self), idx.def(self))
}
} }
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
@ -1521,10 +1656,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
} }
.def(self); .def(self);
if self.builder.lookup_const(elt).is_some() { if self.builder.lookup_const(elt).is_some() {
self.constant_composite(result_type, vec![elt.def; num_elts]) self.constant_composite(result_type, vec![elt.def(self); num_elts])
} else { } else {
self.emit() self.emit()
.composite_construct(result_type, None, std::iter::repeat(elt.def).take(num_elts)) .composite_construct(
result_type,
None,
std::iter::repeat(elt.def(self)).take(num_elts),
)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -1539,7 +1678,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
self.emit() self.emit()
.composite_extract(result_type, None, agg_val.def, [idx as u32].iter().cloned()) .composite_extract(
result_type,
None,
agg_val.def(self),
[idx as u32].iter().cloned(),
)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -1555,8 +1699,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.composite_insert( .composite_insert(
agg_val.ty, agg_val.ty,
None, None,
elt.def, elt.def(self),
agg_val.def, agg_val.def(self),
[idx as u32].iter().cloned(), [idx as u32].iter().cloned(),
) )
.unwrap() .unwrap()
@ -1638,7 +1782,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
}; };
assert_ty_eq!(self, dst_pointee_ty, cmp.ty); assert_ty_eq!(self, dst_pointee_ty, cmp.ty);
assert_ty_eq!(self, dst_pointee_ty, src.ty); assert_ty_eq!(self, dst_pointee_ty, src.ty);
self.validate_atomic(dst_pointee_ty, dst.def); self.validate_atomic(dst_pointee_ty, dst.def(self));
// TODO: Default to device scope // TODO: Default to device scope
let memory = self.constant_u32(Scope::Device as u32); let memory = self.constant_u32(Scope::Device as u32);
let semantics_equal = self.ordering_to_semantics_def(order); let semantics_equal = self.ordering_to_semantics_def(order);
@ -1648,12 +1792,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
.atomic_compare_exchange( .atomic_compare_exchange(
src.ty, src.ty,
None, None,
dst.def, dst.def(self),
memory.def, memory.def(self),
semantics_equal.def, semantics_equal.def(self),
semantics_unequal.def, semantics_unequal.def(self),
src.def, src.def(self),
cmp.def, cmp.def(self),
) )
.unwrap() .unwrap()
.with_type(src.ty) .with_type(src.ty)
@ -1677,24 +1821,94 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
assert_ty_eq!(self, dst_pointee_ty, src.ty); assert_ty_eq!(self, dst_pointee_ty, src.ty);
self.validate_atomic(dst_pointee_ty, dst.def); self.validate_atomic(dst_pointee_ty, dst.def(self));
// TODO: Default to device scope // TODO: Default to device scope
let memory = self.constant_u32(Scope::Device as u32).def; let memory = self.constant_u32(Scope::Device as u32).def(self);
let semantics = self.ordering_to_semantics_def(order).def; let semantics = self.ordering_to_semantics_def(order).def(self);
let mut emit = self.emit(); let mut emit = self.emit();
use AtomicRmwBinOp::*; use AtomicRmwBinOp::*;
match op { match op {
AtomicXchg => emit.atomic_exchange(src.ty, None, dst.def, memory, semantics, src.def), AtomicXchg => emit.atomic_exchange(
AtomicAdd => emit.atomic_i_add(src.ty, None, dst.def, memory, semantics, src.def), src.ty,
AtomicSub => emit.atomic_i_sub(src.ty, None, dst.def, memory, semantics, src.def), None,
AtomicAnd => emit.atomic_and(src.ty, None, dst.def, memory, semantics, src.def), dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicAdd => emit.atomic_i_add(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicSub => emit.atomic_i_sub(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicAnd => emit.atomic_and(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicNand => self.fatal("atomic nand is not supported"), AtomicNand => self.fatal("atomic nand is not supported"),
AtomicOr => emit.atomic_or(src.ty, None, dst.def, memory, semantics, src.def), AtomicOr => emit.atomic_or(
AtomicXor => emit.atomic_xor(src.ty, None, dst.def, memory, semantics, src.def), src.ty,
AtomicMax => emit.atomic_s_max(src.ty, None, dst.def, memory, semantics, src.def), None,
AtomicMin => emit.atomic_s_min(src.ty, None, dst.def, memory, semantics, src.def), dst.def(self),
AtomicUMax => emit.atomic_u_max(src.ty, None, dst.def, memory, semantics, src.def), memory,
AtomicUMin => emit.atomic_u_min(src.ty, None, dst.def, memory, semantics, src.def), semantics,
src.def(self),
),
AtomicXor => emit.atomic_xor(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicMax => emit.atomic_s_max(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicMin => emit.atomic_s_min(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicUMax => emit.atomic_u_max(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
AtomicUMin => emit.atomic_u_min(
src.ty,
None,
dst.def(self),
memory,
semantics,
src.def(self),
),
} }
.unwrap() .unwrap()
.with_type(src.ty) .with_type(src.ty)
@ -1703,8 +1917,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
fn atomic_fence(&mut self, order: AtomicOrdering, _scope: SynchronizationScope) { fn atomic_fence(&mut self, order: AtomicOrdering, _scope: SynchronizationScope) {
// Ignore sync scope (it only has "single thread" and "cross thread") // Ignore sync scope (it only has "single thread" and "cross thread")
// TODO: Default to device scope // TODO: Default to device scope
let memory = self.constant_u32(Scope::Device as u32).def; let memory = self.constant_u32(Scope::Device as u32).def(self);
let semantics = self.ordering_to_semantics_def(order).def; let semantics = self.ordering_to_semantics_def(order).def(self);
self.emit().memory_barrier(memory, semantics).unwrap(); self.emit().memory_barrier(memory, semantics).unwrap();
} }
@ -1759,9 +1973,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> {
for (argument, argument_type) in args.iter().zip(argument_types) { for (argument, argument_type) in args.iter().zip(argument_types) {
assert_ty_eq!(self, argument.ty, argument_type); assert_ty_eq!(self, argument.ty, argument_type);
} }
let args = args.iter().map(|arg| arg.def).collect::<Vec<_>>(); let args = args.iter().map(|arg| arg.def(self)).collect::<Vec<_>>();
self.emit() self.emit()
.function_call(result_type, None, llfn.def, args) .function_call(result_type, None, llfn.def(self), args)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }

View File

@ -62,7 +62,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
None, None,
glsl, glsl,
op as u32, op as u32,
args.iter().map(|a| Operand::IdRef(a.def)), args.iter().map(|a| Operand::IdRef(a.def(self))),
) )
.unwrap() .unwrap()
.with_type(args[0].ty) .with_type(args[0].ty)
@ -77,7 +77,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
None, None,
opencl, opencl,
op as u32, op as u32,
args.iter().map(|a| Operand::IdRef(a.def)), args.iter().map(|a| Operand::IdRef(a.def(self))),
) )
.unwrap() .unwrap()
.with_type(args[0].ty) .with_type(args[0].ty)

View File

@ -99,7 +99,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
// TODO: Implement this // TODO: Implement this
self.zombie(result.def, "saturating_add is not implemented yet"); self.zombie(result.def(self), "saturating_add is not implemented yet");
result result
} }
sym::saturating_sub => { sym::saturating_sub => {
@ -115,7 +115,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
)), )),
}; };
// TODO: Implement this // TODO: Implement this
self.zombie(result.def, "saturating_sub is not implemented yet"); self.zombie(result.def(self), "saturating_sub is not implemented yet");
result result
} }
@ -326,7 +326,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
.u_count_leading_zeros_intel( .u_count_leading_zeros_intel(
args[0].immediate().ty, args[0].immediate().ty,
None, None,
args[0].immediate().def, args[0].immediate().def(self),
) )
.unwrap() .unwrap()
.with_type(args[0].immediate().ty) .with_type(args[0].immediate().ty)
@ -340,7 +340,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
.u_count_trailing_zeros_intel( .u_count_trailing_zeros_intel(
args[0].immediate().ty, args[0].immediate().ty,
None, None,
args[0].immediate().def, args[0].immediate().def(self),
) )
.unwrap() .unwrap()
.with_type(args[0].immediate().ty) .with_type(args[0].immediate().ty)
@ -349,12 +349,12 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> {
sym::ctpop => self sym::ctpop => self
.emit() .emit()
.bit_count(args[0].immediate().ty, None, args[0].immediate().def) .bit_count(args[0].immediate().ty, None, args[0].immediate().def(self))
.unwrap() .unwrap()
.with_type(args[0].immediate().ty), .with_type(args[0].immediate().ty),
sym::bitreverse => self sym::bitreverse => self
.emit() .emit()
.bit_reverse(args[0].immediate().ty, None, args[0].immediate().def) .bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self))
.unwrap() .unwrap()
.with_type(args[0].immediate().ty), .with_type(args[0].immediate().ty),

View File

@ -110,7 +110,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
)), )),
}; };
for index in indices.iter().cloned().skip(1) { for index in indices.iter().cloned().skip(1) {
result_indices.push(index.def); result_indices.push(index.def(self));
result_pointee_type = match self.lookup_type(result_pointee_type) { result_pointee_type = match self.lookup_type(result_pointee_type) {
SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => element, SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => element,
_ => self.fatal(&format!( _ => self.fatal(&format!(
@ -127,12 +127,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
if self.builder.lookup_const_u64(indices[0]) == Some(0) { if self.builder.lookup_const_u64(indices[0]) == Some(0) {
if is_inbounds { if is_inbounds {
self.emit() self.emit()
.in_bounds_access_chain(result_type, None, ptr.def, result_indices) .in_bounds_access_chain(result_type, None, ptr.def(self), result_indices)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} else { } else {
self.emit() self.emit()
.access_chain(result_type, None, ptr.def, result_indices) .access_chain(result_type, None, ptr.def(self), result_indices)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} }
@ -142,15 +142,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.in_bounds_ptr_access_chain( .in_bounds_ptr_access_chain(
result_type, result_type,
None, None,
ptr.def, ptr.def(self),
indices[0].def, indices[0].def(self),
result_indices, result_indices,
) )
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
} else { } else {
self.emit() self.emit()
.ptr_access_chain(result_type, None, ptr.def, indices[0].def, result_indices) .ptr_access_chain(
result_type,
None,
ptr.def(self),
indices[0].def(self),
result_indices,
)
.unwrap() .unwrap()
.with_type(result_type) .with_type(result_type)
}; };
@ -159,7 +165,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
.has_capability(rspirv::spirv::Capability::Addresses); .has_capability(rspirv::spirv::Capability::Addresses);
if !has_addresses { if !has_addresses {
self.zombie( self.zombie(
result.def, result.def(self),
"OpPtrAccessChain without OpCapability Addresses", "OpPtrAccessChain without OpCapability Addresses",
); );
} }
@ -193,7 +199,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
// So we need to check for zero shift, and don't use the shift result if it is. // So we need to check for zero shift, and don't use the shift result if it is.
let mask_is_zero = self let mask_is_zero = self
.emit() .emit()
.i_not_equal(bool, None, mask_shift.def, zero.def) .i_not_equal(bool, None, mask_shift.def(self), zero.def(self))
.unwrap() .unwrap()
.with_type(bool); .with_type(bool);
self.select(mask_is_zero, value, or) self.select(mask_is_zero, value, or)
@ -275,7 +281,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> {
dst: PlaceRef<'tcx, Self::Value>, dst: PlaceRef<'tcx, Self::Value>,
) { ) {
fn next<'a, 'tcx>(bx: &mut Builder<'a, 'tcx>, idx: &mut usize) -> SpirvValue { fn next<'a, 'tcx>(bx: &mut Builder<'a, 'tcx>, idx: &mut usize) -> SpirvValue {
let val = bx.function_parameter_values.borrow()[&bx.current_fn.def][*idx]; let val = bx.function_parameter_values.borrow()[&bx.current_fn.def(bx)][*idx];
*idx += 1; *idx += 1;
val val
} }
@ -333,7 +339,7 @@ impl<'a, 'tcx> AbiBuilderMethods<'tcx> for Builder<'a, 'tcx> {
fn apply_attrs_callsite(&mut self, _fn_abi: &FnAbi<'tcx, Ty<'tcx>>, _callsite: Self::Value) {} fn apply_attrs_callsite(&mut self, _fn_abi: &FnAbi<'tcx, Ty<'tcx>>, _callsite: Self::Value) {}
fn get_param(&self, index: usize) -> Self::Value { fn get_param(&self, index: usize) -> Self::Value {
self.function_parameter_values.borrow()[&self.current_fn.def][index] self.function_parameter_values.borrow()[&self.current_fn.def(self)][index]
} }
} }

View File

@ -1,3 +1,6 @@
use crate::builder;
use crate::codegen_cx::CodegenCx;
use crate::spirv_type::SpirvType;
use bimap::BiHashMap; use bimap::BiHashMap;
use rspirv::dr::{Block, Builder, Module, Operand}; use rspirv::dr::{Block, Builder, Module, Operand};
use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word};
@ -6,19 +9,99 @@ use rustc_middle::bug;
use std::cell::{RefCell, RefMut}; use std::cell::{RefCell, RefMut};
use std::{fs::File, io::Write, path::Path}; use std::{fs::File, io::Write, path::Path};
#[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] #[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub enum SpirvValueKind {
Def(Word),
/// There are a fair number of places where `rustc_codegen_ssa` creates a pointer to something
/// that cannot be pointed to in SPIR-V. For example, constant values are frequently emitted as
/// a pointer to constant memory, and then dereferenced where they're used. Functions are the
/// same way, when compiling a call, the function's pointer is loaded, then dereferenced, then
/// called. Directly translating these constructs is impossible, because SPIR-V doesn't allow
/// pointers to constants, or function pointers. So, instead, we create this ConstantPointer
/// "meta-value": directly using it is an error, however, if it is attempted to be
/// dereferenced, the "load" is instead a no-op that returns the underlying value directly.
ConstantPointer(Word),
}
#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)]
pub struct SpirvValue { pub struct SpirvValue {
pub def: Word, pub kind: SpirvValueKind,
pub ty: Word, pub ty: Word,
} }
impl SpirvValue {
pub fn const_ptr_val(self, cx: &CodegenCx<'_>) -> Option<Self> {
match self.kind {
SpirvValueKind::ConstantPointer(word) => {
let ty = match cx.lookup_type(self.ty) {
SpirvType::Pointer {
storage_class: _,
pointee,
} => pointee,
ty => bug!("load called on variable that wasn't a pointer: {:?}", ty),
};
Some(word.with_type(ty))
}
SpirvValueKind::Def(_) => None,
}
}
// Important: we *cannot* use bx.emit() here, because this is called in
// contexts where the emitter is already locked. Doing so may cause subtle
// rare bugs.
pub fn def(self, bx: &builder::Builder<'_, '_>) -> Word {
match self.kind {
SpirvValueKind::Def(word) => word,
SpirvValueKind::ConstantPointer(_) => {
if bx.is_system_crate() {
*bx.zombie_undefs_for_system_constant_pointers
.borrow()
.get(&self.ty)
.expect("ConstantPointer didn't go through proper undef registration")
} else {
bx.err("Cannot use this pointer directly, it must be dereferenced first");
// Because we never get beyond compilation (into e.g. linking),
// emitting an invalid ID reference here is OK.
0
}
}
}
}
// def and def_cx are separated, because Builder has a span associated with
// what it's currently emitting.
pub fn def_cx(self, cx: &CodegenCx<'_>) -> Word {
match self.kind {
SpirvValueKind::Def(word) => word,
SpirvValueKind::ConstantPointer(_) => {
if cx.is_system_crate() {
*cx.zombie_undefs_for_system_constant_pointers
.borrow()
.get(&self.ty)
.expect("ConstantPointer didn't go through proper undef registration")
} else {
cx.tcx
.sess
.err("Cannot use this pointer directly, it must be dereferenced first");
// Because we never get beyond compilation (into e.g. linking),
// emitting an invalid ID reference here is OK.
0
}
}
}
}
}
pub trait SpirvValueExt { pub trait SpirvValueExt {
fn with_type(self, ty: Word) -> SpirvValue; fn with_type(self, ty: Word) -> SpirvValue;
} }
impl SpirvValueExt for Word { impl SpirvValueExt for Word {
fn with_type(self, ty: Word) -> SpirvValue { fn with_type(self, ty: Word) -> SpirvValue {
SpirvValue { def: self, ty } SpirvValue {
kind: SpirvValueKind::Def(self),
ty,
}
} }
} }

View File

@ -70,7 +70,7 @@ impl<'tcx> CodegenCx<'tcx> {
}, },
SpirvType::Integer(128, _) => { SpirvType::Integer(128, _) => {
let result = self.undef(ty); let result = self.undef(ty);
self.zombie_no_span(result.def, "u128 constant"); self.zombie_no_span(result.def_cx(self), "u128 constant");
result result
} }
other => self.tcx.sess.fatal(&format!( other => self.tcx.sess.fatal(&format!(
@ -169,7 +169,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
let len = s.as_str().len(); let len = s.as_str().len();
let ty = self.type_ptr_to(self.layout_of(self.tcx.types.str_).spirv_type(self)); let ty = self.type_ptr_to(self.layout_of(self.tcx.types.str_).spirv_type(self));
let result = self.undef(ty); let result = self.undef(ty);
self.zombie_no_span(result.def, "constant string"); self.zombie_no_span(result.def_cx(self), "constant string");
(result, self.const_usize(len as u64)) (result, self.const_usize(len as u64))
} }
fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value { fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value {
@ -185,7 +185,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
field_names: None, field_names: None,
} }
.def(self); .def(self);
self.constant_composite(struct_ty, elts.iter().map(|f| f.def).collect()) self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect())
} }
fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> { fn const_to_opt_uint(&self, v: Self::Value) -> Option<u64> {
@ -301,7 +301,10 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
) => { ) => {
if a_space != b_space { if a_space != b_space {
// TODO: Emit the correct type that is passed into this function. // TODO: Emit the correct type that is passed into this function.
self.zombie_no_span(value.def, "invalid pointer space in constant"); self.zombie_no_span(
value.def_cx(self),
"invalid pointer space in constant",
);
} }
assert_ty_eq!(self, a, b); assert_ty_eq!(self, a, b);
} }
@ -330,8 +333,8 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> {
val val
} else { } else {
// constant ptrcast is not supported in spir-v // constant ptrcast is not supported in spir-v
let result = val.def.with_type(ty); let result = val.def_cx(self).with_type(ty);
self.zombie_no_span(result.def, "const_ptrcast"); self.zombie_no_span(result.def_cx(self), "const_ptrcast");
result result
} }
} }
@ -395,7 +398,7 @@ impl<'tcx> CodegenCx<'tcx> {
let mut total_offset_end = total_offset_start; let mut total_offset_end = total_offset_start;
values.push( values.push(
self.create_const_alloc2(alloc, &mut total_offset_end, ty) self.create_const_alloc2(alloc, &mut total_offset_end, ty)
.def, .def_cx(self),
); );
occupied_spaces.push(total_offset_start..total_offset_end); occupied_spaces.push(total_offset_start..total_offset_end);
} }
@ -417,7 +420,10 @@ impl<'tcx> CodegenCx<'tcx> {
SpirvType::Array { element, count } => { SpirvType::Array { element, count } => {
let count = self.builder.lookup_const_u64(count).unwrap() as usize; let count = self.builder.lookup_const_u64(count).unwrap() as usize;
let values = (0..count) let values = (0..count)
.map(|_| self.create_const_alloc2(alloc, offset, element).def) .map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
self.constant_composite(ty, values) self.constant_composite(ty, values)
} }
@ -427,7 +433,10 @@ impl<'tcx> CodegenCx<'tcx> {
.expect("create_const_alloc: Vectors must be sized"); .expect("create_const_alloc: Vectors must be sized");
let final_offset = *offset + total_size; let final_offset = *offset + total_size;
let values = (0..count) let values = (0..count)
.map(|_| self.create_const_alloc2(alloc, offset, element).def) .map(|_| {
self.create_const_alloc2(alloc, offset, element)
.def_cx(self)
})
.collect::<Vec<_>>(); .collect::<Vec<_>>();
assert!(*offset <= final_offset); assert!(*offset <= final_offset);
// Vectors sometimes have padding at the end (e.g. vec3), skip over it. // Vectors sometimes have padding at the end (e.g. vec3), skip over it.
@ -437,7 +446,10 @@ impl<'tcx> CodegenCx<'tcx> {
SpirvType::RuntimeArray { element } => { SpirvType::RuntimeArray { element } => {
let mut values = Vec::new(); let mut values = Vec::new();
while offset.bytes_usize() != alloc.len() { while offset.bytes_usize() != alloc.len() {
values.push(self.create_const_alloc2(alloc, offset, element).def); values.push(
self.create_const_alloc2(alloc, offset, element)
.def_cx(self),
);
} }
let result = self.constant_composite(ty, values); let result = self.constant_composite(ty, values);
// TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv: // TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv:
@ -452,7 +464,7 @@ impl<'tcx> CodegenCx<'tcx> {
*data = *c + asdf->y[*c]; *data = *c + asdf->y[*c];
} }
*/ */
self.zombie_no_span(result.def, "constant runtime array value"); self.zombie_no_span(result.def_cx(self), "constant runtime array value");
result result
} }
SpirvType::Pointer { .. } => { SpirvType::Pointer { .. } => {

View File

@ -82,7 +82,7 @@ impl<'tcx> CodegenCx<'tcx> {
// This can happen if we call a blocklisted function in another crate. // This can happen if we call a blocklisted function in another crate.
let result = self.undef(function_type); let result = self.undef(function_type);
// TODO: Span info here // TODO: Span info here
self.zombie_no_span(result.def, "called blocklisted fn"); self.zombie_no_span(result.def_cx(self), "called blocklisted fn");
return result; return result;
} }
let mut emit = self.emit_global(); let mut emit = self.emit_global();
@ -132,7 +132,7 @@ impl<'tcx> CodegenCx<'tcx> {
let span = self.tcx.def_span(def_id); let span = self.tcx.def_span(def_id);
let g = self.declare_global(span, self.layout_of(ty).spirv_type(self)); let g = self.declare_global(span, self.layout_of(ty).spirv_type(self));
self.instances.borrow_mut().insert(instance, g); self.instances.borrow_mut().insert(instance, g);
self.set_linkage(g.def, sym.to_string(), LinkageType::Import); self.set_linkage(g.def_cx(self), sym.to_string(), LinkageType::Import);
g g
} }
@ -147,7 +147,7 @@ impl<'tcx> CodegenCx<'tcx> {
.variable(ptr_ty, None, StorageClass::Function, None) .variable(ptr_ty, None, StorageClass::Function, None)
.with_type(ptr_ty); .with_type(ptr_ty);
// TODO: These should be StorageClass::Private, so just zombie for now. // TODO: These should be StorageClass::Private, so just zombie for now.
self.zombie_with_span(result.def, span, "Globals are not supported yet"); self.zombie_with_span(result.def_cx(self), span, "Globals are not supported yet");
result result
} }
} }
@ -182,7 +182,7 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
self.instances.borrow_mut().insert(instance, g); self.instances.borrow_mut().insert(instance, g);
if let Some(linkage) = linkage { if let Some(linkage) = linkage {
self.set_linkage(g.def, symbol_name.to_string(), linkage); self.set_linkage(g.def_cx(self), symbol_name.to_string(), linkage);
} }
} }
@ -229,7 +229,7 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> {
impl<'tcx> StaticMethods for CodegenCx<'tcx> { impl<'tcx> StaticMethods for CodegenCx<'tcx> {
fn static_addr_of(&self, cv: Self::Value, _align: Align, _kind: Option<&str>) -> Self::Value { fn static_addr_of(&self, cv: Self::Value, _align: Align, _kind: Option<&str>) -> Self::Value {
self.register_constant_pointer(cv) self.make_constant_pointer(cv)
} }
fn codegen_static(&self, def_id: DefId, _is_mutable: bool) { fn codegen_static(&self, def_id: DefId, _is_mutable: bool) {
@ -260,7 +260,8 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> {
} }
assert_ty_eq!(self, value_ty, v.ty); assert_ty_eq!(self, value_ty, v.ty);
self.builder.set_global_initializer(g.def, v.def); self.builder
.set_global_initializer(g.def_cx(self), v.def_cx(self));
} }
/// Mark the given global value as "used", to prevent a backend from potentially removing a /// Mark the given global value as "used", to prevent a backend from potentially removing a

View File

@ -107,7 +107,7 @@ impl<'tcx> CodegenCx<'tcx> {
emit.function_call( emit.function_call(
entry_func_return, entry_func_return,
None, None,
entry_func.def, entry_func.def_cx(self),
arguments.iter().map(|&(a, _)| a), arguments.iter().map(|&(a, _)| a),
) )
.unwrap(); .unwrap();
@ -234,7 +234,7 @@ impl<'tcx> CodegenCx<'tcx> {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
emit.begin_block(None).unwrap(); emit.begin_block(None).unwrap();
let call_result = emit let call_result = emit
.function_call(entry_func_return, None, entry_func.def, arguments) .function_call(entry_func_return, None, entry_func.def_cx(self), arguments)
.unwrap(); .unwrap();
if self.lookup_type(entry_func_return) == SpirvType::Void { if self.lookup_type(entry_func_return) == SpirvType::Void {
emit.ret().unwrap(); emit.ret().unwrap();

View File

@ -4,11 +4,10 @@ mod entry;
mod type_; mod type_;
use crate::builder::ExtInst; use crate::builder::ExtInst;
use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt}; use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueKind};
use crate::finalizing_passes::export_zombies; use crate::finalizing_passes::export_zombies;
use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache}; use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache};
use crate::symbols::Symbols; use crate::symbols::Symbols;
use bimap::BiHashMap;
use rspirv::dr::{Module, Operand}; use rspirv::dr::{Module, Operand};
use rspirv::spirv::{Decoration, LinkageType, MemoryModel, StorageClass, Word}; use rspirv::spirv::{Decoration, LinkageType, MemoryModel, StorageClass, Word};
use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind}; use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind};
@ -52,9 +51,7 @@ pub struct CodegenCx<'tcx> {
/// Cache of all the builtin symbols we need /// Cache of all the builtin symbols we need
pub sym: Box<Symbols>, pub sym: Box<Symbols>,
pub really_unsafe_ignore_bitcasts: RefCell<HashSet<SpirvValue>>, pub really_unsafe_ignore_bitcasts: RefCell<HashSet<SpirvValue>>,
/// Functions created in `get_fn_addr`, and values created in `static_addr_of`. pub zombie_undefs_for_system_constant_pointers: RefCell<HashMap<Word, Word>>,
/// left: the OpUndef pseudo-pointer. right: the concrete value.
constant_pointers: RefCell<BiHashMap<SpirvValue, SpirvValue>>,
/// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec. /// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec.
/// This enables/disables them. /// This enables/disables them.
pub i8_i16_atomics_allowed: bool, pub i8_i16_atomics_allowed: bool,
@ -104,7 +101,7 @@ impl<'tcx> CodegenCx<'tcx> {
kernel_mode, kernel_mode,
sym, sym,
really_unsafe_ignore_bitcasts: Default::default(), really_unsafe_ignore_bitcasts: Default::default(),
constant_pointers: Default::default(), zombie_undefs_for_system_constant_pointers: Default::default(),
i8_i16_atomics_allowed: false, i8_i16_atomics_allowed: false,
} }
} }
@ -185,47 +182,28 @@ impl<'tcx> CodegenCx<'tcx> {
) )
} }
/// Function pointer registration: /// See note on `SpirvValueKind::ConstantPointer`
/// LLVM, and therefore `codegen_ssa`, is very murky with function values vs. function pub fn make_constant_pointer(&self, value: SpirvValue) -> SpirvValue {
/// pointers. So, `codegen_ssa` has a pattern where *even for direct function calls*, it uses
/// `get_fn_*addr*`, and then uses that function *pointer* when calling
/// `BuilderMethods::call()`. However, spir-v doesn't support function pointers! So, instead,
/// when `get_fn_addr` is called, we register a "token" (via `OpUndef`), storing it in a
/// dictionary. Then, when `BuilderMethods::call()` is called, and it's calling a function
/// pointer, we check the dictionary, and if it is, invoke the function directly. It's kind of
/// conceptually similar to a constexpr deref, except specialized to just functions. We also
/// use the same system with `static_addr_of`: we don't support creating addresses of arbitrary
/// constant values, so instead, we register the constant, and whenever we load something, we
/// check if the pointer we're loading is a special `OpUndef` token: if so, we directly
/// reference the registered value.
pub fn register_constant_pointer(&self, value: SpirvValue) -> SpirvValue {
// If we've already registered this value, return it.
if let Some(undef) = self.constant_pointers.borrow().get_by_right(&value) {
return *undef;
}
let ty = SpirvType::Pointer { let ty = SpirvType::Pointer {
storage_class: StorageClass::Function, storage_class: StorageClass::Function,
pointee: value.ty, pointee: value.ty,
} }
.def(self); .def(self);
// We want a unique ID for these undefs, so don't use the caching system. if self.is_system_crate() {
let result = self.emit_global().undef(ty, None).with_type(ty); // Create these undefs up front instead of on demand in SpirvValue::def because
// It's obviously invalid, so zombie it. Zombie it in user code as well, because it's valid there too. // SpirvValue::def can't use cx.emit()
// It'd be realllly nice to give this a Span, the UX of this is horrible... self.zombie_undefs_for_system_constant_pointers
self.zombie_even_in_user_code(result.def, "dynamic use of address of constant"); .borrow_mut()
self.constant_pointers .entry(ty)
.borrow_mut() .or_insert_with(|| {
.insert_no_overwrite(result, value) // We want a unique ID for these undefs, so don't use the caching system.
.unwrap(); self.emit_global().undef(ty, None)
result });
} }
SpirvValue {
/// See comment on `register_constant_pointer` kind: SpirvValueKind::ConstantPointer(value.def_cx(self)),
pub fn lookup_constant_pointer(&self, pointer: SpirvValue) -> Option<SpirvValue> { ty,
self.constant_pointers }
.borrow()
.get_by_left(&pointer)
.cloned()
} }
} }
@ -286,7 +264,7 @@ impl<'tcx> MiscMethods<'tcx> for CodegenCx<'tcx> {
fn get_fn_addr(&self, instance: Instance<'tcx>) -> Self::Value { fn get_fn_addr(&self, instance: Instance<'tcx>) -> Self::Value {
let function = self.get_fn_ext(instance); let function = self.get_fn_ext(instance);
self.register_constant_pointer(function) self.make_constant_pointer(function)
} }
fn eh_personality(&self) -> Self::Value { fn eh_personality(&self) -> Self::Value {

View File

@ -148,7 +148,7 @@ impl SpirvType {
.sizeof(cx) .sizeof(cx)
.expect("Element of sized array must be sized") .expect("Element of sized array must be sized")
.bytes(); .bytes();
let result = cx.emit_global().type_array(element, count.def); let result = cx.emit_global().type_array(element, count.def_cx(cx));
if !cx.kernel_mode { if !cx.kernel_mode {
// TODO: kernel mode can't do this?? // TODO: kernel mode can't do this??
cx.emit_global().decorate( cx.emit_global().decorate(