diff --git a/trax/optimizers/trainer.py b/trax/optimizers/trainer.py index d4761fc0a..4b1adbb54 100644 --- a/trax/optimizers/trainer.py +++ b/trax/optimizers/trainer.py @@ -440,8 +440,7 @@ def _collect_weights(self, layer): def _free_accelerators(self, exceptions=(), keep_constants=True): """Deletes all live buffers from accelerator with no safety guarantees.""" - backend = jax.lib.xla_bridge.get_backend() - live_buffers = backend.live_buffers() + live_buffers = jax.live_arrays() logging.info('Deleting %d live buffers.', len(live_buffers)) exceptions_buffers = [] for x in fastmath.tree_flatten(exceptions):