|
16 | 16 | from openfl.databases import PersistentTensorDB, TensorDB
|
17 | 17 | from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage
|
18 | 18 | from openfl.pipelines import NoCompressionPipeline, TensorCodec
|
19 |
| -from openfl.protocols import base_pb2, utils |
| 19 | +from openfl.protocols import utils |
20 | 20 | from openfl.protocols.base_pb2 import NamedTensor
|
21 | 21 | from openfl.utilities import TaskResultKey, TensorKey, change_tags
|
22 | 22 |
|
@@ -83,39 +83,13 @@ def __init__(
|
83 | 83 | single_col_cert_common_name=None,
|
84 | 84 | compression_pipeline=None,
|
85 | 85 | db_store_rounds=1,
|
86 |
| - initial_tensor_dict=None, |
87 | 86 | log_memory_usage=False,
|
88 | 87 | write_logs=False,
|
89 | 88 | callbacks: Optional[List] = [],
|
90 | 89 | persist_checkpoint=True,
|
91 | 90 | persistent_db_path=None,
|
92 | 91 | secure_aggregation=False,
|
93 | 92 | ):
|
94 |
| - """Initializes the Aggregator. |
95 |
| -
|
96 |
| - Args: |
97 |
| - aggregator_uuid (int): Aggregation ID. |
98 |
| - federation_uuid (str): Federation ID. |
99 |
| - authorized_cols (list of str): The list of IDs of enrolled |
100 |
| - collaborators. |
101 |
| - init_state_path (str): The location of the initial weight file. |
102 |
| - best_state_path (str): The file location to store the weight of |
103 |
| - the best model. |
104 |
| - last_state_path (str): The file location to store the latest |
105 |
| - weight. |
106 |
| - assigner: Assigner object. |
107 |
| - straggler_handling_policy (optional): Straggler handling policy. |
108 |
| - rounds_to_train (int, optional): Number of rounds to train. |
109 |
| - Defaults to 256. |
110 |
| - single_col_cert_common_name (str, optional): Common name for single |
111 |
| - collaborator certificate. Defaults to None. |
112 |
| - compression_pipeline (optional): Compression pipeline. Defaults to |
113 |
| - NoCompressionPipeline. |
114 |
| - db_store_rounds (int, optional): Rounds to store in TensorDB. |
115 |
| - Defaults to 1. |
116 |
| - initial_tensor_dict (dict, optional): Initial tensor dictionary. |
117 |
| - callbacks: List of callbacks to be used during the experiment. |
118 |
| - """ |
119 | 93 | self.round_number = 0
|
120 | 94 | self.next_model_round_number = 0
|
121 | 95 |
|
@@ -205,20 +179,10 @@ def __init__(
|
205 | 179 | last_state_path=self.last_state_path,
|
206 | 180 | )
|
207 | 181 |
|
208 |
| - if initial_tensor_dict: |
209 |
| - self._load_initial_tensors_from_dict(initial_tensor_dict) |
210 |
| - self.model = utils.construct_model_proto( |
211 |
| - tensor_dict=initial_tensor_dict, |
212 |
| - round_number=0, |
213 |
| - tensor_pipe=self.compression_pipeline, |
214 |
| - ) |
215 |
| - else: |
216 |
| - if self.connector: |
217 |
| - # The model definition will be handled by the respective framework |
218 |
| - self.model = {} |
219 |
| - else: |
220 |
| - self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path) |
221 |
| - self._load_initial_tensors() # keys are TensorKeys |
| 182 | + self.model = {} |
| 183 | + if not self.connector: |
| 184 | + self.model = utils.load_proto(self.init_state_path) |
| 185 | + self._load_initial_tensors() # keys are TensorKeys |
222 | 186 |
|
223 | 187 | self._secure_aggregation_enabled = secure_aggregation
|
224 | 188 | if self._secure_aggregation_enabled:
|
@@ -337,23 +301,6 @@ def _load_initial_tensors(self):
|
337 | 301 | self.tensor_db.cache_tensor(tensor_key_dict)
|
338 | 302 | logger.debug("This is the initial tensor_db: %s", self.tensor_db)
|
339 | 303 |
|
340 |
| - def _load_initial_tensors_from_dict(self, tensor_dict): |
341 |
| - """Load all of the tensors required to begin federated learning. |
342 |
| -
|
343 |
| - Required tensors are: \ |
344 |
| - 1. Initial model. |
345 |
| -
|
346 |
| - Returns: |
347 |
| - None |
348 |
| - """ |
349 |
| - tensor_key_dict = { |
350 |
| - TensorKey(k, self.uuid, self.round_number, False, ("model",)): v |
351 |
| - for k, v in tensor_dict.items() |
352 |
| - } |
353 |
| - # all initial model tensors are loaded here |
354 |
| - self.tensor_db.cache_tensor(tensor_key_dict) |
355 |
| - logger.debug("This is the initial tensor_db: %s", self.tensor_db) |
356 |
| - |
357 | 304 | def _save_model(self, round_number, file_path):
|
358 | 305 | """Save the best or latest model.
|
359 | 306 |
|
@@ -485,10 +432,7 @@ def get_tasks(self, collaborator_name):
|
485 | 432 |
|
486 | 433 | # first, if it is time to quit, inform the collaborator
|
487 | 434 | if self._time_to_quit():
|
488 |
| - logger.info( |
489 |
| - "Sending signal to collaborator %s to shutdown...", |
490 |
| - collaborator_name, |
491 |
| - ) |
| 435 | + logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name) |
492 | 436 | self.quit_job_sent_to.append(collaborator_name)
|
493 | 437 |
|
494 | 438 | tasks = None
|
@@ -853,7 +797,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict:
|
853 | 797 | )
|
854 | 798 | # Leave out straggler for the round even if they've partially
|
855 | 799 | # completed given tasks
|
856 |
| - collaborators_for_task = [] |
857 | 800 | collaborators_for_task = [
|
858 | 801 | c for c in all_collaborators_for_task if c in self.collaborators_done
|
859 | 802 | ]
|
@@ -925,11 +868,7 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict:
|
925 | 868 | f"{agg_results:f}"
|
926 | 869 | )
|
927 | 870 | self._save_model(round_number, self.best_state_path)
|
928 |
| - else: |
929 |
| - logger.info( |
930 |
| - f"Round {round_number}: best score observed {agg_results:f} " |
931 |
| - "(model not saved in evaluation mode)" |
932 |
| - ) |
| 871 | + |
933 | 872 | if "trained" in tags:
|
934 | 873 | self._prepare_trained(tensor_name, origin, round_number, report, agg_results)
|
935 | 874 |
|
|
0 commit comments