Lower latency associative scan option #10599
Unanswered
oliverdutton
asked this question in
Ideas
Replies: 2 comments
-
|
Looks like a lot of people would be interested (myself included). Are you still interested in creating such a PR? |
Beta Was this translation helpful? Give feedback.
0 replies
-
|
Cool, I'll put together a PR in the next couple of days. I will make the code clearer and generalise it to accept trees of arrays. I'm pretty sure that if r_max > n will never get executed. |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
In one of my problems the implementation was bottlenecked by a cumulative matmul. JAX has a handy implemention of a work-efficient associative scan for this
lax.associative_scan. This reduces the procedure from N steps to 2 log_2{N}-2 steps. There is a work-inefficient implementation that reduces this to log_2{N} steps, shown below which is faster for small problem sizes where the GPU is not saturated.Would anyone else be interested in having a work-inefficient option in the lax.associative scan? If so I can put together a pull request.
see https://developer.nvidia.com/gpugems/gpugems3/part-vi-gpu-computing/chapter-39-parallel-prefix-sum-scan-cuda, http://www.cs.cmu.edu/~guyb/papers/Ble93.pdf, https://en.wikipedia.org/wiki/Prefix_sum
Beta Was this translation helpful? Give feedback.
All reactions