-
Notifications
You must be signed in to change notification settings - Fork 110
core: Safer tensor handles + checked vectorization #877
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
27f4c2e
to
8e4b170
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not convinced this is a good idea; those calculations are 100% wasteful with Burn, since validations are done at the tensor API level (user-facing) and not internally. We assume data consistency internally, which reduces error handling to a minimum. Otherwise, we would have the same error handling at multiple levels of abstractions, which is very wasteful and adds complexity.
Do you have opinions on this @wingertge
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks | ||
/// for vectorization compatibility. | ||
/// Try to convert the handle into a tensor argument, validating that the | ||
/// requested vectorization factor is supported by the runtime. This does not | ||
/// enforce inner-most contiguity or alignment requirements as kernels may | ||
/// legally vectorize along axes other than the innermost. | ||
pub fn try_as_tensor_arg( | ||
&'a self, | ||
vectorization: u8, | ||
) -> Result<TensorArg<'a, R>, TensorArgError> { | ||
if !R::supported_line_sizes().contains(&vectorization) { | ||
return Err(TensorArgError::UnsupportedVectorization { requested: vectorization, supported: R::supported_line_sizes() }); | ||
} | ||
Ok(self.as_tensor_arg(vectorization)) | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The problem I see here is that the validation is done in 100% of cases in Burn and CubeCL when we choose the line size, not when applying it! It's kind of wasteful to do it multiple times, especially since we iterate over a list each time. I don't mind too much having a try function, but it would not be good practice to use them; validation should be done before creating the tensor argument.
if shape.len() != strides.len() { | ||
return Err(TensorHandleError::RankMismatch { | ||
shape_rank: shape.len(), | ||
stride_rank: strides.len(), | ||
}); | ||
} | ||
if elem_size == 0 { | ||
return Err(TensorHandleError::ElemSizeZero); | ||
} | ||
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed). | ||
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() { | ||
if s == 0 && d > 1 { | ||
return Err(TensorHandleError::ZeroStride { axis: i }); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The same thing here: most of those things are validated in other places, and it's kind of wasteful to do those validations multiple times.
I actually think it could be a good idea, but specifically gated behind |
58e6a34
to
714498a
Compare
Per feedback, I reverted internal hot paths to the fast path ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That seems like the right balance. The CI should be fixed before merging.
b52c9ee
to
77556c2
Compare
…t across crates - TensorHandleRef::{try_from_parts, try_from_typed} - TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only) - Errors: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported } - Adopt try_as_tensor_arg in attention/matmul/convolution/reduce/std - Runtime tests for handle validation and unsupported vectorization factors core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
77556c2
to
3a98f13
Compare
Branch has been rebased on main (cf40b4e) and confirmed CI workflow passes on my fork. |
This PR makes constructing tensor handles and choosing a vectorization factor safer and more ergonomic—especially for host wrappers and FFI—by surfacing misuse as clear, early errors instead of relying on scattered
unsafe
and assumptions.Behavior
unsafe
entry points remain available; no changes to kernel or runtime ABIs.Impact
Validation
Notes
-D warnings
.PR has been validated with Burn — no compilation or test errors.