Skip to content

Commit a3c2ec1

Browse files
authored
Merge pull request #171 from ComputationalPsychiatry/dev
Dev
2 parents a3db70a + 52c64c2 commit a3c2ec1

File tree

4 files changed

+45
-18
lines changed

4 files changed

+45
-18
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ActionModels"
22
uuid = "320cf53b-cc3b-4b34-9a10-0ecb113566a3"
33
authors = ["Peter Thestrup Waade [email protected]", "Luke Ring [email protected]", "Malte Lau Møller", "Christoph Mathys [email protected]"]
4-
version = "0.7.2"
4+
version = "0.7.3"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

docs/julia_files/B_user_guide/4_model_fitting.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ chns = sample_posterior!(model)
9999
# We can specify the number of samples and chains to sample with the `n_samples` and `n_chains` keyword arguments.
100100
# The `init_params` keyword argument can be used to specify how the initial parameters for the chains are set.
101101
# It can be set to `:MAP` or `:MLE` to use the maximum a posteriori or maximum likelihood estimates as the initial parameters, respectively.
102-
# It can be set to `:sample_prior` to draw a single sample from the prior distribution, or to `nothing` to use Turing's default of random values between -2 and 2 as the initial parameters.`
102+
# It can be set to `:sample_prior` to draw a single sample from the prior distribution, or to `nothing` to use Turing's default of random values between -2 and 2 as the initial parameters.
103103
# Finally, a vector of initial parameters can be passed, which will be used as the initial parameters for all chains.
104104
# Other arguments for the sampling can also be passed. This includes the autodifferentiation backend to use, which can be set with the `ad_type` keyword argument, and the sampler to use, which can be set with the `sampler` keyword argument.
105105
# Notably, `sample_posterior!` will return the already sampled `Chains` object if the posterior has already been sampled. Set `resample = true` to override the already sampled posterior.

src/fitting_models/turing_model/create_model.jl

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -175,9 +175,26 @@ function create_model(
175175
state_name => state.initial_value for
176176
(state_name, state) in pairs(action_model.states)
177177
)
178+
179+
## Create population data ##
180+
#Remove action and observation columns
181+
population_data = data[!,setdiff(Symbol.(names(data)), vcat(observation_cols, action_cols))]
182+
#If there are session columns
183+
if length(session_cols) > 0
184+
#Only one row per session
185+
population_data = unique(population_data, session_cols)
186+
#Sort population data by session columns
187+
population_data = sort(population_data, session_cols)
188+
else
189+
#If there are no session columns, just take the first row
190+
population_data = population_data[1:1, :]
191+
end
178192

179-
## Group data by sessions ##
180-
grouped_data = groupby(data, session_cols)
193+
## Create sessions data ##
194+
#Only keep actions, observations and session columns
195+
sessions_data = data[!, unique(vcat(collect(observation_cols), collect(action_cols), session_cols))]
196+
#Group sessions data by session columns
197+
sessions_data = groupby(sessions_data, session_cols, sort = true)
181198

182199
## Create IDs for each session ##
183200
session_ids = [
@@ -186,18 +203,18 @@ function create_model(
186203
string(col_name) * id_column_separator * string(first(subdata)[col_name]) for col_name in session_cols
187204
],
188205
id_separator,
189-
) for subdata in grouped_data
206+
) for subdata in sessions_data
190207
]
191208

192209
## Extract observations and actions ##
193210
observations = Vector{Tuple{observation_types_data...}}[
194211
Tuple{observation_types_data...}.(
195212
eachrow(session_data[!, collect(observation_cols)]),
196-
) for session_data in grouped_data
213+
) for session_data in sessions_data
197214
]
198215
actions = Vector{Tuple{action_types_data...}}[
199216
Tuple{action_types_data...}.(eachrow(session_data[!, collect(action_cols)])) for
200-
session_data in grouped_data
217+
session_data in sessions_data
201218
]
202219

203220
### CREATE MODEL ###
@@ -220,15 +237,6 @@ function create_model(
220237
initial_states,
221238
)
222239

223-
## Extract population data ##
224-
if length(session_cols) == 0
225-
#If there are no session columns, the population data is the same as the data
226-
population_data = data[1:1, :]
227-
else
228-
#Otherwise, extract the population data from the grouped data
229-
population_data = unique(data, session_cols)
230-
end
231-
232240
return ModelFit(
233241
model = model,
234242
population_model_type = population_model_type,

src/fitting_models/turing_model/population_models/glm_population_model.jl

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,12 @@ function create_model(
8888
regressions = F[regressions]
8989
end
9090

91+
#If only a symbol was specified for session cols
92+
if session_cols isa Symbol
93+
#Convert single session column to vector
94+
session_cols = [session_cols]
95+
end
96+
9197
#Make sure that single formulas are made into Regression objects
9298
regressions = [
9399
regression isa Regression ? regression : Regression(regression) for
@@ -107,8 +113,21 @@ function create_model(
107113
kwargs...,
108114
)
109115

110-
#Extract just the data needed for the linear regression
111-
population_data = unique(data, session_cols)
116+
## Create population data ##
117+
#Remove action and observation columns
118+
population_data =
119+
data[!, setdiff(Symbol.(names(data)), vcat(observation_cols, action_cols))]
120+
#If there are session columns
121+
if length(session_cols) > 0
122+
#Only one row per session
123+
population_data = unique(population_data, session_cols)
124+
#Sort population data by session columns
125+
population_data = sort(population_data, session_cols)
126+
else
127+
#If there are no session columns, just take the first row
128+
population_data = population_data[1:1, :]
129+
end
130+
112131
#Extract number of sessions
113132
n_sessions = nrow(population_data)
114133

0 commit comments

Comments
 (0)