Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 24 additions & 6 deletions shuttle/src/future/batch_semaphore.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! A counting semaphore supporting both async and sync operations.
use crate::current;
use crate::runtime::execution::ExecutionState;
use crate::runtime::task::{clock::VectorClock, TaskId};
use crate::runtime::task::{clock::VectorClock, Event, TaskId};
use crate::runtime::thread;
use crate::sync::{ResourceSignature, ResourceType};
use std::cell::RefCell;
Expand Down Expand Up @@ -289,7 +289,6 @@ impl BatchSemaphoreState {
pub struct BatchSemaphore {
state: RefCell<BatchSemaphoreState>,
fairness: Fairness,
#[allow(unused)]
signature: ResourceSignature,
}

Expand Down Expand Up @@ -392,8 +391,9 @@ impl BatchSemaphore {

/// Closes the semaphore. This prevents the semaphore from issuing new
/// permits and notifies all pending waiters.
#[track_caller]
pub fn close(&self) {
thread::switch();
thread::switch(Event::batch_semaphore_rel(&self.signature));

self.init_object_id();
let mut state = self.state.borrow_mut();
Expand Down Expand Up @@ -436,8 +436,9 @@ impl BatchSemaphore {
/// If the permits are available, returns Ok(())
/// If the semaphore is closed, returns `Err(TryAcquireError::Closed)`
/// If there aren't enough permits, returns `Err(TryAcquireError::NoPermits)`
#[track_caller]
pub fn try_acquire(&self, num_permits: usize) -> Result<(), TryAcquireError> {
thread::switch();
thread::switch(Event::batch_semaphore_acq(&self.signature));

self.init_object_id();
let mut state = self.state.borrow_mut();
Expand Down Expand Up @@ -538,22 +539,25 @@ impl BatchSemaphore {
}

/// Acquire the specified number of permits (async API)
#[track_caller]
pub fn acquire(&self, num_permits: usize) -> Acquire<'_> {
// No switch here; switch should be triggered on polling future
self.init_object_id();
Acquire::new(self, num_permits)
}

/// Acquire the specified number of permits (blocking API)
#[track_caller]
pub fn acquire_blocking(&self, num_permits: usize) -> Result<(), AcquireError> {
// No switch here; switch should be triggered on polling future
self.init_object_id();
crate::future::block_on(self.acquire(num_permits))
}

/// Release `num_permits` back to the Semaphore
#[track_caller]
pub fn release(&self, num_permits: usize) {
thread::switch();
thread::switch(Event::batch_semaphore_rel(&self.signature));

self.init_object_id();
if num_permits == 0 {
Expand Down Expand Up @@ -645,6 +649,7 @@ pub struct Acquire<'a> {
}

impl<'a> Acquire<'a> {
#[track_caller]
fn new(semaphore: &'a BatchSemaphore, num_permits: usize) -> Self {
let waiter = Arc::new(Waiter::new(num_permits));
Self {
Expand Down Expand Up @@ -689,7 +694,7 @@ impl Future for Acquire<'_> {
let blocking_is_not_commutative = self.semaphore.fairness == Fairness::StrictlyFair;

if self.never_polled && (will_succeed || blocking_is_not_commutative) {
thread::switch();
thread::switch(Event::batch_semaphore_acq(&self.semaphore.signature));
}
self.never_polled = false;

Expand Down Expand Up @@ -778,12 +783,25 @@ impl Future for Acquire<'_> {
self.waiter.is_queued.store(true, Ordering::SeqCst);
}
trace!("Acquire::poll for waiter {:?} that is enqueued", self.waiter);

let event = Event::batch_semaphore_acq(&self.semaphore.signature);
// SAFETY: This is safe because the current task immediately suspends after this future
// returns Poll::Pending (src/future/mod.rs). Whenever a task resumes, the `next_event`
// is unset, so there is no opportunity to corrupt the reference to our signature while
// it is set as the `next_task`.
ExecutionState::with(|s| unsafe { s.current_mut().set_next_event(event) });
Poll::Pending
}
Err(TryAcquireError::Closed) => unreachable!(),
}
} else {
// No progress made, future is still pending.
let event = Event::batch_semaphore_acq(&self.semaphore.signature);
// SAFETY: This is safe because the current task immediately suspends after this future
// returns Poll::Pending (src/future/mod.rs). Whenever a task resumes, the `next_event`
// is unset, so there is no opportunity to corrupt the reference to our signature while
// it is set as the `next_task`.
ExecutionState::with(|s| unsafe { s.current_mut().set_next_event(event) });
Poll::Pending
}
}
Expand Down
2 changes: 1 addition & 1 deletion shuttle/src/future/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ pub fn block_on<F: Future>(future: F) -> F::Output {
Poll::Ready(result) => break result,
Poll::Pending => {
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
thread::switch();
thread::switch_keeping_current_event();
}
}
}
Expand Down
22 changes: 15 additions & 7 deletions shuttle/src/runtime/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use crate::runtime::failure::{init_panic_hook, persist_failure};
use crate::runtime::storage::{StorageKey, StorageMap};
use crate::runtime::task::clock::VectorClock;
use crate::runtime::task::labels::Labels;
use crate::runtime::task::{ChildLabelFn, Task, TaskId, TaskName, TaskSignature, DEFAULT_INLINE_TASKS};
use crate::runtime::task::{ChildLabelFn, Event, Task, TaskId, TaskName, TaskSignature, DEFAULT_INLINE_TASKS};
use crate::runtime::thread;
use crate::runtime::thread::continuation::PooledContinuation;
use crate::scheduler::{Schedule, Scheduler};
Expand Down Expand Up @@ -531,7 +531,8 @@ impl ExecutionState {
where
F: Future<Output = ()> + 'static,
{
thread::switch();
let signature = ExecutionState::with(|state| state.current_mut().signature.new_child(caller));
thread::switch(Event::Spawn(&signature));
let task_id = Self::with(|state| {
let schedule_len = CurrentSchedule::len();
let parent_span_id = state.top_level_span.id();
Expand All @@ -554,7 +555,7 @@ impl ExecutionState {
schedule_len,
tag,
Some(state.current().id()),
state.current_mut().signature.new_child(caller),
signature,
);

state.tasks.push(task);
Expand All @@ -574,7 +575,8 @@ impl ExecutionState {
mut initial_clock: Option<VectorClock>,
caller: &'static Location<'static>,
) -> TaskId {
thread::switch();
let signature = ExecutionState::with(|state| state.current_mut().signature.new_child(caller));
thread::switch(Event::Spawn(&signature));
let task_id = Self::with(|state| {
let parent_span_id = state.top_level_span.id();
let task_id = TaskId(state.tasks.len());
Expand All @@ -601,7 +603,7 @@ impl ExecutionState {
CurrentSchedule::len(),
tag,
Some(state.current().id()),
state.current_mut().signature.new_child(caller),
signature,
);
state.tasks.push(task);

Expand Down Expand Up @@ -658,8 +660,10 @@ impl ExecutionState {
if std::thread::panicking() && !state.in_cleanup {
return true;
}

debug_assert!(
matches!(state.current_task, ScheduledTask::Some(_)) && state.next_task == ScheduledTask::None,
matches!(state.current_task, ScheduledTask::Some(_) | ScheduledTask::Finished)
&& state.next_task == ScheduledTask::None,
"we're inside a task and scheduler should not yet have run"
);

Expand Down Expand Up @@ -750,6 +754,10 @@ impl ExecutionState {
self.tasks.get(id.0)
}

pub(crate) fn try_current_mut(&mut self) -> Option<&mut Task> {
self.tasks.get_mut(self.current_task.id()?.0)
}

pub(crate) fn in_cleanup(&self) -> bool {
self.in_cleanup
}
Expand Down Expand Up @@ -881,7 +889,7 @@ impl ExecutionState {
.scheduler
.borrow_mut()
.next_task(task_refs, self.current_task.id(), is_yielding)
.map(ScheduledTask::Some)
.map(|task| ScheduledTask::Some(task.id()))
.unwrap_or(ScheduledTask::Stopped);

// Tracing this `in_scope` is purely a matter of taste. We do it because
Expand Down
6 changes: 3 additions & 3 deletions shuttle/src/runtime/runner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,12 +237,12 @@ impl<S: Scheduler> Scheduler for PortfolioStoppableScheduler<S> {
}
}

fn next_task(
fn next_task<'a>(
&mut self,
runnable_tasks: &[&Task],
runnable_tasks: &'a [&'a Task],
current_task: Option<TaskId>,
is_yielding: bool,
) -> Option<TaskId> {
) -> Option<&'a Task> {
if self.stop_signal.load(Ordering::SeqCst) {
None
} else {
Expand Down
152 changes: 151 additions & 1 deletion shuttle/src/runtime/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,132 @@ impl PartialEq for TaskSignature {

impl Eq for TaskSignature {}

pub(crate) type Loc = &'static Location<'static>;

#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub(crate) enum Event<'a> {
AtomicRead(&'a ResourceSignature, Loc),
AtomicWrite(&'a ResourceSignature, Loc),
AtomicReadWrite(&'a ResourceSignature, Loc),
BatchSemaphoreAcq(&'a ResourceSignature, Loc),
BatchSemaphoreRel(&'a ResourceSignature, Loc),
BarrierWait(&'a ResourceSignature, Loc),
CondvarWait(&'a ResourceSignature, Loc),
CondvarNotify(Loc),
Park(Loc),
Unpark(&'a TaskSignature, Loc),
ChannelSend(&'a ResourceSignature, Loc),
ChannelRecv(&'a ResourceSignature, Loc),
Spawn(&'a TaskSignature),
Yield(Loc),
Sleep(Loc),
Exit,
Join(&'a TaskSignature, Loc),
Unknown,
}

impl<'a> std::fmt::Display for Event<'a> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Event::AtomicRead(_, loc) => write!(f, "AtomicRead at {}", loc),
Event::AtomicWrite(_, loc) => write!(f, "AtomicWrite at {}", loc),
Event::AtomicReadWrite(_, loc) => write!(f, "AtomicReadWrite at {}", loc),
Event::BatchSemaphoreAcq(_, loc) => write!(f, "BatchSemaphoreAcq at {}", loc),
Event::BatchSemaphoreRel(_, loc) => write!(f, "BatchSemaphoreRel at {}", loc),
Event::BarrierWait(_, loc) => write!(f, "BarrierWait at {}", loc),
Event::CondvarWait(_, loc) => write!(f, "CondvarWait at {}", loc),
Event::CondvarNotify(loc) => write!(f, "CondvarNotify at {}", loc),
Event::Park(loc) => write!(f, "Park at {}", loc),
Event::Unpark(_, loc) => write!(f, "Unpark at {}", loc),
Event::ChannelSend(_, loc) => write!(f, "ChannelSend at {}", loc),
Event::ChannelRecv(_, loc) => write!(f, "ChannelRecv at {}", loc),
Event::Spawn(sig) => write!(f, "Spawn at {}", sig.task_creation_stack.last().unwrap().0),
Event::Yield(loc) => write!(f, "Yield at {}", loc),
Event::Sleep(loc) => write!(f, "Sleep at {}", loc),
Event::Exit => write!(f, "Exit"),
Event::Join(_, loc) => write!(f, "Join at {}", loc),
Event::Unknown => write!(f, "Unknown"),
}
}
}

impl<'a> Event<'a> {
#[track_caller]
pub(crate) fn atomic_read(sig: &'a ResourceSignature) -> Self {
Self::AtomicRead(sig, Location::caller())
}

#[track_caller]
pub(crate) fn atomic_write(sig: &'a ResourceSignature) -> Self {
Self::AtomicWrite(sig, Location::caller())
}

#[track_caller]
pub(crate) fn atomic_read_write(sig: &'a ResourceSignature) -> Self {
Self::AtomicReadWrite(sig, Location::caller())
}

#[track_caller]
pub(crate) fn batch_semaphore_acq(sig: &'a ResourceSignature) -> Self {
Self::BatchSemaphoreAcq(sig, Location::caller())
}

#[track_caller]
pub(crate) fn batch_semaphore_rel(sig: &'a ResourceSignature) -> Self {
Self::BatchSemaphoreRel(sig, Location::caller())
}

#[track_caller]
pub(crate) fn barrier_wait(sig: &'a ResourceSignature) -> Self {
Self::BarrierWait(sig, Location::caller())
}

#[track_caller]
pub(crate) fn condvar_wait(sig: &'a ResourceSignature) -> Self {
Self::CondvarWait(sig, Location::caller())
}

#[track_caller]
pub(crate) fn condvar_notify() -> Self {
Self::CondvarNotify(Location::caller())
}

#[track_caller]
pub(crate) fn park() -> Self {
Self::Park(Location::caller())
}

#[track_caller]
pub(crate) fn unpark(sig: &'a TaskSignature) -> Self {
Self::Unpark(sig, Location::caller())
}

#[track_caller]
pub(crate) fn channel_send(sig: &'a ResourceSignature) -> Self {
Self::ChannelSend(sig, Location::caller())
}

#[track_caller]
pub(crate) fn channel_recv(sig: &'a ResourceSignature) -> Self {
Self::ChannelRecv(sig, Location::caller())
}

#[track_caller]
pub(crate) fn yield_now() -> Self {
Self::Yield(Location::caller())
}

#[track_caller]
pub(crate) fn sleep() -> Self {
Self::Sleep(Location::caller())
}

#[track_caller]
pub(crate) fn join(sig: &'a TaskSignature) -> Self {
Self::Join(sig, Location::caller())
}
}

/// A `Task` represents a user-level unit of concurrency. Each task has an `id` that is unique within
/// the execution, and a `state` reflecting whether the task is runnable (enabled) or not.
#[derive(Debug)]
Expand All @@ -276,6 +402,8 @@ pub struct Task {
// Remember whether the waker was invoked while we were running
woken: bool,

next_event: Event<'static>,

name: Option<String>,

local_storage: StorageMap,
Expand Down Expand Up @@ -351,6 +479,7 @@ impl Task {
waiter: None,
waker,
woken: false,
next_event: Event::Unknown,
detached: false,
park_state: ParkState::default(),
name,
Expand Down Expand Up @@ -425,7 +554,7 @@ impl Task {
let cx = &mut Context::from_waker(&waker);
while future.as_mut().poll(cx).is_pending() {
ExecutionState::with(|state| state.current_mut().sleep_unless_woken());
thread::switch();
thread::switch_keeping_current_event();
}
}),
stack_size,
Expand Down Expand Up @@ -679,6 +808,27 @@ impl Task {
}
)
}

/// Get the next_event with a downcast lifetime tied to self. This prevents the caller from
/// borrowing as `'static`, which is only used to avoid borrow-checker issues and spurious
/// lifetime annotations on the Task struct
/// SAFETY: The borrowed lifetime should not extend past resuming the task after a switch
pub(crate) fn next_event(&self) -> &Event<'_> {
unsafe { std::mem::transmute(&self.next_event) }
}

/// Transmutes an Event reference to a `'static` lifetime to avoid borrow checker issues when
/// switching coroutines
/// SAFETY: For this to be safe, the next event must be unset with `unset_next_event` before
/// the actual lifetime of the event expires. In general, as long as the event is unset when
/// the task is resumed, this will be safe.
pub(crate) unsafe fn set_next_event(&mut self, event: Event<'_>) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

SAFETY comment describing how to make it safe. Also describe why unsafe is needed

self.next_event = unsafe { std::mem::transmute::<Event<'_>, Event<'static>>(event) };
}

pub(crate) fn unset_next_event(&mut self) {
self.next_event = Event::Unknown;
}
}

#[derive(PartialEq, Eq, Clone, Copy, Debug)]
Expand Down
Loading
Loading