-
Notifications
You must be signed in to change notification settings - Fork 3
[wip] starting point for folds #153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,7 @@ | ||
[deps] | ||
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" | ||
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" | ||
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" | ||
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" | ||
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" | ||
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
# CC BY-SA 4.0 | ||
# ============================================================================= | ||
# EasyHybrid Example: Synthetic Data Analysis | ||
# ============================================================================= | ||
# This example demonstrates how to use EasyHybrid to train a hybrid model | ||
# on synthetic data for respiration modeling with Q10 temperature sensitivity. | ||
# ============================================================================= | ||
|
||
# ============================================================================= | ||
# Project Setup and Environment | ||
# ============================================================================= | ||
using Pkg | ||
|
||
# Set project path and activate environment | ||
project_path = "projects/book_chapter" | ||
Pkg.activate(project_path) | ||
|
||
# Check if manifest exists, create project if needed | ||
manifest_path = joinpath(project_path, "Manifest.toml") | ||
if !isfile(manifest_path) | ||
package_path = pwd() | ||
if !endswith(package_path, "EasyHybrid") | ||
@error "You opened in the wrong directory. Please open the EasyHybrid folder, create a new project in the projects folder and provide the relative path to the project folder as project_path." | ||
end | ||
Pkg.develop(path=package_path) | ||
Pkg.instantiate() | ||
end | ||
|
||
using EasyHybrid | ||
|
||
# ============================================================================= | ||
# Data Loading and Preprocessing | ||
# ============================================================================= | ||
# Load synthetic dataset from GitHub | ||
ds = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc") | ||
|
||
# Select a subset of data for faster execution | ||
ds = ds[1:20000, :] | ||
|
||
# ============================================================================= | ||
# Define the Physical Model | ||
# ============================================================================= | ||
# RbQ10 model: Respiration model with Q10 temperature sensitivity | ||
# Parameters: | ||
# - ta: air temperature [°C] | ||
# - Q10: temperature sensitivity factor [-] | ||
# - rb: basal respiration rate [μmol/m²/s] | ||
# - tref: reference temperature [°C] (default: 15.0) | ||
function RbQ10(;ta, Q10, rb, tref = 15.0f0) | ||
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref)) | ||
return (; reco, Q10, rb) | ||
end | ||
|
||
# ============================================================================= | ||
# Define Model Parameters | ||
# ============================================================================= | ||
# Parameter specification: (default, lower_bound, upper_bound) | ||
parameters = ( | ||
# Parameter name | Default | Lower | Upper | Description | ||
rb = ( 3.0f0, 0.0f0, 13.0f0 ), # Basal respiration [μmol/m²/s] | ||
Q10 = ( 2.0f0, 1.0f0, 4.0f0 ), # Temperature sensitivity factor [-] | ||
) | ||
|
||
# ============================================================================= | ||
# Configure Hybrid Model Components | ||
# ============================================================================= | ||
# Define input variables | ||
forcing = [:ta] # Forcing variables (temperature) | ||
predictors = [:sw_pot, :dsw_pot] # Predictor variables (solar radiation, and its derivative) | ||
|
||
# Target variable | ||
target = [:reco] # Target variable (respiration) | ||
|
||
# Parameter classification | ||
global_param_names = [:Q10] # Global parameters (same for all samples) | ||
neural_param_names = [:rb] # Neural network predicted parameters | ||
|
||
# ============================================================================= | ||
# Construct the Hybrid Model | ||
# ============================================================================= | ||
# Create hybrid model using the unified constructor | ||
hybrid_model = constructHybridModel( | ||
predictors, # Input features | ||
forcing, # Forcing variables | ||
target, # Target variables | ||
RbQ10, # Process-based model function | ||
parameters, # Parameter definitions | ||
neural_param_names, # NN-predicted parameters | ||
global_param_names, # Global parameters | ||
hidden_layers = [16, 16], # Neural network architecture | ||
activation = sigmoid, # Activation function | ||
scale_nn_outputs = true, # Scale neural network outputs | ||
input_batchnorm = true # Apply batch normalization to inputs | ||
) | ||
|
||
# ============================================================================= | ||
# Model Training | ||
# ============================================================================= | ||
using WGLMakie | ||
|
||
using MLUtils | ||
|
||
function make_folds(ds; k::Int=5, shuffle=true) | ||
n = numobs(ds) | ||
_, val_idx = MLUtils.kfolds(n, k) | ||
folds = fill(0, n) | ||
perm = shuffle ? randperm(n) : 1:n | ||
for (f, idx) in enumerate(val_idx) | ||
fidx = perm[idx] | ||
folds[fidx] .= f | ||
end | ||
return folds | ||
end | ||
|
||
k = 3 | ||
folds = make_folds(ds, k=k, shuffle=true) | ||
|
||
results = Vector{Any}(undef, k) | ||
|
||
for val_fold in 1:k | ||
@info "Training fold $val_fold of $k" | ||
out = train( | ||
hybrid_model, | ||
ds, | ||
(); | ||
nepochs = 100, | ||
patience = 10, | ||
batchsize = 512, # Batch size for training | ||
opt = RMSProp(0.001), # Optimizer and learning rate | ||
monitor_names = [:rb, :Q10], | ||
folds = folds, | ||
val_fold = val_fold | ||
) | ||
results[val_fold] = out | ||
end | ||
|
||
|
||
|
||
for val_fold in 1:k | ||
@info "Split data outside of train function. Training fold $val_fold of $k" | ||
sdata = split_data(ds, hybrid_model; val_fold = val_fold, folds = folds) | ||
out = train( | ||
hybrid_model, | ||
sdata, | ||
(); | ||
nepochs = 100, | ||
patience = 10, | ||
batchsize = 512, # Batch size for training | ||
opt = RMSProp(0.001), # Optimizer and learning rate | ||
monitor_names = [:rb, :Q10] | ||
) | ||
results[val_fold] = out | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,8 +17,8 @@ end | |
""" | ||
train(hybridModel, data, save_ps; nepochs=200, batchsize=10, opt=Adam(0.01), patience=typemax(Int), | ||
file_name=nothing, loss_types=[:mse, :r2], training_loss=:mse, agg=sum, train_from=nothing, | ||
random_seed=161803, shuffleobs=false, yscale=log10, monitor_names=[], return_model=:best, | ||
split_by_id=nothing, split_data_at=0.8, plotting=true, show_progress=true, hybrid_name=randstring(10)) | ||
random_seed=161803, yscale=log10, monitor_names=[], return_model=:best, | ||
plotting=true, show_progress=true, hybrid_name=randstring(10), kwargs...) | ||
|
||
Train a hybrid model using the provided data and save the training process to a file in JLD2 format. | ||
Default output file is `trained_model.jld2` at the current working directory under `output_tmp`. | ||
|
@@ -39,10 +39,12 @@ Default output file is `trained_model.jld2` at the current working directory und | |
- `loss_types`: A vector of loss types to compute during training (default: `[:mse, :r2]`). | ||
- `agg`: The aggregation function to apply to the computed losses (default: `sum`). | ||
|
||
## Data Handling: | ||
## Data Handling (passed via kwargs): | ||
- `shuffleobs`: Whether to shuffle the training data (default: false). | ||
- `split_by_id`: Column name or function to split data by ID (default: nothing -> no ID-based splitting). | ||
- `split_data_at`: Fraction of data to use for training when splitting (default: 0.8). | ||
- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation (default: nothing). | ||
- `val_fold`: The validation fold to use when `folds` is provided (default: nothing). | ||
|
||
## Training State and Reproducibility: | ||
- `train_from`: A tuple of physical parameters and state to start training from or an output of `train` (default: nothing -> new training). | ||
|
@@ -72,10 +74,7 @@ function train(hybridModel, data, save_ps; | |
loss_types=[:mse, :r2], | ||
agg=sum, | ||
|
||
# Data handling | ||
shuffleobs=false, | ||
split_by_id=nothing, | ||
split_data_at=0.8, | ||
# Data handling parameters are now passed via kwargs... | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This would not blow up the length of train arguments further but even decrease it |
||
|
||
# Training state and reproducibility | ||
train_from=nothing, | ||
|
@@ -109,9 +108,8 @@ function train(hybridModel, data, save_ps; | |
if !isnothing(random_seed) | ||
Random.seed!(random_seed) | ||
end | ||
|
||
# ? split training and validation data | ||
(x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; split_by_id=split_by_id, shuffleobs=shuffleobs, split_data_at=split_data_at) | ||
|
||
(x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; kwargs...) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. pass everything for data handling via kwargs...? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe, we need to make that works (I tried this already and somehow it was failing in some cases), hence some tests will be needed. Let's do these updates to |
||
|
||
train_loader = DataLoader((x_train, y_train), batchsize=batchsize, shuffle=true); | ||
|
||
|
@@ -382,49 +380,13 @@ function header_and_paddings(nt; digits=5) | |
return headers, paddings | ||
end | ||
|
||
function split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
|
||
data_ = prepare_data(hybridModel, data) | ||
# all the KeyedArray thing! | ||
|
||
if !isnothing(split_by_id) | ||
if isa(split_by_id, Symbol) | ||
ids = getbyname(data, split_by_id) | ||
unique_ids = unique(ids) | ||
elseif isa(split_by_id, AbstractVector) | ||
ids = split_by_id | ||
unique_ids = unique(ids) | ||
split_by_id = "split_by_id" | ||
end | ||
|
||
train_ids, val_ids = splitobs(unique_ids; at=split_data_at, shuffle=shuffleobs) | ||
|
||
train_idx = findall(id -> id in train_ids, ids) | ||
val_idx = findall(id -> id in val_ids, ids) | ||
|
||
@info "Splitting data by $split_by_id" | ||
@info "Number of unique $split_by_id's: $(length(unique_ids))" | ||
@info "Number of $split_by_id's in training set: $(length(train_ids))" | ||
@info "Number of $split_by_id's in validation set: $(length(val_ids))" | ||
|
||
x_all, y_all = data_ | ||
|
||
x_train, y_train = x_all[:, train_idx], y_all[:, train_idx] | ||
x_val, y_val = x_all[:, val_idx], y_all[:, val_idx] | ||
else | ||
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) | ||
end | ||
|
||
return (x_train, y_train), (x_val, y_val) | ||
end | ||
|
||
function split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
function split_data(data::AbstractDimArray, hybridModel; shuffleobs=false, split_data_at=0.8, kwargs...) | ||
data_ = prepare_data(hybridModel, data) | ||
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) | ||
return (x_train, y_train), (x_val, y_val) | ||
end | ||
|
||
function split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
function split_data(data::Tuple, hybridModel; shuffleobs=false, split_data_at=0.8, kwargs...) | ||
data_ = prepare_data(hybridModel, data) | ||
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) | ||
return (x_train, y_train), (x_val, y_val) | ||
|
@@ -434,25 +396,83 @@ function split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...) | |
return data | ||
end | ||
|
||
function split_data( | ||
data::Union{DataFrame, KeyedArray}, | ||
hybridModel; | ||
split_by_id::Union{Nothing,Symbol,AbstractVector}=nothing, | ||
folds::Union{Nothing,AbstractVector,Symbol}=nothing, | ||
val_fold::Union{Nothing,Int}=nothing, | ||
shuffleobs::Bool=false, | ||
split_data_at::Real=0.8, | ||
kwargs... | ||
) | ||
data_ = prepare_data(hybridModel, data) | ||
|
||
if split_by_id !== nothing | ||
# --- Option A: split by ID --- | ||
ids = isa(split_by_id, Symbol) ? getbyname(data, split_by_id) : split_by_id | ||
unique_ids = unique(ids) | ||
train_ids, val_ids = splitobs(unique_ids; at=split_data_at, shuffle=shuffleobs) | ||
train_idx = findall(in(train_ids), ids) | ||
val_idx = findall(in(val_ids), ids) | ||
|
||
@info "Splitting data by $(split_by_id)" | ||
@info "Number of unique $(split_by_id): $(length(unique_ids))" | ||
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))" | ||
|
||
x_all, y_all = data_ | ||
x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) | ||
x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) | ||
return (x_train, y_train), (x_val, y_val) | ||
|
||
elseif folds !== nothing || val_fold !== nothing | ||
# --- Option B: external K-fold assignment --- | ||
@assert val_fold !== nothing "Provide val_fold when using folds." | ||
@assert folds !== nothing "Provide folds when using val_fold." | ||
x_all, y_all = data_ | ||
f = isa(folds, Symbol) ? getbyname(data, folds) : folds | ||
n = size(x_all, 2) | ||
@assert length(f) == n "length(folds) ($(length(f))) must equal number of samples/columns ($n)." | ||
@assert 1 ≤ val_fold ≤ maximum(f) "val_fold=$val_fold is out of range 1:$(maximum(f))." | ||
|
||
val_idx = findall(==(val_fold), f) | ||
@assert !isempty(val_idx) "No samples assigned to validation fold $val_fold." | ||
train_idx = setdiff(1:n, val_idx) | ||
|
||
@info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))" | ||
|
||
x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) | ||
x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) | ||
return (x_train, y_train), (x_val, y_val) | ||
|
||
else | ||
# --- Fallback: simple random/chronological split of prepared data --- | ||
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) | ||
return (x_train, y_train), (x_val, y_val) | ||
end | ||
end | ||
|
||
|
||
""" | ||
split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) | ||
split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) | ||
split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, folds=nothing, val_fold=nothing, kwargs...) | ||
split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) | ||
split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) | ||
split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...) | ||
|
||
Split data into training and validation sets, either randomly or by grouping by ID. | ||
Split data into training and validation sets, either randomly, by grouping by ID, or using external fold assignments. | ||
|
||
# Arguments: | ||
- `data`: The data to split, which can be a DataFrame, KeyedArray, AbstractDimArray, or Tuple | ||
- `hybridModel`: The hybrid model object used for data preparation | ||
- `split_by_id=nothing`: Either `nothing` for random splitting, a `Symbol` for column-based splitting, or an `AbstractVector` for custom ID-based splitting | ||
- `shuffleobs=false`: Whether to shuffle observations during splitting | ||
- `split_data_at=0.8`: Ratio of data to use for training | ||
- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation | ||
- `val_fold`: The validation fold to use when `folds` is provided | ||
|
||
# Behavior: | ||
- For DataFrame/KeyedArray: Supports both random and ID-based splitting with logging | ||
- For DataFrame/KeyedArray: Supports random splitting, ID-based splitting, and external fold assignments | ||
- For AbstractDimArray/Tuple: Random splitting only after data preparation | ||
- For pre-split Tuple{Tuple, Tuple}: Returns input unchanged | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we can also do the split outside of train. I guess then we don't blow up train with more keyword arguments and can get rid of the additional ones I added for kfold. @lazarusA
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, please.