-
Notifications
You must be signed in to change notification settings - Fork 1
rhat monitor #38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
rhat monitor #38
Conversation
WardBrian
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overall I find this pretty easy to read, but a few things did jump out at me
|
Thanks, Brian! Some quick answers in order of priority.
|
I think the reason this was happening is because my laptop only has 8 cores. I added some logging (which is of course fraught, as it might change the timings!) to the pop-while loop in the controller thread, and it looks like it's basically always taking 64 elements from the first few chains, and 0 from the rest. The chains never seem to almost never context switch between each other, so you need to pick It's also interesting to note that, basically no matter what I make |
|
With M=16, I can also get it to work reasonably well by inserting something like this inside the loop each sampler thread has: if ((iter + 1) % 100 == 0) {
std::this_thread::yield();
}This allows other chains to progress before the chain hits |
I figured it out -- unless you put a join there, the scope ends basically immediately after the threads are created, and that triggers the stop token being set in the controller thread. So your options are:
You can then delete that for loop |
|
Thanks---that was all super useful. I at least made an attempt at fixing everything you suggested in the review, I made the I also added a 1 nanosecond sleep to the normal sampler and now the sampling is very fast and very even across 16 threads. |
examples/rhat_monitor.cpp
Outdated
| if (r_hat <= rhat_threshold) { | ||
| for (auto& w : workers) { | ||
| w.request_stop(); | ||
| } | ||
| break; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this is the only way out of the loop, this will also want some check for if all the workers hit their max number of iterations
examples/rhat_monitor.cpp
Outdated
| std::jthread controller([&](std::stop_token st) { | ||
| interactive_qos(); | ||
| controller_loop(st, queues, workers, rhat_threshold, start_gate); | ||
| }); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is there a reason to run the controller thread on its own thread rather than on the main thread here?
|
I think that |
SteveBronder
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Few comments. Nothing major! If there is anything you want me to look at that I missed lmk happy to review again
| for (std::size_t m = 0; m < M; ++m) { | ||
| bool popped = false; | ||
| SampleStats u; | ||
| while (queues[m].pop(u)) { | ||
| chain_means[m] = u.sample_mean; | ||
| chain_variances[m] = u.sample_var; | ||
| counts[m] = u.count; | ||
| } | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is your intention here to busy spin until this particular queue can pop off a new item?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No. The intention is to keep popping until there's nothing left to pop. We always want to read the latest. Is there a better way to do this? And it's not critical that we read at least one thing per m.
examples/rhat_monitor.cpp
Outdated
| sampler_.sample(sample_.draws(), logp); | ||
| sample_.append_logp(logp); | ||
| logp_stats_.push(logp); | ||
| q_.emplace(logp_stats_.sample_stats()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since you are doing a busy spin on master for each read, but not busy spinning to write, I think this will throw away most of your results. I would throw a busy spin here as well until we start letting things run more asynchronously.
| q_.emplace(logp_stats_.sample_stats()); | |
| while(!q_.emplace(logp_stats_.sample_stats())) {}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If I do that, the algorithm reliably hangs. I'm not busy-spinning the controller loop in the main body---it only updates the chain means, variances, and counts if it can pop them from the ring buffer. It keeps popping while it can to get to the most recent.
The individual chains busy spin after they have hit their max warmup to make sure their final update gets registered. I haven't seen the be a problem yet, but I keep thinking there may be more hidden race conditions.
I feel like an alternative might be to have another latch used to have everyone stop.
|
Thanks, @StreveBronder. I'll hopefully be able to get to these tomorrow. I'll follow up if I have questions. I pushed a bunch of changes this morning, so good you're looking at the new one. I was getting segfaults in higher dimensions, went to dentist and it occurred to me the parallel memory access was the RAII for I'm going to leave the simple queue in place here, but the plan is to work on improvements to that later since it's a drop-in replacement. The way I'm using the queue, we only ever need the very latest value. The statistics that get pushed are cumulative means and variances for the chains, so old ones don't matter once we've updated. Intuitively, we don't want to evaluate if 89 iterations were enough to converge when we've already generated the 90th iteration. |
I like passing around an allocator for this. I'll look deeper at the code for what kind in particular. My guess is something like a fixed size memory pool like
That's fine. Since you always want the very latest value we may still want to have a function like |
This is a first draft of a tool to run sampling in multiple threads and asynchronously monitor R-hat for completion.
I haven't added any doc yet in the hope that it actually makes the code a bit easier to understand.
I've more or less hit a fixed point with LLM code review. ChatGPT 5 is still suggesting the following two relevant changes, which look reasonable. It's still on me to slow down the controller busy spin---it has suggested like 10 different alternatives and it won't take "I'd rather spin the controller than have the chains run too long" for an answer. It's also very offended I just let the R-hat be NaN initially and wants me to put all sorts of guards in place.