Skip to content

Conversation

@bob-carpenter
Copy link
Collaborator

This is a first draft of a tool to run sampling in multiple threads and asynchronously monitor R-hat for completion.

I haven't added any doc yet in the hope that it actually makes the code a bit easier to understand.

I've more or less hit a fixed point with LLM code review. ChatGPT 5 is still suggesting the following two relevant changes, which look reasonable. It's still on me to slow down the controller busy spin---it has suggested like 10 different alternatives and it won't take "I'd rather spin the controller than have the chains run too long" for an answer. It's also very offended I just let the R-hat be NaN initially and wants me to put all sorts of guards in place.

Below are targeted changes that improve allocation behavior, responsiveness, and clarity while keeping your non-blocking controller design.

1) Eliminate per-iteration allocations (hot path)

Reuse a per-chain scratch buffer and pass by std::span. This removes an O(D) allocation and a move per draw.

// Sample: accept span to avoid constructing a fresh std::vector each iter
void append_draw(std::size_t iteration, double logp, std::span<const double> draw) {
  logp_.emplace_back(logp);
  theta_.insert(theta_.end(), draw.begin(), draw.end());
}

// ChainTask: reuse one buffer
class ChainTask {
  // ...
  std::vector<double> scratch_;
public:
  ChainTask(std::size_t chain_id, std::size_t draws_per_chain, Sampler& sampler,
            Queue& q, std::latch& start_gate)
      : chain_id_(chain_id), draws_per_chain_(draws_per_chain), sampler_(sampler),
        sample_(chain_id, sampler.dim(), draws_per_chain),
        q_(q), start_gate_(start_gate), scratch_(sampler.dim()) {}

  void operator()(std::stop_token st) {
    initiated_qos();
    start_gate_.arrive_and_wait();
    for (std::size_t iter = 0; iter < draws_per_chain_; ++iter) {
      double logp = sampler_(scratch_);                    // changed
      logp_stats_.push(logp);
      sample_.append_draw(iter, logp, scratch_);           // changed
      q_.emplace(logp_stats_.count(), logp_stats_.mean(),  // unchanged
                 logp_stats_.sample_variance());
      if (st.stop_requested()) break;
    }
  }
  // ...
};
// StandardNormalSampler: fill into caller-provided buffer
double operator()(std::span<double> out) noexcept {
  double log_density = 0.0;
  for (std::size_t i = 0; i < dim_; ++i) {
    double x = normal_dist_(engine_);
    out[i] = x;
    log_density += -0.5 * x * x;
  }
  return log_density;
}

2) Make the ring buffer explicitly SPSC and cheaper to index

Use a power-of-two capacity and a bit mask; store elements in std::array to avoid default-constructing/relocating a std::vector<T> payload.

template <class T, std::size_t CapacityPow2>
class alignas(std::hardware_destructive_interference_size) SpscRingBuffer {
  static_assert(std::has_single_bit(CapacityPow2),
                "Capacity must be a power of two.");
public:
  bool emplace(const T& x) noexcept {
    auto w = write_.load(std::memory_order_relaxed);
    auto next = (w + 1) & (CapacityPow2 - 1);
    if (next == read_.load(std::memory_order_acquire)) return false;
    data_[w] = x;                                         // trivially cheap
    write_.store(next, std::memory_order_release);
    return true;
  }
  template <class... Args>
  bool emplace(Args&&... args) noexcept {
    auto w = write_.load(std::memory_order_relaxed);
    auto next = (w + 1) & (CapacityPow2 - 1);
    if (next == read_.load(std::memory_order_acquire)) return false;
    data_[w] = T(std::forward<Args>(args)...);
    write_.store(next, std::memory_order_release);
    return true;
  }
  bool pop(T& out) noexcept {
    auto r = read_.load(std::memory_order_relaxed);
    if (r == write_.load(std::memory_order_acquire)) return false;
    out = std::move(data_[r]);
    read_.store((r + 1) & (CapacityPow2 - 1), std::memory_order_release);
    return true;
  }
  constexpr std::size_t capacity() const noexcept { return CapacityPow2; }
private:
  std::array<T, CapacityPow2> data_{};
  alignas(std::hardware_destructive_interference_size) std::atomic<std::size_t> read_{0};
  alignas(std::hardware_destructive_interference_size) std::atomic<std::size_t> write_{0};
};

// Replace aliases
constexpr std::size_t RING_CAPACITY = 64;  // generous to avoid producer back-pressure
using Queue = SpscRingBuffer<SampleStats, RING_CAPACITY>;

Copy link
Collaborator

@WardBrian WardBrian left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall I find this pretty easy to read, but a few things did jump out at me

@bob-carpenter
Copy link
Collaborator Author

Thanks, Brian!

Some quick answers in order of priority.

  • The big question is why it takes so long to get beyond NaN. It only requires two non-identical draws per chain to be picked up by the controller. Maybe the controller needs to try harder to read? I should also introduce a bit of a delay into the simulation---when we plug in Stan, it's not going to generate draws this fast and this problem may be largely mitigated. Also need to experiment with 4 threads. I'm just exercising my new Mac. I have a feeling part of the problem may be from having non-performance cores assigned to some of the threads---there's a huge disparity in the number of iterations per thread across runs. Sometimes it's consistent within a run and sometimes off by an order of magnitude or more.

  • I can't figure out why the behavior is different if I explicitly join. My understanding is that jthreads are implicitly joined at the end of scope. I need to understand this much better.

  • I plan to use Steve's more focused SPSC for this task, but I could also use Rigtorp's. I doubt the ring buffer's going to be the bottleneck. I just wanted to start with one file. I have a really hard time getting the devops going for these projects, which is why I'm also not using Eigen. The next round should replace most of the std::vector uses with Eigen operations.

  • I forgot to put all my usual hooks into emacs on the new machine! I had run clang-format, but I need to figure out how to configure it to insist on braces. I sort of like the code w/o braces if it fits on one line.

  • I'll also clean up the over-allocation GPT suggested cleaning up.

  • I'll make all the minor code changes suggested. I wound up writing some things generally bottom-up then not using everything in them in the end. I tried to prune back, but obviously didn't catch everything.

@WardBrian
Copy link
Collaborator

The big question is why it takes so long to get beyond NaN. It only requires two non-identical draws per chain to be picked up by the controller. Maybe the controller needs to try harder to read? I should also introduce a bit of a delay into the simulation---when we plug in Stan, it's not going to generate draws this fast and this problem may be largely mitigated. Also need to experiment with 4 threads. I'm just exercising my new Mac. I have a feeling part of the problem may be from having non-performance cores assigned to some of the threads---there's a huge disparity in the number of iterations per thread across runs. Sometimes it's consistent within a run and sometimes off by an order of magnitude or more.

I think the reason this was happening is because my laptop only has 8 cores. I added some logging (which is of course fraught, as it might change the timings!) to the pop-while loop in the controller thread, and it looks like it's basically always taking 64 elements from the first few chains, and 0 from the rest. The chains never seem to almost never context switch between each other, so you need to pick M very carefully. If I set it to 6, it works better.

It's also interesting to note that, basically no matter what I make RING_CAPACITY, the while loop pops approximately that many elements per chain each time. It seems likely to me that this is an artifact of how cheap the 'sampler' is here.

@WardBrian
Copy link
Collaborator

With M=16, I can also get it to work reasonably well by inserting something like this inside the loop each sampler thread has:

      if ((iter + 1) % 100 == 0) {
        std::this_thread::yield();
      }

This allows other chains to progress before the chain hits draws_per_chain and ends

@WardBrian
Copy link
Collaborator

I can't figure out why the behavior is different if I explicitly join. My understanding is that jthreads are implicitly joined at the end of scope. I need to understand this much better.

I figured it out -- unless you put a join there, the scope ends basically immediately after the threads are created, and that triggers the stop token being set in the controller thread. So your options are:

  1. call controller.join()
  2. just ignore the stop token in the controller thread, make it while (true). You don't really want that thread to be cancellable, so I prefer this option honestly.

You can then delete that for loop

@bob-carpenter
Copy link
Collaborator Author

Thanks---that was all super useful. I at least made an attempt at fixing everything you suggested in the review, I made the controller.join() fix, and the added the yield.

I also added a 1 nanosecond sleep to the normal sampler and now the sampling is very fast and very even across 16 threads.

Comment on lines 280 to 284
if (r_hat <= rhat_threshold) {
for (auto& w : workers) {
w.request_stop();
}
break;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this is the only way out of the loop, this will also want some check for if all the workers hit their max number of iterations

Comment on lines 311 to 314
std::jthread controller([&](std::stop_token st) {
interactive_qos();
controller_loop(st, queues, workers, rhat_threshold, start_gate);
});
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason to run the controller thread on its own thread rather than on the main thread here?

@WardBrian
Copy link
Collaborator

I think that std::this_thread::sleep_for as you're using it is equivalent to yielding every iteration, which means it isn't a great simulation of a more expensive sampler

Copy link
Collaborator

@SteveBronder SteveBronder left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Few comments. Nothing major! If there is anything you want me to look at that I missed lmk happy to review again

Comment on lines 290 to 298
for (std::size_t m = 0; m < M; ++m) {
bool popped = false;
SampleStats u;
while (queues[m].pop(u)) {
chain_means[m] = u.sample_mean;
chain_variances[m] = u.sample_var;
counts[m] = u.count;
}
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is your intention here to busy spin until this particular queue can pop off a new item?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No. The intention is to keep popping until there's nothing left to pop. We always want to read the latest. Is there a better way to do this? And it's not critical that we read at least one thing per m.

sampler_.sample(sample_.draws(), logp);
sample_.append_logp(logp);
logp_stats_.push(logp);
q_.emplace(logp_stats_.sample_stats());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since you are doing a busy spin on master for each read, but not busy spinning to write, I think this will throw away most of your results. I would throw a busy spin here as well until we start letting things run more asynchronously.

Suggested change
q_.emplace(logp_stats_.sample_stats());
while(!q_.emplace(logp_stats_.sample_stats())) {};

Copy link
Collaborator Author

@bob-carpenter bob-carpenter Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I do that, the algorithm reliably hangs. I'm not busy-spinning the controller loop in the main body---it only updates the chain means, variances, and counts if it can pop them from the ring buffer. It keeps popping while it can to get to the most recent.

The individual chains busy spin after they have hit their max warmup to make sure their final update gets registered. I haven't seen the be a problem yet, but I keep thinking there may be more hidden race conditions.

I feel like an alternative might be to have another latch used to have everyone stop.

@bob-carpenter
Copy link
Collaborator Author

bob-carpenter commented Nov 12, 2025

Thanks, @StreveBronder. I'll hopefully be able to get to these tomorrow. I'll follow up if I have questions.

I pushed a bunch of changes this morning, so good you're looking at the new one. I was getting segfaults in higher dimensions, went to dentist and it occurred to me the parallel memory access was the RAII for std::vector. Turns out I wasn't compiling with -pthread. Adding that cleaned it up. Should I instead be doing all my fixed allocation outside and just passing references to it into the threads so they never have to allocate? It's a one-time .reserve() call for the results.

I'm going to leave the simple queue in place here, but the plan is to work on improvements to that later since it's a drop-in replacement. The way I'm using the queue, we only ever need the very latest value. The statistics that get pushed are cumulative means and variances for the chains, so old ones don't matter once we've updated. Intuitively, we don't want to evaluate if 89 iterations were enough to converge when we've already generated the 90th iteration.

@SteveBronder
Copy link
Collaborator

Should I instead be doing all my fixed allocation outside and just passing references to it into the threads so they never have to allocate? It's a one-time .reserve() call for the results.

I like passing around an allocator for this. I'll look deeper at the code for what kind in particular. My guess is something like a fixed size memory pool like std::pmr::unsynchronized_pool_resource per thread would work nicely.

I'm going to leave the simple queue in place here, but the plan is to work on improvements to that later since it's a drop-in replacement. The way I'm using the queue, we only ever need the very latest value. The statistics that get pushed are cumulative means and variances for the chains, so old ones don't matter once we've updated. Intuitively, we don't want to evaluate if 89 iterations were enough to converge when we've already generated the 90th iteration.

That's fine. Since you always want the very latest value we may still want to have a function like force_push that forces the new value to be added to the queue. But yeah I agree it is fine as is

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants