Skip to content

Commit cfc0e98

Browse files
authored
Merge pull request #39 from ilabcode/dev
Dev
2 parents f6fc5d6 + 67d31a2 commit cfc0e98

File tree

12 files changed

+493
-155
lines changed

12 files changed

+493
-155
lines changed

src/ActionModels_variations/utils/give_inputs.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function ActionModels.give_inputs!(hgf::HGF, inputs::Array)
4646
#Take each row in the array
4747
for input in eachrow(inputs)
4848
#Input it to the hgf
49-
update_hgf!(hgf, input)
49+
update_hgf!(hgf, Vector(input))
5050
end
5151

5252
return nothing

src/ActionModels_variations/utils/reset.jl

Lines changed: 46 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,60 @@ function ActionModels.reset!(hgf::HGF)
77

88
#For categorical state nodes
99
if node isa CategoricalStateNode
10-
#Reset the posterior to all 0's
11-
node.states.posterior .= zero(Real)
12-
#Set to missing
13-
node.states.prediction = missing
14-
node.states.value_prediction_error = missing
10+
#Set states to vectors of missing
11+
node.states.posterior .= missing
12+
node.states.value_prediction_error .= missing
13+
#Empty prediction state
14+
empty!(node.states.prediction)
1515

16-
#For other nodes
16+
#For binary input nodes
17+
elseif node isa BinaryInputNode
18+
#Set states to missing
19+
node.states.value_prediction_error .= missing
20+
node.states.input_value = missing
21+
22+
#For continuous state nodes
23+
elseif node isa ContinuousStateNode
24+
#Set posterior to initial belief
25+
node.states.posterior_mean = node.params.initial_mean
26+
node.states.posterior_precision = node.params.initial_precision
27+
#For other states
28+
for state_name in [
29+
:value_prediction_error,
30+
:volatility_prediction_error,
31+
:prediction_mean,
32+
:prediction_volatility,
33+
:prediction_precision,
34+
:auxiliary_prediction_precision,
35+
]
36+
#Set the state to missing
37+
setfield!(node.states, state_name, missing)
38+
end
39+
40+
#For continuous input nodes
41+
elseif node isa ContinuousInputNode
42+
43+
#For all states except auxiliary prediction precision
44+
for state_name in [
45+
:input_value,
46+
:value_prediction_error,
47+
:volatility_prediction_error,
48+
:prediction_volatility,
49+
:prediction_precision,
50+
]
51+
#Set the state to missing
52+
setfield!(node.states, state_name, missing)
53+
end
54+
55+
#For other nodes
1756
else
1857
#For each state
1958
for state_name in fieldnames(typeof(node.states))
20-
#Set the state to first value in history
59+
#Set the state to missing
2160
setfield!(node.states, state_name, missing)
2261
end
2362
end
2463

25-
#For continuous state nodes
26-
if node isa ContinuousStateNode
27-
#Set the initial posterior
28-
node.states.posterior_mean = node.params.initial_mean
29-
node.states.posterior_precision = node.params.initial_precision
30-
end
31-
3264
#For each state in the history
3365
for state_name in fieldnames(typeof(node.history))
3466

src/create_hgf/create_premade_hgf.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ function premade_hgf(model_name::String, config::Dict = Dict(); verbose = true)
1616
"binary_3level" => premade_binary_3level, #The standard binary input 3 level HGF
1717
"JGET" => premade_JGET, #The JGET model
1818
"categorical_3level" => premade_categorical_3level, #The standard categorical input 3 level HGF
19+
"categorical_3level_state_transitions" =>
20+
premade_categorical_3level_state_transitions, #Categorical 3 level HGF for learning state transitions
1921
)
2022

2123
#Check that the specified model is in the list of keys

src/create_hgf/init_hgf.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,8 +282,13 @@ function init_hgf(;
282282
push!(node.category_parent_order, parent.name)
283283
end
284284

285-
#Set posterior to node
286-
node.states.posterior = zeros(length(node.value_parents))
285+
#Set posterior to vector of zeros equal to the number of categories
286+
node.states.posterior = Vector{Union{Real,Missing}}(missing,length(node.value_parents))
287+
push!(node.history.posterior, node.states.posterior)
288+
289+
#Set posterior to vector of missing equal to the number of categories
290+
node.states.value_prediction_error = node.states.posterior
291+
push!(node.history.value_prediction_error, node.states.value_prediction_error)
287292

288293
#For other nodes
289294
else

src/premade_models/premade_hgfs.jl

Lines changed: 187 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ end
324324
function premade_categorical_3level(config::Dict; verbose::Bool = true)
325325

326326
#Defaults
327-
spec_defaults = Dict(
327+
defaults = Dict(
328328
"n_categories" => 4,
329329
("x2", "evolution_rate") => 0,
330330
("x2", "initial_mean") => 0,
@@ -338,18 +338,18 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true)
338338

339339
#Warn the user about used defaults and misspecified keys
340340
if verbose
341-
warn_premade_defaults(spec_defaults, config)
341+
warn_premade_defaults(defaults, config)
342342
end
343343

344344
#Merge to overwrite defaults
345-
config = merge(spec_defaults, config)
345+
config = merge(defaults, config)
346346

347347

348348
##Prep category node parent names
349349
#Vector for category node binary parent names
350-
category_binary_parent_names = []
350+
category_binary_parent_names = Vector{String}()
351351
#Vector for binary node continuous parent names
352-
binary_continuous_parent_names = []
352+
binary_continuous_parent_names = Vector{String}()
353353
#Populate the above vectors with node names
354354
for category_number = 1:config["n_categories"]
355355
push!(category_binary_parent_names, "x1_" * string(category_number))
@@ -361,7 +361,7 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true)
361361
input_nodes = Dict("name" => "u", "type" => "categorical")
362362

363363
##List of state nodes
364-
state_nodes =[Dict{String, Any}("name" => "x1", "type" => "categorical")]
364+
state_nodes = [Dict{String,Any}("name" => "x1", "type" => "categorical")]
365365

366366
#Add category node binary parents
367367
for node_name in category_binary_parent_names
@@ -431,4 +431,184 @@ function premade_categorical_3level(config::Dict; verbose::Bool = true)
431431
edges = edges,
432432
verbose = false,
433433
)
434-
end
434+
end
435+
436+
function premade_categorical_3level_state_transitions(config::Dict; verbose::Bool = true)
437+
438+
#Defaults
439+
defaults = Dict(
440+
"n_categories" => 4,
441+
("x2", "evolution_rate") => 0,
442+
("x2", "initial_mean") => 0,
443+
("x2", "initial_precision") => 1,
444+
("x3", "evolution_rate") => 0,
445+
("x3", "initial_mean") => 0,
446+
("x3", "initial_precision") => 1,
447+
("x1", "x2", "value_coupling") => 1,
448+
("x2", "x3", "volatility_coupling") => 1,
449+
)
450+
451+
#Warn the user about used defaults and misspecified keys
452+
if verbose
453+
warn_premade_defaults(defaults, config)
454+
end
455+
456+
#Merge to overwrite defaults
457+
config = merge(defaults, config)
458+
459+
460+
##Prepare node names
461+
#Empty lists
462+
categorical_input_node_names = Vector{String}()
463+
categorical_state_node_names = Vector{String}()
464+
categorical_node_binary_parent_names = Vector{String}()
465+
binary_node_continuous_parent_names = Vector{String}()
466+
467+
#Go through each category that the transition may have been from
468+
for category_from = 1:config["n_categories"]
469+
#One input node and its state node parent for each
470+
push!(categorical_input_node_names, "u" * string(category_from))
471+
push!(categorical_state_node_names, "x1_" * string(category_from))
472+
#Go through each category that the transition may have been to
473+
for category_to = 1:config["n_categories"]
474+
#Each categorical state node has a binary parent for each
475+
push!(
476+
categorical_node_binary_parent_names,
477+
"x1_" * string(category_from) * "_" * string(category_to),
478+
)
479+
#And each binary parent has a continuous parent of its own
480+
push!(
481+
binary_node_continuous_parent_names,
482+
"x2_" * string(category_from) * "_" * string(category_to),
483+
)
484+
end
485+
end
486+
487+
##Create input nodes
488+
#Initialize list
489+
input_nodes = Vector{Dict}()
490+
491+
#For each categorical input node
492+
for node_name in categorical_input_node_names
493+
#Add it to the list
494+
push!(input_nodes, Dict("name" => node_name, "type" => "categorical"))
495+
end
496+
497+
##Create state nodes
498+
#Initialize list
499+
state_nodes = Vector{Dict}()
500+
501+
#For each cateogrical state node
502+
for node_name in categorical_state_node_names
503+
#Add it to the list
504+
push!(state_nodes, Dict("name" => node_name, "type" => "categorical"))
505+
end
506+
507+
#For each categorical node binary parent
508+
for node_name in categorical_node_binary_parent_names
509+
#Add it to the list
510+
push!(state_nodes, Dict("name" => node_name, "type" => "binary"))
511+
end
512+
513+
#For each binary node continuous parent
514+
for node_name in binary_node_continuous_parent_names
515+
#Add it to the list, with parameter settings from the config
516+
push!(
517+
state_nodes,
518+
Dict(
519+
"name" => node_name,
520+
"type" => "continuous",
521+
"evolution_rate" => config[("x2", "evolution_rate")],
522+
"initial_mean" => config[("x2", "initial_mean")],
523+
"initial_precision" => config[("x2", "initial_precision")],
524+
),
525+
)
526+
end
527+
528+
#Add the shared volatility parent of the continuous nodes
529+
push!(
530+
state_nodes,
531+
Dict(
532+
"name" => "x3",
533+
"type" => "continuous",
534+
"evolution_rate" => config[("x3", "evolution_rate")],
535+
"initial_mean" => config[("x3", "initial_mean")],
536+
"initial_precision" => config[("x3", "initial_precision")],
537+
),
538+
)
539+
540+
##Create child-parent relations
541+
#Initialize list
542+
edges = Vector{Dict}()
543+
544+
#For each categorical input node and its corresponding state node parent
545+
for (child_name, parent_name) in
546+
zip(categorical_input_node_names, categorical_state_node_names)
547+
#Add their relation to the list
548+
push!(edges, Dict("child" => child_name, "value_parents" => parent_name))
549+
end
550+
551+
#For each categorical state node
552+
for child_node_name in categorical_state_node_names
553+
#Get the category it represents transitions from
554+
(child_supername, child_category_from) = split(child_node_name, "_")
555+
556+
#For each potential parent node
557+
for parent_node_name in categorical_node_binary_parent_names
558+
#Get the category it represents transitions from
559+
(parent_supername, parent_category_from, parent_category_to) =
560+
split(parent_node_name, "_")
561+
562+
#If these match
563+
if parent_category_from == child_category_from
564+
#Add the parent as parent of the child
565+
push!(
566+
edges,
567+
Dict("child" => child_node_name, "value_parents" => parent_node_name),
568+
)
569+
end
570+
end
571+
end
572+
573+
#For each binary parent of categorical nodes and their corresponding continuous parents
574+
for (child_name, parent_name) in
575+
zip(categorical_node_binary_parent_names, binary_node_continuous_parent_names)
576+
#Add their relations to the list, with the same value coupling
577+
push!(
578+
edges,
579+
Dict(
580+
"child" => child_name,
581+
"value_parents" => (parent_name, config[("x1", "x2", "value_coupling")]),
582+
),
583+
)
584+
end
585+
586+
#Add the shared continuous node volatility parent to the continuous nodes
587+
for child_name in binary_node_continuous_parent_names
588+
push!(
589+
edges,
590+
Dict(
591+
"child" => child_name,
592+
"volatility_parents" => ("x3", config[("x2", "x3", "volatility_coupling")]),
593+
),
594+
)
595+
end
596+
597+
#Initialize the HGF
598+
init_hgf(
599+
input_nodes = input_nodes,
600+
state_nodes = state_nodes,
601+
edges = edges,
602+
verbose = false,
603+
)
604+
end
605+
606+
607+
608+
609+
610+
611+
612+
613+
614+

0 commit comments

Comments
 (0)