Skip to content

Commit 5e7c808

Browse files
authored
Merge pull request #43 from jdb78/test/hyperparameter_optimization
Test and docs for hyperparameter optimization
2 parents a3ee1e6 + ee2826e commit 5e7c808

File tree

8 files changed

+280
-169
lines changed

8 files changed

+280
-169
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Pytorch Forecasting aims to ease timeseries forecasting with neural networks for
1010
for real-world deployment and come with in-built interpretation capabilities
1111
- Multi-horizon timeseries metrics
1212
- Ranger optimizer for faster model training
13+
- Hyperparameter tuning with [optuna](https://optuna.readthedocs.io/)
1314

1415
The package is built on [pytorch-lightning])(https://pytorch-lightning.readthedocs.io/) to allow training on CPUs, single and multiple GPUs out-of-the-box.
1516

@@ -28,7 +29,7 @@ Visit the documentation at [https://pytorch-forecasting.readthedocs.io](https://
2829
# Available models
2930

3031
- [Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting](https://arxiv.org/pdf/1912.09363.pdf)
31-
- [N-Beats](http://arxiv.org/abs/1905.10437)
32+
- [N-BEATS: Neural basis expansion analysis for interpretable time series forecasting](http://arxiv.org/abs/1905.10437)
3233

3334
# Usage
3435

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ research alike. Specifically, the package provides
1717
for real-world deployment and come with in-built interpretation capabilities
1818
* Multi-horizon timeseries metrics
1919
* Ranger optimizer for faster model training
20+
* Hyperparameter tuning with `optuna <https://optuna.readthedocs.io/>`_
2021

2122
The package is built on `PyTorch Lightning <https://pytorch-lightning.readthedocs.io/>`_ to allow
2223
training on CPUs, single and multiple GPUs out-of-the-box.

docs/source/models.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@ Pytorch Forecasting provides a ``.from_dataset()`` method for each model that
1111
takes a :py:class:`~data.timeseries.TimeSeriesDataSet` and additional parameters
1212
that cannot directy derived from the dataset such as, e.g. ``learning_rate`` or ``hidden_size``.
1313

14+
To tune models, `optuna <https://optuna.readthedocs.io/>`_ can be used. For example, tuning of the :py:class:`~models.temporal_fusion_transformer.TemporalFusionTransformer`
15+
is implemented by :py:func:`~models.temporal_fusion_transformer.tuning.optimize_hyperparameters`
16+
1417
Details
1518
--------
1619

poetry.lock

Lines changed: 209 additions & 164 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ homepage = "https://pytorch-forecasting.readthedocs.io"
4040
[tool.poetry.dependencies]
4141
python= "^3.6.1"
4242

43-
torch = "^1.6"
43+
torch = "^1.4"
4444
pytorch-lightning = "^0.9.0"
4545
optuna = "^2.0.0"
4646
scipy = "*"

pytorch_forecasting/models/nbeats/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def __init__(
3737
"""
3838
Initialize NBeats Model - use its :py:meth:`~from_dataset` method if possible.
3939
40+
Based on the article
41+
`N-BEATS: Neural basis expansion analysis for interpretable time series
42+
forecasting <http://arxiv.org/abs/1905.10437>`_.
43+
4044
Args:
4145
stack_types: One of the following values: “generic”, “seasonality" or “trend". A list of strings
4246
of length 1 or ‘num_stacks’. Default and recommended value

pytorch_forecasting/models/temporal_fusion_transformer/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,25 @@ def __init__(
6060
"""
6161
Temporal Fusion Transformer for forecasting timeseries - use its :py:meth:`~from_dataset` method if possible.
6262
63+
Implementation of the article
64+
`Temporal Fusion Transformers for Interpretable Multi-horizon Time Series
65+
Forecasting <https://arxiv.org/pdf/1912.09363.pdf>`_.
66+
67+
Enhancements compared to the original implementation (apart from capabilities added through base model
68+
such as monotone constraints):
69+
70+
* static variables can be continuous
71+
* multiple categorical variables can be summarized with an EmbeddingBag
72+
* variable encoder and decoder length by sample
73+
* categorical embeddings are not transformed by variable selection network (because it is a redundant operation)
74+
* variable dimension in variable selection network are scaled up via linear interpolation to reduce
75+
number of parameters
76+
* non-linear variable processing in variable selection network can be shared among decoder and encoder
77+
(not shared by default)
78+
79+
Tune its hyperparameters with
80+
:py:func:`~pytorch_forecasting.models.temporal_fusion_transformer.tuning.optimize_hyperparameters`.
81+
6382
Args:
6483
6584
hidden_size: hidden size of network which is its main hyperparameter and can range from 8 to 512

pytorch_forecasting/models/temporal_fusion_transformer/tuning.py

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,46 @@ def optimize_hyperparameters(
4242
hidden_continuous_size_range: Tuple[int, int] = (8, 64),
4343
attention_head_size_range: Tuple[int, int] = (1, 4),
4444
dropout_range: Tuple[float, float] = (0.1, 0.3),
45+
learning_rate_range: Tuple[float, float] = (1e-5, 1.0),
46+
use_learning_rate_finder: bool = True,
4547
trainer_kwargs: Dict[str, Any] = {},
4648
log_dir: str = "lightning_logs",
4749
**kwargs,
4850
) -> optuna.Study:
51+
"""
52+
Optimize Temporal Fusion Transformer hyperparameters.
53+
54+
Run hyperparameter optimization. Learning rate for is determined with
55+
the PyTorch Lightning learning rate finder.
56+
57+
Args:
58+
train_dataloader (DataLoader): dataloader for training model
59+
val_dataloader (DataLoader): dataloader for validating model
60+
model_path (str): folder to which model checkpoints are saved
61+
max_epochs (int, optional): Maximum number of epochs to run training. Defaults to 20.
62+
n_trials (int, optional): Number of hyperparameter trials to run. Defaults to 100.
63+
timeout (float, optional): Time in seconds after which training is stopped regardless of number of epochs
64+
or validation metric. Defaults to 3600*8.0.
65+
hidden_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_size`` hyperparameter. Defaults
66+
to (16, 265).
67+
hidden_continuous_size_range (Tuple[int, int], optional): Minimum and maximum of ``hidden_continuous_size``
68+
hyperparameter. Defaults to (8, 64).
69+
attention_head_size_range (Tuple[int, int], optional): Minimum and maximum of ``attention_head_size``
70+
hyperparameter. Defaults to (1, 4).
71+
dropout_range (Tuple[float, float], optional): Minimum and maximum of ``dropout`` hyperparameter. Defaults to
72+
(0.1, 0.3).
73+
learning_rate_range (Tuple[float, float], optional): Learning rate range. Defaults to (1e-5, 1.0).
74+
use_learning_rate_finder (bool): If to use learning rate finder or optimize as part of hyperparameters.
75+
Defaults to True.
76+
trainer_kwargs (Dict[str, Any], optional): Additional arguments to the
77+
`PyTorch Lightning trainer <https://pytorch-lightning.readthedocs.io/en/latest/trainer.html>`_ such
78+
as ``limit_train_batches``. Defaults to {}.
79+
log_dir (str, optional): Folder into which to log results for tensorboard. Defaults to "lightning_logs".
80+
**kwargs: Additional arguments for the :py:class:`~TemporalFusionTransformer`.
81+
82+
Returns:
83+
optuna.Study: optuna study results
84+
"""
4985
assert isinstance(train_dataloader.dataset, TimeSeriesDataSet) and isinstance(
5086
val_dataloader.dataset, TimeSeriesDataSet
5187
), "dataloaders must be built from timeseriesdataset"
@@ -92,7 +128,7 @@ def objective(trial: optuna.Trial) -> float:
92128
**kwargs,
93129
)
94130
# find good learning rate
95-
if "learning_rate" not in kwargs or isinstance(kwargs["learning_rate"], (tuple, list)):
131+
if use_learning_rate_finder:
96132
lr_trainer = pl.Trainer(
97133
gradient_clip_val=gradient_clip_val,
98134
gpus=[0] if torch.cuda.is_available() else None,
@@ -103,9 +139,9 @@ def objective(trial: optuna.Trial) -> float:
103139
train_dataloader=train_dataloader,
104140
val_dataloaders=val_dataloader,
105141
early_stop_threshold=10000.0,
106-
min_lr=kwargs.get("learning_rate", [1e-5, 1.0])[0],
142+
min_lr=learning_rate_range[0],
107143
num_training=100,
108-
max_lr=kwargs.get("learning_rate", [1e-5, 1.0])[1],
144+
max_lr=learning_rate_range[1],
109145
)
110146

111147
loss_finite = np.isfinite(res.results["loss"])
@@ -118,6 +154,8 @@ def objective(trial: optuna.Trial) -> float:
118154
optimal_lr = lr_smoothed[optimal_idx]
119155
print(f"Using learning rate of {optimal_lr:.3g}")
120156
model.hparams.learning_rate = optimal_lr
157+
else:
158+
model.hparams.learning_rate = trial.suggest_loguniform("learning_rate_range", *learning_rate_range)
121159

122160
# fit
123161
trainer.fit(model, train_dataloader=train_dataloader, val_dataloaders=val_dataloader)

0 commit comments

Comments
 (0)