Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,4 @@ results/*
**MNIST/
**cert/
.history/
.DS_Store
227 changes: 133 additions & 94 deletions openfl/component/collaborator/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,38 +157,71 @@ def ping(self):
"""Ping the Aggregator."""
self.client.ping()

def run(self):
"""Run the collaborator."""
# Experiment begin
self.callbacks.on_experiment_begin()
def run(self) -> None:
"""
Run the collaborator main loop.

Handles experiment lifecycle, round execution, and error logging.
"""
try:
self.callbacks.on_experiment_begin()
self._execute_collaborator_rounds()
self.callbacks.on_experiment_end()
logger.info("Received shutdown signal. Exiting...")
except Exception as experiment_error:
logger.critical(
f"Critical error in collaborator execution. Error: {experiment_error}",
exc_info=True,
)
self.callbacks.on_experiment_end({"error": str(experiment_error)})
logger.critical("Collaborator is shutting down due to critical error.")
raise RuntimeError("Collaborator execution failed") from experiment_error

def _execute_collaborator_rounds(self) -> None:
"""
Execute rounds until a shutdown signal is received.

Each round consists of receiving tasks, executing them, and reporting results.
If any task fails, the round is aborted and the error is logged.
"""
while True:
tasks, round_num, sleep_time, time_to_quit = self.client.get_tasks()

if time_to_quit:
break

if not tasks:
sleep(sleep_time)
continue
try:
logger.info("Round: %d Received Tasks: %s", round_num, tasks)
self.callbacks.on_round_begin(round_num)
logs = self._execute_round_tasks(tasks, round_num)
self.tensor_db.clean_up(self.db_store_rounds)
self.callbacks.on_round_end(round_num, logs)
except Exception as round_error:
logger.error(
f"Error during round {round_num} execution. Error: {round_error}", exc_info=True
)
sleep(sleep_time or 10)

# Round begin
logger.info("Round: %d Received Tasks: %s", round_num, tasks)
self.callbacks.on_round_begin(round_num)
def _execute_round_tasks(self, tasks: list, round_number: int) -> dict:
"""
Execute all tasks in a round.

# Run tasks
logs = {}
for task in tasks:
metrics = self.do_task(task, round_num)
logs.update(metrics)
Args:
tasks: List of tasks to execute.
round_number: Current round number.

# Round end
self.tensor_db.clean_up(self.db_store_rounds)
self.callbacks.on_round_end(round_num, logs)
Returns:
Dictionary of logs/metrics from task execution.

# Experiment end
self.callbacks.on_experiment_end()
logger.info("Received shutdown signal. Exiting...")
Raises:
Exception: If any task execution fails, aborts the round.
"""
logs = {}
for task in tasks:
metrics = self.do_task(task, round_number)
logs.update(metrics)
return logs

def do_task(self, task, round_number) -> dict:
"""Perform the specified task.
Expand Down Expand Up @@ -270,97 +303,103 @@ def do_task(self, task, round_number) -> dict:

return metrics

def get_data_for_tensorkey(self, tensor_key):
"""Resolve the tensor corresponding to the requested tensorkey.
def get_data_for_tensorkey(self, tensor_key) -> object:
"""
Resolve and return the tensor for the requested TensorKey.

This function checks the local cache, previous rounds, and the aggregator as needed.

Args:
tensor_key (namedtuple): Tensorkey that will be resolved locally or
remotely. May be the product of other tensors.
tensor_key: The TensorKey to resolve.

Returns:
nparray: The decompressed tensor associated with the requested
tensor key.
The decompressed tensor associated with the requested tensor key.

Raises:
Exception: If the tensor cannot be retrieved or reconstructed.
"""
# try to get from the store
tensor_name, origin, round_number, report, tags = tensor_key
logger.debug("Attempting to retrieve tensor %s from local store", tensor_key)
nparray = self.tensor_db.get_tensor_from_cache(tensor_key)

# if None and origin is our client, request it from the client
if nparray is None:
if origin == self.collaborator_name:
logger.info(
f"Attempting to find locally stored {tensor_name} tensor from prior round..."
)
prior_round = round_number - 1
while prior_round >= 0:
nparray = self.tensor_db.get_tensor_from_cache(
TensorKey(tensor_name, origin, prior_round, report, tags)
try:
nparray = self.tensor_db.get_tensor_from_cache(tensor_key)
if nparray is None:
if origin == self.collaborator_name:
logger.info(
f"Attempting to find locally stored {tensor_name} "
f"tensor from prior round..."
)
if nparray is not None:
logger.debug(
f"Found tensor {tensor_name} in local TensorDB for round {prior_round}"
prior_round = round_number - 1
while prior_round >= 0:
nparray = self.tensor_db.get_tensor_from_cache(
TensorKey(tensor_name, origin, prior_round, report, tags)
)
return nparray
prior_round -= 1
logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...")
# Determine whether there are additional compression related
# dependencies.
# Typically, dependencies are only relevant to model layers
tensor_dependencies = self.tensor_codec.find_dependencies(
tensor_key, self.use_delta_updates
)
logger.debug(
"Unable to get tensor from local store..."
"attempting to retrieve from client len tensor_dependencies"
f" tensor_key {tensor_key}"
)
if len(tensor_dependencies) > 0:
# Resolve dependencies
# tensor_dependencies[0] corresponds to the prior version
# of the model.
# If it exists locally, should pull the remote delta because
# this is the least costly path
prior_model_layer = self.tensor_db.get_tensor_from_cache(tensor_dependencies[0])
if prior_model_layer is not None:
uncompressed_delta = self.get_aggregated_tensor_from_aggregator(
tensor_dependencies[1]
)
new_model_tk, nparray = self.tensor_codec.apply_delta(
tensor_dependencies[1],
uncompressed_delta,
prior_model_layer,
creates_model=True,
if nparray is not None:
logger.debug(
f"Found tensor {tensor_name} in local TensorDB "
f"for round {prior_round}"
)
return nparray
prior_round -= 1
logger.info(f"Cannot find any prior version of tensor {tensor_name} locally...")
# Determine whether there are additional compression related
# dependencies.
# Typically, dependencies are only relevant to model layers
tensor_dependencies = self.tensor_codec.find_dependencies(
tensor_key, self.use_delta_updates
)
logger.debug(
"Unable to get tensor from local store..."
"attempting to retrieve from client len tensor_dependencies"
f" tensor_key {tensor_key}"
)
if len(tensor_dependencies) > 0:
# Resolve dependencies
# tensor_dependencies[0] corresponds to the prior version
# of the model.
# If it exists locally, should pull the remote delta because
# this is the least costly path
prior_model_layer = self.tensor_db.get_tensor_from_cache(tensor_dependencies[0])
if prior_model_layer is not None:
uncompressed_delta = self.get_aggregated_tensor_from_aggregator(
tensor_dependencies[1]
)
new_model_tk, nparray = self.tensor_codec.apply_delta(
tensor_dependencies[1],
uncompressed_delta,
prior_model_layer,
creates_model=True,
)
self.tensor_db.cache_tensor({new_model_tk: nparray})
else:
logger.info(
"Could not find previous model layer. "
"Fetching latest layer from aggregator"
)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
elif "model" in tags:
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
self.tensor_db.cache_tensor({new_model_tk: nparray})
else:
tensor_name, origin, round_number, report, tags = tensor_key
tags = (self.collaborator_name,) + tags
tensor_key = (tensor_name, origin, round_number, report, tags)
logger.info(
"Could not find previous model layer.Fetching latest layer from aggregator"
"Could not find previous model layer."
f"Fetching latest layer from aggregator {tensor_key}"
)
# The original model tensor should be fetched from aggregator
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
elif "model" in tags:
# Pulling the model for the first time
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
# we should try fetching the tensor from aggregator
tensor_name, origin, round_number, report, tags = tensor_key
tags = (self.collaborator_name,) + tags
tensor_key = (tensor_name, origin, round_number, report, tags)
logger.info(
"Could not find previous model layer."
f"Fetching latest layer from aggregator {tensor_key}"
)
nparray = self.get_aggregated_tensor_from_aggregator(
tensor_key, require_lossless=True
)
else:
logger.debug("Found tensor %s in local TensorDB", tensor_key)

logger.debug("Found tensor %s in local TensorDB", tensor_key)
except Exception as get_tensor_error:
logger.error(
f"Error retrieving tensor {tensor_key}. Error: {get_tensor_error}", exc_info=True
)
raise
return nparray

def get_aggregated_tensor_from_aggregator(self, tensor_key, require_lossless=False):
Expand Down
36 changes: 22 additions & 14 deletions openfl/interface/collaborator.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,34 @@ def collaborator(context):
help="The certified common name of the collaborator.",
)
def start_(plan, collaborator_name, data_config):
"""Starts a collaborator service."""
"""
Starts a collaborator service.

Args:
plan: Path to the FL plan YAML file.
collaborator_name: The certified common name of the collaborator.
data_config: Path to the dataset shard configuration file.
"""
if plan and is_directory_traversal(plan):
echo("Federated learning plan path is out of the openfl workspace scope.")
sys.exit(1)
if data_config and is_directory_traversal(data_config):
echo("The data set/shard configuration file path is out of the openfl workspace scope.")
sys.exit(1)

plan_obj = Plan.parse(
plan_config_path=Path(plan).absolute(),
data_config_path=Path(data_config).absolute(),
)

# TODO: Need to restructure data loader config file loader
logger.info(f"Data paths: {plan_obj.cols_data_paths}")
echo(f"Data = {plan_obj.cols_data_paths}")
logger.info("🧿 Starting a Collaborator Service.")

collaborator = plan_obj.get_collaborator(collaborator_name)
collaborator.run()
try:
plan_obj = Plan.parse(
plan_config_path=Path(plan).absolute(),
data_config_path=Path(data_config).absolute(),
)
logger.info(f"Data paths: {plan_obj.cols_data_paths}")
echo(f"Data = {plan_obj.cols_data_paths}")
logger.info("🧿 Starting a Collaborator Service.")
collaborator = plan_obj.get_collaborator(collaborator_name)
collaborator.run()
except Exception as e:
logger.critical(f"Critical error starting or running collaborator: {e}", exc_info=True)
echo(style(f"Collaborator failed with error: {e}", fg="red"))
sys.exit(1)


@collaborator.command(name="ping")
Expand Down
Loading