Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 51 additions & 13 deletions recml/core/data/tf_dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import re
from typing import Any, Protocol

from absl import flags
from absl import logging
import jax
from recml.core.utils import types
Expand Down Expand Up @@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
Defaults to False.
seed: An optional seed to use for deterministic shuffling / preprocessing.
Defaults to None.
tf_data_service_address: An optional URI of a tf.data service to offload
preprocessing onto during training. The URI should be in the format
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
data service will be applied.
enable_tf_data_service: Whether to apply tf.data service for this dataset.
If True, flag `tf_data_service_address` must be set.
tf_data_service_policy: Sharding policy to use for tf.data service when it
is enabled.
tf_data_service_job_name: Job name to use for tf.data service. If None, the
default job name will be used.
offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
to tf.data service. If True, enable_tf_data_service must also be True, and
the preprocessing transformation will be offloaded to tf data service
workers. Otherwise, the preprocessing transformation will be applied on
the host CPU. If tf data service is not enabled, this arg must be set
False. Defaults to False.
tf_data_service_replicate_on_split: Whether to replicate the file dataset on
split when distributing data to tf.data service workers. Note: it could be
used in the case where multiple datasets are processed together under
`Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
enabled is equivalent to having that dataset processed as `Off` mode.
feature_spec: A mapping of feature keys to `FixedLenFeature`,
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
used to parse the TF examples, or as context_features spec to parse TF
Expand Down Expand Up @@ -208,7 +220,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
tensorflow.
debug: An optional boolean indicating whether to debug input boundedness. If
`True`, the dataset will consist of a single batch that's cached and
infinitely repeated
infinitely repeated.
"""

cache_reading: bool = False
Expand All @@ -231,10 +243,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
readahead: str | None = None
group_uris_by_dir: bool = False
seed: int | None = None
tf_data_service_address: str | None = None
enable_tf_data_service: bool = False
tf_data_service_job_name: str | None = None
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
tf.data.experimental.service.ShardingPolicy.OFF
)
offload_preprocessing_to_tf_data_service: bool = False
feature_spec: Mapping[str, IO_Feature] | None = None
sequence_feature_spec: Mapping[str, IO_Feature] | None = None
tf_transform_output: TFTransformOutput | None = None
Expand All @@ -246,14 +260,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
sharding_info: DatasetShardingInfo = dataclasses.field(
default_factory=DatasetShardingInfo
)
tf_data_service_replicate_on_split: bool = False
debug: bool = False

def __post_init__(self):
if self.tf_data_service_address is not None:
if self.enable_tf_data_service:
if flags.FLAGS.tf_data_service_address is None:
raise ValueError(
"Flag `tf_data_service_address` must be set when"
" `enable_tf_data_service` is True."
)
if self.seed is not None:
raise ValueError("`seed` must be None for data service.")
if self.sharding:
raise ValueError("`sharding` must be set to False for data service.")
else:
if self.offload_preprocessing_to_tf_data_service:
raise ValueError(
"`offload_preprocessing_to_tf_data_service` must be False when"
" `enable_tf_data_service` is False."
)

@functools.cached_property
def tfds_metadata(self) -> TFDSMetadata | None:
Expand Down Expand Up @@ -464,6 +490,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
# Create a dataset of file / file group uris.
dataset = tf.data.Dataset.from_tensor_slices(uris)

if self.tf_data_service_replicate_on_split:
dataset = tf.data.apply_rewrite(dataset, rewrite="replicate_on_split")

# Repeat the dataset. We might need to repeat the dataset here in case the
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
# even we do have enough shards for the input data.
Expand Down Expand Up @@ -533,23 +562,26 @@ def _maybe_apply_tf_data_service(
self, dataset: tf.data.Dataset
) -> tf.data.Dataset:
"""Applies the tf.data service to the dataset."""
if self.tf_data_service_address is None:
if not self.enable_tf_data_service:
return dataset

tf_data_service_address = flags.FLAGS.tf_data_service_address

per_proc_batch_size = self.sharding_info.per_process_batch_size(
self.global_batch_size
)
logging.info(
"Applying tf.data service with address %s and per replica batch"
" size %s",
self.tf_data_service_address,
tf_data_service_address,
per_proc_batch_size,
)
return dataset.apply(
tf.data.experimental.service.distribute(
processing_mode=self.tf_data_service_policy,
service=self.tf_data_service_address,
job_name=f"bs_{per_proc_batch_size}",
service=tf_data_service_address,
job_name=self.tf_data_service_job_name
or "tf_data_service_shared_job_name",
)
)

Expand All @@ -566,12 +598,18 @@ def make(self) -> tf.data.Dataset:
dataset = self._parse_dataset(dataset)
# Apply filters to the batched dataset.
dataset = self._maybe_filter_dataset(dataset)
# Apply data service.
dataset = self._maybe_apply_tf_data_service(dataset)
# Apply TF Data service before preprocessing.
if not self.offload_preprocessing_to_tf_data_service:
dataset = self._maybe_apply_tf_data_service(dataset)

# Apply transformations on the dataset.
for fn in self.map_fns:
dataset = dataset.map(fn, num_parallel_calls=self.num_parallel_threads)

# Apply TF Data Service after preprocessing.
if self.offload_preprocessing_to_tf_data_service:
dataset = self._maybe_apply_tf_data_service(dataset)

if self.debug:
dataset = dataset.take(1).cache().repeat()

Expand Down
16 changes: 8 additions & 8 deletions recml/core/ops/hstu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def _apply_mask(
masks = []
if mask_ref is not None:
if k_in_lanes:
mask = pl.load(mask_ref, (slice(None), k_slice))
mask = mask_ref[:, k_slice]
else:
mask = pl.load(mask_ref, (k_slice, slice(None)))
mask = mask_ref[k_slice, :]

snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _apply_mask(
k_sequence = k_offset + jax.lax.broadcasted_iota(
jnp.int32, (k_slice.size, bq), 0
)
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
q_sequence = q_sequence_ref[:1, :] # [1, bq]
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))

assert q_sequence.shape == k_sequence.shape
Expand All @@ -170,7 +170,7 @@ def _apply_mask(

if q_segment_ids_ref is not None:
if k_in_lanes:
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
if rem:
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
Expand All @@ -181,9 +181,9 @@ def _apply_mask(
if rem:
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
kv_ids = pltpu.repeat(
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
kv_segment_ids_ref[k_slice, :], repeats, axis=1
) # [k_slice, bq]
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
masks.append(q_ids == kv_ids)

if masks:
Expand Down Expand Up @@ -228,7 +228,7 @@ def body(kv_compute_index, _):
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)

q = q_ref[...]
k = pl.load(k_ref, (slice_k, slice(None)))
k = k_ref[slice_k, :]
qk = jax.lax.dot_general(
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
)
Expand Down Expand Up @@ -256,7 +256,7 @@ def body(kv_compute_index, _):
)

sv_dims = NN_DIM_NUMBERS
v = pl.load(v_ref, (slice_k, slice(None)))
v = v_ref[slice_k, :]

to_float32 = lambda x: x.astype(jnp.float32)
v = to_float32(v)
Expand Down
Loading
Loading