Skip to content

Commit d1b8b76

Browse files
authored
Merge pull request #165 from ilabcode/dev
minor update
2 parents 66b5bec + 6fd9ae5 commit d1b8b76

File tree

13 files changed

+14
-95
lines changed

13 files changed

+14
-95
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ authors = [ "Peter Thestrup Waade [email protected]",
44
"Anna Hedvig Møller [email protected]",
55
"Jacopo Comoglio [email protected]",
66
"Christoph Mathys [email protected]"]
7-
version = "0.6.0"
7+
version = "0.6.1"
88

99

1010
[deps]

docs/julia_files/tutorials/classic_binary.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ inputs = CSV.read(data_path * "classic_binary_inputs.csv", DataFrame)[!, 1];
2121

2222
# Create an HGF
2323
hgf_parameters = Dict(
24-
("u", "category_means") => Real[0.0, 1.0],
25-
("u", "input_precision") => Inf,
2624
("xprob", "volatility") => -2.5,
2725
("xprob", "initial_mean") => 0,
2826
("xprob", "initial_precision") => 1,
@@ -54,8 +52,6 @@ plot_trajectory(agent, ("xvol", "posterior"))
5452
# Set fixed parameters
5553
fixed_parameters = Dict(
5654
"action_noise" => 0.2,
57-
("u", "category_means") => Real[0.0, 1.0],
58-
("u", "input_precision") => Inf,
5955
("xprob", "initial_mean") => 0,
6056
("xprob", "initial_precision") => 1,
6157
("xvol", "initial_mean") => 1,

docs/julia_files/user_guide/fitting_hgf_models.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ using HierarchicalGaussianFiltering
4545
# We will define a binary 3-level HGF and its parameters
4646

4747
hgf_parameters = Dict(
48-
("u", "category_means") => Real[0.0, 1.0],
49-
("u", "input_precision") => Inf,
5048
("xprob", "volatility") => -2.5,
5149
("xprob", "initial_mean") => 0,
5250
("xprob", "initial_precision") => 1,
@@ -86,8 +84,6 @@ plot_trajectory!(agent, ("xbin", "prediction"))
8684
# Set fixed parameters. We choose to fit the evolution rate of the xprob node.
8785
fixed_parameters = Dict(
8886
"action_noise" => 0.2,
89-
("u", "category_means") => Real[0.0, 1.0],
90-
("u", "input_precision") => Inf,
9187
("xprob", "initial_mean") => 0,
9288
("xprob", "initial_precision") => 1,
9389
("xvol", "initial_mean") => 1,

docs/julia_files/user_guide/utility_functions.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,6 @@ agent_parameter = Dict("action_noise" => 0.3)
6262
#We also specify our HGF and custom parameter settings:
6363

6464
hgf_parameters = Dict(
65-
("u", "category_means") => Real[0.0, 1.0],
66-
("u", "input_precision") => Inf,
6765
("xprob", "volatility") => -2.5,
6866
("xprob", "initial_mean") => 0,
6967
("xprob", "initial_precision") => 1,

src/ActionModels_variations/utils/set_parameters.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,6 @@ function ActionModels.set_parameters!(
4040
)
4141
end
4242

43-
#If the param is a vector of category_means
44-
if param_value isa Vector
45-
#Convert it to a vector of reals
46-
param_value = convert(Vector{Real}, param_value)
47-
end
48-
4943
#Set the parameter value
5044
setfield!(node.parameters, Symbol(param_name), param_value)
5145

src/create_hgf/hgf_structs.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,6 @@ end
340340
Configuration of parameters in binary input node. Default category mean set to [0,1]
341341
"""
342342
Base.@kwdef mutable struct BinaryInputNodeParameters
343-
category_means::Vector{Union{Real}} = [0, 1]
344-
input_precision::Real = Inf
345343
coupling_strengths::Dict{String,Real} = Dict{String,Real}()
346344
end
347345

src/premade_models/premade_hgfs/premade_binary_2level.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@ function premade_binary_2level(config::Dict; verbose::Bool = true)
1111

1212
#Defaults
1313
spec_defaults = Dict(
14-
("u", "category_means") => [0, 1],
15-
("u", "input_precision") => Inf,
1614
("xprob", "volatility") => -2,
1715
("xprob", "drift") => 0,
1816
("xprob", "autoconnection_strength") => 1,

src/premade_models/premade_hgfs/premade_binary_3level.jl

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,6 @@ This HGF has five shared parameters:
1313
"coupling_strengths_xprob_xvol"
1414
1515
# Config defaults:
16-
- ("u", "category_means"): [0, 1]
17-
- ("u", "input_precision"): Inf
1816
- ("xprob", "volatility"): -2
1917
- ("xvol", "volatility"): -2
2018
- ("xbin", "xprob", "coupling_strength"): 1
@@ -28,8 +26,6 @@ function premade_binary_3level(config::Dict; verbose::Bool = true)
2826

2927
#Defaults
3028
defaults = Dict(
31-
("u", "category_means") => [0, 1],
32-
("u", "input_precision") => Inf,
3329
("xprob", "volatility") => -2,
3430
("xprob", "drift") => 0,
3531
("xprob", "autoconnection_strength") => 1,

src/update_hgf/node_updates/binary_state_node.jl

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,8 @@ function calculate_posterior_precision(node::BinaryStateNode)
9999
child = node.edges.observation_children[1]
100100

101101
#Simple update with inifinte precision
102-
if child.parameters.input_precision == Inf
103-
posterior_precision = Inf
104-
#Update with finite precision
105-
else
106-
posterior_precision =
107-
1 / (node.states.posterior_mean * (1 - node.states.posterior_mean))
108-
end
102+
posterior_precision = Inf
103+
109104

110105
## If the child is a category child ##
111106
elseif length(node.edges.category_children) > 0
@@ -141,30 +136,7 @@ function calculate_posterior_mean(node::BinaryStateNode, update_type::HGFUpdateT
141136
#Set the posterior to missing
142137
posterior_mean = missing
143138
else
144-
#Update with infinte input precision
145-
if child.parameters.input_precision == Inf
146-
posterior_mean = child.states.input_value
147-
148-
#Update with finite input precision
149-
else
150-
posterior_mean =
151-
node.states.prediction_mean * exp(
152-
-0.5 *
153-
node.states.prediction_precision *
154-
child.parameters.category_means[1]^2,
155-
) / (
156-
node.states.prediction_mean * exp(
157-
-0.5 *
158-
node.states.prediction_precision *
159-
child.parameters.category_means[1]^2,
160-
) +
161-
(1 - node.states.prediction_mean) * exp(
162-
-0.5 *
163-
node.states.prediction_precision *
164-
child.parameters.category_means[2]^2,
165-
)
166-
)
167-
end
139+
posterior_mean = child.states.input_value
168140
end
169141

170142
## If the child is a category child ##

src/utils/get_surprise.jl

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -103,40 +103,15 @@ function get_surprise(node::BinaryInputNode)
103103
parents_prediction_mean += parent.states.prediction_mean
104104
end
105105

106-
#If the input precision is infinite
107-
if node.parameters.input_precision == Inf
108-
109-
#If a 1 was observed
110-
if node.states.input_value == 0
111-
#Get surprise
112-
surprise = -log(1 - parents_prediction_mean)
113-
114-
#If a 0 was observed
115-
elseif node.states.input_value == 1
116-
#Get surprise
117-
surprise = -log(parents_prediction_mean)
118-
end
119-
120-
#If the input precision is finite
121-
else
122-
#Get the surprise
123-
surprise =
124-
-log(
125-
parents_prediction_mean * pdf(
126-
Normal(
127-
node.parameters.category_means[1],
128-
node.parameters.input_precision,
129-
),
130-
node.states.input_value,
131-
) +
132-
(1 - parents_prediction_mean) * pdf(
133-
Normal(
134-
node.parameters.category_means[2],
135-
node.parameters.input_precision,
136-
),
137-
node.states.input_value,
138-
),
139-
)
106+
#If a 1 was observed
107+
if node.states.input_value == 0
108+
#Get surprise
109+
surprise = -log(1 - parents_prediction_mean)
110+
111+
#If a 0 was observed
112+
elseif node.states.input_value == 1
113+
#Get surprise
114+
surprise = -log(parents_prediction_mean)
140115
end
141116

142117
return surprise

0 commit comments

Comments
 (0)