Skip to content
52 changes: 47 additions & 5 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
has_functorch = True
except ImportError as e:
has_functorch = False

#test tensorboard install
try:
from torch.utils.tensorboard import SummaryWriter
has_tensorboard = True
except ImportError as e:
has_tensorboard = False
has_compile = hasattr(torch, 'compile')


Expand Down Expand Up @@ -347,8 +352,8 @@
help='use the multi-epochs-loader to save time at the beginning of every epoch')
group.add_argument('--log-wandb', action='store_true', default=False,
help='log training and validation metrics to wandb')


group.add_argument('--log-tensorboard', default='', type=str, metavar='PATH',
help='log training and validation metrics to TensorBoard')
def _parse_args():
# Do we have a config file to parse?
args_config, remaining = config_parser.parse_known_args()
Expand Down Expand Up @@ -725,6 +730,18 @@ def main():
_logger.warning(
"You've requested to log metrics to wandb but package not found. "
"Metrics not being logged to wandb, try `pip install wandb`")
tensorboard_writer = None
if should_log_to_tensorboard(args):
if has_tensorboard:
tensorboard_writer = SummaryWriter(args.log_tensorboard)


else:
_logger.warning(
"You've requested to log metrics to tensorboard but package not found. "
"Metrics not being logged to tensorboard, try `pip install tensorboard`")



# setup learning rate schedule and starting epoch
updates_per_epoch = len(loader_train)
Expand Down Expand Up @@ -770,6 +787,7 @@ def main():
loss_scaler=loss_scaler,
model_ema=model_ema,
mixup_fn=mixup_fn,
tensorboard_writer=tensorboard_writer,
)

if args.distributed and args.dist_bn in ('broadcast', 'reduce'):
Expand All @@ -783,6 +801,8 @@ def main():
validate_loss_fn,
args,
amp_autocast=amp_autocast,
tensorboard_writer=tensorboard_writer,
epoch=epoch,
)

if model_ema is not None and not args.model_ema_force_cpu:
Expand All @@ -796,6 +816,8 @@ def main():
args,
amp_autocast=amp_autocast,
log_suffix=' (EMA)',
tensorboard_writer=tensorboard_writer,
epoch=epoch,
)
eval_metrics = ema_eval_metrics

Expand All @@ -809,6 +831,7 @@ def main():
lr=sum(lrs) / len(lrs),
write_header=best_metric is None,
log_wandb=args.log_wandb and has_wandb,

)

if saver is not None:
Expand All @@ -825,6 +848,8 @@ def main():

if best_metric is not None:
_logger.info('*** Best metric: {0} (epoch {1})'.format(best_metric, best_epoch))
if should_log_to_tensorboard(args) and tensorboard_writer is not None:
tensorboard_writer.close()


def train_one_epoch(
Expand All @@ -841,7 +866,8 @@ def train_one_epoch(
amp_autocast=suppress,
loss_scaler=None,
model_ema=None,
mixup_fn=None
mixup_fn=None,
tensorboard_writer=None,
):
if args.mixup_off_epoch and epoch >= args.mixup_off_epoch:
if args.prefetcher and loader.mixup_enabled:
Expand Down Expand Up @@ -903,6 +929,10 @@ def train_one_epoch(

num_updates += 1
batch_time_m.update(time.time() - end)
#write to tensorboard if enabled
if should_log_to_tensorboard(args):
tensorboard_writer.add_scalar('train/loss', losses_m.val, num_updates)
tensorboard_writer.add_scalar('train/lr', optimizer.param_groups[0]['lr'], num_updates)
if last_batch or batch_idx % args.log_interval == 0:
lrl = [param_group['lr'] for param_group in optimizer.param_groups]
lr = sum(lrl) / len(lrl)
Expand Down Expand Up @@ -954,14 +984,21 @@ def train_one_epoch(
return OrderedDict([('loss', losses_m.avg)])


def should_log_to_tensorboard(args):
return args.log_tensorboard and utils.is_primary(args) and has_tensorboard


def validate(
model,
loader,
loss_fn,
args,
device=torch.device('cuda'),
amp_autocast=suppress,
log_suffix=''
log_suffix='',
tensorboard_writer=None,
epoch=None,

):
batch_time_m = utils.AverageMeter()
losses_m = utils.AverageMeter()
Expand Down Expand Up @@ -1011,6 +1048,11 @@ def validate(

batch_time_m.update(time.time() - end)
end = time.time()
if should_log_to_tensorboard(args) and epoch is not None:
#by the updates
tensorboard_writer.add_scalar('val/loss', losses_m.val, epoch*last_idx+batch_idx)
tensorboard_writer.add_scalar('val/acc1', top1_m.val, epoch*last_idx+batch_idx)
tensorboard_writer.add_scalar('val/acc5', top5_m.val, epoch*last_idx+batch_idx)
if utils.is_primary(args) and (last_batch or batch_idx % args.log_interval == 0):
log_name = 'Test' + log_suffix
_logger.info(
Expand Down