Replies: 1 comment 7 replies
-
|
There is not currently any way to do this, but starting in JAX v0.8.0 the import jax
import jax.numpy as jnp
x = jnp.zeros(10)
with jax.enable_x64():
y = jnp.arange(10)
print(x.dtype, y.dtype) # float32 int64 |
Beta Was this translation helpful? Give feedback.
7 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
-
Currently, the only way to enable usage of i64 is using jax.enable_x64 which switches on the f64. Especially in modern times where schedules and things like that might depend on the token budget, which exceeds 2^32, it would be useful to have a way of enabling only int64, but not allowing for f64.
Beta Was this translation helpful? Give feedback.
All reactions