24
24
import re
25
25
from typing import Any , Protocol
26
26
27
+ from absl import flags
27
28
from absl import logging
28
29
import jax
29
30
from recml .core .utils import types
@@ -162,12 +163,23 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
162
163
Defaults to False.
163
164
seed: An optional seed to use for deterministic shuffling / preprocessing.
164
165
Defaults to None.
165
- tf_data_service_address: An optional URI of a tf.data service to offload
166
- preprocessing onto during training. The URI should be in the format
167
- "protocol://address", e.g. "grpc://tf-data-service:5050". If `None` no
168
- data service will be applied.
166
+ enable_tf_data_service: Whether to apply tf.data service for this dataset.
167
+ If True, flag `tf_data_service_address` must be set.
169
168
tf_data_service_policy: Sharding policy to use for tf.data service when it
170
169
is enabled.
170
+ tf_data_service_job_name: Job name to use for tf.data service. If None, the
171
+ default job name will be used.
172
+ offload_preprocessing_to_tf_data_service: Whether to offload preprocessing
173
+ to tf.data service. If True, enable_tf_data_service must also be True, and
174
+ the preprocessing transformation will be offloaded to tf data service
175
+ workers. Otherwise, the preprocessing transformation will be applied on
176
+ the host CPU. If tf data service is not enabled, this arg must be set
177
+ False. Defaults to False.
178
+ tf_data_service_replicate_on_split: Whether to replicate the file dataset on
179
+ split when distributing data to tf.data service workers. Note: it could be
180
+ used in the case where multiple datasets are processed together under
181
+ `Dynamic` mode. The dataset with `tf_data_service_replicate_on_split`
182
+ enabled is equivalent to having that dataset processed as `Off` mode.
171
183
feature_spec: A mapping of feature keys to `FixedLenFeature`,
172
184
`VarLenFeature`, `SparseFeature`, or `RaggedFeature` values. This will be
173
185
used to parse the TF examples, or as context_features spec to parse TF
@@ -208,7 +220,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
208
220
tensorflow.
209
221
debug: An optional boolean indicating whether to debug input boundedness. If
210
222
`True`, the dataset will consist of a single batch that's cached and
211
- infinitely repeated
223
+ infinitely repeated.
212
224
"""
213
225
214
226
cache_reading : bool = False
@@ -231,10 +243,12 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
231
243
readahead : str | None = None
232
244
group_uris_by_dir : bool = False
233
245
seed : int | None = None
234
- tf_data_service_address : str | None = None
246
+ enable_tf_data_service : bool = False
247
+ tf_data_service_job_name : str | None = None
235
248
tf_data_service_policy : tf .data .experimental .service .ShardingPolicy = (
236
249
tf .data .experimental .service .ShardingPolicy .OFF
237
250
)
251
+ offload_preprocessing_to_tf_data_service : bool = False
238
252
feature_spec : Mapping [str , IO_Feature ] | None = None
239
253
sequence_feature_spec : Mapping [str , IO_Feature ] | None = None
240
254
tf_transform_output : TFTransformOutput | None = None
@@ -246,14 +260,26 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
246
260
sharding_info : DatasetShardingInfo = dataclasses .field (
247
261
default_factory = DatasetShardingInfo
248
262
)
263
+ tf_data_service_replicate_on_split : bool = False
249
264
debug : bool = False
250
265
251
266
def __post_init__ (self ):
252
- if self .tf_data_service_address is not None :
267
+ if self .enable_tf_data_service :
268
+ if flags .FLAGS .tf_data_service_address is None :
269
+ raise ValueError (
270
+ "Flag `tf_data_service_address` must be set when"
271
+ " `enable_tf_data_service` is True."
272
+ )
253
273
if self .seed is not None :
254
274
raise ValueError ("`seed` must be None for data service." )
255
275
if self .sharding :
256
276
raise ValueError ("`sharding` must be set to False for data service." )
277
+ else :
278
+ if self .offload_preprocessing_to_tf_data_service :
279
+ raise ValueError (
280
+ "`offload_preprocessing_to_tf_data_service` must be False when"
281
+ " `enable_tf_data_service` is False."
282
+ )
257
283
258
284
@functools .cached_property
259
285
def tfds_metadata (self ) -> TFDSMetadata | None :
@@ -464,6 +490,9 @@ def _file_group_reader(file_group: str) -> tf.data.Dataset:
464
490
# Create a dataset of file / file group uris.
465
491
dataset = tf .data .Dataset .from_tensor_slices (uris )
466
492
493
+ if self .tf_data_service_replicate_on_split :
494
+ dataset = tf .data .apply_rewrite (dataset , rewrite = "replicate_on_split" )
495
+
467
496
# Repeat the dataset. We might need to repeat the dataset here in case the
468
497
# issue is encountered: internal screenshot link:6jAKKoEMT3afXRe
469
498
# even we do have enough shards for the input data.
@@ -533,23 +562,26 @@ def _maybe_apply_tf_data_service(
533
562
self , dataset : tf .data .Dataset
534
563
) -> tf .data .Dataset :
535
564
"""Applies the tf.data service to the dataset."""
536
- if self .tf_data_service_address is None :
565
+ if not self .enable_tf_data_service :
537
566
return dataset
538
567
568
+ tf_data_service_address = flags .FLAGS .tf_data_service_address
569
+
539
570
per_proc_batch_size = self .sharding_info .per_process_batch_size (
540
571
self .global_batch_size
541
572
)
542
573
logging .info (
543
574
"Applying tf.data service with address %s and per replica batch"
544
575
" size %s" ,
545
- self . tf_data_service_address ,
576
+ tf_data_service_address ,
546
577
per_proc_batch_size ,
547
578
)
548
579
return dataset .apply (
549
580
tf .data .experimental .service .distribute (
550
581
processing_mode = self .tf_data_service_policy ,
551
- service = self .tf_data_service_address ,
552
- job_name = f"bs_{ per_proc_batch_size } " ,
582
+ service = tf_data_service_address ,
583
+ job_name = self .tf_data_service_job_name
584
+ or "tf_data_service_shared_job_name" ,
553
585
)
554
586
)
555
587
@@ -566,12 +598,18 @@ def make(self) -> tf.data.Dataset:
566
598
dataset = self ._parse_dataset (dataset )
567
599
# Apply filters to the batched dataset.
568
600
dataset = self ._maybe_filter_dataset (dataset )
569
- # Apply data service.
570
- dataset = self ._maybe_apply_tf_data_service (dataset )
601
+ # Apply TF Data service before preprocessing.
602
+ if not self .offload_preprocessing_to_tf_data_service :
603
+ dataset = self ._maybe_apply_tf_data_service (dataset )
604
+
571
605
# Apply transformations on the dataset.
572
606
for fn in self .map_fns :
573
607
dataset = dataset .map (fn , num_parallel_calls = self .num_parallel_threads )
574
608
609
+ # Apply TF Data Service after preprocessing.
610
+ if self .offload_preprocessing_to_tf_data_service :
611
+ dataset = self ._maybe_apply_tf_data_service (dataset )
612
+
575
613
if self .debug :
576
614
dataset = dataset .take (1 ).cache ().repeat ()
577
615
0 commit comments