diff --git a/examples/ml_perf/configs/__init__.py b/examples/ml_perf/configs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/datasets/__init__.py b/examples/ml_perf/configs/datasets/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/datasets/dummy_dataset.py b/examples/ml_perf/configs/datasets/dummy_dataset.py new file mode 100644 index 00000000..aac66c57 --- /dev/null +++ b/examples/ml_perf/configs/datasets/dummy_dataset.py @@ -0,0 +1,166 @@ +from keras.utils import Config + +# === Dataset === +dataset_config = Config() +dataset_config.file_pattern = None +# Features +dataset_config.label = "clicked" +dataset_config.dense = [f"int-feature-{i}" for i in range(1, 14)] +dataset_config.lookup = [ + { + "name": "categorical-feature-14", + "vocabulary_size": 40000000, + "feature_list_length": 3, + "new_name": "cat_14", + }, + { + "name": "categorical-feature-15", + "vocabulary_size": 39060, + "feature_list_length": 2, + "new_name": "cat_15", + }, + { + "name": "categorical-feature-16", + "vocabulary_size": 17295, + "feature_list_length": 1, + "new_name": "cat_16", + }, + { + "name": "categorical-feature-17", + "vocabulary_size": 7424, + "feature_list_length": 2, + "new_name": "cat_17", + }, + { + "name": "categorical-feature-18", + "vocabulary_size": 20265, + "feature_list_length": 6, + "new_name": "cat_18", + }, + { + "name": "categorical-feature-19", + "vocabulary_size": 3, + "feature_list_length": 1, + "new_name": "cat_19", + }, + { + "name": "categorical-feature-20", + "vocabulary_size": 7122, + "feature_list_length": 1, + "new_name": "cat_20", + }, + { + "name": "categorical-feature-21", + "vocabulary_size": 1543, + "feature_list_length": 1, + "new_name": "cat_21", + }, + { + "name": "categorical-feature-22", + "vocabulary_size": 63, + "feature_list_length": 1, + "new_name": "cat_22", + }, + { + "name": "categorical-feature-23", + "vocabulary_size": 40000000, + "feature_list_length": 7, + "new_name": "cat_23", + }, + { + "name": "categorical-feature-24", + "vocabulary_size": 3067956, + "feature_list_length": 3, + "new_name": "cat_24", + }, + { + "name": "categorical-feature-25", + "vocabulary_size": 405282, + "feature_list_length": 8, + "new_name": "cat_25", + }, + { + "name": "categorical-feature-26", + "vocabulary_size": 10, + "feature_list_length": 1, + "new_name": "cat_26", + }, + { + "name": "categorical-feature-27", + "vocabulary_size": 2209, + "feature_list_length": 6, + "new_name": "cat_27", + }, + { + "name": "categorical-feature-28", + "vocabulary_size": 11938, + "feature_list_length": 9, + "new_name": "cat_28", + }, + { + "name": "categorical-feature-29", + "vocabulary_size": 155, + "feature_list_length": 5, + "new_name": "cat_29", + }, + { + "name": "categorical-feature-30", + "vocabulary_size": 4, + "feature_list_length": 1, + "new_name": "cat_30", + }, + { + "name": "categorical-feature-31", + "vocabulary_size": 976, + "feature_list_length": 1, + "new_name": "cat_31", + }, + { + "name": "categorical-feature-32", + "vocabulary_size": 14, + "feature_list_length": 1, + "new_name": "cat_32", + }, + { + "name": "categorical-feature-33", + "vocabulary_size": 40000000, + "feature_list_length": 12, + "new_name": "cat_33", + }, + { + "name": "categorical-feature-34", + "vocabulary_size": 40000000, + "feature_list_length": 100, + "new_name": "cat_34", + }, + { + "name": "categorical-feature-35", + "vocabulary_size": 40000000, + "feature_list_length": 27, + "new_name": "cat_35", + }, + { + "name": "categorical-feature-36", + "vocabulary_size": 590152, + "feature_list_length": 10, + "new_name": "cat_36", + }, + { + "name": "categorical-feature-37", + "vocabulary_size": 12973, + "feature_list_length": 3, + "new_name": "cat_37", + }, + { + "name": "categorical-feature-38", + "vocabulary_size": 108, + "feature_list_length": 1, + "new_name": "cat_38", + }, + { + "name": "categorical-feature-39", + "vocabulary_size": 36, + "feature_list_length": 1, + "new_name": "cat_39", + }, +] diff --git a/examples/ml_perf/configs/models/__init__.py b/examples/ml_perf/configs/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/models/default_model.py b/examples/ml_perf/configs/models/default_model.py new file mode 100644 index 00000000..1a07e9d9 --- /dev/null +++ b/examples/ml_perf/configs/models/default_model.py @@ -0,0 +1,19 @@ +from keras.utils import Config + +# === Model === +model_config = Config() +# Embedding +model_config.embedding_dim = 128 +model_config.allow_id_dropping = True +model_config.embedding_threshold = 21000 +model_config.max_ids_per_partition = 4096 +model_config.max_unique_ids_per_partition = 2048 +model_config.learning_rate = 0.005 + +# MLP +model_config.bottom_mlp_dims = [512, 256, 128] +model_config.top_mlp_dims = [1024, 1024, 512, 256, 1] + +# DCN +model_config.num_dcn_layers = 3 +model_config.dcn_projection_dim = 512 diff --git a/examples/ml_perf/configs/training/__init__.py b/examples/ml_perf/configs/training/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/examples/ml_perf/configs/training/default_training.py b/examples/ml_perf/configs/training/default_training.py new file mode 100644 index 00000000..b758bc59 --- /dev/null +++ b/examples/ml_perf/configs/training/default_training.py @@ -0,0 +1,7 @@ +from keras.utils import Config + +# === Training Hyperparameters === +training_config = Config() +training_config.learning_rate = 0.005 +training_config.global_batch_size = 128 +training_config.num_epochs = 1 diff --git a/examples/ml_perf/configs/v6e_16.py b/examples/ml_perf/configs/v6e_16.py new file mode 100644 index 00000000..4b6df8df --- /dev/null +++ b/examples/ml_perf/configs/v6e_16.py @@ -0,0 +1,16 @@ +from keras.utils import Config + +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config + +config = Config() + +config.experiment_name = "v6e_16" +config.model_dir = "./v6e_16" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/ml_perf/configs/v6e_8.py b/examples/ml_perf/configs/v6e_8.py new file mode 100644 index 00000000..fcd81e39 --- /dev/null +++ b/examples/ml_perf/configs/v6e_8.py @@ -0,0 +1,16 @@ +from keras.utils import Config + +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config + +config = Config() + +config.experiment_name = "v6e_8" +config.model_dir = "./v6e_8" + +config.dataset = dataset_config +config.model = model_config +config.training = training_config + +config.freeze() diff --git a/examples/ml_perf/configs/v6e_8_full_dataset.py b/examples/ml_perf/configs/v6e_8_full_dataset.py new file mode 100644 index 00000000..8489b084 --- /dev/null +++ b/examples/ml_perf/configs/v6e_8_full_dataset.py @@ -0,0 +1,27 @@ +from keras.utils import Config + +from .datasets.dummy_dataset import dataset_config +from .models.default_model import model_config +from .training.default_training import training_config + +config = Config() + +config.experiment_name = "v6e_8_full_dataset" +config.model_dir = "./v6e_8_full_dataset" + +config.dataset = dataset_config +config.dataset.file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) +config.dataset.val_file_pattern = ( + "gs://qinyiyan-vm/mlperf-dataset/criteo_merge_balanced_4224/" + "train-00000-of-01024tfrecord" +) +# The path which we are reading from already has the batched dataset. +config.dataset.file_batch_size = 4224 +config.model = model_config +config.training = training_config +config.training.batch_size = 256 + +config.freeze() diff --git a/examples/ml_perf/dataloader.py b/examples/ml_perf/dataloader.py new file mode 100644 index 00000000..ce2e7286 --- /dev/null +++ b/examples/ml_perf/dataloader.py @@ -0,0 +1,200 @@ +import numpy as np +import tensorflow as tf + + +class DataLoader: + def __init__( + self, + file_pattern, + batch_size, + file_batch_size, + dense_features, + large_emb_features, + small_emb_features, + label, + training=False, + ): + # Passed attributes. + self.file_pattern = file_pattern + self.batch_size = batch_size + self.file_batch_size = file_batch_size + self.dense_features = dense_features + self.large_emb_features = large_emb_features + self.small_emb_features = small_emb_features + self.label = label + self.training = training + + # Derived attributes. + self._return_dummy_dataset = file_pattern is None + + def _get_dummy_batch(self): + """Returns a dummy batch of data in the final desired structure.""" + + # Labels + data = { + "clicked": np.random.randint( + 0, 2, size=(self.batch_size,), dtype=np.int64 + ) + } + + # Dense features + dense_input_list = [ + np.random.uniform(0.0, 0.9, size=(self.batch_size, 1)).astype( + np.float32 + ) + for _ in range(13) + ] + data["dense_input"] = np.concatenate(dense_input_list, axis=-1) + + # Big embedding features + large_emb_inputs = {} + for large_emb_feature in self.large_emb_features: + name = large_emb_feature["name"] + new_name = large_emb_feature.get("new_name", name) + vocabulary_size = large_emb_feature["vocabulary_size"] + feature_list_length = large_emb_feature["feature_list_length"] + + large_emb_inputs[f"{new_name}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, feature_list_length), + dtype=np.int64, + ) + + data["large_emb_inputs"] = large_emb_inputs + + # Small embedding features + small_emb_inputs = {} + for small_emb_feature in self.small_emb_features: + name = small_emb_feature["name"] + new_name = small_emb_feature.get("new_name", name) + vocabulary_size = small_emb_feature["vocabulary_size"] + feature_list_length = small_emb_feature["feature_list_length"] + + small_emb_inputs[f"{new_name}_id"] = np.random.randint( + low=0, + high=vocabulary_size, + size=(self.batch_size, feature_list_length), + dtype=np.int64, + ) + + if small_emb_inputs: + data["small_emb_inputs"] = small_emb_inputs + + return data + + def _create_dummy_dataset(self): + """Creates a TF dummy dataset (randomly initialised).""" + dummy_data = self._get_dummy_batch() + + # Separate labels from features to create a `(features, labels)` tuple. + labels = dummy_data.pop("clicked") + features = dummy_data + + dataset = tf.data.Dataset.from_tensors((features, labels)).repeat(512) + return dataset + + def _get_feature_spec(self): + feature_spec = { + self.label: tf.io.FixedLenFeature( + [self.file_batch_size], + dtype=tf.int64, + ) + } + + for dense_feat in self.dense_features: + feature_spec[dense_feat] = tf.io.FixedLenFeature( + [self.file_batch_size], + dtype=tf.float32, + ) + + for emb_feat in self.large_emb_features + self.small_emb_features: + name = emb_feat["name"] + feature_spec[name] = tf.io.FixedLenFeature( + [self.file_batch_size], + dtype=tf.string, + ) + + return feature_spec + + def _preprocess(self, example): + # Read example. + feature_spec = self._get_feature_spec() + example = tf.io.parse_single_example(example, feature_spec) + + # Dense features + dense_input = tf.concat( + [ + tf.reshape(example[dense_feature], [self.file_batch_size, 1]) + for dense_feature in self.dense_features + ], + axis=-1, + ) + + def _get_emb_inputs(emb_features): + emb_inputs = {} + for emb_feature in emb_features: + name = emb_feature["name"] + new_name = emb_feature.get("new_name", name) + feature_list_length = emb_feature["feature_list_length"] + + raw_values = tf.io.decode_raw(example[name], tf.int64) + raw_values = tf.reshape( + raw_values, [self.file_batch_size, feature_list_length] + ) + emb_inputs[f"{new_name}_id"] = raw_values + return emb_inputs + + # Embedding/lookup features + large_emb_inputs = _get_emb_inputs(self.large_emb_features) + small_emb_inputs = _get_emb_inputs(self.small_emb_features) + + # Labels + labels = tf.reshape(example[self.label], [self.file_batch_size]) + + x = { + "dense_input": dense_input, + "large_emb_inputs": large_emb_inputs, + } + if small_emb_inputs: + x["small_emb_inputs"] = small_emb_inputs + + return (x, labels) + + def create_dataset(self, process_id=0, num_processes=1, shuffle_buffer=256): + if self._return_dummy_dataset: + return self._create_dummy_dataset() + + dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=False) + + # Shard the dataset across hosts/workers. + # TODO: Do we need to do this if we are distributing the dataset + # manually using distribution.distribute_dataset(...)? + if num_processes > 1: + dataset = dataset.shard(num_processes, process_id) + + dataset = tf.data.TFRecordDataset( + dataset, + buffer_size=None, + num_parallel_reads=tf.data.AUTOTUNE, + ) + + # Process example. + dataset = dataset.map( + self._preprocess, num_parallel_calls=tf.data.AUTOTUNE + ) + dataset = dataset.unbatch() + + # Shuffle dataset if in training mode. + if self.training and shuffle_buffer and shuffle_buffer > 0: + dataset = dataset.shuffle(shuffle_buffer) + + dataset = dataset.batch( + self.batch_size, + drop_remainder=True, + num_parallel_calls=tf.data.AUTOTUNE, + ) + + dataset = dataset.prefetch(tf.data.AUTOTUNE) + + return dataset diff --git a/examples/ml_perf/main.py b/examples/ml_perf/main.py new file mode 100644 index 00000000..3f661a8c --- /dev/null +++ b/examples/ml_perf/main.py @@ -0,0 +1,261 @@ +import argparse +import importlib +import os + +os.environ["KERAS_BACKEND"] = "jax" + +import keras + +import keras_rs + +from .dataloader import DataLoader +from .model import DLRMDCNV2 + +SEED = 1337 + + +def main( + file_pattern, + val_file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + shuffle_buffer, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + file_batch_size, + num_epochs, +): + # Set DDP as Keras distribution strategy + devices = keras.distribution.list_devices(device_type="tpu") + distribution = keras.distribution.DataParallel(devices=devices) + keras.distribution.set_distribution(distribution) + num_processes = distribution._num_process + + per_host_batch_size = global_batch_size // num_processes + + # === Distributed embeddings' configs for lookup features === + # Set XLA flags. + os.environ["XLA_FLAGS"] = ( + "--xla_sparse_core_max_ids_per_partition_per_sample=" + f"{max_ids_per_partition} " + "--xla_sparse_core_max_unique_ids_per_partition_per_sample=" + f"{max_unique_ids_per_partition}" + ) + feature_configs = {} + for large_emb_feature in large_emb_features: + feature_name = large_emb_feature["new_name"] + vocabulary_size = large_emb_feature["vocabulary_size"] + feature_list_length = large_emb_feature["feature_list_length"] + + table_config = keras_rs.layers.TableConfig( + name=f"{feature_name}_table", + vocabulary_size=vocabulary_size, + embedding_dim=embedding_dim, + # TODO(abheesht): Verify. + initializer=keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=SEED, + ), + optimizer=keras.optimizers.Adagrad( + learning_rate=embedding_learning_rate + ), + combiner="sum", + placement="sparsecore", + # TODO: These two args are not getting passed down to + # `jax-tpu-embedding` properly, seems like. + max_ids_per_partition=max_ids_per_partition, + max_unique_ids_per_partition=max_unique_ids_per_partition, + ) + feature_configs[f"{feature_name}_id"] = keras_rs.layers.FeatureConfig( + name=feature_name, + table=table_config, + # TODO: Verify whether it should be `(bsz, 1)` or + # `(bsz, feature_list_length)`. + input_shape=(per_host_batch_size, feature_list_length), + output_shape=(per_host_batch_size, embedding_dim), + ) + + # === Instantiate model === + # We instantiate the model first, because we need to preprocess large + # embedding feature inputs using the distributed embedding layer defined + # inside the model class. + print("===== Initialising model =====") + model = DLRMDCNV2( + large_emb_feature_configs=feature_configs, + small_emb_features=small_emb_features, + embedding_dim=embedding_dim, + bottom_mlp_dims=bottom_mlp_dims, + top_mlp_dims=top_mlp_dims, + num_dcn_layers=num_dcn_layers, + dcn_projection_dim=dcn_projection_dim, + seed=SEED, + dtype="float32", + name="dlrm_dcn_v2", + ) + model.compile( + loss=keras.losses.BinaryCrossentropy(), + optimizer=keras.optimizers.Adagrad(learning_rate=learning_rate), + metrics=[keras.metrics.BinaryAccuracy()], + ) + + # === Load dataset === + print("===== Loading dataset =====") + train_ds = DataLoader( + file_pattern=file_pattern, + batch_size=global_batch_size, + file_batch_size=file_batch_size, + dense_features=dense_features, + large_emb_features=large_emb_features, + small_emb_features=small_emb_features, + label=label, + training=True, + ).create_dataset( + process_id=distribution._process_id, + num_processes=num_processes, + shuffle_buffer=shuffle_buffer, + ) + # For the multi-host case, the dataset has to be distributed manually. + # See note here: + # https://github.com/keras-team/keras-rs/blob/main/keras_rs/src/layers/embedding/base_distributed_embedding.py#L352-L363. + if num_processes > 1: + train_ds = distribution.distribute_dataset(train_ds) + # eval_ds = distribution.distribute_dataset(eval_ds) + distribution.auto_shard_dataset = False + + # Print one sample. + for element in train_ds.take(1): + print(">>> train sample", element[0]) + + def generator(dataset, training=False): + """Converts tf.data Dataset to a Python generator and preprocesses + large embedding features. + """ + for features, labels in dataset: + preprocessed_large_embeddings = model.embedding_layer.preprocess( + features["large_emb_inputs"], training=training + ) + + x = { + "dense_input": features["dense_input"], + "large_emb_inputs": preprocessed_large_embeddings, + "small_emb_inputs": features["small_emb_inputs"], + } + y = labels + yield (x, y) + + train_generator = generator(train_ds, training=True) + # eval_generator = generator(eval_ds, training=False) + for first_batch in train_generator: + model(first_batch[0]) + break + + # Train the model. + model.fit(train_generator, epochs=num_epochs) + + +if __name__ == "__main__": + keras.config.disable_traceback_filtering() + + print("====== Launching train script =======") + parser = argparse.ArgumentParser( + description=( + "Benchmark the DLRM-DCNv2 model on the Criteo dataset (MLPerf)" + ) + ) + parser.add_argument( + "--config_name", type=str, help="Name of the `.py` config file." + ) + args = parser.parse_args() + + print(f"===== Reading config from {args.config_name} ======") + config = importlib.import_module( + f".configs.{args.config_name}", package=__package__ + ).config + + # === Unpack args from config === + + # == Dataset config == + ds_cfg = config["dataset"] + # File path + file_pattern = ds_cfg["file_pattern"] + val_file_pattern = ds_cfg.get("val_file_pattern", None) + # File batch size + file_batch_size = ds_cfg.get("file_batch_size", None) + # Shuffling + shuffle_buffer = ds_cfg.get("shuffle_buffer", None) + # Features + label = ds_cfg["label"] + dense_features = ds_cfg["dense"] + emb_features = ds_cfg["lookup"] + + # == Model config == + model_cfg = config["model"] + # Embedding + embedding_dim = model_cfg["embedding_dim"] + allow_id_dropping = model_cfg["allow_id_dropping"] + embedding_threshold = model_cfg["embedding_threshold"] + max_ids_per_partition = model_cfg["max_ids_per_partition"] + max_unique_ids_per_partition = model_cfg["max_unique_ids_per_partition"] + embedding_learning_rate = model_cfg["learning_rate"] + # MLP + bottom_mlp_dims = model_cfg["bottom_mlp_dims"] + top_mlp_dims = model_cfg["top_mlp_dims"] + # DCN + num_dcn_layers = model_cfg["num_dcn_layers"] + dcn_projection_dim = model_cfg["dcn_projection_dim"] + + # == Training config == + training_cfg = config["training"] + learning_rate = training_cfg["learning_rate"] + global_batch_size = training_cfg["global_batch_size"] + num_epochs = training_cfg["num_epochs"] + + # For features which have vocabulary_size < embedding_threshold, we can + # just do a normal dense lookup for those instead of having distributed + # embeddings. We could ideally pass `placement = default_device` to + # `keras_rs.layers.TableConfig` directly (and wouldn't have to do this + # separation of features), but doing it that way will necessarily require + # a separate optimiser for the embedding layer. + small_emb_features = [] + large_emb_features = [] + for emb_feature in emb_features: + if emb_feature["vocabulary_size"] < embedding_threshold: + small_emb_features.append(emb_feature) + else: + large_emb_features.append(emb_feature) + + main( + file_pattern, + val_file_pattern, + dense_features, + large_emb_features, + small_emb_features, + label, + shuffle_buffer, + embedding_dim, + allow_id_dropping, + max_ids_per_partition, + max_unique_ids_per_partition, + embedding_learning_rate, + bottom_mlp_dims, + top_mlp_dims, + num_dcn_layers, + dcn_projection_dim, + learning_rate, + global_batch_size, + file_batch_size, + num_epochs, + ) diff --git a/examples/ml_perf/model.py b/examples/ml_perf/model.py new file mode 100644 index 00000000..4f84bbbf --- /dev/null +++ b/examples/ml_perf/model.py @@ -0,0 +1,323 @@ +from typing import Any, TypeAlias + +import keras +from keras import ops + +import keras_rs + +Tensor: TypeAlias = Any + + +def _clone_initializer( + initializer: keras.initializers.Initializer, + seed: int | keras.random.SeedGenerator, +): + """Clones the provided initializer with a new seed. + + This function creates a new instance of a Keras initializer from an + existing one, but with a different seed. This is useful for ensuring + different weights in a model are initialized with different seeds. + + Args: + initializer: a keras.initializers.Initializer instance. The initializer + to be cloned. + seed: int, or a keras.random.SeedGenerator() instance. The random seed. + + Returns: + A new `keras.initializers.Initializer` instance configured with the + provided seed. + """ + config = initializer.get_config() + config.pop("seed") + config = {**config, "seed": seed} + initializer_class: type[keras.initializers.Initializer] = ( + initializer.__class__ + ) + return initializer_class.from_config(config) + + +class DLRMDCNV2(keras.Model): + def __init__( + self, + large_emb_feature_configs: dict[str, keras_rs.layers.FeatureConfig], + small_emb_features: list, + embedding_dim: int, + bottom_mlp_dims: list[int], + top_mlp_dims: list[int], + num_dcn_layers: int, + dcn_projection_dim: int, + seed: int | keras.random.SeedGenerator | None = None, + dtype: str | None = None, + name: str | None = None, + **kwargs: Any, + ): + """DLRM-DCNv2 model. + + The model processes two types of input features: + 1. Dense Features: Continuous-valued features that are processed by + a multi-layer perceptron (the "bottom MLP"). + 2. Lookup Features: High-cardinality categorical features that are + first mapped into low-dimensional embedding vectors using the + `keras_rs.layers.DistributedEmbedding` layer. This layer is highly + optimized for large-scale recommendation models, especially on TPUs + with SparseCore, as it can shard large embedding tables across + multiple accelerator chips for improved performance. On other + hardware (GPUs, CPUs), it functions like a standard embedding layer. + + The output of the bottom MLP and the embedding vectors are then + concatenated and fed into a DCN block for learning feature interactions. + The output of the DCN block is then processed by another MLP + (the "top MLP") to produce a final prediction. + + Args: + large_emb_feature_configs: A dictionary with features names as keys + and `keras_rs.layers.FeatureConfig` objects as values. These + configs link features to their corresponding embedding tables + (`keras_rs.layers.TableConfig`), specifying parameters like + vocabulary size, embedding dimension, and hardware placement + strategy. + bottom_mlp_dims: A list of integers specifying the number of units + in each layer of the bottom MLP. + top_mlp_dims: A list of integers specifying the number of units in + each layer of the top MLP. The last value is the final output + dimension (e.g., 1 for binary classification). + num_dcn_layers: The number of feature-crossing layers in the DCNv2 + block. + dcn_projection_dim: The projection dimension used within each DCNv2 + cross-layer. + seed: The random seed. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + self.seed = seed + + # === Layers ==== + + # Bottom MLP for encoding dense features + self.bottom_mlp = keras.Sequential( + self._get_mlp_layers( + dims=bottom_mlp_dims, + intermediate_activation="relu", + final_activation="relu", + ), + name="bottom_mlp", + ) + # Distributed embeddings for large embedding tables + self.embedding_layer = keras_rs.layers.DistributedEmbedding( + feature_configs=large_emb_feature_configs, + table_stacking="auto", + dtype=dtype, + name="embedding_layer", + ) + # Embedding layers for small embedding tables + self.small_embedding_layers = None + if small_emb_features: + self.small_embedding_layers = [ + keras.layers.Embedding( + input_dim=small_emb_feature["vocabulary_size"], + output_dim=embedding_dim, + embeddings_initializer=keras.initializers.LecunNormal( + seed=self.seed, + ), + name=f"small_embedding_layer_{i}", + ) + for i, small_emb_feature in enumerate(small_emb_features) + ] + # DCN for "interactions" + self.dcn_block = DCNBlock( + num_layers=num_dcn_layers, + projection_dim=dcn_projection_dim, + seed=seed, + dtype=dtype, + name="dcn_block", + ) + # Top MLP for predictions + self.top_mlp = keras.Sequential( + self._get_mlp_layers( + dims=top_mlp_dims, + intermediate_activation="relu", + final_activation="sigmoid", + ), + name="top_mlp", + ) + + # === Passed attributes === + self.large_emb_feature_configs = large_emb_feature_configs + self.small_emb_features = small_emb_features + self.embedding_dim = embedding_dim + self.bottom_mlp_dims = bottom_mlp_dims + self.top_mlp_dims = top_mlp_dims + self.num_dcn_layers = num_dcn_layers + self.dcn_projection_dim = dcn_projection_dim + + def call(self, inputs: dict[str, Tensor]) -> Tensor: + """Forward pass of the model. + + Args: + inputs: A dictionary containing `"dense_features"` and + `"preprocessed_large_emb_features"` as keys. + """ + # Inputs + dense_input = inputs["dense_input"] + large_emb_inputs = inputs["large_emb_inputs"] + + # Embed features. + dense_output = self.bottom_mlp(dense_input) + # jax.debug.print("dense_ouput {}", dense_output.shape) + large_embeddings = self.embedding_layer(large_emb_inputs) + small_embeddings = None + if self.small_emb_features: + small_embeddings = [] + small_emb_inputs = inputs["small_emb_inputs"] + for small_emb_input, embedding_layer in zip( + small_emb_inputs.values(), self.small_embedding_layers + ): + embedding = embedding_layer(small_emb_input) + embedding = ops.sum(embedding, axis=-2) + small_embeddings.append(embedding) + + small_embeddings = ops.concatenate(small_embeddings, axis=-1) + + # Interaction + to_concatenate = [dense_output, *large_embeddings.values()] + if small_embeddings is not None: + to_concatenate += [small_embeddings] + x = ops.concatenate(to_concatenate, axis=-1) + x = self.dcn_block(x) + + # Predictions + outputs = self.top_mlp(x) + return outputs + + def _get_mlp_layers( + self, + dims: list[int], + intermediate_activation: str | keras.layers.Activation, + final_activation: str | keras.layers.Activation, + ) -> list[keras.layers.Layer]: + """Creates a list of Dense layers. + + Args: + dims: list. Output dimensions of the dense layers to be created. + intermediate_activation: string or `keras.layers.Activation`. The + activation to be used in all layers, save the last. + final_activation: str or `keras.layers.Activation`. The activation + to be used in the last layer. + + Returns: + A list of `keras.layers.Dense` layers. + """ + initializer = keras.initializers.VarianceScaling( + scale=1.0, + mode="fan_in", + distribution="uniform", + seed=self.seed, + ) + + layers = [ + keras.layers.Dense( + units=dim, + activation=intermediate_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + for dim in dims[:-1] + ] + layers += [ + keras.layers.Dense( + units=dims[-1], + activation=final_activation, + kernel_initializer=_clone_initializer( + initializer, seed=self.seed + ), + bias_initializer=_clone_initializer( + initializer, seed=self.seed + ), + dtype=self.dtype, + ) + ] + return layers + + def get_config(self): + """Returns the config of the model.""" + config = super().get_config() + config.update( + { + "large_emb_feature_configs": self.large_emb_feature_configs, + "small_emb_features": self.small_emb_features, + "embedding_dim": self.embedding_dim, + "bottom_mlp_dims": self.bottom_mlp_dims, + "top_mlp_dims": self.top_mlp_dims, + "num_dcn_layers": self.num_dcn_layers, + "dcn_projection_dim": self.dcn_projection_dim, + "seed": self.seed, + } + ) + return config + + +class DCNBlock(keras.layers.Layer): + def __init__( + self, + num_layers: int, + projection_dim: int, + seed: int | keras.random.SeedGenerator, + dtype: str | None = None, + name: str | None = None, + **kwargs, + ): + """ + A block of Deep & Cross Network V2 (DCNv2) layers. + + This layer implements the "cross network" part of the DCNv2 architecture + by stacking multiple `keras_rs.layers.FeatureCross` layers, which learn + feature interactions. + + Args: + num_layers: The number of `FeatureCross` layers to stack. + projection_dim: The dimensionality of the low-rank projection used + within each cross layer. + seed: The random seed for initializers. + dtype: Optional dtype. + name: The name of the layer. + """ + super().__init__(dtype=dtype, name=name, **kwargs) + + # Layers + self.layers = [ + keras_rs.layers.FeatureCross( + projection_dim=projection_dim, + kernel_initializer=keras.initializers.GlorotUniform(seed=seed), + bias_initializer="zeros", + dtype=dtype, + ) + for _ in range(num_layers) + ] + + # Passed attributes + self.num_layers = num_layers + self.projection_dim = projection_dim + self.seed = seed + + def call(self, x0): + xl = x0 + for layer in self.layers: + xl = layer(x0, xl) + return xl + + def get_config(self): + config = super().get_config() + config.update( + { + "num_layers": self.num_layers, + "projection_dim": self.projection_dim, + "seed": self.seed, + } + ) + return config diff --git a/examples/ml_perf/run.sh b/examples/ml_perf/run.sh new file mode 100644 index 00000000..7a774221 --- /dev/null +++ b/examples/ml_perf/run.sh @@ -0,0 +1,171 @@ +#!/bin/bash +set -euo pipefail + +# ============================================================================== +# Script Configuration & Argument Handling +# ============================================================================== +# This script accepts up to four optional arguments: +# 1. --accelerator-type (default: v6e-8, options: v6e-8, v6e-16) +# 2. --zone (default: us-east5-a) +# 3. --project (default: tpu-prod-env-one-vm) +# 4. --config-name (default: derived from accelerator type, e.g., v6e_8) + +# Defaults +ACCELERATOR_TYPE="v6e-8" +ZONE="us-east5-a" +PROJECT="tpu-prod-env-one-vm" +USER_CONFIG_NAME="" + +# ============================================================================== +# Argument Parsing +# ============================================================================== + +show_help() { +cat << EOF +Usage: $0 [--accelerator-type ] [--zone ] [--project ] [--config-name ] +Options: + --accelerator-type The type of TPU accelerator (default: v6e-8). Options: v6e-8, v6e-16. + --zone The GCP zone for the TPU VM (default: us-east5-a). + --project The GCP project ID (default: tpu-prod-env-one-vm). + --config-name The specific configuration name to use for the training script. + (default: derived from accelerator type, e.g., v6e_8). + -h, --help Show this help message. +EOF +} + + +while [[ "$#" -gt 0 ]]; do + case $1 in + --accelerator-type) ACCELERATOR_TYPE="$2"; shift ;; + --zone) ZONE="$2"; shift ;; + --project) PROJECT="$2"; shift ;; + --config-name) USER_CONFIG_NAME="$2"; shift ;; + *) echo "Unknown parameter passed: $1"; show_help; exit 1 ;; + esac + shift +done + +# Validate the provided accelerator type +if [[ "${ACCELERATOR_TYPE}" != "v6e-8" && "${ACCELERATOR_TYPE}" != "v6e-16" ]]; then + echo "Error: Invalid accelerator type '${ACCELERATOR_TYPE}'." >&2 + show_help + exit 1 +fi + +# ============================================================================== +# Environment Variables +# ============================================================================== +export TPU_NAME="abheesht-mlperf-${ACCELERATOR_TYPE}" +export ZONE +export PROJECT + +# Use the user-provided config name if it exists, otherwise derive it. +if [[ -n "${USER_CONFIG_NAME}" ]]; then + export CONFIG_NAME=${USER_CONFIG_NAME} +else + export CONFIG_NAME=${ACCELERATOR_TYPE//-/_} +fi + +echo ">>> Using Configuration:" +echo " Accelerator: ${ACCELERATOR_TYPE}" +echo " TPU Name: ${TPU_NAME}" +echo " Zone: ${ZONE}" +echo " Project: ${PROJECT}" +echo " Config Name: ${CONFIG_NAME}" +echo "--------------------------------------------------" + + +# ============================================================================== +# TPU VM Creation +# ============================================================================== +echo ">>> Checking for existing TPU VM: ${TPU_NAME}..." +if gcloud alpha compute tpus tpu-vm describe ${TPU_NAME} --zone=${ZONE} --project=${PROJECT} &> /dev/null; then + echo ">>> TPU VM '${TPU_NAME}' already exists. Skipping creation." +else + echo ">>> Creating TPU VM: ${TPU_NAME} with accelerator ${ACCELERATOR_TYPE}..." + gcloud alpha compute tpus tpu-vm create ${TPU_NAME} \ + --zone=${ZONE} \ + --accelerator-type=${ACCELERATOR_TYPE} \ + --version=v2-alpha-tpuv6e \ + --project=${PROJECT} \ + --metadata=enable-oslogin=TRUE \ + --scopes=https://www.googleapis.com/auth/cloud-platform +fi + + +# ============================================================================== +# Setup Python venv on all workers +# ============================================================================== +echo ">>> Checking for Python virtual environment..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="sudo apt-get update && sudo apt install -y python3.12-venv && if [ ! -d '.keras-env' ]; then echo '>>> Creating .keras-env...'; python3.12 -m venv .keras-env; else echo '>>> .keras-env already exists.'; fi" + + +# ============================================================================== +# Clone/Update KerasRS and Install Dependencies +# ============================================================================== +echo ">>> Cloning or updating KerasRS and installing dependencies..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command=" + set -e # Ensure script exits on error + source .keras-env/bin/activate + + if [ ! -d 'keras-rs' ]; then + echo '>>> Cloning keras-rs repository...' + git clone https://github.com/abheesht17/keras-rs.git + cd keras-rs + git checkout ml-perf + else + echo '>>> keras-rs repository exists. Pulling latest changes...' + cd keras-rs + git checkout ml-perf # Ensure we are on the correct branch + git pull + fi + + echo '>>> Installing/updating dependencies...' + pip install -e . + pip uninstall -y tensorflow keras + pip install git+https://github.com/keras-team/keras.git + pip install jax-tpu-embedding tensorflow-cpu + " + + +# ============================================================================== +# Install TPU-compatible JAX +# ============================================================================== +echo ">>> Re-installing JAX for TPU compatibility..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && pip uninstall -y jax jaxlib && pip install -U 'jax[tpu]' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html" + + +# ============================================================================== +# Verify Installation +# ============================================================================== +echo ">>> Verifying JAX installation..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && echo 'import jax; print(jax.devices())' > script.py && python3.12 script.py" + + +# ============================================================================== +# Run Training Script +# ============================================================================== +echo ">>> Running the main script with config for ${ACCELERATOR_TYPE}..." +gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} \ + --project ${PROJECT} \ + --zone ${ZONE} \ + --worker=all \ + --command="source .keras-env/bin/activate && cd keras-rs && python3.12 -m examples.ml_perf.main --config_name ${CONFIG_NAME}" + +echo ">>> Script finished." diff --git a/keras_rs/src/layers/embedding/jax/distributed_embedding.py b/keras_rs/src/layers/embedding/jax/distributed_embedding.py index 2562a8be..72f504af 100644 --- a/keras_rs/src/layers/embedding/jax/distributed_embedding.py +++ b/keras_rs/src/layers/embedding/jax/distributed_embedding.py @@ -15,7 +15,6 @@ table_stacking as jte_table_stacking, ) from jax_tpu_embedding.sparsecore.utils import utils as jte_utils -from keras.src import backend from keras_rs.src import types from keras_rs.src.layers.embedding import base_distributed_embedding @@ -247,23 +246,6 @@ def _create_sparsecore_distribution( ) return sparsecore_distribution, sparsecore_layout - def _create_cpu_distribution( - self, cpu_axis_name: str = "cpu" - ) -> tuple[ - keras.distribution.ModelParallel, keras.distribution.TensorLayout - ]: - """Share a variable across all CPU processes.""" - cpu_devices = jax.devices("cpu") - device_mesh = keras.distribution.DeviceMesh( - (len(cpu_devices),), [cpu_axis_name], cpu_devices - ) - replicated_layout = keras.distribution.TensorLayout([], device_mesh) - layout_map = keras.distribution.LayoutMap(device_mesh=device_mesh) - cpu_distribution = keras.distribution.ModelParallel( - layout_map=layout_map - ) - return cpu_distribution, replicated_layout - def _add_sparsecore_weight( self, name: str, @@ -405,11 +387,6 @@ def sparsecore_build( self._sparsecore_layout = sparsecore_layout self._sparsecore_distribution = sparsecore_distribution - # Distribution for CPU operations. - cpu_distribution, cpu_layout = self._create_cpu_distribution() - self._cpu_distribution = cpu_distribution - self._cpu_layout = cpu_layout - mesh = sparsecore_distribution.device_mesh.backend_mesh global_device_count = mesh.devices.size num_sc_per_device = jte_utils.num_sparsecores_per_device( @@ -466,10 +443,6 @@ def sparsecore_build( # Collect all stacked tables. table_specs = embedding_utils.get_table_specs(feature_specs) table_stacks = embedding_utils.get_table_stacks(table_specs) - stacked_table_specs = { - stack_name: stack[0].stacked_table_spec - for stack_name, stack in table_stacks.items() - } # Create variables for all stacked tables and slot variables. with sparsecore_distribution.scope(): @@ -502,50 +475,6 @@ def sparsecore_build( ) self._iterations.overwrite_with_gradient = True - with cpu_distribution.scope(): - # Create variables to track static buffer size and max IDs for each - # table during preprocessing. These variables are shared across all - # processes on CPU. We don't add these via `add_weight` because we - # can't have them passed to the training function. - replicated_zeros_initializer = ShardedInitializer( - "zeros", cpu_layout - ) - - with backend.name_scope(self.name, caller=self): - self._preprocessing_buffer_size = { - table_name: backend.Variable( - initializer=replicated_zeros_initializer, - shape=(), - dtype=backend.standardize_dtype("int32"), - trainable=False, - name=table_name + ":preprocessing:buffer_size", - ) - for table_name in stacked_table_specs.keys() - } - self._preprocessing_max_unique_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_unique_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - - self._preprocessing_max_ids_per_partition = { - table_name: backend.Variable( - shape=(), - name=table_name - + ":preprocessing:max_ids_per_partition", - initializer=replicated_zeros_initializer, - dtype=backend.standardize_dtype("int32"), - trainable=False, - ) - for table_name in stacked_table_specs.keys() - } - self._config = jte_embedding_lookup.EmbeddingLookupConfiguration( feature_specs, mesh=mesh, @@ -660,76 +589,35 @@ def _sparsecore_preprocess( mesh.devices.item(0) ) - # Get current buffer size/max_ids. - previous_max_ids_per_partition = keras.tree.map_structure( - lambda max_ids_per_partition: max_ids_per_partition.value.item(), - self._preprocessing_max_ids_per_partition, - ) - previous_max_unique_ids_per_partition = keras.tree.map_structure( - lambda max_unique_ids_per_partition: ( - max_unique_ids_per_partition.value.item() - ), - self._preprocessing_max_unique_ids_per_partition, - ) - previous_buffer_size = keras.tree.map_structure( - lambda buffer_size: buffer_size.value.item(), - self._preprocessing_buffer_size, - ) - preprocessed, stats = embedding_utils.stack_and_shard_samples( self._config.feature_specs, samples, local_device_count, global_device_count, num_sc_per_device, - static_buffer_size=previous_buffer_size, ) - # Extract max unique IDs and buffer sizes. - # We need to replicate this value across all local CPU devices. if training: + # Synchronize input statistics across all devices and update the + # underlying stacked tables specs in the feature specs. + prev_stats = embedding_utils.get_stacked_table_stats( + self._config.feature_specs + ) + + # Take the maximum with existing stats. + stats = keras.tree.map_structure(max, prev_stats, stats) + + # Flatten the stats so we can more efficiently transfer them + # between hosts. We use jax.tree because we will later need to + # unflatten. + flat_stats, stats_treedef = jax.tree.flatten(stats) + + # In the case of multiple local CPU devices per host, we need to + # replicate the stats to placate JAX collectives. num_local_cpu_devices = jax.local_device_count("cpu") - local_max_ids_per_partition = { - table_name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_ids_per_partition[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } - local_max_unique_ids_per_partition = { - name: np.repeat( - # Maximum across all partitions and previous max. - np.maximum( - np.max(elems), - previous_max_unique_ids_per_partition[name], - ), - num_local_cpu_devices, - ) - for name, elems in stats.max_unique_ids_per_partition.items() - } - local_buffer_size = { - table_name: np.repeat( - np.maximum( - np.max( - # Round values up to the next multiple of 8. - # Currently using this as a proxy for the actual - # required buffer size. - ((elems + 7) // 8) * 8 - ) - * global_device_count - * num_sc_per_device - * local_device_count - * num_sc_per_device, - previous_buffer_size[table_name], - ), - num_local_cpu_devices, - ) - for table_name, elems in stats.max_ids_per_partition.items() - } + tiled_stats = np.tile( + np.array(flat_stats, dtype=np.int32), (num_local_cpu_devices, 1) + ) # Aggregate variables across all processes/devices. max_across_cpus = jax.pmap( @@ -737,48 +625,24 @@ def _sparsecore_preprocess( x, "all_cpus" ), axis_name="all_cpus", - devices=self._cpu_layout.device_mesh.backend_mesh.devices, - ) - new_max_ids_per_partition = max_across_cpus( - local_max_ids_per_partition - ) - new_max_unique_ids_per_partition = max_across_cpus( - local_max_unique_ids_per_partition + backend="cpu", ) - new_buffer_size = max_across_cpus(local_buffer_size) - - # Assign new preprocessing parameters. - with self._cpu_distribution.scope(): - # For each process, all max ids/buffer sizes are replicated - # across all local devices. Take the value from the first - # device. - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_ids_per_partition, - new_max_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_max_unique_ids_per_partition, - new_max_unique_ids_per_partition, - ) - keras.tree.map_structure( - lambda var, values: var.assign(values[0]), - self._preprocessing_buffer_size, - new_buffer_size, - ) - # Update parameters in the underlying feature specs. - int_max_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), new_max_ids_per_partition - ) - int_max_unique_ids_per_partition = keras.tree.map_structure( - lambda varray: varray.item(), - new_max_unique_ids_per_partition, + flat_stats = max_across_cpus(tiled_stats)[0].tolist() + stats = jax.tree.unflatten(stats_treedef, flat_stats) + + # Update configuration and repeat preprocessing if stats changed. + if stats != prev_stats: + embedding_utils.update_stacked_table_stats( + self._config.feature_specs, stats ) - embedding_utils.update_stacked_table_specs( + + # Re-execute preprocessing with consistent input statistics. + preprocessed, _ = embedding_utils.stack_and_shard_samples( self._config.feature_specs, - int_max_ids_per_partition, - int_max_unique_ids_per_partition, + samples, + local_device_count, + global_device_count, + num_sc_per_device, ) return {"inputs": preprocessed} diff --git a/keras_rs/src/layers/embedding/jax/embedding_utils.py b/keras_rs/src/layers/embedding/jax/embedding_utils.py index 393c197c..38e69f7d 100644 --- a/keras_rs/src/layers/embedding/jax/embedding_utils.py +++ b/keras_rs/src/layers/embedding/jax/embedding_utils.py @@ -35,6 +35,12 @@ class ShardedCooMatrix(NamedTuple): values: ArrayLike +class InputStatsPerTable(NamedTuple): + max_ids_per_partition: int + max_unique_ids_per_partition: int + required_buffer_size_per_device: int + + def _round_up_to_multiple(value: int, multiple: int) -> int: return ((value + multiple - 1) // multiple) * multiple @@ -335,19 +341,47 @@ def get_table_stacks( return stacked_table_specs -def update_stacked_table_specs( +def get_stacked_table_stats( feature_specs: Nested[FeatureSpec], - max_ids_per_partition: Mapping[str, int], - max_unique_ids_per_partition: Mapping[str, int], +) -> dict[str, InputStatsPerTable]: + """Extracts the stacked-table input statistics from the feature specs. + + Args: + feature_specs: Feature specs from which to extracts the statistics. + + Returns: + A mapping of stacked table names to input statistics per table. + """ + stacked_table_specs: dict[str, StackedTableSpec] = {} + for feature_spec in jax.tree.flatten(feature_specs)[0]: + feature_spec = typing.cast(FeatureSpec, feature_spec) + stacked_table_spec = typing.cast( + StackedTableSpec, feature_spec.table_spec.stacked_table_spec + ) + stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec + + stats: dict[str, InputStatsPerTable] = {} + for stacked_table_spec in stacked_table_specs.values(): + buffer_size = stacked_table_spec.suggested_coo_buffer_size_per_device + buffer_size = buffer_size or 0 + stats[stacked_table_spec.stack_name] = InputStatsPerTable( + max_ids_per_partition=stacked_table_spec.max_ids_per_partition, + max_unique_ids_per_partition=stacked_table_spec.max_unique_ids_per_partition, + required_buffer_size_per_device=buffer_size, + ) + + return stats + + +def update_stacked_table_stats( + feature_specs: Nested[FeatureSpec], + stats: Mapping[str, InputStatsPerTable], ) -> None: - """Updates properties in the supplied feature specs. + """Updates stacked-table input properties in the supplied feature specs. Args: feature_specs: Feature specs to update in-place. - max_ids_per_partition: Mapping of table stack name to - new `max_ids_per_partition` for the stack. - max_unique_ids_per_partition: Mapping of table stack name to - new `max_unique_ids_per_partition` for the stack. + stats: Per-stacked-table input statistics. """ # Collect table specs and stacked table specs. table_specs: dict[str, TableSpec] = {} @@ -363,18 +397,17 @@ def update_stacked_table_specs( stacked_table_specs[stacked_table_spec.stack_name] = stacked_table_spec # Replace fields in the stacked_table_specs. - stacked_table_specs = { - stack_name: dataclasses.replace( + stack_names = stacked_table_specs.keys() + for stack_name in stack_names: + stack_stats = stats[stack_name] + stacked_table_spec = stacked_table_specs[stack_name] + buffer_size = stack_stats.required_buffer_size_per_device or None + stacked_table_specs[stack_name] = dataclasses.replace( stacked_table_spec, - max_ids_per_partition=max_ids_per_partition[ - stacked_table_spec.stack_name - ], - max_unique_ids_per_partition=max_unique_ids_per_partition[ - stacked_table_spec.stack_name - ], + max_ids_per_partition=stack_stats.max_ids_per_partition, + max_unique_ids_per_partition=stack_stats.max_unique_ids_per_partition, + suggested_coo_buffer_size_per_device=buffer_size, ) - for stack_name, stacked_table_spec in stacked_table_specs.items() - } # Insert new stacked tables into tables. for table_spec in table_specs.values(): @@ -534,7 +567,7 @@ def stack_and_shard_samples( global_device_count: int, num_sc_per_device: int, static_buffer_size: int | Mapping[str, int] | None = None, -) -> tuple[dict[str, ShardedCooMatrix], embedding.SparseDenseMatmulInputStats]: +) -> tuple[dict[str, ShardedCooMatrix], dict[str, InputStatsPerTable]]: """Prepares input samples for use in embedding lookups. Args: @@ -544,8 +577,8 @@ def stack_and_shard_samples( global_device_count: Number of global JAX devices. num_sc_per_device: Number of sparsecores per device. static_buffer_size: The static buffer size to use for the samples. - Defaults to None, in which case an upper-bound for the buffer size - will be automatically determined. + Defaults to None, in which case an upper-bound for the buffer size + will be automatically determined. Returns: The preprocessed inputs, and statistics useful for updating FeatureSpecs @@ -579,6 +612,7 @@ def collect_tokens_and_weights( ) out: dict[str, ShardedCooMatrix] = {} + out_stats: dict[str, InputStatsPerTable] = {} tables_names = preprocessed_inputs.lhs_row_pointers.keys() for table_name in tables_names: shard_ends = preprocessed_inputs.lhs_row_pointers[table_name] @@ -592,5 +626,17 @@ def collect_tokens_and_weights( row_ids=preprocessed_inputs.lhs_sample_ids[table_name], values=preprocessed_inputs.lhs_gains[table_name], ) + out_stats[table_name] = InputStatsPerTable( + max_ids_per_partition=np.max( + stats.max_ids_per_partition[table_name] + ), + max_unique_ids_per_partition=np.max( + stats.max_unique_ids_per_partition[table_name] + ), + required_buffer_size_per_device=np.max( + stats.required_buffer_size_per_sc[table_name] + ) + * num_sc_per_device, + ) - return out, stats + return out, out_stats