diff --git a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs index 3179030f40..1e2abbd4f6 100644 --- a/crates/rustc_codegen_spirv/src/builder/builder_methods.rs +++ b/crates/rustc_codegen_spirv/src/builder/builder_methods.rs @@ -24,7 +24,7 @@ macro_rules! simple_op { assert_ty_eq!(self, lhs.ty, rhs.ty); let result_type = lhs.ty; self.emit() - .$inst_name(result_type, None, lhs.def, rhs.def) + .$inst_name(result_type, None, lhs.def(self), rhs.def(self)) .unwrap() .with_type(result_type) } @@ -36,7 +36,7 @@ macro_rules! simple_op_unchecked_type { ($func_name:ident, $inst_name:ident) => { fn $func_name(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value { self.emit() - .$inst_name(lhs.ty, None, lhs.def, rhs.def) + .$inst_name(lhs.ty, None, lhs.def(self), rhs.def(self)) .unwrap() .with_type(lhs.ty) } @@ -47,7 +47,7 @@ macro_rules! simple_uni_op { ($func_name:ident, $inst_name:ident) => { fn $func_name(&mut self, val: Self::Value) -> Self::Value { self.emit() - .$inst_name(val.ty, None, val.def) + .$inst_name(val.ty, None, val.def(self)) .unwrap() .with_type(val.ty) } @@ -133,7 +133,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { let semantics = self.constant_u32(semantics.bits()); if invalid_seq_cst { self.zombie( - semantics.def, + semantics.def(self), "Cannot use AtomicOrdering=SequentiallyConsistent on Vulkan memory model. Check if AcquireRelease fits your needs.", ); } @@ -145,24 +145,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { SpirvType::Void => self.fatal("memset invalid on void pattern"), SpirvType::Bool => self.fatal("memset invalid on bool pattern"), SpirvType::Integer(width, _signedness) => match width { - 8 => self.constant_u8(fill_byte).def, - 16 => self.constant_u16(memset_fill_u16(fill_byte)).def, - 32 => self.constant_u32(memset_fill_u32(fill_byte)).def, - 64 => self.constant_u64(memset_fill_u64(fill_byte)).def, + 8 => self.constant_u8(fill_byte).def(self), + 16 => self.constant_u16(memset_fill_u16(fill_byte)).def(self), + 32 => self.constant_u32(memset_fill_u32(fill_byte)).def(self), + 64 => self.constant_u64(memset_fill_u64(fill_byte)).def(self), _ => self.fatal(&format!( "memset on integer width {} not implemented yet", width )), }, SpirvType::Float(width) => match width { - 32 => { - self.constant_f32(f32::from_bits(memset_fill_u32(fill_byte))) - .def - } - 64 => { - self.constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) - .def - } + 32 => self + .constant_f32(f32::from_bits(memset_fill_u32(fill_byte))) + .def(self), + 64 => self + .constant_f64(f64::from_bits(memset_fill_u64(fill_byte))) + .def(self), _ => self.fatal(&format!( "memset on float width {} not implemented yet", width @@ -173,13 +171,13 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { SpirvType::Vector { element, count } => { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); self.constant_composite(ty.clone().def(self), vec![elem_pat; count as usize]) - .def + .def(self) } SpirvType::Array { element, count } => { let elem_pat = self.memset_const_pattern(&self.lookup_type(element), fill_byte); let count = self.builder.lookup_const_u64(count).unwrap() as usize; self.constant_composite(ty.clone().def(self), vec![elem_pat; count]) - .def + .def(self) } SpirvType::RuntimeArray { .. } => { self.fatal("memset on runtime arrays not implemented yet") @@ -371,7 +369,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn new_block<'b>(cx: &'a Self::CodegenCx, llfn: Self::Function, _name: &'b str) -> Self { - let cursor_fn = cx.builder.select_function_by_id(llfn.def); + let cursor_fn = cx.builder.select_function_by_id(llfn.def_cx(cx)); let label = cx.emit_with_cursor(cursor_fn).begin_block(None).unwrap(); let cursor = cx.builder.select_block_by_id(label); Self { @@ -388,7 +386,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Self { cx, cursor: Default::default(), - current_fn: Default::default(), + current_fn: 0.with_type(0), basic_block: Default::default(), current_span: Default::default(), } @@ -446,7 +444,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn ret(&mut self, value: Self::Value) { - self.emit().ret_value(value.def).unwrap(); + self.emit().ret_value(value.def(self)).unwrap(); } fn br(&mut self, dest: Self::BasicBlock) { @@ -464,7 +462,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { else_llbb: Self::BasicBlock, ) { self.emit() - .branch_conditional(cond.def, then_llbb, else_llbb, empty()) + .branch_conditional(cond.def(self), then_llbb, else_llbb, empty()) .unwrap() } @@ -547,7 +545,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let cases = cases .map(|(i, b)| (construct_case(self, signed, i), b)) .collect::>(); - self.emit().switch(v.def, else_llbb, cases).unwrap() + self.emit().switch(v.def(self), else_llbb, cases).unwrap() } fn invoke( @@ -606,8 +604,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { assert_ty_eq!(self, lhs.ty, rhs.ty); let ty = lhs.ty; match self.lookup_type(ty) { - SpirvType::Integer(_, _) => self.emit().bitwise_and(ty, None, lhs.def, rhs.def), - SpirvType::Bool => self.emit().logical_and(ty, None, lhs.def, rhs.def), + SpirvType::Integer(_, _) => { + self.emit() + .bitwise_and(ty, None, lhs.def(self), rhs.def(self)) + } + SpirvType::Bool => self + .emit() + .logical_and(ty, None, lhs.def(self), rhs.def(self)), o => self.fatal(&format!( "and() not implemented for type {}", o.debug(ty, self) @@ -620,8 +623,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { assert_ty_eq!(self, lhs.ty, rhs.ty); let ty = lhs.ty; match self.lookup_type(ty) { - SpirvType::Integer(_, _) => self.emit().bitwise_or(ty, None, lhs.def, rhs.def), - SpirvType::Bool => self.emit().logical_or(ty, None, lhs.def, rhs.def), + SpirvType::Integer(_, _) => { + self.emit() + .bitwise_or(ty, None, lhs.def(self), rhs.def(self)) + } + SpirvType::Bool => self + .emit() + .logical_or(ty, None, lhs.def(self), rhs.def(self)), o => self.fatal(&format!( "or() not implemented for type {}", o.debug(ty, self) @@ -634,8 +642,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { assert_ty_eq!(self, lhs.ty, rhs.ty); let ty = lhs.ty; match self.lookup_type(ty) { - SpirvType::Integer(_, _) => self.emit().bitwise_xor(ty, None, lhs.def, rhs.def), - SpirvType::Bool => self.emit().logical_not_equal(ty, None, lhs.def, rhs.def), + SpirvType::Integer(_, _) => { + self.emit() + .bitwise_xor(ty, None, lhs.def(self), rhs.def(self)) + } + SpirvType::Bool => { + self.emit() + .logical_not_equal(ty, None, lhs.def(self), rhs.def(self)) + } o => self.fatal(&format!( "xor() not implemented for type {}", o.debug(ty, self) @@ -646,12 +660,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn not(&mut self, val: Self::Value) -> Self::Value { match self.lookup_type(val.ty) { - SpirvType::Integer(_, _) => self.emit().not(val.ty, None, val.def), + SpirvType::Integer(_, _) => self.emit().not(val.ty, None, val.def(self)), SpirvType::Bool => { let true_ = self.constant_bool(true); // intel-compute-runtime doesn't like OpLogicalNot self.emit() - .logical_not_equal(val.ty, None, val.def, true_.def) + .logical_not_equal(val.ty, None, val.def(self), true_.def(self)) } o => self.fatal(&format!( "not() not implemented for type {}", @@ -676,7 +690,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { OverflowOp::Mul => (self.mul(lhs, rhs), fals), }; self.zombie( - result.1.def, + result.1.def(self), match oop { OverflowOp::Add => "checked add is not supported yet", OverflowOp::Sub => "checked sub is not supported yet", @@ -751,8 +765,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } fn load(&mut self, ptr: Self::Value, _align: Align) -> Self::Value { - // See comment on `register_constant_pointer` - if let Some(value) = self.lookup_constant_pointer(ptr) { + // See comment on `SpirvValueKind::ConstantPointer` + if let Some(value) = ptr.const_ptr_val(self) { return value; } let ty = match self.lookup_type(ptr.ty) { @@ -766,7 +780,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; self.emit() - .load(ty, None, ptr.def, None, empty()) + .load(ty, None, ptr.def(self), None, empty()) .unwrap() .with_type(ty) } @@ -774,7 +788,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn volatile_load(&mut self, ptr: Self::Value) -> Self::Value { // TODO: Implement this let result = self.load(ptr, Align::from_bytes(0).unwrap()); - self.zombie(result.def, "volatile load is not supported yet"); + self.zombie(result.def(self), "volatile load is not supported yet"); result } @@ -794,10 +808,16 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let semantics = self.ordering_to_semantics_def(order); let result = self .emit() - .atomic_load(ty, None, ptr.def, memory.def, semantics.def) + .atomic_load( + ty, + None, + ptr.def(self), + memory.def(self), + semantics.def(self), + ) .unwrap() .with_type(ty); - self.validate_atomic(ty, result.def); + self.validate_atomic(ty, result.def(self)); result } @@ -887,7 +907,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; assert_ty_eq!(self, ptr_elem_ty, val.ty); - self.emit().store(ptr.def, val.def, None, empty()).unwrap(); + self.emit() + .store(ptr.def(self), val.def(self), None, empty()) + .unwrap(); val } @@ -928,9 +950,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // TODO: Default to device scope let memory = self.constant_u32(Scope::Device as u32); let semantics = self.ordering_to_semantics_def(order); - self.validate_atomic(val.ty, ptr.def); + self.validate_atomic(val.ty, ptr.def(self)); self.emit() - .atomic_store(ptr.def, memory.def, semantics.def, val.def) + .atomic_store( + ptr.def(self), + memory.def(self), + semantics.def(self), + val.def(self), + ) .unwrap(); } @@ -972,9 +999,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { if idx > u32::MAX as u64 { self.fatal("struct_gep bigger than u32::MAX"); } - let index_const = self.constant_u32(idx as u32).def; + let index_const = self.constant_u32(idx as u32).def(self); self.emit() - .access_chain(result_type, None, ptr.def, [index_const].iter().cloned()) + .access_chain( + result_type, + None, + ptr.def(self), + [index_const].iter().cloned(), + ) .unwrap() .with_type(result_type) } @@ -1003,7 +1035,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .convert_f_to_u(dest_ty, None, val.def) + .convert_f_to_u(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1014,7 +1046,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .convert_f_to_s(dest_ty, None, val.def) + .convert_f_to_s(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1025,7 +1057,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .convert_u_to_f(dest_ty, None, val.def) + .convert_u_to_f(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1036,7 +1068,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .convert_s_to_f(dest_ty, None, val.def) + .convert_s_to_f(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1047,7 +1079,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .f_convert(dest_ty, None, val.def) + .f_convert(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1058,7 +1090,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { val } else { self.emit() - .f_convert(dest_ty, None, val.def) + .f_convert(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty) } @@ -1077,10 +1109,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } else { let result = self .emit() - .convert_ptr_to_u(dest_ty, None, val.def) + .convert_ptr_to_u(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty); - self.zombie_convert_ptr_to_u(result.def); + self.zombie_convert_ptr_to_u(result.def(self)); result } } @@ -1098,10 +1130,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } else { let result = self .emit() - .convert_u_to_ptr(dest_ty, None, val.def) + .convert_u_to_ptr(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty); - self.zombie_convert_u_to_ptr(result.def); + self.zombie_convert_u_to_ptr(result.def(self)); result } } @@ -1112,13 +1144,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } else { let result = self .emit() - .bitcast(dest_ty, None, val.def) + .bitcast(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty); let val_is_ptr = matches!(self.lookup_type(val.ty), SpirvType::Pointer{..}); let dest_is_ptr = matches!(self.lookup_type(dest_ty), SpirvType::Pointer{..}); if val_is_ptr || dest_is_ptr { - self.zombie_bitcast_ptr(result.def, val.ty, dest_ty); + self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty); } result } @@ -1136,7 +1168,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { SpirvType::Integer(dest_width, dest_signedness), ) if val_width == dest_width && val_signedness != dest_signedness => self .emit() - .bitcast(dest_ty, None, val.def) + .bitcast(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty), // width change, and optional sign change @@ -1144,9 +1176,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // spir-v spec doesn't seem to say that signedness needs to match the operands, only that the signedness // of the destination type must match the instruction's signedness. if dest_signedness { - self.emit().s_convert(dest_ty, None, val.def) + self.emit().s_convert(dest_ty, None, val.def(self)) } else { - self.emit().u_convert(dest_ty, None, val.def) + self.emit().u_convert(dest_ty, None, val.def(self)) } .unwrap() .with_type(dest_ty) @@ -1157,7 +1189,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let if_true = self.constant_int(dest_ty, 1); let if_false = self.constant_int(dest_ty, 0); self.emit() - .select(dest_ty, None, val.def, if_true.def, if_false.def) + .select( + dest_ty, + None, + val.def(self), + if_true.def(self), + if_false.def(self), + ) .unwrap() .with_type(dest_ty) } @@ -1165,7 +1203,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { // spir-v doesn't have a direct conversion instruction, glslang emits OpINotEqual let zero = self.constant_int(val.ty, 0); self.emit() - .i_not_equal(dest_ty, None, val.def, zero.def) + .i_not_equal(dest_ty, None, val.def(self), zero.def(self)) .unwrap() .with_type(dest_ty) } @@ -1196,10 +1234,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } else if let Some(indices) = self.try_pointercast_via_gep(val_pointee, dest_pointee) { let indices = indices .into_iter() - .map(|idx| self.constant_u32(idx).def) + .map(|idx| self.constant_u32(idx).def(self)) .collect::>(); self.emit() - .access_chain(dest_ty, None, val.def, indices) + .access_chain(dest_ty, None, val.def(self), indices) .unwrap() .with_type(dest_ty) } else if self @@ -1211,10 +1249,10 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } else { let result = self .emit() - .bitcast(dest_ty, None, val.def) + .bitcast(dest_ty, None, val.def(self)) .unwrap() .with_type(dest_ty); - self.zombie_bitcast_ptr(result.def, val.ty, dest_ty); + self.zombie_bitcast_ptr(result.def(self), val.ty, dest_ty); result } } @@ -1226,71 +1264,126 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let b = SpirvType::Bool.def(self); match self.lookup_type(lhs.ty) { SpirvType::Integer(_, _) => match op { - IntEQ => self.emit().i_equal(b, None, lhs.def, rhs.def), - IntNE => self.emit().i_not_equal(b, None, lhs.def, rhs.def), - IntUGT => self.emit().u_greater_than(b, None, lhs.def, rhs.def), - IntUGE => self.emit().u_greater_than_equal(b, None, lhs.def, rhs.def), - IntULT => self.emit().u_less_than(b, None, lhs.def, rhs.def), - IntULE => self.emit().u_less_than_equal(b, None, lhs.def, rhs.def), - IntSGT => self.emit().s_greater_than(b, None, lhs.def, rhs.def), - IntSGE => self.emit().s_greater_than_equal(b, None, lhs.def, rhs.def), - IntSLT => self.emit().s_less_than(b, None, lhs.def, rhs.def), - IntSLE => self.emit().s_less_than_equal(b, None, lhs.def, rhs.def), + IntEQ => self.emit().i_equal(b, None, lhs.def(self), rhs.def(self)), + IntNE => self + .emit() + .i_not_equal(b, None, lhs.def(self), rhs.def(self)), + IntUGT => self + .emit() + .u_greater_than(b, None, lhs.def(self), rhs.def(self)), + IntUGE => self + .emit() + .u_greater_than_equal(b, None, lhs.def(self), rhs.def(self)), + IntULT => self + .emit() + .u_less_than(b, None, lhs.def(self), rhs.def(self)), + IntULE => self + .emit() + .u_less_than_equal(b, None, lhs.def(self), rhs.def(self)), + IntSGT => self + .emit() + .s_greater_than(b, None, lhs.def(self), rhs.def(self)), + IntSGE => self + .emit() + .s_greater_than_equal(b, None, lhs.def(self), rhs.def(self)), + IntSLT => self + .emit() + .s_less_than(b, None, lhs.def(self), rhs.def(self)), + IntSLE => self + .emit() + .s_less_than_equal(b, None, lhs.def(self), rhs.def(self)), }, SpirvType::Pointer { .. } => match op { IntEQ => { if self.emit().version().unwrap() > (1, 3) { - self.emit().ptr_equal(b, None, lhs.def, rhs.def) + self.emit().ptr_equal(b, None, lhs.def(self), rhs.def(self)) } else { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().i_not_equal(b, None, lhs, rhs) } } IntNE => { if self.emit().version().unwrap() > (1, 3) { - self.emit().ptr_not_equal(b, None, lhs.def, rhs.def) + self.emit() + .ptr_not_equal(b, None, lhs.def(self), rhs.def(self)) } else { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().i_not_equal(b, None, lhs, rhs) } } IntUGT => { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().u_greater_than(b, None, lhs, rhs) } IntUGE => { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().u_greater_than_equal(b, None, lhs, rhs) } IntULT => { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().u_less_than(b, None, lhs, rhs) } IntULE => { let int_ty = self.type_usize(); - let lhs = self.emit().convert_ptr_to_u(int_ty, None, lhs.def).unwrap(); + let lhs = self + .emit() + .convert_ptr_to_u(int_ty, None, lhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(lhs); - let rhs = self.emit().convert_ptr_to_u(int_ty, None, rhs.def).unwrap(); + let rhs = self + .emit() + .convert_ptr_to_u(int_ty, None, rhs.def(self)) + .unwrap(); self.zombie_convert_ptr_to_u(rhs); self.emit().u_less_than_equal(b, None, lhs, rhs) } @@ -1300,44 +1393,48 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { IntSLE => self.fatal("TODO: pointer operator IntSLE not implemented yet"), }, SpirvType::Bool => match op { - IntEQ => self.emit().logical_equal(b, None, lhs.def, rhs.def), - IntNE => self.emit().logical_not_equal(b, None, lhs.def, rhs.def), + IntEQ => self + .emit() + .logical_equal(b, None, lhs.def(self), rhs.def(self)), + IntNE => self + .emit() + .logical_not_equal(b, None, lhs.def(self), rhs.def(self)), // x > y => x && !y IntUGT => { // intel-compute-runtime doesn't like OpLogicalNot let true_ = self.constant_bool(true); let rhs = self .emit() - .logical_not_equal(b, None, rhs.def, true_.def) + .logical_not_equal(b, None, rhs.def(self), true_.def(self)) .unwrap(); - self.emit().logical_and(b, None, lhs.def, rhs) + self.emit().logical_and(b, None, lhs.def(self), rhs) } // x >= y => x || !y IntUGE => { let true_ = self.constant_bool(true); let rhs = self .emit() - .logical_not_equal(b, None, rhs.def, true_.def) + .logical_not_equal(b, None, rhs.def(self), true_.def(self)) .unwrap(); - self.emit().logical_or(b, None, lhs.def, rhs) + self.emit().logical_or(b, None, lhs.def(self), rhs) } // x < y => !x && y IntULE => { let true_ = self.constant_bool(true); let lhs = self .emit() - .logical_not_equal(b, None, lhs.def, true_.def) + .logical_not_equal(b, None, lhs.def(self), true_.def(self)) .unwrap(); - self.emit().logical_and(b, None, lhs, rhs.def) + self.emit().logical_and(b, None, lhs, rhs.def(self)) } // x <= y => !x || y IntULT => { let true_ = self.constant_bool(true); let lhs = self .emit() - .logical_not_equal(b, None, lhs.def, true_.def) + .logical_not_equal(b, None, lhs.def(self), true_.def(self)) .unwrap(); - self.emit().logical_or(b, None, lhs, rhs.def) + self.emit().logical_or(b, None, lhs, rhs.def(self)) } IntSGT => self.fatal("TODO: boolean operator IntSGT not implemented yet"), IntSGE => self.fatal("TODO: boolean operator IntSGE not implemented yet"), @@ -1360,26 +1457,45 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { match op { RealPredicateFalse => return self.cx.constant_bool(false), RealPredicateTrue => return self.cx.constant_bool(true), - RealOEQ => self.emit().f_ord_equal(b, None, lhs.def, rhs.def), - RealOGT => self.emit().f_ord_greater_than(b, None, lhs.def, rhs.def), + RealOEQ => self + .emit() + .f_ord_equal(b, None, lhs.def(self), rhs.def(self)), + RealOGT => self + .emit() + .f_ord_greater_than(b, None, lhs.def(self), rhs.def(self)), RealOGE => self .emit() - .f_ord_greater_than_equal(b, None, lhs.def, rhs.def), - RealOLT => self.emit().f_ord_less_than(b, None, lhs.def, rhs.def), - RealOLE => self.emit().f_ord_less_than_equal(b, None, lhs.def, rhs.def), - RealONE => self.emit().f_ord_not_equal(b, None, lhs.def, rhs.def), - RealORD => self.emit().ordered(b, None, lhs.def, rhs.def), - RealUNO => self.emit().unordered(b, None, lhs.def, rhs.def), - RealUEQ => self.emit().f_unord_equal(b, None, lhs.def, rhs.def), - RealUGT => self.emit().f_unord_greater_than(b, None, lhs.def, rhs.def), - RealUGE => self + .f_ord_greater_than_equal(b, None, lhs.def(self), rhs.def(self)), + RealOLT => self .emit() - .f_unord_greater_than_equal(b, None, lhs.def, rhs.def), - RealULT => self.emit().f_unord_less_than(b, None, lhs.def, rhs.def), + .f_ord_less_than(b, None, lhs.def(self), rhs.def(self)), + RealOLE => self + .emit() + .f_ord_less_than_equal(b, None, lhs.def(self), rhs.def(self)), + RealONE => self + .emit() + .f_ord_not_equal(b, None, lhs.def(self), rhs.def(self)), + RealORD => self.emit().ordered(b, None, lhs.def(self), rhs.def(self)), + RealUNO => self.emit().unordered(b, None, lhs.def(self), rhs.def(self)), + RealUEQ => self + .emit() + .f_unord_equal(b, None, lhs.def(self), rhs.def(self)), + RealUGT => self + .emit() + .f_unord_greater_than(b, None, lhs.def(self), rhs.def(self)), + RealUGE => { + self.emit() + .f_unord_greater_than_equal(b, None, lhs.def(self), rhs.def(self)) + } + RealULT => self + .emit() + .f_unord_less_than(b, None, lhs.def(self), rhs.def(self)), RealULE => self .emit() - .f_unord_less_than_equal(b, None, lhs.def, rhs.def), - RealUNE => self.emit().f_unord_not_equal(b, None, lhs.def, rhs.def), + .f_unord_less_than_equal(b, None, lhs.def(self), rhs.def(self)), + RealUNE => self + .emit() + .f_unord_not_equal(b, None, lhs.def(self), rhs.def(self)), } .unwrap() .with_type(b) @@ -1411,19 +1527,31 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { }; let src_element_size = src_pointee.and_then(|p| self.lookup_type(p).sizeof(self)); if src_element_size.is_some() && src_element_size == const_size.map(Size::from_bytes) { - if let Some(const_value) = self.lookup_constant_pointer(src) { + // See comment on `SpirvValueKind::ConstantPointer` + + if let Some(const_value) = src.const_ptr_val(self) { self.store(const_value, dst, Align::from_bytes(0).unwrap()); } else { self.emit() - .copy_memory(dst.def, src.def, None, None, empty()) + .copy_memory(dst.def(self), src.def(self), None, None, empty()) .unwrap(); } } else { self.emit() - .copy_memory_sized(dst.def, src.def, size.def, None, None, empty()) + .copy_memory_sized( + dst.def(self), + src.def(self), + size.def(self), + None, + None, + empty(), + ) .unwrap(); if !self.builder.has_capability(Capability::Addresses) { - self.zombie(dst.def, "OpCopyMemorySized without OpCapability Addresses") + self.zombie( + dst.def(self), + "OpCopyMemorySized without OpCapability Addresses", + ) } } } @@ -1464,7 +1592,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { let elem_ty_spv = self.lookup_type(elem_ty); let pat = match self.builder.lookup_const_u64(fill_byte) { Some(fill_byte) => self.memset_const_pattern(&elem_ty_spv, fill_byte as u8), - None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def), + None => self.memset_dynamic_pattern(&elem_ty_spv, fill_byte.def(self)), } .with_type(elem_ty); match self.builder.lookup_const_u64(size) { @@ -1482,7 +1610,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { assert_ty_eq!(self, then_val.ty, else_val.ty); let result_type = then_val.ty; self.emit() - .select(result_type, None, cond.def, then_val.def, else_val.def) + .select( + result_type, + None, + cond.def(self), + then_val.def(self), + else_val.def(self), + ) .unwrap() .with_type(result_type) } @@ -1503,12 +1637,13 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { Some(const_index) => self.emit().composite_extract( result_type, None, - vec.def, + vec.def(self), [const_index as u32].iter().cloned(), ), - None => self - .emit() - .vector_extract_dynamic(result_type, None, vec.def, idx.def), + None => { + self.emit() + .vector_extract_dynamic(result_type, None, vec.def(self), idx.def(self)) + } } .unwrap() .with_type(result_type) @@ -1521,10 +1656,14 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { } .def(self); if self.builder.lookup_const(elt).is_some() { - self.constant_composite(result_type, vec![elt.def; num_elts]) + self.constant_composite(result_type, vec![elt.def(self); num_elts]) } else { self.emit() - .composite_construct(result_type, None, std::iter::repeat(elt.def).take(num_elts)) + .composite_construct( + result_type, + None, + std::iter::repeat(elt.def(self)).take(num_elts), + ) .unwrap() .with_type(result_type) } @@ -1539,7 +1678,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; self.emit() - .composite_extract(result_type, None, agg_val.def, [idx as u32].iter().cloned()) + .composite_extract( + result_type, + None, + agg_val.def(self), + [idx as u32].iter().cloned(), + ) .unwrap() .with_type(result_type) } @@ -1555,8 +1699,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .composite_insert( agg_val.ty, None, - elt.def, - agg_val.def, + elt.def(self), + agg_val.def(self), [idx as u32].iter().cloned(), ) .unwrap() @@ -1638,7 +1782,7 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { }; assert_ty_eq!(self, dst_pointee_ty, cmp.ty); assert_ty_eq!(self, dst_pointee_ty, src.ty); - self.validate_atomic(dst_pointee_ty, dst.def); + self.validate_atomic(dst_pointee_ty, dst.def(self)); // TODO: Default to device scope let memory = self.constant_u32(Scope::Device as u32); let semantics_equal = self.ordering_to_semantics_def(order); @@ -1648,12 +1792,12 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { .atomic_compare_exchange( src.ty, None, - dst.def, - memory.def, - semantics_equal.def, - semantics_unequal.def, - src.def, - cmp.def, + dst.def(self), + memory.def(self), + semantics_equal.def(self), + semantics_unequal.def(self), + src.def(self), + cmp.def(self), ) .unwrap() .with_type(src.ty) @@ -1677,24 +1821,94 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { )), }; assert_ty_eq!(self, dst_pointee_ty, src.ty); - self.validate_atomic(dst_pointee_ty, dst.def); + self.validate_atomic(dst_pointee_ty, dst.def(self)); // TODO: Default to device scope - let memory = self.constant_u32(Scope::Device as u32).def; - let semantics = self.ordering_to_semantics_def(order).def; + let memory = self.constant_u32(Scope::Device as u32).def(self); + let semantics = self.ordering_to_semantics_def(order).def(self); let mut emit = self.emit(); use AtomicRmwBinOp::*; match op { - AtomicXchg => emit.atomic_exchange(src.ty, None, dst.def, memory, semantics, src.def), - AtomicAdd => emit.atomic_i_add(src.ty, None, dst.def, memory, semantics, src.def), - AtomicSub => emit.atomic_i_sub(src.ty, None, dst.def, memory, semantics, src.def), - AtomicAnd => emit.atomic_and(src.ty, None, dst.def, memory, semantics, src.def), + AtomicXchg => emit.atomic_exchange( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicAdd => emit.atomic_i_add( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicSub => emit.atomic_i_sub( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicAnd => emit.atomic_and( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), AtomicNand => self.fatal("atomic nand is not supported"), - AtomicOr => emit.atomic_or(src.ty, None, dst.def, memory, semantics, src.def), - AtomicXor => emit.atomic_xor(src.ty, None, dst.def, memory, semantics, src.def), - AtomicMax => emit.atomic_s_max(src.ty, None, dst.def, memory, semantics, src.def), - AtomicMin => emit.atomic_s_min(src.ty, None, dst.def, memory, semantics, src.def), - AtomicUMax => emit.atomic_u_max(src.ty, None, dst.def, memory, semantics, src.def), - AtomicUMin => emit.atomic_u_min(src.ty, None, dst.def, memory, semantics, src.def), + AtomicOr => emit.atomic_or( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicXor => emit.atomic_xor( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicMax => emit.atomic_s_max( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicMin => emit.atomic_s_min( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicUMax => emit.atomic_u_max( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), + AtomicUMin => emit.atomic_u_min( + src.ty, + None, + dst.def(self), + memory, + semantics, + src.def(self), + ), } .unwrap() .with_type(src.ty) @@ -1703,8 +1917,8 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { fn atomic_fence(&mut self, order: AtomicOrdering, _scope: SynchronizationScope) { // Ignore sync scope (it only has "single thread" and "cross thread") // TODO: Default to device scope - let memory = self.constant_u32(Scope::Device as u32).def; - let semantics = self.ordering_to_semantics_def(order).def; + let memory = self.constant_u32(Scope::Device as u32).def(self); + let semantics = self.ordering_to_semantics_def(order).def(self); self.emit().memory_barrier(memory, semantics).unwrap(); } @@ -1759,9 +1973,9 @@ impl<'a, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'tcx> { for (argument, argument_type) in args.iter().zip(argument_types) { assert_ty_eq!(self, argument.ty, argument_type); } - let args = args.iter().map(|arg| arg.def).collect::>(); + let args = args.iter().map(|arg| arg.def(self)).collect::>(); self.emit() - .function_call(result_type, None, llfn.def, args) + .function_call(result_type, None, llfn.def(self), args) .unwrap() .with_type(result_type) } diff --git a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs index f845226031..5b69368de3 100644 --- a/crates/rustc_codegen_spirv/src/builder/ext_inst.rs +++ b/crates/rustc_codegen_spirv/src/builder/ext_inst.rs @@ -62,7 +62,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { None, glsl, op as u32, - args.iter().map(|a| Operand::IdRef(a.def)), + args.iter().map(|a| Operand::IdRef(a.def(self))), ) .unwrap() .with_type(args[0].ty) @@ -77,7 +77,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { None, opencl, op as u32, - args.iter().map(|a| Operand::IdRef(a.def)), + args.iter().map(|a| Operand::IdRef(a.def(self))), ) .unwrap() .with_type(args[0].ty) diff --git a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs index a783ab533d..349df7328d 100644 --- a/crates/rustc_codegen_spirv/src/builder/intrinsics.rs +++ b/crates/rustc_codegen_spirv/src/builder/intrinsics.rs @@ -99,7 +99,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { )), }; // TODO: Implement this - self.zombie(result.def, "saturating_add is not implemented yet"); + self.zombie(result.def(self), "saturating_add is not implemented yet"); result } sym::saturating_sub => { @@ -115,7 +115,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { )), }; // TODO: Implement this - self.zombie(result.def, "saturating_sub is not implemented yet"); + self.zombie(result.def(self), "saturating_sub is not implemented yet"); result } @@ -326,7 +326,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { .u_count_leading_zeros_intel( args[0].immediate().ty, None, - args[0].immediate().def, + args[0].immediate().def(self), ) .unwrap() .with_type(args[0].immediate().ty) @@ -340,7 +340,7 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { .u_count_trailing_zeros_intel( args[0].immediate().ty, None, - args[0].immediate().def, + args[0].immediate().def(self), ) .unwrap() .with_type(args[0].immediate().ty) @@ -349,12 +349,12 @@ impl<'a, 'tcx> IntrinsicCallMethods<'tcx> for Builder<'a, 'tcx> { sym::ctpop => self .emit() - .bit_count(args[0].immediate().ty, None, args[0].immediate().def) + .bit_count(args[0].immediate().ty, None, args[0].immediate().def(self)) .unwrap() .with_type(args[0].immediate().ty), sym::bitreverse => self .emit() - .bit_reverse(args[0].immediate().ty, None, args[0].immediate().def) + .bit_reverse(args[0].immediate().ty, None, args[0].immediate().def(self)) .unwrap() .with_type(args[0].immediate().ty), diff --git a/crates/rustc_codegen_spirv/src/builder/mod.rs b/crates/rustc_codegen_spirv/src/builder/mod.rs index e115966fd3..fb2ab79074 100644 --- a/crates/rustc_codegen_spirv/src/builder/mod.rs +++ b/crates/rustc_codegen_spirv/src/builder/mod.rs @@ -110,7 +110,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { )), }; for index in indices.iter().cloned().skip(1) { - result_indices.push(index.def); + result_indices.push(index.def(self)); result_pointee_type = match self.lookup_type(result_pointee_type) { SpirvType::Array { element, .. } | SpirvType::RuntimeArray { element } => element, _ => self.fatal(&format!( @@ -127,12 +127,12 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { if self.builder.lookup_const_u64(indices[0]) == Some(0) { if is_inbounds { self.emit() - .in_bounds_access_chain(result_type, None, ptr.def, result_indices) + .in_bounds_access_chain(result_type, None, ptr.def(self), result_indices) .unwrap() .with_type(result_type) } else { self.emit() - .access_chain(result_type, None, ptr.def, result_indices) + .access_chain(result_type, None, ptr.def(self), result_indices) .unwrap() .with_type(result_type) } @@ -142,15 +142,21 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .in_bounds_ptr_access_chain( result_type, None, - ptr.def, - indices[0].def, + ptr.def(self), + indices[0].def(self), result_indices, ) .unwrap() .with_type(result_type) } else { self.emit() - .ptr_access_chain(result_type, None, ptr.def, indices[0].def, result_indices) + .ptr_access_chain( + result_type, + None, + ptr.def(self), + indices[0].def(self), + result_indices, + ) .unwrap() .with_type(result_type) }; @@ -159,7 +165,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { .has_capability(rspirv::spirv::Capability::Addresses); if !has_addresses { self.zombie( - result.def, + result.def(self), "OpPtrAccessChain without OpCapability Addresses", ); } @@ -193,7 +199,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> { // So we need to check for zero shift, and don't use the shift result if it is. let mask_is_zero = self .emit() - .i_not_equal(bool, None, mask_shift.def, zero.def) + .i_not_equal(bool, None, mask_shift.def(self), zero.def(self)) .unwrap() .with_type(bool); self.select(mask_is_zero, value, or) @@ -275,7 +281,7 @@ impl<'a, 'tcx> ArgAbiMethods<'tcx> for Builder<'a, 'tcx> { dst: PlaceRef<'tcx, Self::Value>, ) { fn next<'a, 'tcx>(bx: &mut Builder<'a, 'tcx>, idx: &mut usize) -> SpirvValue { - let val = bx.function_parameter_values.borrow()[&bx.current_fn.def][*idx]; + let val = bx.function_parameter_values.borrow()[&bx.current_fn.def(bx)][*idx]; *idx += 1; val } @@ -333,7 +339,7 @@ impl<'a, 'tcx> AbiBuilderMethods<'tcx> for Builder<'a, 'tcx> { fn apply_attrs_callsite(&mut self, _fn_abi: &FnAbi<'tcx, Ty<'tcx>>, _callsite: Self::Value) {} fn get_param(&self, index: usize) -> Self::Value { - self.function_parameter_values.borrow()[&self.current_fn.def][index] + self.function_parameter_values.borrow()[&self.current_fn.def(self)][index] } } diff --git a/crates/rustc_codegen_spirv/src/builder_spirv.rs b/crates/rustc_codegen_spirv/src/builder_spirv.rs index 8db21807f8..8a1072dd87 100644 --- a/crates/rustc_codegen_spirv/src/builder_spirv.rs +++ b/crates/rustc_codegen_spirv/src/builder_spirv.rs @@ -1,3 +1,6 @@ +use crate::builder; +use crate::codegen_cx::CodegenCx; +use crate::spirv_type::SpirvType; use bimap::BiHashMap; use rspirv::dr::{Block, Builder, Module, Operand}; use rspirv::spirv::{AddressingModel, Capability, MemoryModel, Op, Word}; @@ -6,19 +9,99 @@ use rustc_middle::bug; use std::cell::{RefCell, RefMut}; use std::{fs::File, io::Write, path::Path}; -#[derive(Copy, Clone, Debug, Default, Ord, PartialOrd, Eq, PartialEq, Hash)] +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] +pub enum SpirvValueKind { + Def(Word), + /// There are a fair number of places where `rustc_codegen_ssa` creates a pointer to something + /// that cannot be pointed to in SPIR-V. For example, constant values are frequently emitted as + /// a pointer to constant memory, and then dereferenced where they're used. Functions are the + /// same way, when compiling a call, the function's pointer is loaded, then dereferenced, then + /// called. Directly translating these constructs is impossible, because SPIR-V doesn't allow + /// pointers to constants, or function pointers. So, instead, we create this ConstantPointer + /// "meta-value": directly using it is an error, however, if it is attempted to be + /// dereferenced, the "load" is instead a no-op that returns the underlying value directly. + ConstantPointer(Word), +} + +#[derive(Copy, Clone, Debug, Ord, PartialOrd, Eq, PartialEq, Hash)] pub struct SpirvValue { - pub def: Word, + pub kind: SpirvValueKind, pub ty: Word, } +impl SpirvValue { + pub fn const_ptr_val(self, cx: &CodegenCx<'_>) -> Option { + match self.kind { + SpirvValueKind::ConstantPointer(word) => { + let ty = match cx.lookup_type(self.ty) { + SpirvType::Pointer { + storage_class: _, + pointee, + } => pointee, + ty => bug!("load called on variable that wasn't a pointer: {:?}", ty), + }; + Some(word.with_type(ty)) + } + SpirvValueKind::Def(_) => None, + } + } + + // Important: we *cannot* use bx.emit() here, because this is called in + // contexts where the emitter is already locked. Doing so may cause subtle + // rare bugs. + pub fn def(self, bx: &builder::Builder<'_, '_>) -> Word { + match self.kind { + SpirvValueKind::Def(word) => word, + SpirvValueKind::ConstantPointer(_) => { + if bx.is_system_crate() { + *bx.zombie_undefs_for_system_constant_pointers + .borrow() + .get(&self.ty) + .expect("ConstantPointer didn't go through proper undef registration") + } else { + bx.err("Cannot use this pointer directly, it must be dereferenced first"); + // Because we never get beyond compilation (into e.g. linking), + // emitting an invalid ID reference here is OK. + 0 + } + } + } + } + + // def and def_cx are separated, because Builder has a span associated with + // what it's currently emitting. + pub fn def_cx(self, cx: &CodegenCx<'_>) -> Word { + match self.kind { + SpirvValueKind::Def(word) => word, + SpirvValueKind::ConstantPointer(_) => { + if cx.is_system_crate() { + *cx.zombie_undefs_for_system_constant_pointers + .borrow() + .get(&self.ty) + .expect("ConstantPointer didn't go through proper undef registration") + } else { + cx.tcx + .sess + .err("Cannot use this pointer directly, it must be dereferenced first"); + // Because we never get beyond compilation (into e.g. linking), + // emitting an invalid ID reference here is OK. + 0 + } + } + } + } +} + pub trait SpirvValueExt { fn with_type(self, ty: Word) -> SpirvValue; } impl SpirvValueExt for Word { fn with_type(self, ty: Word) -> SpirvValue { - SpirvValue { def: self, ty } + SpirvValue { + kind: SpirvValueKind::Def(self), + ty, + } } } diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs index 5f4dc863e9..a538266752 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/constant.rs @@ -70,7 +70,7 @@ impl<'tcx> CodegenCx<'tcx> { }, SpirvType::Integer(128, _) => { let result = self.undef(ty); - self.zombie_no_span(result.def, "u128 constant"); + self.zombie_no_span(result.def_cx(self), "u128 constant"); result } other => self.tcx.sess.fatal(&format!( @@ -169,7 +169,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { let len = s.as_str().len(); let ty = self.type_ptr_to(self.layout_of(self.tcx.types.str_).spirv_type(self)); let result = self.undef(ty); - self.zombie_no_span(result.def, "constant string"); + self.zombie_no_span(result.def_cx(self), "constant string"); (result, self.const_usize(len as u64)) } fn const_struct(&self, elts: &[Self::Value], _packed: bool) -> Self::Value { @@ -185,7 +185,7 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { field_names: None, } .def(self); - self.constant_composite(struct_ty, elts.iter().map(|f| f.def).collect()) + self.constant_composite(struct_ty, elts.iter().map(|f| f.def_cx(self)).collect()) } fn const_to_opt_uint(&self, v: Self::Value) -> Option { @@ -301,7 +301,10 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { ) => { if a_space != b_space { // TODO: Emit the correct type that is passed into this function. - self.zombie_no_span(value.def, "invalid pointer space in constant"); + self.zombie_no_span( + value.def_cx(self), + "invalid pointer space in constant", + ); } assert_ty_eq!(self, a, b); } @@ -330,8 +333,8 @@ impl<'tcx> ConstMethods<'tcx> for CodegenCx<'tcx> { val } else { // constant ptrcast is not supported in spir-v - let result = val.def.with_type(ty); - self.zombie_no_span(result.def, "const_ptrcast"); + let result = val.def_cx(self).with_type(ty); + self.zombie_no_span(result.def_cx(self), "const_ptrcast"); result } } @@ -395,7 +398,7 @@ impl<'tcx> CodegenCx<'tcx> { let mut total_offset_end = total_offset_start; values.push( self.create_const_alloc2(alloc, &mut total_offset_end, ty) - .def, + .def_cx(self), ); occupied_spaces.push(total_offset_start..total_offset_end); } @@ -417,7 +420,10 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::Array { element, count } => { let count = self.builder.lookup_const_u64(count).unwrap() as usize; let values = (0..count) - .map(|_| self.create_const_alloc2(alloc, offset, element).def) + .map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }) .collect::>(); self.constant_composite(ty, values) } @@ -427,7 +433,10 @@ impl<'tcx> CodegenCx<'tcx> { .expect("create_const_alloc: Vectors must be sized"); let final_offset = *offset + total_size; let values = (0..count) - .map(|_| self.create_const_alloc2(alloc, offset, element).def) + .map(|_| { + self.create_const_alloc2(alloc, offset, element) + .def_cx(self) + }) .collect::>(); assert!(*offset <= final_offset); // Vectors sometimes have padding at the end (e.g. vec3), skip over it. @@ -437,7 +446,10 @@ impl<'tcx> CodegenCx<'tcx> { SpirvType::RuntimeArray { element } => { let mut values = Vec::new(); while offset.bytes_usize() != alloc.len() { - values.push(self.create_const_alloc2(alloc, offset, element).def); + values.push( + self.create_const_alloc2(alloc, offset, element) + .def_cx(self), + ); } let result = self.constant_composite(ty, values); // TODO: Figure out how to do this. Compiling the below crashes both clspv *and* llvm-spirv: @@ -452,7 +464,7 @@ impl<'tcx> CodegenCx<'tcx> { *data = *c + asdf->y[*c]; } */ - self.zombie_no_span(result.def, "constant runtime array value"); + self.zombie_no_span(result.def_cx(self), "constant runtime array value"); result } SpirvType::Pointer { .. } => { diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs index 9e24a029eb..53337813cd 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/declare.rs @@ -82,7 +82,7 @@ impl<'tcx> CodegenCx<'tcx> { // This can happen if we call a blocklisted function in another crate. let result = self.undef(function_type); // TODO: Span info here - self.zombie_no_span(result.def, "called blocklisted fn"); + self.zombie_no_span(result.def_cx(self), "called blocklisted fn"); return result; } let mut emit = self.emit_global(); @@ -132,7 +132,7 @@ impl<'tcx> CodegenCx<'tcx> { let span = self.tcx.def_span(def_id); let g = self.declare_global(span, self.layout_of(ty).spirv_type(self)); self.instances.borrow_mut().insert(instance, g); - self.set_linkage(g.def, sym.to_string(), LinkageType::Import); + self.set_linkage(g.def_cx(self), sym.to_string(), LinkageType::Import); g } @@ -147,7 +147,7 @@ impl<'tcx> CodegenCx<'tcx> { .variable(ptr_ty, None, StorageClass::Function, None) .with_type(ptr_ty); // TODO: These should be StorageClass::Private, so just zombie for now. - self.zombie_with_span(result.def, span, "Globals are not supported yet"); + self.zombie_with_span(result.def_cx(self), span, "Globals are not supported yet"); result } } @@ -182,7 +182,7 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { self.instances.borrow_mut().insert(instance, g); if let Some(linkage) = linkage { - self.set_linkage(g.def, symbol_name.to_string(), linkage); + self.set_linkage(g.def_cx(self), symbol_name.to_string(), linkage); } } @@ -229,7 +229,7 @@ impl<'tcx> PreDefineMethods<'tcx> for CodegenCx<'tcx> { impl<'tcx> StaticMethods for CodegenCx<'tcx> { fn static_addr_of(&self, cv: Self::Value, _align: Align, _kind: Option<&str>) -> Self::Value { - self.register_constant_pointer(cv) + self.make_constant_pointer(cv) } fn codegen_static(&self, def_id: DefId, _is_mutable: bool) { @@ -260,7 +260,8 @@ impl<'tcx> StaticMethods for CodegenCx<'tcx> { } assert_ty_eq!(self, value_ty, v.ty); - self.builder.set_global_initializer(g.def, v.def); + self.builder + .set_global_initializer(g.def_cx(self), v.def_cx(self)); } /// Mark the given global value as "used", to prevent a backend from potentially removing a diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs index 4968d0c8d4..c2f602bcbf 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/entry.rs @@ -107,7 +107,7 @@ impl<'tcx> CodegenCx<'tcx> { emit.function_call( entry_func_return, None, - entry_func.def, + entry_func.def_cx(self), arguments.iter().map(|&(a, _)| a), ) .unwrap(); @@ -234,7 +234,7 @@ impl<'tcx> CodegenCx<'tcx> { .collect::>(); emit.begin_block(None).unwrap(); let call_result = emit - .function_call(entry_func_return, None, entry_func.def, arguments) + .function_call(entry_func_return, None, entry_func.def_cx(self), arguments) .unwrap(); if self.lookup_type(entry_func_return) == SpirvType::Void { emit.ret().unwrap(); diff --git a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs index 37b17f7d6f..aa5c731c32 100644 --- a/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs +++ b/crates/rustc_codegen_spirv/src/codegen_cx/mod.rs @@ -4,11 +4,10 @@ mod entry; mod type_; use crate::builder::ExtInst; -use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueExt}; +use crate::builder_spirv::{BuilderCursor, BuilderSpirv, SpirvValue, SpirvValueKind}; use crate::finalizing_passes::export_zombies; use crate::spirv_type::{SpirvType, SpirvTypePrinter, TypeCache}; use crate::symbols::Symbols; -use bimap::BiHashMap; use rspirv::dr::{Module, Operand}; use rspirv::spirv::{Decoration, LinkageType, MemoryModel, StorageClass, Word}; use rustc_codegen_ssa::mir::debuginfo::{FunctionDebugContext, VariableKind}; @@ -52,9 +51,7 @@ pub struct CodegenCx<'tcx> { /// Cache of all the builtin symbols we need pub sym: Box, pub really_unsafe_ignore_bitcasts: RefCell>, - /// Functions created in `get_fn_addr`, and values created in `static_addr_of`. - /// left: the OpUndef pseudo-pointer. right: the concrete value. - constant_pointers: RefCell>, + pub zombie_undefs_for_system_constant_pointers: RefCell>, /// Some runtimes (e.g. intel-compute-runtime) disallow atomics on i8 and i16, even though it's allowed by the spec. /// This enables/disables them. pub i8_i16_atomics_allowed: bool, @@ -104,7 +101,7 @@ impl<'tcx> CodegenCx<'tcx> { kernel_mode, sym, really_unsafe_ignore_bitcasts: Default::default(), - constant_pointers: Default::default(), + zombie_undefs_for_system_constant_pointers: Default::default(), i8_i16_atomics_allowed: false, } } @@ -185,47 +182,28 @@ impl<'tcx> CodegenCx<'tcx> { ) } - /// Function pointer registration: - /// LLVM, and therefore `codegen_ssa`, is very murky with function values vs. function - /// pointers. So, `codegen_ssa` has a pattern where *even for direct function calls*, it uses - /// `get_fn_*addr*`, and then uses that function *pointer* when calling - /// `BuilderMethods::call()`. However, spir-v doesn't support function pointers! So, instead, - /// when `get_fn_addr` is called, we register a "token" (via `OpUndef`), storing it in a - /// dictionary. Then, when `BuilderMethods::call()` is called, and it's calling a function - /// pointer, we check the dictionary, and if it is, invoke the function directly. It's kind of - /// conceptually similar to a constexpr deref, except specialized to just functions. We also - /// use the same system with `static_addr_of`: we don't support creating addresses of arbitrary - /// constant values, so instead, we register the constant, and whenever we load something, we - /// check if the pointer we're loading is a special `OpUndef` token: if so, we directly - /// reference the registered value. - pub fn register_constant_pointer(&self, value: SpirvValue) -> SpirvValue { - // If we've already registered this value, return it. - if let Some(undef) = self.constant_pointers.borrow().get_by_right(&value) { - return *undef; - } + /// See note on `SpirvValueKind::ConstantPointer` + pub fn make_constant_pointer(&self, value: SpirvValue) -> SpirvValue { let ty = SpirvType::Pointer { storage_class: StorageClass::Function, pointee: value.ty, } .def(self); - // We want a unique ID for these undefs, so don't use the caching system. - let result = self.emit_global().undef(ty, None).with_type(ty); - // It's obviously invalid, so zombie it. Zombie it in user code as well, because it's valid there too. - // It'd be realllly nice to give this a Span, the UX of this is horrible... - self.zombie_even_in_user_code(result.def, "dynamic use of address of constant"); - self.constant_pointers - .borrow_mut() - .insert_no_overwrite(result, value) - .unwrap(); - result - } - - /// See comment on `register_constant_pointer` - pub fn lookup_constant_pointer(&self, pointer: SpirvValue) -> Option { - self.constant_pointers - .borrow() - .get_by_left(&pointer) - .cloned() + if self.is_system_crate() { + // Create these undefs up front instead of on demand in SpirvValue::def because + // SpirvValue::def can't use cx.emit() + self.zombie_undefs_for_system_constant_pointers + .borrow_mut() + .entry(ty) + .or_insert_with(|| { + // We want a unique ID for these undefs, so don't use the caching system. + self.emit_global().undef(ty, None) + }); + } + SpirvValue { + kind: SpirvValueKind::ConstantPointer(value.def_cx(self)), + ty, + } } } @@ -286,7 +264,7 @@ impl<'tcx> MiscMethods<'tcx> for CodegenCx<'tcx> { fn get_fn_addr(&self, instance: Instance<'tcx>) -> Self::Value { let function = self.get_fn_ext(instance); - self.register_constant_pointer(function) + self.make_constant_pointer(function) } fn eh_personality(&self) -> Self::Value { diff --git a/crates/rustc_codegen_spirv/src/spirv_type.rs b/crates/rustc_codegen_spirv/src/spirv_type.rs index 61fe001f40..e0294d4026 100644 --- a/crates/rustc_codegen_spirv/src/spirv_type.rs +++ b/crates/rustc_codegen_spirv/src/spirv_type.rs @@ -148,7 +148,7 @@ impl SpirvType { .sizeof(cx) .expect("Element of sized array must be sized") .bytes(); - let result = cx.emit_global().type_array(element, count.def); + let result = cx.emit_global().type_array(element, count.def_cx(cx)); if !cx.kernel_mode { // TODO: kernel mode can't do this?? cx.emit_global().decorate(