Skip to content

Commit 10a95cc

Browse files
Jake VanderPlascopybara-github
authored andcommitted
Replace references to deprecated jax array attributes device_buffer and device_buffers
PiperOrigin-RevId: 588553845
1 parent d72bd65 commit 10a95cc

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

trax/optimizers/trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,11 @@ def _free_accelerators(self, exceptions=(), keep_constants=True):
445445
logging.info('Deleting %d live buffers.', len(live_buffers))
446446
exceptions_buffers = []
447447
for x in fastmath.tree_flatten(exceptions):
448-
if hasattr(x, 'device_buffer'): # DeviceArray
448+
if hasattr(x, 'addressable_shards'): # Array
449+
exceptions_buffers.extend(shard.data for shard in x.addressable_shards)
450+
elif hasattr(x, 'device_buffer'): # DeviceArray
449451
exceptions_buffers.append(x.device_buffer)
450-
if hasattr(x, 'device_buffers'): # ShardedDeviceArray
452+
elif hasattr(x, 'device_buffers'): # ShardedDeviceArray
451453
exceptions_buffers.extend(x.device_buffers)
452454
for b in live_buffers:
453455
should_delete = True

0 commit comments

Comments
 (0)