Skip to content

Commit 8b1b27b

Browse files
Remove dead code
Signed-off-by: Shah, Karan <[email protected]>
1 parent 01923ac commit 8b1b27b

File tree

1 file changed

+7
-68
lines changed

1 file changed

+7
-68
lines changed

openfl/component/aggregator/aggregator.py

Lines changed: 7 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from openfl.databases import PersistentTensorDB, TensorDB
1717
from openfl.interface.aggregation_functions import SecureWeightedAverage, WeightedAverage
1818
from openfl.pipelines import NoCompressionPipeline, TensorCodec
19-
from openfl.protocols import base_pb2, utils
19+
from openfl.protocols import utils
2020
from openfl.protocols.base_pb2 import NamedTensor
2121
from openfl.utilities import TaskResultKey, TensorKey, change_tags
2222

@@ -83,39 +83,13 @@ def __init__(
8383
single_col_cert_common_name=None,
8484
compression_pipeline=None,
8585
db_store_rounds=1,
86-
initial_tensor_dict=None,
8786
log_memory_usage=False,
8887
write_logs=False,
8988
callbacks: Optional[List] = [],
9089
persist_checkpoint=True,
9190
persistent_db_path=None,
9291
secure_aggregation=False,
9392
):
94-
"""Initializes the Aggregator.
95-
96-
Args:
97-
aggregator_uuid (int): Aggregation ID.
98-
federation_uuid (str): Federation ID.
99-
authorized_cols (list of str): The list of IDs of enrolled
100-
collaborators.
101-
init_state_path (str): The location of the initial weight file.
102-
best_state_path (str): The file location to store the weight of
103-
the best model.
104-
last_state_path (str): The file location to store the latest
105-
weight.
106-
assigner: Assigner object.
107-
straggler_handling_policy (optional): Straggler handling policy.
108-
rounds_to_train (int, optional): Number of rounds to train.
109-
Defaults to 256.
110-
single_col_cert_common_name (str, optional): Common name for single
111-
collaborator certificate. Defaults to None.
112-
compression_pipeline (optional): Compression pipeline. Defaults to
113-
NoCompressionPipeline.
114-
db_store_rounds (int, optional): Rounds to store in TensorDB.
115-
Defaults to 1.
116-
initial_tensor_dict (dict, optional): Initial tensor dictionary.
117-
callbacks: List of callbacks to be used during the experiment.
118-
"""
11993
self.round_number = 0
12094
self.next_model_round_number = 0
12195

@@ -205,20 +179,10 @@ def __init__(
205179
last_state_path=self.last_state_path,
206180
)
207181

208-
if initial_tensor_dict:
209-
self._load_initial_tensors_from_dict(initial_tensor_dict)
210-
self.model = utils.construct_model_proto(
211-
tensor_dict=initial_tensor_dict,
212-
round_number=0,
213-
tensor_pipe=self.compression_pipeline,
214-
)
215-
else:
216-
if self.connector:
217-
# The model definition will be handled by the respective framework
218-
self.model = {}
219-
else:
220-
self.model: base_pb2.ModelProto = utils.load_proto(self.init_state_path)
221-
self._load_initial_tensors() # keys are TensorKeys
182+
self.model = {}
183+
if not self.connector:
184+
self.model = utils.load_proto(self.init_state_path)
185+
self._load_initial_tensors() # keys are TensorKeys
222186

223187
self._secure_aggregation_enabled = secure_aggregation
224188
if self._secure_aggregation_enabled:
@@ -337,23 +301,6 @@ def _load_initial_tensors(self):
337301
self.tensor_db.cache_tensor(tensor_key_dict)
338302
logger.debug("This is the initial tensor_db: %s", self.tensor_db)
339303

340-
def _load_initial_tensors_from_dict(self, tensor_dict):
341-
"""Load all of the tensors required to begin federated learning.
342-
343-
Required tensors are: \
344-
1. Initial model.
345-
346-
Returns:
347-
None
348-
"""
349-
tensor_key_dict = {
350-
TensorKey(k, self.uuid, self.round_number, False, ("model",)): v
351-
for k, v in tensor_dict.items()
352-
}
353-
# all initial model tensors are loaded here
354-
self.tensor_db.cache_tensor(tensor_key_dict)
355-
logger.debug("This is the initial tensor_db: %s", self.tensor_db)
356-
357304
def _save_model(self, round_number, file_path):
358305
"""Save the best or latest model.
359306
@@ -485,10 +432,7 @@ def get_tasks(self, collaborator_name):
485432

486433
# first, if it is time to quit, inform the collaborator
487434
if self._time_to_quit():
488-
logger.info(
489-
"Sending signal to collaborator %s to shutdown...",
490-
collaborator_name,
491-
)
435+
logger.info("Sending signal to collaborator %s to shutdown...", collaborator_name)
492436
self.quit_job_sent_to.append(collaborator_name)
493437

494438
tasks = None
@@ -853,7 +797,6 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict:
853797
)
854798
# Leave out straggler for the round even if they've partially
855799
# completed given tasks
856-
collaborators_for_task = []
857800
collaborators_for_task = [
858801
c for c in all_collaborators_for_task if c in self.collaborators_done
859802
]
@@ -925,11 +868,7 @@ def _compute_validation_related_task_metrics(self, task_name) -> dict:
925868
f"{agg_results:f}"
926869
)
927870
self._save_model(round_number, self.best_state_path)
928-
else:
929-
logger.info(
930-
f"Round {round_number}: best score observed {agg_results:f} "
931-
"(model not saved in evaluation mode)"
932-
)
871+
933872
if "trained" in tags:
934873
self._prepare_trained(tensor_name, origin, round_number, report, agg_results)
935874

0 commit comments

Comments
 (0)