Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions projects/book_chapter/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3"
Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008"
153 changes: 153 additions & 0 deletions projects/book_chapter/folds.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
# CC BY-SA 4.0
# =============================================================================
# EasyHybrid Example: Synthetic Data Analysis
# =============================================================================
# This example demonstrates how to use EasyHybrid to train a hybrid model
# on synthetic data for respiration modeling with Q10 temperature sensitivity.
# =============================================================================

# =============================================================================
# Project Setup and Environment
# =============================================================================
using Pkg

# Set project path and activate environment
project_path = "projects/book_chapter"
Pkg.activate(project_path)

# Check if manifest exists, create project if needed
manifest_path = joinpath(project_path, "Manifest.toml")
if !isfile(manifest_path)
package_path = pwd()
if !endswith(package_path, "EasyHybrid")
@error "You opened in the wrong directory. Please open the EasyHybrid folder, create a new project in the projects folder and provide the relative path to the project folder as project_path."
end
Pkg.develop(path=package_path)
Pkg.instantiate()
end

using EasyHybrid

# =============================================================================
# Data Loading and Preprocessing
# =============================================================================
# Load synthetic dataset from GitHub
ds = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc")

# Select a subset of data for faster execution
ds = ds[1:20000, :]

# =============================================================================
# Define the Physical Model
# =============================================================================
# RbQ10 model: Respiration model with Q10 temperature sensitivity
# Parameters:
# - ta: air temperature [°C]
# - Q10: temperature sensitivity factor [-]
# - rb: basal respiration rate [μmol/m²/s]
# - tref: reference temperature [°C] (default: 15.0)
function RbQ10(;ta, Q10, rb, tref = 15.0f0)
reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref))
return (; reco, Q10, rb)
end

# =============================================================================
# Define Model Parameters
# =============================================================================
# Parameter specification: (default, lower_bound, upper_bound)
parameters = (
# Parameter name | Default | Lower | Upper | Description
rb = ( 3.0f0, 0.0f0, 13.0f0 ), # Basal respiration [μmol/m²/s]
Q10 = ( 2.0f0, 1.0f0, 4.0f0 ), # Temperature sensitivity factor [-]
)

# =============================================================================
# Configure Hybrid Model Components
# =============================================================================
# Define input variables
forcing = [:ta] # Forcing variables (temperature)
predictors = [:sw_pot, :dsw_pot] # Predictor variables (solar radiation, and its derivative)

# Target variable
target = [:reco] # Target variable (respiration)

# Parameter classification
global_param_names = [:Q10] # Global parameters (same for all samples)
neural_param_names = [:rb] # Neural network predicted parameters

# =============================================================================
# Construct the Hybrid Model
# =============================================================================
# Create hybrid model using the unified constructor
hybrid_model = constructHybridModel(
predictors, # Input features
forcing, # Forcing variables
target, # Target variables
RbQ10, # Process-based model function
parameters, # Parameter definitions
neural_param_names, # NN-predicted parameters
global_param_names, # Global parameters
hidden_layers = [16, 16], # Neural network architecture
activation = sigmoid, # Activation function
scale_nn_outputs = true, # Scale neural network outputs
input_batchnorm = true # Apply batch normalization to inputs
)

# =============================================================================
# Model Training
# =============================================================================
using WGLMakie

using MLUtils

function make_folds(ds; k::Int=5, shuffle=true)
n = numobs(ds)
_, val_idx = MLUtils.kfolds(n, k)
folds = fill(0, n)
perm = shuffle ? randperm(n) : 1:n
for (f, idx) in enumerate(val_idx)
fidx = perm[idx]
folds[fidx] .= f
end
return folds
end

k = 3
folds = make_folds(ds, k=k, shuffle=true)

results = Vector{Any}(undef, k)

for val_fold in 1:k
@info "Training fold $val_fold of $k"
out = train(
hybrid_model,
ds,
();
nepochs = 100,
patience = 10,
batchsize = 512, # Batch size for training
opt = RMSProp(0.001), # Optimizer and learning rate
monitor_names = [:rb, :Q10],
folds = folds,
val_fold = val_fold
)
results[val_fold] = out
end



for val_fold in 1:k
@info "Split data outside of train function. Training fold $val_fold of $k"
sdata = split_data(ds, hybrid_model; val_fold = val_fold, folds = folds)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we can also do the split outside of train. I guess then we don't blow up train with more keyword arguments and can get rid of the additional ones I added for kfold. @lazarusA

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, please.

out = train(
hybrid_model,
sdata,
();
nepochs = 100,
patience = 10,
batchsize = 512, # Batch size for training
opt = RMSProp(0.001), # Optimizer and learning rate
monitor_names = [:rb, :Q10]
)
results[val_fold] = out
end
128 changes: 74 additions & 54 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ end
"""
train(hybridModel, data, save_ps; nepochs=200, batchsize=10, opt=Adam(0.01), patience=typemax(Int),
file_name=nothing, loss_types=[:mse, :r2], training_loss=:mse, agg=sum, train_from=nothing,
random_seed=161803, shuffleobs=false, yscale=log10, monitor_names=[], return_model=:best,
split_by_id=nothing, split_data_at=0.8, plotting=true, show_progress=true, hybrid_name=randstring(10))
random_seed=161803, yscale=log10, monitor_names=[], return_model=:best,
plotting=true, show_progress=true, hybrid_name=randstring(10), kwargs...)

Train a hybrid model using the provided data and save the training process to a file in JLD2 format.
Default output file is `trained_model.jld2` at the current working directory under `output_tmp`.
Expand All @@ -39,10 +39,12 @@ Default output file is `trained_model.jld2` at the current working directory und
- `loss_types`: A vector of loss types to compute during training (default: `[:mse, :r2]`).
- `agg`: The aggregation function to apply to the computed losses (default: `sum`).

## Data Handling:
## Data Handling (passed via kwargs):
- `shuffleobs`: Whether to shuffle the training data (default: false).
- `split_by_id`: Column name or function to split data by ID (default: nothing -> no ID-based splitting).
- `split_data_at`: Fraction of data to use for training when splitting (default: 0.8).
- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation (default: nothing).
- `val_fold`: The validation fold to use when `folds` is provided (default: nothing).

## Training State and Reproducibility:
- `train_from`: A tuple of physical parameters and state to start training from or an output of `train` (default: nothing -> new training).
Expand Down Expand Up @@ -72,10 +74,7 @@ function train(hybridModel, data, save_ps;
loss_types=[:mse, :r2],
agg=sum,

# Data handling
shuffleobs=false,
split_by_id=nothing,
split_data_at=0.8,
# Data handling parameters are now passed via kwargs...
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This would not blow up the length of train arguments further but even decrease it


# Training state and reproducibility
train_from=nothing,
Expand Down Expand Up @@ -109,9 +108,8 @@ function train(hybridModel, data, save_ps;
if !isnothing(random_seed)
Random.seed!(random_seed)
end

# ? split training and validation data
(x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; split_by_id=split_by_id, shuffleobs=shuffleobs, split_data_at=split_data_at)

(x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; kwargs...)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pass everything for data handling via kwargs...?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe, we need to make that works (I tried this already and somehow it was failing in some cases), hence some tests will be needed. Let's do these updates to split_data and train in a new PR, only dealing with that.


train_loader = DataLoader((x_train, y_train), batchsize=batchsize, shuffle=true);

Expand Down Expand Up @@ -382,49 +380,13 @@ function header_and_paddings(nt; digits=5)
return headers, paddings
end

function split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)

data_ = prepare_data(hybridModel, data)
# all the KeyedArray thing!

if !isnothing(split_by_id)
if isa(split_by_id, Symbol)
ids = getbyname(data, split_by_id)
unique_ids = unique(ids)
elseif isa(split_by_id, AbstractVector)
ids = split_by_id
unique_ids = unique(ids)
split_by_id = "split_by_id"
end

train_ids, val_ids = splitobs(unique_ids; at=split_data_at, shuffle=shuffleobs)

train_idx = findall(id -> id in train_ids, ids)
val_idx = findall(id -> id in val_ids, ids)

@info "Splitting data by $split_by_id"
@info "Number of unique $split_by_id's: $(length(unique_ids))"
@info "Number of $split_by_id's in training set: $(length(train_ids))"
@info "Number of $split_by_id's in validation set: $(length(val_ids))"

x_all, y_all = data_

x_train, y_train = x_all[:, train_idx], y_all[:, train_idx]
x_val, y_val = x_all[:, val_idx], y_all[:, val_idx]
else
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs)
end

return (x_train, y_train), (x_val, y_val)
end

function split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
function split_data(data::AbstractDimArray, hybridModel; shuffleobs=false, split_data_at=0.8, kwargs...)
data_ = prepare_data(hybridModel, data)
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs)
return (x_train, y_train), (x_val, y_val)
end

function split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
function split_data(data::Tuple, hybridModel; shuffleobs=false, split_data_at=0.8, kwargs...)
data_ = prepare_data(hybridModel, data)
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs)
return (x_train, y_train), (x_val, y_val)
Expand All @@ -434,25 +396,83 @@ function split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...)
return data
end

function split_data(
data::Union{DataFrame, KeyedArray},
hybridModel;
split_by_id::Union{Nothing,Symbol,AbstractVector}=nothing,
folds::Union{Nothing,AbstractVector,Symbol}=nothing,
val_fold::Union{Nothing,Int}=nothing,
shuffleobs::Bool=false,
split_data_at::Real=0.8,
kwargs...
)
data_ = prepare_data(hybridModel, data)

if split_by_id !== nothing
# --- Option A: split by ID ---
ids = isa(split_by_id, Symbol) ? getbyname(data, split_by_id) : split_by_id
unique_ids = unique(ids)
train_ids, val_ids = splitobs(unique_ids; at=split_data_at, shuffle=shuffleobs)
train_idx = findall(in(train_ids), ids)
val_idx = findall(in(val_ids), ids)

@info "Splitting data by $(split_by_id)"
@info "Number of unique $(split_by_id): $(length(unique_ids))"
@info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))"

x_all, y_all = data_
x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx)
x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx)
return (x_train, y_train), (x_val, y_val)

elseif folds !== nothing || val_fold !== nothing
# --- Option B: external K-fold assignment ---
@assert val_fold !== nothing "Provide val_fold when using folds."
@assert folds !== nothing "Provide folds when using val_fold."
x_all, y_all = data_
f = isa(folds, Symbol) ? getbyname(data, folds) : folds
n = size(x_all, 2)
@assert length(f) == n "length(folds) ($(length(f))) must equal number of samples/columns ($n)."
@assert 1 ≤ val_fold ≤ maximum(f) "val_fold=$val_fold is out of range 1:$(maximum(f))."

val_idx = findall(==(val_fold), f)
@assert !isempty(val_idx) "No samples assigned to validation fold $val_fold."
train_idx = setdiff(1:n, val_idx)

@info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))"

x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx)
x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx)
return (x_train, y_train), (x_val, y_val)

else
# --- Fallback: simple random/chronological split of prepared data ---
(x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs)
return (x_train, y_train), (x_val, y_val)
end
end


"""
split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8)
split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...)
split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, folds=nothing, val_fold=nothing, kwargs...)
split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...)
split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...)
split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...)

Split data into training and validation sets, either randomly or by grouping by ID.
Split data into training and validation sets, either randomly, by grouping by ID, or using external fold assignments.

# Arguments:
- `data`: The data to split, which can be a DataFrame, KeyedArray, AbstractDimArray, or Tuple
- `hybridModel`: The hybrid model object used for data preparation
- `split_by_id=nothing`: Either `nothing` for random splitting, a `Symbol` for column-based splitting, or an `AbstractVector` for custom ID-based splitting
- `shuffleobs=false`: Whether to shuffle observations during splitting
- `split_data_at=0.8`: Ratio of data to use for training
- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation
- `val_fold`: The validation fold to use when `folds` is provided

# Behavior:
- For DataFrame/KeyedArray: Supports both random and ID-based splitting with logging
- For DataFrame/KeyedArray: Supports random splitting, ID-based splitting, and external fold assignments
- For AbstractDimArray/Tuple: Random splitting only after data preparation
- For pre-split Tuple{Tuple, Tuple}: Returns input unchanged

Expand Down
Loading