diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 363a2667174..8d776642249 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -232,6 +232,7 @@ all(target_vendor = "fortanix", target_env = "sgx"), feature(slice_index_methods, coerce_unsized, sgx_platform) )] +#![cfg_attr(windows, feature(round_char_boundary))] // // Language features: #![feature(alloc_error_handler)] diff --git a/library/std/src/sys/windows/c.rs b/library/std/src/sys/windows/c.rs index f58dcf1287b..d1e6594b3f2 100644 --- a/library/std/src/sys/windows/c.rs +++ b/library/std/src/sys/windows/c.rs @@ -6,13 +6,15 @@ use crate::ffi::CStr; use crate::mem; -use crate::os::raw::{c_char, c_int, c_long, c_longlong, c_uint, c_ulong, c_ushort}; +use crate::os::raw::{c_char, c_long, c_longlong, c_uint, c_ulong, c_ushort}; use crate::os::windows::io::{BorrowedHandle, HandleOrInvalid, HandleOrNull}; use crate::ptr; use core::ffi::NonZero_c_ulong; use libc::{c_void, size_t, wchar_t}; +pub use crate::os::raw::c_int; + #[path = "c/errors.rs"] // c.rs is included from two places so we need to specify this mod errors; pub use errors::*; @@ -47,16 +49,19 @@ pub type ACCESS_MASK = DWORD; pub type LPBOOL = *mut BOOL; pub type LPBYTE = *mut BYTE; +pub type LPCCH = *const CHAR; pub type LPCSTR = *const CHAR; +pub type LPCWCH = *const WCHAR; pub type LPCWSTR = *const WCHAR; +pub type LPCVOID = *const c_void; pub type LPDWORD = *mut DWORD; pub type LPHANDLE = *mut HANDLE; pub type LPOVERLAPPED = *mut OVERLAPPED; pub type LPPROCESS_INFORMATION = *mut PROCESS_INFORMATION; pub type LPSECURITY_ATTRIBUTES = *mut SECURITY_ATTRIBUTES; pub type LPSTARTUPINFO = *mut STARTUPINFO; +pub type LPSTR = *mut CHAR; pub type LPVOID = *mut c_void; -pub type LPCVOID = *const c_void; pub type LPWCH = *mut WCHAR; pub type LPWIN32_FIND_DATAW = *mut WIN32_FIND_DATAW; pub type LPWSADATA = *mut WSADATA; @@ -132,6 +137,10 @@ pub const MAX_PATH: usize = 260; pub const FILE_TYPE_PIPE: u32 = 3; +pub const CP_UTF8: DWORD = 65001; +pub const MB_ERR_INVALID_CHARS: DWORD = 0x08; +pub const WC_ERR_INVALID_CHARS: DWORD = 0x80; + #[repr(C)] #[derive(Copy)] pub struct WIN32_FIND_DATAW { @@ -1155,6 +1164,25 @@ extern "system" { lpFilePart: *mut LPWSTR, ) -> DWORD; pub fn GetFileAttributesW(lpFileName: LPCWSTR) -> DWORD; + + pub fn MultiByteToWideChar( + CodePage: UINT, + dwFlags: DWORD, + lpMultiByteStr: LPCCH, + cbMultiByte: c_int, + lpWideCharStr: LPWSTR, + cchWideChar: c_int, + ) -> c_int; + pub fn WideCharToMultiByte( + CodePage: UINT, + dwFlags: DWORD, + lpWideCharStr: LPCWCH, + cchWideChar: c_int, + lpMultiByteStr: LPSTR, + cbMultiByte: c_int, + lpDefaultChar: LPCCH, + lpUsedDefaultChar: LPBOOL, + ) -> c_int; } #[link(name = "ws2_32")] diff --git a/library/std/src/sys/windows/stdio.rs b/library/std/src/sys/windows/stdio.rs index c2cd48470bd..32c6ccffb7a 100644 --- a/library/std/src/sys/windows/stdio.rs +++ b/library/std/src/sys/windows/stdio.rs @@ -169,14 +169,27 @@ fn write( } fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result { + debug_assert!(!utf8.is_empty()); + let mut utf16 = [MaybeUninit::::uninit(); MAX_BUFFER_SIZE / 2]; - let mut len_utf16 = 0; - for (chr, dest) in utf8.encode_utf16().zip(utf16.iter_mut()) { - *dest = MaybeUninit::new(chr); - len_utf16 += 1; - } - // Safety: We've initialized `len_utf16` values. - let utf16: &[u16] = unsafe { MaybeUninit::slice_assume_init_ref(&utf16[..len_utf16]) }; + let utf8 = &utf8[..utf8.floor_char_boundary(utf16.len())]; + + let utf16: &[u16] = unsafe { + // Note that this theoretically checks validity twice in the (most common) case + // where the underlying byte sequence is valid utf-8 (given the check in `write()`). + let result = c::MultiByteToWideChar( + c::CP_UTF8, // CodePage + c::MB_ERR_INVALID_CHARS, // dwFlags + utf8.as_ptr() as c::LPCCH, // lpMultiByteStr + utf8.len() as c::c_int, // cbMultiByte + utf16.as_mut_ptr() as c::LPWSTR, // lpWideCharStr + utf16.len() as c::c_int, // cchWideChar + ); + assert!(result != 0, "Unexpected error in MultiByteToWideChar"); + + // Safety: MultiByteToWideChar initializes `result` values. + MaybeUninit::slice_assume_init_ref(&utf16[..result as usize]) + }; let mut written = write_u16s(handle, &utf16)?; @@ -189,8 +202,8 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result= 0xDCEE && first_char_remaining <= 0xDFFF { + let first_code_unit_remaining = utf16[written]; + if first_code_unit_remaining >= 0xDCEE && first_code_unit_remaining <= 0xDFFF { // low surrogate // We just hope this works, and give up otherwise let _ = write_u16s(handle, &utf16[written..written + 1]); @@ -212,6 +225,7 @@ fn write_valid_utf8_to_console(handle: c::HANDLE, utf8: &str) -> io::Result io::Result { + debug_assert!(data.len() < u32::MAX as usize); let mut written = 0; cvt(unsafe { c::WriteConsoleW( @@ -365,26 +379,32 @@ fn read_u16s(handle: c::HANDLE, buf: &mut [MaybeUninit]) -> io::Result io::Result { - let mut written = 0; - for chr in char::decode_utf16(utf16.iter().cloned()) { - match chr { - Ok(chr) => { - chr.encode_utf8(&mut utf8[written..]); - written += chr.len_utf8(); - } - Err(_) => { - // We can't really do any better than forget all data and return an error. - return Err(io::const_io_error!( - io::ErrorKind::InvalidData, - "Windows stdin in console mode does not support non-UTF-16 input; \ - encountered unpaired surrogate", - )); - } - } + debug_assert!(utf16.len() <= c::c_int::MAX as usize); + debug_assert!(utf8.len() <= c::c_int::MAX as usize); + + let result = unsafe { + c::WideCharToMultiByte( + c::CP_UTF8, // CodePage + c::WC_ERR_INVALID_CHARS, // dwFlags + utf16.as_ptr(), // lpWideCharStr + utf16.len() as c::c_int, // cchWideChar + utf8.as_mut_ptr() as c::LPSTR, // lpMultiByteStr + utf8.len() as c::c_int, // cbMultiByte + ptr::null(), // lpDefaultChar + ptr::null_mut(), // lpUsedDefaultChar + ) + }; + if result == 0 { + // We can't really do any better than forget all data and return an error. + Err(io::const_io_error!( + io::ErrorKind::InvalidData, + "Windows stdin in console mode does not support non-UTF-16 input; \ + encountered unpaired surrogate", + )) + } else { + Ok(result as usize) } - Ok(written) } impl IncompleteUtf8 {