Skip to content

Commit 6b1f7ed

Browse files
committed
Fix attach on Windows when runtime is enabled.
1 parent 568b9a0 commit 6b1f7ed

File tree

2 files changed

+101
-2
lines changed

2 files changed

+101
-2
lines changed

src/driver/iocp/mod.rs

Lines changed: 53 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,24 @@
11
use crate::driver::{Entry, Poller};
22
use crossbeam_queue::SegQueue;
33
use std::{
4+
ffi::c_void,
45
io,
56
os::windows::{
67
io::HandleOrNull,
78
prelude::{AsRawHandle, OwnedHandle, RawHandle},
89
},
9-
ptr::null_mut,
10+
ptr::{null, null_mut},
1011
task::Poll,
1112
time::Duration,
1213
};
1314
use windows_sys::Win32::{
14-
Foundation::{GetLastError, ERROR_HANDLE_EOF, INVALID_HANDLE_VALUE, WAIT_TIMEOUT},
15+
Foundation::{
16+
GetLastError, RtlNtStatusToDosError, ERROR_HANDLE_EOF, INVALID_HANDLE_VALUE, NTSTATUS,
17+
STATUS_SUCCESS, WAIT_TIMEOUT,
18+
},
1519
System::{
1620
Threading::INFINITE,
21+
WindowsProgramming::{FILE_INFORMATION_CLASS, IO_STATUS_BLOCK},
1722
IO::{
1823
CreateIoCompletionPort, GetQueuedCompletionStatus, PostQueuedCompletionStatus,
1924
OVERLAPPED,
@@ -91,8 +96,54 @@ impl Driver {
9196
}
9297
}
9398

99+
fn deattach_iocp(fd: RawFd) -> io::Result<()> {
100+
#[link(name = "ntdll")]
101+
extern "system" {
102+
fn NtSetInformationFile(
103+
FileHandle: usize,
104+
IoStatusBlock: *mut IO_STATUS_BLOCK,
105+
FileInformation: *const c_void,
106+
Length: u32,
107+
FileInformationClass: FILE_INFORMATION_CLASS,
108+
) -> NTSTATUS;
109+
}
110+
#[allow(non_upper_case_globals)]
111+
const FileReplaceCompletionInformation: FILE_INFORMATION_CLASS = 61;
112+
#[repr(C)]
113+
#[allow(non_camel_case_types)]
114+
#[allow(non_snake_case)]
115+
struct FILE_COMPLETION_INFORMATION {
116+
Port: usize,
117+
Key: *const c_void,
118+
}
119+
120+
let mut block = unsafe { std::mem::zeroed() };
121+
let info = FILE_COMPLETION_INFORMATION {
122+
Port: 0,
123+
Key: null(),
124+
};
125+
unsafe {
126+
NtSetInformationFile(
127+
fd as _,
128+
&mut block,
129+
&info as *const _ as _,
130+
std::mem::size_of_val(&info) as _,
131+
FileReplaceCompletionInformation,
132+
)
133+
};
134+
let res = unsafe { block.Anonymous.Status };
135+
if res != STATUS_SUCCESS {
136+
Err(io::Error::from_raw_os_error(unsafe {
137+
RtlNtStatusToDosError(res) as _
138+
}))
139+
} else {
140+
Ok(())
141+
}
142+
}
143+
94144
impl Poller for Driver {
95145
fn attach(&self, fd: RawFd) -> io::Result<()> {
146+
deattach_iocp(fd)?;
96147
let port = unsafe { CreateIoCompletionPort(fd as _, self.port.as_raw_handle() as _, 0, 0) };
97148
if port == 0 {
98149
Err(io::Error::last_os_error())

src/driver/mod.rs

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,54 @@ cfg_if::cfg_if! {
1717
/// It contains some low-level actions of completion-based IO.
1818
///
1919
/// You don't need them unless you are controlling a [`Driver`] yourself.
20+
///
21+
/// # Examples
22+
///
23+
/// ```
24+
/// use compio::{
25+
/// buf::IntoInner,
26+
/// driver::{AsRawFd, Driver, Poller},
27+
/// net::UdpSocket,
28+
/// op,
29+
/// };
30+
/// use std::net::SocketAddr;
31+
///
32+
/// let first_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
33+
/// let second_addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
34+
///
35+
/// // bind sockets
36+
/// let socket = UdpSocket::bind(first_addr).unwrap();
37+
/// let first_addr = socket.local_addr().unwrap();
38+
/// let other_socket = UdpSocket::bind(second_addr).unwrap();
39+
/// let second_addr = other_socket.local_addr().unwrap();
40+
///
41+
/// // connect sockets
42+
/// socket.connect(second_addr).unwrap();
43+
/// other_socket.connect(first_addr).unwrap();
44+
///
45+
/// let driver = Driver::new().unwrap();
46+
/// driver.attach(socket.as_raw_fd()).unwrap();
47+
/// driver.attach(other_socket.as_raw_fd()).unwrap();
48+
///
49+
/// // write data
50+
/// let mut op = op::Send::new(socket.as_raw_fd(), "hello world");
51+
/// unsafe { driver.push(&mut op, 1) }.unwrap();
52+
/// let entry = driver.poll(None).unwrap();
53+
/// assert_eq!(entry.user_data(), 1);
54+
/// entry.into_result().unwrap();
55+
///
56+
/// // read data
57+
/// let buf = Vec::with_capacity(32);
58+
/// let mut op = op::Recv::new(other_socket.as_raw_fd(), buf);
59+
/// unsafe { driver.push(&mut op, 2) }.unwrap();
60+
/// let entry = driver.poll(None).unwrap();
61+
/// assert_eq!(entry.user_data(), 2);
62+
/// let n_bytes = entry.into_result().unwrap();
63+
/// let mut buf = op.into_inner().into_inner();
64+
/// unsafe { buf.set_len(n_bytes) };
65+
///
66+
/// assert_eq!(buf, b"hello world");
67+
/// ```
2068
pub trait Poller {
2169
/// Attach an fd to the driver.
2270
fn attach(&self, fd: RawFd) -> io::Result<()>;

0 commit comments

Comments
 (0)