Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,9 @@ num_cpus = "^1.14.0"
[target.'cfg(any(target_os = "android", target_os = "linux", target_os = "macos", target_os = "freebsd", target_os = "netbsd"))'.dependencies]
libc = "^0.2.30"

[target.'cfg(target_os = "windows")'.dependencies]
winapi = { version = "^0.3.9", features = ["processthreadsapi", "winbase"] }
[target.'cfg(target_os = "windows")'.dependencies.windows-sys]
features = [
"Win32_System_SystemInformation",
"Win32_System_Threading",
]
version = "^0.60"
122 changes: 74 additions & 48 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -216,70 +216,94 @@ fn set_for_current_helper(core_id: CoreId) -> bool {
}

#[cfg(target_os = "windows")]
extern crate winapi;
extern crate windows_sys;

#[cfg(target_os = "windows")]
mod windows {
use winapi::shared::basetsd::{DWORD_PTR, PDWORD_PTR};
use winapi::um::processthreadsapi::{GetCurrentProcess, GetCurrentThread};
use winapi::um::winbase::{GetProcessAffinityMask, SetThreadAffinityMask};
use std::ptr;

use windows_sys::Win32::System::{
SystemInformation::GROUP_AFFINITY,
Threading::{
GetActiveProcessorCount, GetCurrentProcess, GetCurrentThread, GetProcessGroupAffinity,
SetThreadGroupAffinity,
},
};

use super::CoreId;

pub fn get_core_ids() -> Option<Vec<CoreId>> {
if let Some(mask) = get_affinity_mask() {
// Find all active cores in the bitmask.
let mut core_ids: Vec<CoreId> = Vec::new();
let group_list = unsafe { get_group_list() };
let group_list = match group_list {
Some(group_list) => group_list,
None => return None,
};

for i in 0..64 as u64 {
let test_mask = 1 << i;
let core_ids = unsafe {
group_list
.into_iter()
.map(|group| GetActiveProcessorCount(group))
.sum()
};

if (mask & test_mask) == test_mask {
core_ids.push(CoreId { id: i as usize });
}
}
let core_ids: Vec<CoreId> = (0..core_ids)
.into_iter()
.map(|n| CoreId { id: n as usize })
.collect();

Some(core_ids)
}
else {
None
}
Some(core_ids)
}

pub fn set_for_current(core_id: CoreId) -> bool {
// Convert `CoreId` back into mask.
let mask: u64 = 1 << core_id.id;

// Set core affinity for current thread.
let res = unsafe {
SetThreadAffinityMask(
GetCurrentThread(),
mask as DWORD_PTR
)
let group_list = unsafe { get_group_list() };
let group_list = match group_list {
Some(group_list) => group_list,
None => return false,
};
res != 0
}

fn get_affinity_mask() -> Option<u64> {
let mut system_mask: usize = 0;
let mut process_mask: usize = 0;
let mut id = core_id.id;

// Convert `CoreId` to (group, mask).
for group in group_list {
let count = unsafe { GetActiveProcessorCount(group) } as usize;
if id < count {
// If the core_id is within this group, set the mask.
let affinity = GROUP_AFFINITY {
Mask: 1 << id,
Group: group,
..Default::default()
};
return unsafe {
SetThreadGroupAffinity(GetCurrentThread(), &affinity, ptr::null_mut()) != 0
};
}

id -= count;
}

let res = unsafe {
GetProcessAffinityMask(
GetCurrentProcess(),
&mut process_mask as PDWORD_PTR,
&mut system_mask as PDWORD_PTR
)
};
// If we reach here, it means the core_id is out of bounds.
false
}

// Successfully retrieved affinity mask
if res != 0 {
Some(process_mask as u64)
unsafe fn get_group_list() -> Option<Vec<u16>> {
let current_process = GetCurrentProcess();
let mut group_len = 0;
let ret = GetProcessGroupAffinity(current_process, &mut group_len, ptr::null_mut());
if group_len == 0 || ret == 0 {
return None;
}
// Failed to retrieve affinity mask
else {
None

let group_list: Vec<u16> = vec![0; group_len as usize];
let ret = GetProcessGroupAffinity(
current_process,
&mut group_len,
group_list.as_ptr() as *mut _,
);
if ret == 0 {
return None;
}

Some(group_list)
}

#[cfg(test)]
Expand All @@ -293,8 +317,10 @@ mod windows {
match get_core_ids() {
Some(set) => {
assert_eq!(set.len(), num_cpus::get());
},
None => { assert!(false); },
}
None => {
assert!(false);
}
}
}

Expand All @@ -304,7 +330,7 @@ mod windows {

assert!(ids.len() > 0);

assert_ne!(set_for_current(ids[0]), 0);
assert_ne!(set_for_current(ids[0]), false);
}
}
}
Expand Down