1616
1717Distributed training makes it possible to train models quickly on larger
1818datasets. Distributed training in TF-DF relies on the TensorFlow
19- ParameterServerV2 distribution strategy. Only some of the TF-DF models support
20- distributed training.
19+ ParameterServerV2 distribution strategy or the Yggdrasil Decision Forest GRPC
20+ distribute strategy. Only some of the TF-DF models support distributed training.
2121
2222See the
2323[ distributed training] ( https://github.com/google/yggdrasil-decision-forests/documentation/user_manual.md?#distributed-training )
2424section in the Yggdrasil Decision Forests user manual for details about the
25- available distributed training algorithms. When using distributed training in
26- TF-DF, Yggdrasil Decision Forests is effectively running the `TF_DIST distribute
27- implementation`.
25+ available distributed training algorithms. When using distributed training with
26+ TF Parameter Server in TF-DF, Yggdrasil Decision Forests is effectively running
27+ the ` TF_DIST ` distribute implementation.
28+
29+ ** Note:** Currently (Oct. 2021), the shared (i.e. != monolithic) OSS build of
30+ TF-DF does not support TF ParameterServer distribution strategy. Please use the
31+ Yggdrasil DF GRPC distribute strategy instead.
2832
2933## Dataset
3034
@@ -40,21 +44,20 @@ As of today ( Oct 2021), the following solutions are available for TF-DF:
4044 solution is the fastest and the one that gives the best results as it is
4145 currently the only one that guarantees that each example is read only once.
4246 The downside is that this solution does not support TensorFlow
43- pre-processing.
47+ pre-processing. The "Yggdrasil DF GRPC distribute strategy" only support
48+ this option for dataset reading.
4449
45502 . To use ** ParameterServerV2 distributed dataset** with dataset file sharding
4651 using TF-DF worker index. This solution is the most natural for TF users.
4752
4853Currently, using ParameterServerV2 distributed dataset with context or
4954tf.data.service are not compatible with TF-DF.
5055
51- Note that in all cases, ParameterServerV2 is used to distribute the computation.
52-
5356## Examples
5457
5558Following are some examples of distributed training.
5659
57- ### Distribution with Yggdrasil distributed dataset reading
60+ ### Distribution with Yggdrasil distributed dataset reading and TF ParameterServerV2 strategy
5861
5962``` python
6063import tensorflow_decision_forests as tfdf
@@ -78,7 +81,7 @@ See Yggdrasil Decision Forests
7881[ supported formats] ( https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format )
7982for the possible values of ` dataset_format ` .
8083
81- ### Distribution with ParameterServerV2 distributed dataset
84+ ### Distribution with ParameterServerV2 distributed dataset and TF ParameterServerV2 strategy
8285
8386``` python
8487import tensorflow_decision_forests as tfdf
@@ -149,3 +152,38 @@ model.fit(
149152print (" Trained model" )
150153model.summary()
151154```
155+
156+ ### Distribution with Yggdrasil distributed dataset reading and Yggdrasil DF GRPC distribute strategy
157+
158+ ``` python
159+ import tensorflow_decision_forests as tfdf
160+ import tensorflow as tf
161+
162+ deployment_config = tfdf.keras.core.YggdrasilDeploymentConfig()
163+ deployment_config.try_resume_training = True
164+ deployment_config.distribute.implementation_key = " GRPC"
165+ socket_addresses = deployment_config.distribute.Extensions[
166+ tfdf.keras.core.grpc_pb2.grpc].socket_addresses
167+
168+ # Socket addresses of ":grpc_worker_main" running instances.
169+ socket_addresses.addresses.add(ip = " 127.0.0.1" , port = 2001 )
170+ socket_addresses.addresses.add(ip = " 127.0.0.2" , port = 2001 )
171+ socket_addresses.addresses.add(ip = " 127.0.0.3" , port = 2001 )
172+ socket_addresses.addresses.add(ip = " 127.0.0.4" , port = 2001 )
173+
174+ model = tfdf.keras.DistributedGradientBoostedTreesModel(
175+ advanced_arguments = tfdf.keras.AdvancedArguments(
176+ yggdrasil_deployment_config = deployment_config))
177+
178+ model.fit_on_dataset_path(
179+ train_path = " /path/to/dataset@100000" ,
180+ label_key = " label_key" ,
181+ dataset_format = " tfrecord+tfe" )
182+
183+ print (" Trained model" )
184+ model.summary()
185+ ```
186+
187+ See Yggdrasil Decision Forests
188+ [ supported formats] ( https://github.com/google/yggdrasil-decision-forests/blob/main/documentation/user_manual.md#dataset-path-and-format )
189+ for the possible values of ` dataset_format ` .
0 commit comments