Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 31 additions & 9 deletions tensorflow_hub/keras_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@
# pylint: disable=g-import-not-at-top
# Use Keras 2.
version_fn = getattr(tf.keras, "version", None)
if version_fn and version_fn().startswith("3."):
import tf_keras as keras
else:
keras = tf.keras
# Always align with tf.keras to avoid mismatched Layer types
keras = tf.keras

# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import smart_cond
Expand Down Expand Up @@ -210,11 +208,34 @@ def _setup_layer(self, trainable=False, **kwargs):
self.add_loss(self._call_loss_if_trainable(l)) # Supports callables.

def _add_existing_weight(self, weight, trainable=None):
"""Calls add_weight() to register but not create an existing weight."""
if trainable is None: trainable = weight.trainable
self.add_weight(name=weight.name, shape=weight.shape, dtype=weight.dtype,
trainable=trainable, experimental_autocast=False,
getter=lambda *_, **__: weight)
"""Registers an existing tf.Variable with this layer."""
if trainable is None:
trainable = getattr(weight, "trainable", False)

# Create custom weight lists if they don't exist
if not hasattr(self, '_hub_trainable_weights'):
self._hub_trainable_weights = []
if not hasattr(self, '_hub_non_trainable_weights'):
self._hub_non_trainable_weights = []

# Add to appropriate list
if trainable:
self._hub_trainable_weights.append(weight)
else:
self._hub_non_trainable_weights.append(weight)
@property
def trainable_weights(self):
"""Override to include hub weights."""
base_weights = super().trainable_weights
hub_weights = getattr(self, '_hub_trainable_weights', [])
return base_weights + hub_weights

@property
def non_trainable_weights(self):
"""Override to include hub weights."""
base_weights = super().non_trainable_weights
hub_weights = getattr(self, '_hub_non_trainable_weights', [])
return base_weights + hub_weights

def _call_loss_if_trainable(self, loss):
"""Returns `loss` conditioned on whether this layer is trainable."""
Expand Down Expand Up @@ -338,6 +359,7 @@ def get_config(self):
if not isinstance(self._handle, str):
# Need to raise this type in order for tf.saved_model.save() to fall back
# to not using config, instead of crashing.
# TODO(b/134528831): Reconsider the usability implications.
raise NotImplementedError(
"Can only generate a valid config for `hub.KerasLayer(handle, ...)`"
"that uses a string `handle`.\n\n"
Expand Down