Skip to content

Conversation

ariawisp
Copy link

This PR makes constructing tensor handles and choosing a vectorization factor safer and more ergonomic—especially for host wrappers and FFI—by surfacing misuse as clear, early errors instead of relying on scattered unsafe and assumptions.

Behavior

  • Safe constructors validate shape/stride basics; a checked conversion verifies the factor is supported by the runtime.
  • The checked path intentionally does not enforce inner‑most contiguity or divisibility: kernels may vectorize on other axes or handle pitched/tail cases internally.
  • Existing unsafe entry points remain available; no changes to kernel or runtime ABIs.

Impact

  • Reduces boilerplate and the likelihood of undefined behavior in host integrations.
  • Provides consistent, actionable errors when a vectorization factor is not supported.

Validation

  • CubeCL workspace validates: audit, format, clippy, unit tests, and docs all pass.

Notes

  • Includes a tiny macOS‑only clippy cleanup (WGSL SafeTanh) to satisfy workspace -D warnings.

PR has been validated with Burn — no compilation or test errors.

@ariawisp ariawisp force-pushed the safe-tensor-handle branch 6 times, most recently from 27f4c2e to 8e4b170 Compare September 10, 2025 23:46
Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not convinced this is a good idea; those calculations are 100% wasteful with Burn, since validations are done at the tensor API level (user-facing) and not internally. We assume data consistency internally, which reduces error handling to a minimum. Otherwise, we would have the same error handling at multiple levels of abstractions, which is very wasteful and adds complexity.

Do you have opinions on this @wingertge

Comment on lines 219 to 253
/// Convert the handle into a [tensor argument](TensorArg) with basic safety checks
/// for vectorization compatibility.
/// Try to convert the handle into a tensor argument, validating that the
/// requested vectorization factor is supported by the runtime. This does not
/// enforce inner-most contiguity or alignment requirements as kernels may
/// legally vectorize along axes other than the innermost.
pub fn try_as_tensor_arg(
&'a self,
vectorization: u8,
) -> Result<TensorArg<'a, R>, TensorArgError> {
if !R::supported_line_sizes().contains(&vectorization) {
return Err(TensorArgError::UnsupportedVectorization { requested: vectorization, supported: R::supported_line_sizes() });
}
Ok(self.as_tensor_arg(vectorization))
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem I see here is that the validation is done in 100% of cases in Burn and CubeCL when we choose the line size, not when applying it! It's kind of wasteful to do it multiple times, especially since we iterate over a list each time. I don't mind too much having a try function, but it would not be good practice to use them; validation should be done before creating the tensor argument.

Comment on lines +263 to +317
if shape.len() != strides.len() {
return Err(TensorHandleError::RankMismatch {
shape_rank: shape.len(),
stride_rank: strides.len(),
});
}
if elem_size == 0 {
return Err(TensorHandleError::ElemSizeZero);
}
// Disallow zero strides when corresponding dimension extent > 1 (broadcasted dims with extent 1 are allowed).
for (i, (&s, &d)) in strides.iter().zip(shape.iter()).enumerate() {
if s == 0 && d > 1 {
return Err(TensorHandleError::ZeroStride { axis: i });
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same thing here: most of those things are validated in other places, and it's kind of wasteful to do those validations multiple times.

@wingertge
Copy link
Collaborator

I actually think it could be a good idea, but specifically gated behind cfg!(debug_assertions). These checks are not that expensive (you're only iterating over a handful of elements here), and would be useful to just assert for sanity checks while running unoptimized builds.

@ariawisp ariawisp force-pushed the safe-tensor-handle branch 2 times, most recently from 58e6a34 to 714498a Compare September 11, 2025 15:06
@ariawisp
Copy link
Author

ariawisp commented Sep 11, 2025

Per feedback, I reverted internal hot paths to the fast path (as_tensor_arg) and moved validation to selection time. I added debug_assert! guards in the constructors so debug builds catch mistakes with zero release overhead. The try_* APIs are kept only for external entry points—e.g., language bindings, tools/CLI, or test harnesses—where we construct handles from raw parts and prefer Result-based errors over relying on unsafe preconditions. Docs now call this out explicitly. Workspace build and clippy are green. If you’d like, I can also gate the try_* APIs behind a feature flag, but this seems like the right safety/perf balance.

Copy link
Member

@nathanielsimard nathanielsimard left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That seems like the right balance. The CI should be fixed before merging.

@ariawisp ariawisp force-pushed the safe-tensor-handle branch 2 times, most recently from b52c9ee to 77556c2 Compare September 19, 2025 22:18
…t across crates

- TensorHandleRef::{try_from_parts, try_from_typed}
- TensorHandleRef::try_as_tensor_arg (validates runtime-supported vectorization only)
- Errors: #[non_exhaustive], Display impls; UnsupportedVectorization { requested, supported }
- Adopt try_as_tensor_arg in attention/matmul/convolution/reduce/std
- Runtime tests for handle validation and unsupported vectorization factors

core(tensor): avoid redundant checks in hot paths; use debug_asserts and clarify try_* docs

internal: use direct as_tensor_arg in internal launch paths; reserve try_* for FFI/tests
@ariawisp
Copy link
Author

Branch has been rebased on main (cf40b4e) and confirmed CI workflow passes on my fork.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants