diff --git a/.gitignore b/.gitignore index f74ec22..e760fcb 100644 --- a/.gitignore +++ b/.gitignore @@ -25,3 +25,4 @@ Manifest.toml !Bernhard_Ahrens.png *.err *.out +/docs/src/tutorials/folds.md \ No newline at end of file diff --git a/docs/Project.toml b/docs/Project.toml index b474eb1..52df290 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -5,6 +5,8 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" DocumenterVitepress = "4710194d-e776-4893-9690-8d956a29c365" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" +Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306" +OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" [sources] EasyHybrid = {path = ".."} diff --git a/docs/literate/tutorials/folds.jl b/docs/literate/tutorials/folds.jl new file mode 100644 index 0000000..b05fbd2 --- /dev/null +++ b/docs/literate/tutorials/folds.jl @@ -0,0 +1,104 @@ +# # Cross-Validation in EasyHybrid.jl +# +# This tutorial demonstrates one option for cross-validation in EasyHybrid. +# The code for this tutorial can be found in [docs/src/literate/tutorials](https://github.com/EarthyScience/EasyHybrid.jl/tree/main/docs/src/literate/tutorials/) => folds.jl. +# +# ## 1. Load Packages + +using EasyHybrid +using OhMyThreads +using CairoMakie + +# ## 2. Data Loading and Preprocessing + +# Load synthetic dataset from GitHub +df = load_timeseries_netcdf("https://github.com/bask0/q10hybrid/raw/master/data/Synthetic4BookChap.nc"); + +# Select a subset of data for faster execution +df = df[1:20000, :]; +first(df, 5) + +# ## 3. Define the Physical Model + +""" + RbQ10(; ta, Q10, rb, tref=15.0f0) + +Respiration model with Q10 temperature sensitivity. + +- `ta`: air temperature [°C] +- `Q10`: temperature sensitivity factor [-] +- `rb`: basal respiration rate [μmol/m²/s] +- `tref`: reference temperature [°C] (default: 15.0) +""" +function RbQ10(; ta, Q10, rb, tref = 15.0f0) + reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref)) + return (; reco, Q10, rb) +end + +# ## 4. Define Model Parameters + +# Parameter specification: (default, lower_bound, upper_bound) +parameters = ( + rb = (3.0f0, 0.0f0, 13.0f0), + Q10 = (2.0f0, 1.0f0, 4.0f0) +) + +# ## 5. Configure Hybrid Model Components + +# Define input variables +# Forcing variables (temperature) +forcing = [:ta] +# Predictor variables (solar radiation, and its derivative) +predictors = [:sw_pot, :dsw_pot] +# Target variable (respiration) +target = [:reco] + +# Parameter classification +# Global parameters (same for all samples) +global_param_names = [:Q10] +# Neural network predicted parameters +neural_param_names = [:rb] + +# ## 6. Construct the Hybrid Model + +hybrid_model = constructHybridModel( + predictors, + forcing, + target, + RbQ10, + parameters, + neural_param_names, + global_param_names, + hidden_layers = [16, 16], + activation = sigmoid, + scale_nn_outputs = true, + input_batchnorm = true +) + +# ## 7. Model Training: k-Fold Cross-Validation + +k = 3 +folds = make_folds(df, k = k, shuffle = true) + +results = Vector{Any}(undef, k) + +@time @tasks for val_fold in 1:k + @info "Split data outside of train function. Training fold $val_fold of $k" + sdata = split_data(df, hybrid_model; val_fold = val_fold, folds = folds) + out = train( + hybrid_model, + sdata, + (); + nepochs = 10, + patience = 10, + batchsize = 512, # Batch size for training + opt = RMSProp(0.001), # Optimizer and learning rate + monitor_names = [:rb, :Q10], + hybrid_name = "folds_$(val_fold)", + folder_to_save = "CV_results", + file_name = "trained_model_folds_$(val_fold).jld2", + show_progress = false, + plotting = false + ) + results[val_fold] = out +end diff --git a/docs/make.jl b/docs/make.jl index 72e85e6..9b66216 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -1,36 +1,78 @@ +# docs/make.jl using EasyHybrid using Documenter, DocumenterVitepress +literate_root = joinpath(@__DIR__, "literate") + +# collect all .jl files recursively under docs/literate +jl_files = isdir(literate_root) ? + [joinpath(root, f) for (root, _, files) in walkdir(literate_root) for f in files if endswith(f, ".jl")] : + String[] + +if !isempty(jl_files) + @info "Running Literate.jl on $(length(jl_files)) files..." + using Literate + src_root = joinpath(@__DIR__, "src") + + function render_tree(indir::String, outdir::String) + isdir(indir) || return + for (root, _, files) in walkdir(indir) + rel = relpath(root, indir) + target = rel == "." ? outdir : joinpath(outdir, rel) + mkpath(target) + for f in files + endswith(f, ".jl") || continue + inpath = joinpath(root, f) + @info "Literate -> " * relpath(inpath, literate_root) + Literate.markdown( + inpath, target; + documenter = true, + execute = false, + credit = false, + ) + end + end + end + + # Typical folders you might want; add/remove as you wish + render_tree(joinpath(literate_root, "tutorials"), joinpath(src_root, "tutorials")) + render_tree(joinpath(literate_root, "research"), joinpath(src_root, "research")) +else + @info "No Literate sources found — skipping Literate.jl step." +end + +# ----------------------------------------------------------------------------- makedocs(; - modules=[EasyHybrid], - authors="Lazaro Alonso, Bernhard Ahrens, Markus Reichstein", - repo="https://github.com/EarthyScience/EasyHybrid.jl", - sitename="EasyHybrid.jl", - format=DocumenterVitepress.MarkdownVitepress( + modules = [EasyHybrid], + authors = "Lazaro Alonso, Bernhard Ahrens, Markus Reichstein", + repo = "https://github.com/EarthyScience/EasyHybrid.jl", + sitename = "EasyHybrid.jl", + format = DocumenterVitepress.MarkdownVitepress( repo = "https://github.com/EarthyScience/EasyHybrid.jl", devurl = "dev", ), - pages=[ + pages = [ "Home" => "index.md", "Get Started" => "get_started.md", "Tutorial" => [ - "Exponential Response" => "tutorials/exponential_res.md", - "Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md", - "Slurm" => "tutorials/slurm.md" + "Exponential Response" => "tutorials/exponential_res.md", + "Hyperparameter Tuning" => "tutorials/hyperparameter_tuning.md", + "Slurm" => "tutorials/slurm.md", + "Cross-validation" => "tutorials/folds.md", ], - "Research" =>[ - "Overview" => "research/overview.md" - "RbQ10" => "research/RbQ10_results.md" - "BulkDensitySOC" => "research/BulkDensitySOC_results.md" + "Research" => [ + "Overview" => "research/overview.md", + "RbQ10" => "research/RbQ10_results.md", + "BulkDensitySOC" => "research/BulkDensitySOC_results.md", ], "API" => "api.md", ], ) DocumenterVitepress.deploydocs(; - repo = "github.com/EarthyScience/EasyHybrid.jl", # this must be the full URL! - target=joinpath(@__DIR__, "build"), + repo = "github.com/EarthyScience/EasyHybrid.jl", # full URL! + target = joinpath(@__DIR__, "build"), branch = "gh-pages", devbranch = "main", push_preview = true, -) \ No newline at end of file +) diff --git a/projects/book_chapter/Project.toml b/projects/book_chapter/Project.toml index eb8d618..7c0bbe6 100644 --- a/projects/book_chapter/Project.toml +++ b/projects/book_chapter/Project.toml @@ -1,5 +1,9 @@ [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" Hyperopt = "93e5fe13-2215-51db-baaf-2e9a34fb2712" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" NCDatasets = "85f8d34a-cbdd-5861-8df4-14fed0d494ab" +OhMyThreads = "67456a42-1dca-4109-a031-0a68de7e3ad5" WGLMakie = "276b4fcb-3e11-5398-bf8b-a0c2d153d008" diff --git a/src/EasyHybrid.jl b/src/EasyHybrid.jl index 022e5b3..4d7067c 100644 --- a/src/EasyHybrid.jl +++ b/src/EasyHybrid.jl @@ -56,5 +56,6 @@ include("utils/helpers_for_HybridModel.jl") include("plotrecipes.jl") include("utils/helpers_data_loading.jl") include("tune.jl") +include("utils/helpers_cross_validation.jl") end diff --git a/src/train.jl b/src/train.jl index 68e23a9..0fb79ed 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,4 +1,4 @@ -export train, TrainResults +export train, TrainResults, prepare_data, split_data # beneficial for plotting based on type TrainResults? struct TrainResults train_history @@ -17,8 +17,8 @@ end """ train(hybridModel, data, save_ps; nepochs=200, batchsize=10, opt=Adam(0.01), patience=typemax(Int), file_name=nothing, loss_types=[:mse, :r2], training_loss=:mse, agg=sum, train_from=nothing, - random_seed=161803, shuffleobs=false, yscale=log10, monitor_names=[], return_model=:best, - split_by_id=nothing, split_data_at=0.8, plotting=true, show_progress=true, hybrid_name=randstring(10)) + random_seed=161803, yscale=log10, monitor_names=[], return_model=:best, + plotting=true, show_progress=true, hybrid_name=randstring(10), kwargs...) Train a hybrid model using the provided data and save the training process to a file in JLD2 format. Default output file is `trained_model.jld2` at the current working directory under `output_tmp`. @@ -39,10 +39,12 @@ Default output file is `trained_model.jld2` at the current working directory und - `loss_types`: A vector of loss types to compute during training (default: `[:mse, :r2]`). - `agg`: The aggregation function to apply to the computed losses (default: `sum`). -## Data Handling: +## Data Handling (passed via kwargs): - `shuffleobs`: Whether to shuffle the training data (default: false). - `split_by_id`: Column name or function to split data by ID (default: nothing -> no ID-based splitting). - `split_data_at`: Fraction of data to use for training when splitting (default: 0.8). +- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation (default: nothing). +- `val_fold`: The validation fold to use when `folds` is provided (default: nothing). ## Training State and Reproducibility: - `train_from`: A tuple of physical parameters and state to start training from or an output of `train` (default: nothing -> new training). @@ -72,10 +74,7 @@ function train(hybridModel, data, save_ps; loss_types=[:mse, :r2], agg=sum, - # Data handling - shuffleobs=false, - split_by_id=nothing, - split_data_at=0.8, + # Data handling parameters are now passed via kwargs... # Training state and reproducibility train_from=nothing, @@ -109,9 +108,8 @@ function train(hybridModel, data, save_ps; if !isnothing(random_seed) Random.seed!(random_seed) end - - # ? split training and validation data - (x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; split_by_id=split_by_id, shuffleobs=shuffleobs, split_data_at=split_data_at) + + (x_train, y_train), (x_val, y_val) = split_data(data, hybridModel; kwargs...) train_loader = DataLoader((x_train, y_train), batchsize=batchsize, shuffle=true); @@ -382,67 +380,81 @@ function header_and_paddings(nt; digits=5) return headers, paddings end -function split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - +function split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...) + @warn "data was prepared already, none of the keyword arguments for split_data will be used" + return data +end + +function split_data( + data::Union{DataFrame, KeyedArray, Tuple, AbstractDimArray}, + hybridModel; + split_by_id::Union{Nothing,Symbol,AbstractVector}=nothing, + folds::Union{Nothing,AbstractVector,Symbol}=nothing, + val_fold::Union{Nothing,Int}=nothing, + shuffleobs::Bool=false, + split_data_at::Real=0.8, + kwargs... +) data_ = prepare_data(hybridModel, data) - # all the KeyedArray thing! - - if !isnothing(split_by_id) - if isa(split_by_id, Symbol) - ids = getbyname(data, split_by_id) - unique_ids = unique(ids) - elseif isa(split_by_id, AbstractVector) - ids = split_by_id - unique_ids = unique(ids) - split_by_id = "split_by_id" - end + if split_by_id !== nothing && folds !== nothing + + throw(ArgumentError("split_by_id and folds are not supported together; do the split when constructing folds")) + + elseif split_by_id !== nothing + # --- Option A: split by ID --- + ids = isa(split_by_id, Symbol) ? getbyname(data, split_by_id) : split_by_id + unique_ids = unique(ids) train_ids, val_ids = splitobs(unique_ids; at=split_data_at, shuffle=shuffleobs) + train_idx = findall(in(train_ids), ids) + val_idx = findall(in(val_ids), ids) - train_idx = findall(id -> id in train_ids, ids) - val_idx = findall(id -> id in val_ids, ids) + @info "Splitting data by $(split_by_id)" + @info "Number of unique $(split_by_id): $(length(unique_ids))" + @info "Train IDs: $(length(train_ids)) | Val IDs: $(length(val_ids))" - @info "Splitting data by $split_by_id" - @info "Number of unique $split_by_id's: $(length(unique_ids))" - @info "Number of $split_by_id's in training set: $(length(train_ids))" - @info "Number of $split_by_id's in validation set: $(length(val_ids))" - x_all, y_all = data_ + x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) + x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) + return (x_train, y_train), (x_val, y_val) + + elseif folds !== nothing || val_fold !== nothing + # --- Option B: external K-fold assignment --- + @assert val_fold !== nothing "Provide val_fold when using folds." + @assert folds !== nothing "Provide folds when using val_fold." + @warn "shuffleobs is not supported when using folds and val_fold, this will be ignored and should be done during fold constructions" + x_all, y_all = data_ + f = isa(folds, Symbol) ? getbyname(data, folds) : folds + n = size(x_all, 2) + @assert length(f) == n "length(folds) ($(length(f))) must equal number of samples/columns ($n)." + @assert 1 ≤ val_fold ≤ maximum(f) "val_fold=$val_fold is out of range 1:$(maximum(f))." - x_train, y_train = x_all[:, train_idx], y_all[:, train_idx] - x_val, y_val = x_all[:, val_idx], y_all[:, val_idx] - else - (x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) - end - - return (x_train, y_train), (x_val, y_val) -end + val_idx = findall(==(val_fold), f) + @assert !isempty(val_idx) "No samples assigned to validation fold $val_fold." + train_idx = setdiff(1:n, val_idx) -function split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - data_ = prepare_data(hybridModel, data) - (x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) - return (x_train, y_train), (x_val, y_val) -end + @info "K-fold via external assignments: val_fold=$val_fold → train=$(length(train_idx)) val=$(length(val_idx))" -function split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - data_ = prepare_data(hybridModel, data) - (x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) - return (x_train, y_train), (x_val, y_val) -end + x_train, y_train = view(x_all, :, train_idx), view(y_all, :, train_idx) + x_val, y_val = view(x_all, :, val_idx), view(y_all, :, val_idx) + return (x_train, y_train), (x_val, y_val) -function split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...) - return data + else + # --- Fallback: simple random/chronological split of prepared data --- + (x_train, y_train), (x_val, y_val) = splitobs(data_; at=split_data_at, shuffle=shuffleobs) + return (x_train, y_train), (x_val, y_val) + end end """ - split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) - split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8) + split_data(data, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) + split_data(data::Union{DataFrame, KeyedArray}, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, folds=nothing, val_fold=nothing, kwargs...) + split_data(data::AbstractDimArray, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) + split_data(data::Tuple, hybridModel; split_by_id=nothing, shuffleobs=false, split_data_at=0.8, kwargs...) split_data(data::Tuple{Tuple, Tuple}, hybridModel; kwargs...) -Split data into training and validation sets, either randomly or by grouping by ID. +Split data into training and validation sets, either randomly, by grouping by ID, or using external fold assignments. # Arguments: - `data`: The data to split, which can be a DataFrame, KeyedArray, AbstractDimArray, or Tuple @@ -450,9 +462,11 @@ Split data into training and validation sets, either randomly or by grouping by - `split_by_id=nothing`: Either `nothing` for random splitting, a `Symbol` for column-based splitting, or an `AbstractVector` for custom ID-based splitting - `shuffleobs=false`: Whether to shuffle observations during splitting - `split_data_at=0.8`: Ratio of data to use for training +- `folds`: Vector or column name of fold assignments (1..k), one per sample/column for k-fold cross-validation +- `val_fold`: The validation fold to use when `folds` is provided # Behavior: -- For DataFrame/KeyedArray: Supports both random and ID-based splitting with logging +- For DataFrame/KeyedArray: Supports random splitting, ID-based splitting, and external fold assignments - For AbstractDimArray/Tuple: Random splitting only after data preparation - For pre-split Tuple{Tuple, Tuple}: Returns input unchanged @@ -497,7 +511,7 @@ end function prepare_data(hm, data::AbstractDimArray) predictors_forcing, targets = get_prediction_target_names(hm) - return (data[col=At(predictors_forcing)], data[col=At(targets)]) + return (data[col=At(predictors_forcing)], data[col=At(targets)]) # TODO check what this should be rows or cols, I would say rows, but maybe it does not matter end function prepare_data(hm, data::Tuple) diff --git a/src/utils/helpers_cross_validation.jl b/src/utils/helpers_cross_validation.jl new file mode 100644 index 0000000..3380a09 --- /dev/null +++ b/src/utils/helpers_cross_validation.jl @@ -0,0 +1,26 @@ +export make_folds + +""" + make_folds(df::DataFrame; k::Int=5, shuffle=true) -> Vector{Int} + +Assigns each observation in the DataFrame `df` to one of `k` folds for cross-validation. + +# Arguments +- `df::DataFrame`: The input DataFrame whose rows are to be split into folds. +- `k::Int=5`: Number of folds to create. +- `shuffle=true`: Whether to shuffle the data before assigning folds. + +# Returns +- `folds::Vector{Int}`: A vector of length `nrow(df)` where each entry is an integer in `1:k` indicating the fold assignment for that observation. +""" +function make_folds(df::DataFrame; k::Int=5, shuffle=true) + n = numobs(df) + _, val_idx = kfolds(n, k) + folds = fill(0, n) + perm = shuffle ? randperm(n) : 1:n + for (f, idx) in enumerate(val_idx) + fidx = perm[idx] + folds[fidx] .= f + end + return folds +end diff --git a/test/Project.toml b/test/Project.toml index a40c858..6e8c31b 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,8 +1,12 @@ [deps] AxisKeys = "94b1ba4f-4ee9-5380-92f1-94cde586c3c5" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DimensionalData = "0703355e-b756-11e9-17c0-8b28908087d0" EasyHybrid = "61bb816a-e6af-4913-ab9e-91bff2e122e3" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 22bef63..d3c7396 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -9,6 +9,9 @@ dk_twos = gen_linear_data_2outputs() # Include GenericHybridModel tests include("test_generic_hybrid_model.jl") +# Include SplitData tests +include("test_split_data_train.jl") + @testset "LinearHM" begin # test model instantiation diff --git a/test/test_split_data_train.jl b/test/test_split_data_train.jl new file mode 100644 index 0000000..c24de21 --- /dev/null +++ b/test/test_split_data_train.jl @@ -0,0 +1,133 @@ +# test/book_chapter_example_tests.jl +using Test +using Random +using EasyHybrid +using Lux +using DataFrames +using Statistics +using DimensionalData +using ChainRulesCore + +# ------------------------------------------------------------------------------ +# Synthetic data similar to the example's columns (no network calls) +# ------------------------------------------------------------------------------ +function make_synth_df(n::Int=512; seed::Int=42) + rng = MersenneTwister(seed) + ta = 10 .+ 10 .* randn(rng, n) # air temperature [°C] + sw_pot = abs.(50 .+ 20 .* randn(rng, n)) # solar radiation-ish + dsw_pot = vcat(0.0, diff(sw_pot)) # simple derivative + true_Q10 = 2.0 + true_rb = 3.0 .+ 0.02 .* (sw_pot .- mean(sw_pot)) + tref = 15.0 + reco = true_rb .* (true_Q10 .^ (0.1 .* (ta .- tref))) .+ 0.1 .* randn(rng, n) + DataFrame(; ta = Float32.(ta), + sw_pot = Float32.(sw_pot), + dsw_pot = Float32.(dsw_pot), + reco = Float32.(reco), + id = 1:n) +end + +# ------------------------------------------------------------------------------ +# RbQ10 physical model (from example) +# ------------------------------------------------------------------------------ +function RbQ10(; ta, Q10, rb, tref = 15.0f0) + reco = rb .* Q10 .^ (0.1f0 .* (ta .- tref)) + return (; reco, Q10, rb) +end + +# Parameter spec analogous to the example +const RbQ10_PARAMS = ( + rb = (3.0f0, 0.0f0, 13.0f0), + Q10 = (2.0f0, 1.0f0, 4.0f0), +) + +# ------------------------------------------------------------------------------ +# Tests +# ------------------------------------------------------------------------------ +@testset "Book Chapter Example - RbQ10 Hybrid" begin + df = make_synth_df(32) # keep it small/fast + + forcing = [:ta] + predictors = [:sw_pot, :dsw_pot] + target = [:reco] + global_param_names = [:Q10] + neural_param_names = [:rb] + + @testset "test DataFrame and thereby KeyedArray" begin + model = constructHybridModel( + predictors, forcing, target, RbQ10, + RbQ10_PARAMS, neural_param_names, global_param_names + ) + @test model isa SingleNNHybridModel + # prepare_data should produce something consumable by split_data + ka = prepare_data(model, df) + @test !isnothing(ka) + + trainshort(ka; kwargs...) = train(model, ka, (); + nepochs = 1, + batchsize = 12, + plotting = false, + show_progress = false, + hybrid_name = "test", + kwargs... + ) + + out = trainshort(ka) + @test !isnothing(out) + + out = trainshort(ka, shuffleobs = true) + @test !isnothing(out) + + out = trainshort(ka, split_data_at = 0.8) + @test !isnothing(out) + + out = trainshort(ka, shuffleobs = true, split_data_at = 0.8) + @test !isnothing(out) + + #only doable on df not ka, since that row gets deleted at the moment + out = trainshort(df, split_by_id = :id) + @test !isnothing(out) + + out = trainshort(ka, split_by_id = df.id) + @test !isnothing(out) + + out = trainshort(ka, split_by_id = df.id, shuffleobs = true) + @test !isnothing(out) + + out = trainshort(ka, split_by_id = df.id, shuffleobs = false) + @test !isnothing(out) + + folds = make_folds(df, k=3, shuffle=true) + @test !isnothing(folds) + + df.folds = folds + + out = trainshort(ka, folds = folds, val_fold = 1) + @test !isnothing(out) + + out = trainshort(df, folds = :folds, val_fold = 1) + @test !isnothing(out) + + out = trainshort(df, folds = :folds, val_fold = 1, shuffleobs = true) + @test !isnothing(out) + + @test_throws ArgumentError trainshort(df; folds = :folds, val_fold = 1, shuffleobs = true, split_by_id = :id) + + sdata = split_data(df, model, split_by_id = :id) + @test !isnothing(sdata) + + out = trainshort(sdata) + @test !isnothing(out) + + mat = vcat(ka[1], ka[2]) + da = DimArray(mat, (Dim{:col}(mat.keys[1]), Dim{:row}(1:size(mat,2))))' + ka = prepare_data(model, da) + @test !isnothing(ka) + + # TODO: this is not working, transpose da columns to rows? + #dtuple_tuple = split_data(da, model) + #@test !isnothing(dtuple_tuple) + # TODO: this is not working, need to fix GenericHybrid Model for DimensionalData + # out = trainshort(dtuple_tuple) + end +end