@@ -163,6 +163,25 @@ def close(self) -> None:
163
163
if self .wandb .run is not None :
164
164
self .wandb .finish ()
165
165
166
+ class LoggerContainer (BaseLogger ):
167
+ """Container to call all loggers enabled in the job config."""
168
+ def __init__ (self ) -> None :
169
+ self ._loggers : list [BaseLogger ] = []
170
+
171
+ def add_logger (self , logger_instance : BaseLogger ) -> None :
172
+ self ._loggers .append (logger_instance )
173
+
174
+ def log (self , metrics : dict [str , Any ], step : int ) -> None :
175
+ for logger_instance in self ._loggers :
176
+ logger_instance .log (metrics , step )
177
+
178
+ @property
179
+ def number_of_loggers (self ) -> int :
180
+ return len (self ._loggers )
181
+
182
+ def close (self ) -> None :
183
+ for logger_instance in self ._loggers :
184
+ logger_instance .close ()
166
185
167
186
def ensure_pp_loss_visible (
168
187
parallel_dims : ParallelDims , job_config : JobConfig , color : Color
@@ -274,11 +293,15 @@ def _build_metric_logger(
274
293
base_log_dir , f"rank_{ torch .distributed .get_rank ()} "
275
294
)
276
295
296
+ # Create logger container
297
+ logger_container = LoggerContainer ()
298
+
277
299
# Create loggers in priority order
278
300
if metrics_config .enable_wandb :
279
301
logger .debug ("Attempting to create WandB logger" )
280
302
try :
281
- return WandBLogger (base_log_dir , job_config , tag )
303
+ wandb_logger = WandBLogger (base_log_dir , job_config , tag )
304
+ logger_container .add_logger (wandb_logger )
282
305
except Exception as e :
283
306
if "No module named 'wandb'" in str (e ):
284
307
logger .error (
@@ -289,10 +312,12 @@ def _build_metric_logger(
289
312
290
313
if metrics_config .enable_tensorboard :
291
314
logger .debug ("Creating TensorBoard logger" )
292
- return TensorBoardLogger (base_log_dir , tag )
315
+ tensorboard_logger = TensorBoardLogger (base_log_dir , tag )
316
+ logger_container .add_logger (tensorboard_logger )
293
317
294
- logger .debug ("No loggers enabled, returning BaseLogger" )
295
- return BaseLogger ()
318
+ if logger_container .number_of_loggers == 0 :
319
+ logger .debug ("No loggers enabled, returning an emtpy LoggerContainer" )
320
+ return logger_container
296
321
297
322
298
323
class MetricsProcessor :
0 commit comments