diff --git a/example/calculator.cpp b/example/calculator.cpp index 86ce388..73dd56f 100644 --- a/example/calculator.cpp +++ b/example/calculator.cpp @@ -51,14 +51,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpcalculator \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpcalculator", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, kj::mv(stream), *init); loop.loop(); return 0; } diff --git a/example/example.cpp b/example/example.cpp index 3831397..f6fe68e 100644 --- a/example/example.cpp +++ b/example/example.cpp @@ -25,14 +25,14 @@ namespace fs = std::filesystem; static auto Spawn(mp::EventLoop& loop, const std::string& process_argv0, const std::string& new_exe_name) { - int pid; - const int fd = mp::SpawnProcess(pid, [&](int fd) -> std::vector { + auto pair{mp::SocketPair()}; + mp::ProcessId pid{mp::SpawnProcess(pair[0], [&](mp::ConnectInfo info) -> std::vector { fs::path path = process_argv0; path.remove_filename(); path.append(new_exe_name); - return {path.string(), std::to_string(fd)}; - }); - return std::make_tuple(mp::ConnectStream(loop, fd), pid); + return {path.string(), std::move(info)}; + })}; + return std::make_tuple(mp::ConnectStream(loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(pair[1])), pid); } static void LogPrint(mp::LogMessage log_data) diff --git a/example/printer.cpp b/example/printer.cpp index 9150d59..03b67d3 100644 --- a/example/printer.cpp +++ b/example/printer.cpp @@ -44,14 +44,11 @@ int main(int argc, char** argv) std::cout << "Usage: mpprinter \n"; return 1; } - int fd; - if (std::from_chars(argv[1], argv[1] + strlen(argv[1]), fd).ec != std::errc{}) { - std::cerr << argv[1] << " is not a number or is larger than an int\n"; - return 1; - } + mp::SocketId socket{mp::StartSpawned(argv[1])}; mp::EventLoop loop("mpprinter", LogPrint); std::unique_ptr init = std::make_unique(); - mp::ServeStream(loop, fd, *init); + mp::Stream stream{loop.m_io_context.lowLevelProvider->wrapSocketFd(socket, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; + mp::ServeStream(loop, std::move(stream), *init); loop.loop(); return 0; } diff --git a/include/mp/proxy-io.h b/include/mp/proxy-io.h index f736746..8927808 100644 --- a/include/mp/proxy-io.h +++ b/include/mp/proxy-io.h @@ -185,6 +185,17 @@ class Logger std::string LongThreadName(const char* exe_name); +using Stream = kj::Own; + +inline SocketId StreamSocketId(const Stream& stream) +{ + if (stream) KJ_IF_MAYBE(fd, stream->getFd()) return *fd; +#ifdef WIN32 + if (stream) KJ_IF_MAYBE(handle, stream->getWin32Handle()) return reinterpret_cast(*handle); +#endif + throw std::logic_error("Stream socket unset"); +} + //! Event loop implementation. //! //! Cap'n Proto threading model is very simple: all I/O operations are @@ -283,11 +294,12 @@ class EventLoop //! Callback functions to run on async thread. std::optional m_async_fns MP_GUARDED_BY(m_mutex); - //! Pipe read handle used to wake up the event loop thread. - int m_wait_fd = -1; + //! Socket pair used to post and wait for wakeups to the event loop thread. + kj::Own m_wait_stream; + kj::Own m_post_stream; - //! Pipe write handle used to wake up the event loop thread. - int m_post_fd = -1; + //! Synchronous writer used to write to m_post_stream. + kj::Own m_post_writer; //! Number of clients holding references to ProxyServerBase objects that //! reference this event loop. @@ -679,13 +691,11 @@ struct ThreadContext //! over the stream. Also create a new Connection object embedded in the //! client that is freed when the client is closed. template -std::unique_ptr> ConnectStream(EventLoop& loop, int fd) +std::unique_ptr> ConnectStream(EventLoop& loop, kj::Own stream) { typename InitInterface::Client init_client(nullptr); std::unique_ptr connection; loop.sync([&] { - auto stream = - loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP); connection = std::make_unique(loop, kj::mv(stream)); init_client = connection->m_rpc_system->bootstrap(ServerVatId().vat_id).castAs(); Connection* connection_ptr = connection.get(); @@ -735,10 +745,9 @@ void _Listen(EventLoop& loop, kj::Own&& listener, InitIm //! Given stream file descriptor and an init object, handle requests on the //! stream by calling methods on the Init object. template -void ServeStream(EventLoop& loop, int fd, InitImpl& init) +void ServeStream(EventLoop& loop, kj::Own stream, InitImpl& init) { - _Serve( - loop, loop.m_io_context.lowLevelProvider->wrapSocketFd(fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP), init); + _Serve(loop, kj::mv(stream), init); } //! Given listening socket file descriptor and an init object, handle incoming diff --git a/include/mp/util.h b/include/mp/util.h index e5b4dd1..e3f8f58 100644 --- a/include/mp/util.h +++ b/include/mp/util.h @@ -19,6 +19,10 @@ #include #include +#ifdef WIN32 +#include +#endif + namespace mp { //! Generic utility functions used by capnp code. @@ -216,22 +220,44 @@ std::string ThreadName(const char* exe_name); //! errors in python unit tests. std::string LogEscape(const kj::StringTree& string, size_t max_size); +#ifdef WIN32 +using ProcessId = uintptr_t; +using SocketId = uintptr_t; +constexpr SocketId SocketError{INVALID_SOCKET}; +#else +using ProcessId = int; +using SocketId = int; +constexpr SocketId SocketError{-1}; +#endif + +//! Information about parent process passed to child process. On unix this is +//! just the inherited int file descriptor formatted as a string. On windows, +//! this is a path to a named path pipe the parent process will write +//! WSADuplicateSocket info to. +using ConnectInfo = std::string; + //! Callback type used by SpawnProcess below. -using FdToArgsFn = std::function(int fd)>; +using ConnectInfoToArgsFn = std::function(const ConnectInfo&)>; + +//! Create a socket pair that can be used to communicate within a process or +//! between parent and child processes. +std::array SocketPair(); + +//! Spawn a new process that communicates with the current process over provided +//! socket argument. Calls connect_info_to_args callback with a connection +//! string that needs to be passed to the child process, and executes the +//! argv command line it returns. Returns child process id. +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args); -//! Spawn a new process that communicates with the current process over a socket -//! pair. Returns pid through an output argument, and file descriptor for the -//! local side of the socket. Invokes fd_to_args callback with the remote file -//! descriptor number which returns the command line arguments that should be -//! used to execute the process, and which should have the remote file -//! descriptor embedded in whatever format the child process expects. -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args); +//! Initialize spawned child process using the ConnectInfo string passed to it, +//! returning a socket id for communicating with the parent process. +SocketId StartSpawned(const ConnectInfo& connect_info); //! Call execvp with vector args. void ExecProcess(const std::vector& args); //! Wait for a process to exit and return its exit code. -int WaitProcess(int pid); +int WaitProcess(ProcessId pid); inline char* CharCast(char* c) { return c; } inline char* CharCast(unsigned char* c) { return (char*)c; } diff --git a/src/mp/proxy.cpp b/src/mp/proxy.cpp index 57545d3..8af57f8 100644 --- a/src/mp/proxy.cpp +++ b/src/mp/proxy.cpp @@ -30,12 +30,15 @@ #include #include #include -#include #include #include -#include #include +#ifndef WIN32 +#include +#include +#endif + namespace mp { thread_local ThreadContext g_thread_context; @@ -66,10 +69,9 @@ void EventLoopRef::reset(bool relock) MP_NO_TSA loop->m_num_clients -= 1; if (loop->done()) { loop->m_cv.notify_all(); - int post_fd{loop->m_post_fd}; loop_lock->unlock(); char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); // NOLINT(bugprone-suspicious-semicolon) + loop->m_post_writer->write(&buffer, 1); // By default, do not try to relock `loop_lock` after writing, // because the event loop could wake up and destroy itself and the // mutex might no longer exist. @@ -96,6 +98,20 @@ Connection::~Connection() // after the calls finish. m_rpc_system.reset(); + // shutdownWrite is needed on Windows so pending data in the m_stream socket + // will be sent instead of discarded when m_stream is destroyed. On unix, + // this doesn't seem to be needed because data is sent more reliably. + // + // Sending pending data is important if the connection is a socketpair + // because when one side of the socketpair is closed, the other side doesn't + // seem to receive any onDisconnect event. So it is important for the other + // side to instead receive Cap'n Proto "release" messages (see `struct + // Release` in capnp/rpc.capnp) from local Client objects being being + // destroyed so the remote side can free resources and shut down cleanly. + // Without this call, Server objects corresponding to the Client objects on + // the other side of the connection are not freed by Cap'n Proto. + m_stream->shutdownWrite(); + // ProxyClient cleanup handlers are in sync list, and ProxyServer cleanup // handlers are in the async list. // @@ -192,6 +208,40 @@ void EventLoop::addAsyncCleanup(std::function fn) startAsyncThread(); } +#ifdef WIN32 +//! Synchronous socket output stream. Cap'n Proto library only provides limited +//! support for synchronous IO. It provides `FdOutputStream` which wraps unix +//! file descriptors and calls write() internally, and `HandleOutStream` which +//! wraps windows HANDLE values and calls WriteFile() internally. This class +//! just provides analagous functionality wrapping SOCKET values and calls +//! send() internally. +class SocketOutputStream : public kj::OutputStream { +public: + explicit SocketOutputStream(SOCKET socket) : m_socket(socket) {} + + void write(const void* buffer, size_t size) override; + +private: + SOCKET m_socket; +}; + +static constexpr size_t WRITE_CLAMP_SIZE = 1u << 30; // 1GB clamp for Windows, like FdOutputStream + +void SocketOutputStream::write(const void* buffer, size_t size) { + const char* pos = reinterpret_cast(buffer); + + while (size > 0) { + int n = send(m_socket, pos, static_cast(kj::min(size, WRITE_CLAMP_SIZE)), 0); + + KJ_WIN32(n != SOCKET_ERROR, "send() failed"); + KJ_ASSERT(n > 0, "send() returned zero."); + + pos += n; + size -= n; + } +} +#endif + EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) : m_exe_name(exe_name), m_io_context(kj::setupAsyncIo()), @@ -199,10 +249,18 @@ EventLoop::EventLoop(const char* exe_name, LogOptions log_opts, void* context) m_log_opts(std::move(log_opts)), m_context(context) { - int fds[2]; - KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, fds)); - m_wait_fd = fds[0]; - m_post_fd = fds[1]; + auto pipe = m_io_context.provider->newTwoWayPipe(); + m_wait_stream = kj::mv(pipe.ends[0]); + m_post_stream = kj::mv(pipe.ends[1]); + KJ_IF_MAYBE(fd, m_post_stream->getFd()) { + m_post_writer = kj::heap(*fd); +#ifdef WIN32 + } else KJ_IF_MAYBE(handle, m_post_stream->getWin32Handle()) { + m_post_writer = kj::heap(reinterpret_cast(*handle)); +#endif + } else { + throw std::logic_error("Could not get file descriptor for new pipe."); + } } EventLoop::~EventLoop() @@ -211,8 +269,8 @@ EventLoop::~EventLoop() const Lock lock(m_mutex); KJ_ASSERT(m_post_fn == nullptr); KJ_ASSERT(!m_async_fns); - KJ_ASSERT(m_wait_fd == -1); - KJ_ASSERT(m_post_fd == -1); + KJ_ASSERT(!m_wait_stream); + KJ_ASSERT(!m_post_stream); KJ_ASSERT(m_num_clients == 0); // Spin event loop. wait for any promises triggered by RPC shutdown. @@ -232,9 +290,7 @@ void EventLoop::loop() m_async_fns.emplace(); } - kj::Own wait_stream{ - m_io_context.lowLevelProvider->wrapSocketFd(m_wait_fd, kj::LowLevelAsyncIoProvider::TAKE_OWNERSHIP)}; - int post_fd{m_post_fd}; + kj::Own& wait_stream{m_wait_stream}; char buffer = 0; for (;;) { const size_t read_bytes = wait_stream->read(&buffer, 0, 1).wait(m_io_context.waitScope); @@ -246,7 +302,7 @@ void EventLoop::loop() m_cv.notify_all(); } else if (done()) { // Intentionally do not break if m_post_fn was set, even if done() - // would return true, to ensure that the EventLoopRef write(post_fd) + // would return true, to ensure that the EventLoopRef write(post_stream) // call always succeeds and the loop does not exit between the time // that the done condition is set and the write call is made. break; @@ -256,10 +312,9 @@ void EventLoop::loop() m_task_set.reset(); MP_LOG(*this, Log::Info) << "EventLoop::loop bye."; wait_stream = nullptr; - KJ_SYSCALL(::close(post_fd)); const Lock lock(m_mutex); - m_wait_fd = -1; - m_post_fd = -1; + m_wait_stream = nullptr; + m_post_stream = nullptr; m_async_fns.reset(); m_cv.notify_all(); } @@ -274,10 +329,9 @@ void EventLoop::post(kj::Function fn) EventLoopRef ref(*this, &lock); m_cv.wait(lock.m_lock, [this]() MP_REQUIRES(m_mutex) { return m_post_fn == nullptr; }); m_post_fn = &fn; - int post_fd{m_post_fd}; Unlock(lock, [&] { char buffer = 0; - KJ_SYSCALL(write(post_fd, &buffer, 1)); + m_post_writer->write(&buffer, 1); }); m_cv.wait(lock.m_lock, [this, &fn]() MP_REQUIRES(m_mutex) { return m_post_fn != &fn; }); } diff --git a/src/mp/util.cpp b/src/mp/util.cpp index 509913b..2dbf248 100644 --- a/src/mp/util.cpp +++ b/src/mp/util.cpp @@ -10,19 +10,27 @@ #include #include #include +#include #include #include #include #include -#include -#include -#include #include #include // NOLINT(misc-include-cleaner) // IWYU pragma: keep #include #include #include +#ifdef WIN32 +#include +#include +#else +#include +#include +#include +#include +#endif + #ifdef __linux__ #include #endif @@ -33,9 +41,15 @@ namespace fs = std::filesystem; +#ifdef WIN32 +// Forward-declare internal capnp function. +namespace kj { namespace _ { int win32Socketpair(SOCKET socks[2]); } } +#endif + namespace mp { namespace { +#ifndef WIN32 //! Return highest possible file descriptor. size_t MaxFd() { @@ -46,6 +60,7 @@ size_t MaxFd() return 1023; } } +#endif } // namespace @@ -67,6 +82,8 @@ std::string ThreadName(const char* exe_name) // the former are shorter and are the same as what gdb prints "LWP ...". #ifdef __linux__ buffer << syscall(SYS_gettid); +#elif defined(WIN32) + buffer << GetCurrentThreadId(); #elif defined(HAVE_PTHREAD_THREADID_NP) uint64_t tid = 0; pthread_threadid_np(NULL, &tid); @@ -104,32 +121,138 @@ std::string LogEscape(const kj::StringTree& string, size_t max_size) return result; } -int SpawnProcess(int& pid, FdToArgsFn&& fd_to_args) +std::array SocketPair() +{ +#ifdef WIN32 + SOCKET pair[2]; + KJ_WINSOCK(kj::_::win32Socketpair(pair)); +#else + int pair[2]; + KJ_SYSCALL(socketpair(AF_UNIX, SOCK_STREAM, 0, pair)); +#endif + return {pair[0], pair[1]}; +} + +//! Generate command line that the executable being invoked will split up using +//! the CommandLineToArgvW function, which expects arguments with spaces to be +//! quoted, quote characters to be backslash-escaped, and backslashes to also be +//! backslash-escaped, but only if they precede a quote character. +std::string CommandLineFromArgv(const std::vector& argv) { - int fds[2]; - if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) != 0) { - throw std::system_error(errno, std::system_category(), "socketpair"); + std::string out; + for (const auto& arg : argv) { + if (!out.empty()) out += " "; + if (!arg.empty() && arg.find_first_of(" \t\"") == std::string::npos) { + // Argument has no quotes or spaces so escaping not necessary. + out += arg; + } else { + out += '"'; // Start with a quote + for (size_t i = 0; i < arg.size(); ++i) { + if (arg[i] == '\\') { + // Count consecutive backslashes + size_t backslash_count = 0; + while (i < arg.size() && arg[i] == '\\') { + ++backslash_count; + ++i; + } + if (i < arg.size() && arg[i] == '"') { + // Backslashes before a quote need to be doubled + out.append(backslash_count * 2 + 1, '\\'); + out.push_back('"'); + } else { + // Otherwise, backslashes remain as-is + out.append(backslash_count, '\\'); + --i; // Compensate for the outer loop's increment + } + } else if (arg[i] == '"') { + // Escape double quotes with a backslash + out.push_back('\\'); + out.push_back('"'); + } else { + out.push_back(arg[i]); + } + } + out += '"'; // End with a quote + } } + return out; +} - pid = fork(); +ProcessId SpawnProcess(SocketId socket, ConnectInfoToArgsFn&& connect_info_to_args) +{ +#ifndef WIN32 + int pid{fork()}; if (pid == -1) { throw std::system_error(errno, std::system_category(), "fork"); } - // Parent process closes the descriptor for socket 0, child closes the descriptor for socket 1. - if (close(fds[pid ? 0 : 1]) != 0) { - throw std::system_error(errno, std::system_category(), "close"); - } if (!pid) { // Child process must close all potentially open descriptors, except socket 0. const int maxFd = MaxFd(); for (int fd = 3; fd < maxFd; ++fd) { - if (fd != fds[0]) { + if (fd != socket) { close(fd); } } - ExecProcess(fd_to_args(fds[0])); + + int flags = fcntl(socket, F_GETFD); + if (flags == -1) throw std::system_error(errno, std::system_category(), "fcntl F_GETFD"); + if (flags & FD_CLOEXEC) { + flags &= ~FD_CLOEXEC; + if (fcntl(socket, F_SETFD, flags) == -1) throw std::system_error(errno, std::system_category(), "fcntl F_SETFD"); + } + + ExecProcess(connect_info_to_args(std::to_string(socket))); } - return fds[1]; + return pid; +#else + // Create windows pipe to send pipe.ends[0] over to child process. + static std::atomic counter{1}; + ConnectInfo pipe_path{"\\\\.\\pipe\\mp-" + std::to_string(GetCurrentProcessId()) + "-" + std::to_string(counter.fetch_add(1))}; + HANDLE pipe{CreateNamedPipeA(pipe_path.c_str(), PIPE_ACCESS_OUTBOUND, PIPE_TYPE_MESSAGE | PIPE_WAIT, 1, 0, 0, 0, nullptr)}; + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateNamedPipe failed"); + + // Start child process + std::string cmd{CommandLineFromArgv(connect_info_to_args(pipe_path))}; + STARTUPINFOA si{}; + si.cb = sizeof(si); + PROCESS_INFORMATION pi{}; + KJ_WIN32(CreateProcessA(nullptr, const_cast(cmd.c_str()), nullptr, nullptr, TRUE, 0, nullptr, nullptr, &si, &pi), "CreateProcess failed"); + CloseHandle(pi.hThread); // not needed + + // Duplicate socket for the child (now that we know its PID) + WSAPROTOCOL_INFO info{}; + KJ_WINSOCK(WSADuplicateSocket(socket, pi.dwProcessId, &info), "WSADuplicateSocket failed"); + + // Send socket to the child via the pipe + KJ_WIN32(ConnectNamedPipe(pipe, nullptr) || GetLastError() == ERROR_PIPE_CONNECTED, "ConnectNamedPipe failed"); + DWORD wr; + KJ_WIN32(WriteFile(pipe, &info, sizeof(info), &wr, nullptr) && wr == sizeof(info), "WriteFile(pipe) failed"); + CloseHandle(pipe); + + return reinterpret_cast(pi.hProcess); +#endif +} + +SocketId StartSpawned(const ConnectInfo& connect_info) +{ +#ifndef WIN32 + return std::stoi(connect_info); +#else + HANDLE pipe = CreateFileA(connect_info.c_str(), GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr); + KJ_WIN32(pipe != INVALID_HANDLE_VALUE, "CreateFile(pipe) failed"); + + WSAPROTOCOL_INFO info{}; + DWORD rd; + KJ_WIN32(ReadFile(pipe, &info, sizeof(info), &rd, nullptr) && rd == sizeof(info), "ReadFile(pipe) failed"); + CloseHandle(pipe); + + WSADATA dontcare; + KJ_WIN32(WSAStartup(MAKEWORD(2, 2), &dontcare) != 0, "WSAStartup() failed"); + + SOCKET socket{WSASocketA(FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, FROM_PROTOCOL_INFO, &info, 0, WSA_FLAG_OVERLAPPED | WSA_FLAG_NO_HANDLE_INHERIT)}; + KJ_WINSOCK(socket, "WSASocket(FROM_PROTOCOL_INFO) failed"); + return socket; +#endif } void ExecProcess(const std::vector& args) @@ -149,13 +272,22 @@ void ExecProcess(const std::vector& args) } } -int WaitProcess(int pid) +int WaitProcess(ProcessId pid) { +#ifndef WIN32 int status; if (::waitpid(pid, &status, 0 /* options */) != pid) { throw std::system_error(errno, std::system_category(), "waitpid"); } return status; +#else + HANDLE handle{reinterpret_cast(pid)}; + DWORD result{WaitForSingleObject(handle, INFINITE)}; + KJ_WIN32(result != WAIT_OBJECT_0, "WaitForSingleObject(child) failed"); + KJ_WIN32(GetExitCodeProcess(handle, &result), "GetExitCodeProcess failed"); + CloseHandle(handle); + return result; +#endif } } // namespace mp