Skip to content

Commit 1d2865a

Browse files
authored
Update explicit layouts to use JAX 0.6.2 layout.Format. (#125)
There was a rename in JAX from `jax.experimental.layout.Layout` to `jax.experimental.layout.Format`, and JAX arrays have an attribute change `jax.Array.layout -> jax.Array.format`.
1 parent 1ed84ae commit 1d2865a

File tree

3 files changed

+4
-4
lines changed

3 files changed

+4
-4
lines changed

keras_rs/src/layers/embedding/distributed_embedding_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -560,10 +560,10 @@ def test_correctness(
560560
# Determine explicit shardings/layouts for jit compilation
561561
# (required for sparsecore computations).
562562
trainable_layouts = keras.tree.map_structure(
563-
lambda x: x.value.layout, layer.trainable_variables
563+
lambda x: x.value.format, layer.trainable_variables
564564
)
565565
non_trainable_layouts = keras.tree.map_structure(
566-
lambda x: x.value.layout, layer.non_trainable_variables
566+
lambda x: x.value.format, layer.non_trainable_variables
567567
)
568568
# Input/output data involved in sparsecore operations are
569569
# sharded across all sparse-core-capable devices.

requirements-jax-cuda.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ torch>=2.1.0
88
# Jax with cuda support.
99
# Keep same version as Keras repo.
1010
--find-links https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
11-
jax[cuda12_pip]==0.6.2
11+
jax[cuda12]==0.6.2
1212

1313
# Support for large embeddings.
1414
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ tensorflow~=2.18.0;sys_platform == 'darwin'
77
torch>=2.1.0
88

99
# Jax.
10-
jax[cpu]
10+
jax[cpu]>=0.6.2
1111
jax-tpu-embedding;sys_platform == 'linux' and platform_machine == 'x86_64'
1212

1313
# pre-commit checks (formatting, linting, etc.)

0 commit comments

Comments
 (0)