Create cubecl-scan crate #863
Draft
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR adds a new
cubecl-scan
crate to implement the associative scan (inclusive/exclusive) operation as an optimized GPU primitive.Rough Overview
The goal is to implement something very much like the JAX associative_scan operation. The user should be able to provide an arbitrary associative operator based on which the implementation should compute the scan operation across a tensor.
Current Status
So far, I have implemented a somewhat functional version of the decoupled lookback algorithm. It lives in
src/kernels/decoupled_lookback.rs
and is currently functional incubecl-wgsl
andcubecl-cpu
. Unfortunately, incubecl-cuda
it currently hangs indefinitely, and I suspect it is related to incorrect memory ordering of the atomic stores in the aggregates and flags.This has sat around on my local disk for too long and I have recently encountered a case in a personal project where I would like to use
burn
, but require the presence of a fast scan implementation. I am sharing this as a draft PR in the hopes that others might look at it as well.The goal is to also implement a naive version of the scan operation using multiple passes for targets that do not have forward-progress guarantees (questionable for Vulkan already) or lack the necessary atomic operations.
Notes
It was hard to get the memory ordering to be roughly what I want due to being unable to specify atomic ordering modes like in standard Rust atomics and memory/compiler barriers. A hacky solution for now was to use atomic values for the aggregates. However, that means that it is impossible to use aggregate types that do not work with atomics. The algorithm itself does not need them to be atomic, and I would like to get rid of that as well. For example, the Linear Recurrent Unit proposed by DeepMind requires the scan to be performed over vectors, which would not be supported by the current implementation.
Related Issues
tracel-ai/burn#3806