Skip to content

Commit 1e676b7

Browse files
RecML authorsrecml authors
authored andcommitted
Reverts changelist 793734230
PiperOrigin-RevId: 814387092
1 parent 847628b commit 1e676b7

26 files changed

+4785
-460
lines changed

recml/core/data/tf_dataset_factory.py

Lines changed: 51 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import re
2525
from typing import Any, Protocol
2626

27+
from absl import flags
2728
from absl import logging
2829
import jax
2930
from recml.core.utils import types
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162163
Defaults to False.
163164
seed: An optional seed to use for deterministic shuffling / preprocessing.
164165
Defaults to None.
165-
tf_data_service_address: An optional URI of a tf.data service to offload
166-
preprocessing onto during training. The URI should be in the format
167-
"protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168-
data service will be applied.
166+
enable_tf_data_service: Whether to apply tf.data service for this dataset.
167+
If True, flag `tf_data_service_address` must be set.
169168
tf_data_service_policy: Sharding policy to use for tf.data service when it
170169
is enabled.
170+
tf_data_service_job_name: Job name to use for tf.data service. If None, the
171+
default job name will be used.
172+
offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
173+
to tf.data service. If True, enable_tf_data_service must also be True, and
174+
the preprocessing transformation will be offloaded to tf data service
175+
workers. Otherwise, the preprocessing transformation will be applied on
176+
the host CPU. If tf data service is not enabled, this arg must be set
177+
False. Defaults to False.
178+
tf_data_service_replicate_on_split: Whether to replicate the file dataset on
179+
split when distributing data to tf.data service workers. Note: it could be
180+
used in the case where multiple datasets are processed together under
181+
`Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
182+
enabled is equivalent to having that dataset processed as `Off` mode.
171183
feature_spec: A mapping of feature keys to `FixedLenFeature`,
172184
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173185
used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +220,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208220
tensorflow.
209221
debug: An optional boolean indicating whether to debug input boundedness. If
210222
`True`, the dataset will consist of a single batch that's cached and
211-
infinitely repeated
223+
infinitely repeated.
212224
"""
213225

214226
cache_reading: bool = False
@@ -231,10 +243,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231243
readahead: str | None = None
232244
group_uris_by_dir: bool = False
233245
seed: int | None = None
234-
tf_data_service_address: str | None = None
246+
enable_tf_data_service: bool = False
247+
tf_data_service_job_name: str | None = None
235248
tf_data_service_policy: tf.data.experimental.service.ShardingPolicy = (
236249
tf.data.experimental.service.ShardingPolicy.OFF
237250
)
251+
offload_preprocessing_to_tf_data_service: bool = False
238252
feature_spec: Mapping[str, IO_Feature] | None = None
239253
sequence_feature_spec: Mapping[str, IO_Feature] | None = None
240254
tf_transform_output: TFTransformOutput | None = None
@@ -246,14 +260,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246260
sharding_info: DatasetShardingInfo = dataclasses.field(
247261
default_factory=DatasetShardingInfo
248262
)
263+
tf_data_service_replicate_on_split: bool = False
249264
debug: bool = False
250265

251266
def __post_init__(self):
252-
if self.tf_data_service_address is not None:
267+
if self.enable_tf_data_service:
268+
if flags.FLAGS.tf_data_service_address is None:
269+
raise ValueError(
270+
"Flag `tf_data_service_address` must be set when"
271+
" `enable_tf_data_service` is True."
272+
)
253273
if self.seed is not None:
254274
raise ValueError("`seed` must be None for data service.")
255275
if self.sharding:
256276
raise ValueError("`sharding` must be set to False for data service.")
277+
else:
278+
if self.offload_preprocessing_to_tf_data_service:
279+
raise ValueError(
280+
"`offload_preprocessing_to_tf_data_service` must be False when"
281+
" `enable_tf_data_service` is False."
282+
)
257283

258284
@functools.cached_property
259285
def tfds_metadata(self) -> TFDSMetadata | None:
@@ -464,6 +490,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464490
# Create a dataset of file / file group uris.
465491
dataset = tf.data.Dataset.from_tensor_slices(uris)
466492

493+
if self.tf_data_service_replicate_on_split:
494+
dataset = tf.data.apply_rewrite(dataset, rewrite="replicate_on_split")
495+
467496
# Repeat the dataset. We might need to repeat the dataset here in case the
468497
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469498
# even we do have enough shards for the input data.
@@ -533,23 +562,26 @@ def _maybe_apply_tf_data_service(
533562
self, dataset: tf.data.Dataset
534563
) -> tf.data.Dataset:
535564
"""Applies the tf.data service to the dataset."""
536-
if self.tf_data_service_address is None:
565+
if not self.enable_tf_data_service:
537566
return dataset
538567

568+
tf_data_service_address = flags.FLAGS.tf_data_service_address
569+
539570
per_proc_batch_size = self.sharding_info.per_process_batch_size(
540571
self.global_batch_size
541572
)
542573
logging.info(
543574
"Applying tf.data service with address %s and per replica batch"
544575
" size %s",
545-
self.tf_data_service_address,
576+
tf_data_service_address,
546577
per_proc_batch_size,
547578
)
548579
return dataset.apply(
549580
tf.data.experimental.service.distribute(
550581
processing_mode=self.tf_data_service_policy,
551-
service=self.tf_data_service_address,
552-
job_name=f"bs_{per_proc_batch_size}",
582+
service=tf_data_service_address,
583+
job_name=self.tf_data_service_job_name
584+
or "tf_data_service_shared_job_name",
553585
)
554586
)
555587

@@ -566,12 +598,18 @@ def make(self) -> tf.data.Dataset:
566598
dataset = self._parse_dataset(dataset)
567599
# Apply filters to the batched dataset.
568600
dataset = self._maybe_filter_dataset(dataset)
569-
# Apply data service.
570-
dataset = self._maybe_apply_tf_data_service(dataset)
601+
# Apply TF Data service before preprocessing.
602+
if not self.offload_preprocessing_to_tf_data_service:
603+
dataset = self._maybe_apply_tf_data_service(dataset)
604+
571605
# Apply transformations on the dataset.
572606
for fn in self.map_fns:
573607
dataset = dataset.map(fn, num_parallel_calls=self.num_parallel_threads)
574608

609+
# Apply TF Data Service after preprocessing.
610+
if self.offload_preprocessing_to_tf_data_service:
611+
dataset = self._maybe_apply_tf_data_service(dataset)
612+
575613
if self.debug:
576614
dataset = dataset.take(1).cache().repeat()
577615

recml/core/ops/hstu_ops.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ def _apply_mask(
125125
masks = []
126126
if mask_ref is not None:
127127
if k_in_lanes:
128-
mask = pl.load(mask_ref, (slice(None), k_slice))
128+
mask = mask_ref[:, k_slice]
129129
else:
130-
mask = pl.load(mask_ref, (k_slice, slice(None)))
130+
mask = mask_ref[k_slice, :]
131131

132132
snm = jnp.where(should_not_mask, 1, 0)
133133
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
@@ -156,7 +156,7 @@ def _apply_mask(
156156
k_sequence = k_offset + jax.lax.broadcasted_iota(
157157
jnp.int32, (k_slice.size, bq), 0
158158
)
159-
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
159+
q_sequence = q_sequence_ref[:1, :] # [1, bq]
160160
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))
161161

162162
assert q_sequence.shape == k_sequence.shape
@@ -170,7 +170,7 @@ def _apply_mask(
170170

171171
if q_segment_ids_ref is not None:
172172
if k_in_lanes:
173-
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
173+
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
174174
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
175175
if rem:
176176
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
@@ -181,9 +181,9 @@ def _apply_mask(
181181
if rem:
182182
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
183183
kv_ids = pltpu.repeat(
184-
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
184+
kv_segment_ids_ref[k_slice, :], repeats, axis=1
185185
) # [k_slice, bq]
186-
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
186+
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
187187
masks.append(q_ids == kv_ids)
188188

189189
if masks:
@@ -228,7 +228,7 @@ def body(kv_compute_index, _):
228228
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)
229229

230230
q = q_ref[...]
231-
k = pl.load(k_ref, (slice_k, slice(None)))
231+
k = k_ref[slice_k, :]
232232
qk = jax.lax.dot_general(
233233
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
234234
)
@@ -256,7 +256,7 @@ def body(kv_compute_index, _):
256256
)
257257

258258
sv_dims = NN_DIM_NUMBERS
259-
v = pl.load(v_ref, (slice_k, slice(None)))
259+
v = v_ref[slice_k, :]
260260

261261
to_float32 = lambda x: x.astype(jnp.float32)
262262
v = to_float32(v)

0 commit comments

Comments
 (0)