refactor: refine thread variant for windows

This commit is contained in:
discord9 2024-11-15 19:17:23 +08:00
parent 34ad8de02d
commit cecf2b3eae
4 changed files with 57 additions and 32 deletions

View File

@ -113,6 +113,11 @@ impl ThreadId {
self.0
}
/// Create a new thread id from a `u32` without checking if this thread exists.
pub fn new_unchecked(id: u32) -> Self {
Self(id)
}
pub const MAIN_THREAD: ThreadId = ThreadId(0);
}

View File

@ -7,6 +7,7 @@ use rustc_span::Symbol;
use self::shims::windows::handle::{Handle, PseudoHandle};
use crate::shims::os_str::bytes_to_os_str;
use crate::shims::windows::handle::HandleError;
use crate::shims::windows::*;
use crate::*;
@ -488,7 +489,7 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let thread_id =
this.CreateThread(security, stacksize, start, arg, flags, thread)?;
this.write_scalar(Handle::Thread(thread_id.to_u32()).to_scalar(this), dest)?;
this.write_scalar(Handle::Thread(thread_id).to_scalar(this), dest)?;
}
"WaitForSingleObject" => {
let [handle, timeout] =
@ -513,10 +514,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let handle = this.read_scalar(handle)?;
let name = this.read_wide_str(this.read_pointer(name)?)?;
let thread = match Handle::from_scalar(handle, this)? {
Some(Handle::Thread(thread)) => this.thread_id_try_from(thread),
Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
_ => this.invalid_handle("SetThreadDescription")?,
let thread = match Handle::try_from_scalar(handle, this)? {
Ok(Handle::Thread(thread)) => Ok(thread),
Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
Ok(_) | Err(HandleError::InvalidHandle) =>
this.invalid_handle("SetThreadDescription")?,
Err(HandleError::ThreadNotFound(e)) => Err(e),
};
let res = match thread {
Ok(thread) => {
@ -536,10 +539,12 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let handle = this.read_scalar(handle)?;
let name_ptr = this.deref_pointer(name_ptr)?; // the pointer where we should store the ptr to the name
let thread = match Handle::from_scalar(handle, this)? {
Some(Handle::Thread(thread)) => this.thread_id_try_from(thread),
Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
_ => this.invalid_handle("GetThreadDescription")?,
let thread = match Handle::try_from_scalar(handle, this)? {
Ok(Handle::Thread(thread)) => Ok(thread),
Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => Ok(this.active_thread()),
Ok(_) | Err(HandleError::InvalidHandle) =>
this.invalid_handle("GetThreadDescription")?,
Err(HandleError::ThreadNotFound(e)) => Err(e),
};
let (name, res) = match thread {
Ok(thread) => {

View File

@ -2,6 +2,7 @@ use std::mem::variant_count;
use rustc_abi::HasDataLayout;
use crate::concurrency::thread::ThreadNotFound;
use crate::*;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@ -14,7 +15,7 @@ pub enum PseudoHandle {
pub enum Handle {
Null,
Pseudo(PseudoHandle),
Thread(u32),
Thread(ThreadId),
}
impl PseudoHandle {
@ -34,6 +35,14 @@ impl PseudoHandle {
}
}
/// Errors that can occur when constructing a [`Handle`] from a Scalar.
pub enum HandleError {
/// There is no thread with the given ID.
ThreadNotFound(ThreadNotFound),
/// Can't convert scalar to handle because it is structurally invalid.
InvalidHandle,
}
impl Handle {
const NULL_DISCRIMINANT: u32 = 0;
const PSEUDO_DISCRIMINANT: u32 = 1;
@ -51,7 +60,7 @@ impl Handle {
match self {
Self::Null => 0,
Self::Pseudo(pseudo_handle) => pseudo_handle.value(),
Self::Thread(thread) => thread,
Self::Thread(thread) => thread.to_u32(),
}
}
@ -95,7 +104,7 @@ impl Handle {
match discriminant {
Self::NULL_DISCRIMINANT if data == 0 => Some(Self::Null),
Self::PSEUDO_DISCRIMINANT => Some(Self::Pseudo(PseudoHandle::from_value(data)?)),
Self::THREAD_DISCRIMINANT => Some(Self::Thread(data)),
Self::THREAD_DISCRIMINANT => Some(Self::Thread(ThreadId::new_unchecked(data))),
_ => None,
}
}
@ -126,10 +135,14 @@ impl Handle {
Scalar::from_target_isize(signed_handle.into(), cx)
}
pub fn from_scalar<'tcx>(
/// Convert a scalar into a structured `Handle`.
/// Structurally invalid handles return [`HandleError::InvalidHandle`].
/// If the handle is structurally valid but semantically invalid, e.g. a for non-existent thread
/// ID, returns [`HandleError::ThreadNotFound`].
pub fn try_from_scalar<'tcx>(
handle: Scalar,
cx: &impl HasDataLayout,
) -> InterpResult<'tcx, Option<Self>> {
cx: &MiriInterpCx<'tcx>,
) -> InterpResult<'tcx, Result<Self, HandleError>> {
let sign_extended_handle = handle.to_target_isize(cx)?;
#[expect(clippy::cast_sign_loss)] // we want to lose the sign
@ -137,10 +150,20 @@ impl Handle {
signed_handle as u32
} else {
// if a handle doesn't fit in an i32, it isn't valid.
return interp_ok(None);
return interp_ok(Err(HandleError::InvalidHandle));
};
interp_ok(Self::from_packed(handle))
match Self::from_packed(handle) {
Some(Self::Thread(thread)) => {
// validate the thread id
match cx.machine.threads.thread_id_try_from(thread.to_u32()) {
Ok(id) => interp_ok(Ok(Self::Thread(id))),
Err(e) => interp_ok(Err(HandleError::ThreadNotFound(e))),
}
}
Some(handle) => interp_ok(Ok(handle)),
None => interp_ok(Err(HandleError::InvalidHandle)),
}
}
}
@ -158,14 +181,10 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let this = self.eval_context_mut();
let handle = this.read_scalar(handle_op)?;
let ret = match Handle::from_scalar(handle, this)? {
Some(Handle::Thread(thread)) => {
if let Ok(thread) = this.thread_id_try_from(thread) {
this.detach_thread(thread, /*allow_terminated_joined*/ true)?;
this.eval_windows("c", "TRUE")
} else {
this.invalid_handle("CloseHandle")?
}
let ret = match Handle::try_from_scalar(handle, this)? {
Ok(Handle::Thread(thread)) => {
this.detach_thread(thread, /*allow_terminated_joined*/ true)?;
this.eval_windows("c", "TRUE")
}
_ => this.invalid_handle("CloseHandle")?,
};

View File

@ -65,15 +65,11 @@ pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
let handle = this.read_scalar(handle_op)?;
let timeout = this.read_scalar(timeout_op)?.to_u32()?;
let thread = match Handle::from_scalar(handle, this)? {
Some(Handle::Thread(thread)) =>
match this.thread_id_try_from(thread) {
Ok(thread) => thread,
Err(_) => this.invalid_handle("WaitForSingleObject")?,
},
let thread = match Handle::try_from_scalar(handle, this)? {
Ok(Handle::Thread(thread)) => thread,
// Unlike on posix, the outcome of joining the current thread is not documented.
// On current Windows, it just deadlocks.
Some(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(),
Ok(Handle::Pseudo(PseudoHandle::CurrentThread)) => this.active_thread(),
_ => this.invalid_handle("WaitForSingleObject")?,
};