4
4
# First load packages
5
5
using ActionModels
6
6
using HierarchicalGaussianFiltering
7
- using Turing
8
7
using CSV
9
8
using DataFrames
10
9
using Plots
@@ -19,7 +18,7 @@ data_path = hgf_path * "/docs/tutorials/data/"
19
18
inputs = CSV. read (data_path * " classic_binary_inputs.csv" , DataFrame)[! , 1 ];
20
19
21
20
# Create an HGF
22
- hgf_params = Dict (
21
+ hgf_parameters = Dict (
23
22
(" u" , " category_means" ) => Real[0.0 , 1.0 ],
24
23
(" u" , " input_precision" ) => Inf ,
25
24
(" x2" , " evolution_rate" ) => - 2.5 ,
@@ -31,11 +30,12 @@ hgf_params = Dict(
31
30
(" x1" , " x2" , " value_coupling" ) => 1.0 ,
32
31
(" x2" , " x3" , " volatility_coupling" ) => 1.0 ,
33
32
);
34
- hgf = premade_hgf (" binary_3level" , hgf_params , verbose = false );
33
+ hgf = premade_hgf (" binary_3level" , hgf_parameters , verbose = false );
35
34
36
35
# Create an agent
37
- agent_params = Dict (" sigmoid_action_precision" => 5 );
38
- agent = premade_agent (" hgf_unit_square_sigmoid_action" , hgf, agent_params, verbose = false );
36
+ agent_parameters = Dict (" sigmoid_action_precision" => 5 );
37
+ agent =
38
+ premade_agent (" hgf_unit_square_sigmoid_action" , hgf, agent_parameters, verbose = false );
39
39
40
40
# Evolve agent and save actions
41
41
actions = give_inputs! (agent, inputs);
@@ -48,7 +48,7 @@ plot_trajectory(agent, ("x2", "posterior"))
48
48
plot_trajectory (agent, (" x3" , " posterior" ))
49
49
50
50
# Set fixed parameters
51
- fixed_params = Dict (
51
+ fixed_parameters = Dict (
52
52
" sigmoid_action_precision" => 5 ,
53
53
(" u" , " category_means" ) => Real[0.0 , 1.0 ],
54
54
(" u" , " input_precision" ) => Inf ,
@@ -58,26 +58,31 @@ fixed_params = Dict(
58
58
(" x3" , " initial_precision" ) => 1 ,
59
59
(" x1" , " x2" , " value_coupling" ) => 1.0 ,
60
60
(" x2" , " x3" , " volatility_coupling" ) => 1.0 ,
61
- (" x2" , " evolution_rate" ) => - 3.0 ,
62
61
(" x3" , " evolution_rate" ) => - 6.0 ,
63
62
);
64
63
65
64
# Set priors for parameter recovery
66
65
param_priors = Dict ((" x2" , " evolution_rate" ) => Normal (- 3.0 , 0.5 ));
67
66
68
67
# Prior predictive plot
69
- plot_predictive_simulation (param_priors, agent, inputs, (" x1" , " prediction_mean" ), n_simulations = 100 )
68
+ plot_predictive_simulation (
69
+ param_priors,
70
+ agent,
71
+ inputs,
72
+ (" x1" , " prediction_mean" ),
73
+ n_simulations = 100 ,
74
+ )
70
75
71
76
# Get the actions from the MATLAB tutorial
72
77
actions = CSV. read (data_path * " classic_binary_actions.csv" , DataFrame)[! , 1 ];
73
78
74
79
# Fit the actions
75
80
fitted_model = fit_model (
76
81
agent,
82
+ param_priors,
77
83
inputs,
78
84
actions,
79
- param_priors,
80
- fixed_params,
85
+ fixed_parameters = fixed_parameters,
81
86
verbose = true ,
82
87
n_iterations = 10 ,
83
88
)
@@ -89,4 +94,10 @@ plot(fitted_model)
89
94
plot_parameter_distribution (fitted_model, param_priors)
90
95
91
96
# Posterior predictive plot
92
- plot_predictive_simulation (fitted_model, agent, inputs, (" x1" , " prediction_mean" ), n_simulations = 3 )
97
+ plot_predictive_simulation (
98
+ fitted_model,
99
+ agent,
100
+ inputs,
101
+ (" x1" , " prediction_mean" ),
102
+ n_simulations = 3 ,
103
+ )
0 commit comments