Skip to content

Commit 612e4a1

Browse files
authored
Enable Simultanous WANDB and Tensorboard logging in torchtitan
Differential Revision: D82159896 Pull Request resolved: #1698
1 parent 71dea16 commit 612e4a1

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

torchtitan/components/metrics.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,25 @@ def close(self) -> None:
163163
if self.wandb.run is not None:
164164
self.wandb.finish()
165165

166+
class LoggerContainer(BaseLogger):
167+
"""Container to call all loggers enabled in the job config."""
168+
def __init__(self) -> None:
169+
self._loggers : list[BaseLogger] = []
170+
171+
def add_logger(self, logger_instance: BaseLogger) -> None:
172+
self._loggers.append(logger_instance)
173+
174+
def log(self, metrics: dict[str, Any], step: int) -> None:
175+
for logger_instance in self._loggers:
176+
logger_instance.log(metrics, step)
177+
178+
@property
179+
def number_of_loggers(self) -> int:
180+
return len(self._loggers)
181+
182+
def close(self) -> None:
183+
for logger_instance in self._loggers:
184+
logger_instance.close()
166185

167186
def ensure_pp_loss_visible(
168187
parallel_dims: ParallelDims, job_config: JobConfig, color: Color
@@ -274,11 +293,15 @@ def _build_metric_logger(
274293
base_log_dir, f"rank_{torch.distributed.get_rank()}"
275294
)
276295

296+
# Create logger container
297+
logger_container = LoggerContainer()
298+
277299
# Create loggers in priority order
278300
if metrics_config.enable_wandb:
279301
logger.debug("Attempting to create WandB logger")
280302
try:
281-
return WandBLogger(base_log_dir, job_config, tag)
303+
wandb_logger = WandBLogger(base_log_dir, job_config, tag)
304+
logger_container.add_logger(wandb_logger)
282305
except Exception as e:
283306
if "No module named 'wandb'" in str(e):
284307
logger.error(
@@ -289,10 +312,12 @@ def _build_metric_logger(
289312

290313
if metrics_config.enable_tensorboard:
291314
logger.debug("Creating TensorBoard logger")
292-
return TensorBoardLogger(base_log_dir, tag)
315+
tensorboard_logger = TensorBoardLogger(base_log_dir, tag)
316+
logger_container.add_logger(tensorboard_logger)
293317

294-
logger.debug("No loggers enabled, returning BaseLogger")
295-
return BaseLogger()
318+
if logger_container.number_of_loggers == 0:
319+
logger.debug("No loggers enabled, returning an emtpy LoggerContainer")
320+
return logger_container
296321

297322

298323
class MetricsProcessor:

0 commit comments

Comments
 (0)