diff --git a/.gitignore b/.gitignore index 8ff79ba..5a2f3d5 100644 --- a/.gitignore +++ b/.gitignore @@ -2,12 +2,9 @@ *.jl.cov *.jl.mem Manifest.toml -!/docs/Manifest.toml -!/test/Manifest.toml !/binder/Manifest.toml -/attic/ +attic/ /.vscode/ -/docs/attic/ /docs/build/ /docs/.CondaPkg/ /docs/LocalPreferences.toml diff --git a/Project.toml b/Project.toml index 24ec8c9..4e24eff 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComputerAdaptiveTesting" uuid = "5a0d4f34-1f62-4a66-80fe-87aba0485488" authors = ["Frankie Robertson"] -version = "0.3.2" +version = "0.4.0" [deps] Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" @@ -11,6 +11,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" EffectSizes = "e248de7e-9197-5860-972e-353a2af44d75" +ElasticArrays = "fdbdab4c-e67f-52f5-8c3f-e7b388dad3d4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -23,6 +24,7 @@ MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" Mmap = "a63ad114-7e13-5084-954f-fe012c677804" PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a" PsychometricsBazaarBase = "b0d9cada-d963-45e9-a4c6-4746243987f1" +QuickHeaps = "30b38841-0f52-47f8-a5f8-18d5d4064379" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" @@ -46,8 +48,9 @@ DataFrames = "1.6.1" Distributions = "^0.25.88" DocStringExtensions = " ^0.9" EffectSizes = "^1.0.1" +ElasticArrays = "1.2.12" FillArrays = "0.13, 1.5.0" -FittedItemBanks = "^0.6.3" +FittedItemBanks = "^0.7.2" ForwardDiff = "1" HypothesisTests = "^0.10.12, ^0.11.0" Interpolations = "^0.14, ^0.15" @@ -59,10 +62,10 @@ MacroTools = "^0.5.6" Mmap = "^1.11" Optim = "1.7.3" PrecompileTools = "1.2.1" -PsychometricsBazaarBase = "^0.8.1" +PsychometricsBazaarBase = "^0.8.4" +QuickHeaps = "0.2.2" Random = "^1.11" Reexport = "1" -ResumableFunctions = "^0.6" Setfield = "^1" SparseArrays = "^1.11" StaticArrays = "1" @@ -75,8 +78,7 @@ julia = "^1.11" Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b" Optim = "429524aa-4258-5aef-a3af-852621145aeb" -ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Aqua", "JET", "Optim", "ResumableFunctions", "Test"] +test = ["Aqua", "JET", "Optim", "Test"] diff --git a/benchmark/benchmarks.jl b/benchmark/benchmarks.jl index ff1c61e..5466a4e 100644 --- a/benchmark/benchmarks.jl +++ b/benchmark/benchmarks.jl @@ -8,7 +8,6 @@ using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel4PL using ComputerAdaptiveTesting.Aggregators using PsychometricsBazaarBase.Optimizers using PsychometricsBazaarBase.Integrators: even_grid -using ComputerAdaptiveTesting.NextItemRules: mirtcat_quadpts using ComputerAdaptiveTesting.NextItemRules: ExpectationBasedItemCriterion, PointResponseExpectation using ComputerAdaptiveTesting.NextItemRules @@ -27,10 +26,10 @@ function prepare_4pls(group) num_questions = 20, num_testees = 1 ) - integrator = even_grid(-6.0, 6.0, mirtcat_quadpts(1)) + integrator = even_grid(-6.0, 6.0, 121) optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) - dist_ability_estimator = PriorAbilityEstimator() + dist_ability_estimator = PosteriorAbilityEstimator() ability_estimators = [ ("mean", MeanAbilityEstimator(dist_ability_estimator, integrator)), ("mode", ModeAbilityEstimator(dist_ability_estimator, optimizer)) @@ -38,10 +37,10 @@ function prepare_4pls(group) response_idxs = sample(rng, 1:20, 10) for (est_nick, ability_estimator) in ability_estimators - next_item_rule = ItemStrategyNextItemRule( + next_item_rule = ItemCriterionRule( ExhaustiveSearch(), ExpectationBasedItemCriterion(PointResponseExpectation(ability_estimator), - AbilityVarianceStateCriterion( + AbilityVariance( integrator, distribution_estimator(ability_estimator))) ) next_item_rule = preallocate(next_item_rule) diff --git a/docs/examples/examples/ability_convergence_3pl.jl b/docs/examples/examples/ability_convergence_3pl.jl index cb3608e..b31b1f0 100644 --- a/docs/examples/examples/ability_convergence_3pl.jl +++ b/docs/examples/examples/ability_convergence_3pl.jl @@ -20,9 +20,9 @@ using Distributions: Normal, cdf using AlgebraOfGraphics using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder -using ComputerAdaptiveTesting.NextItemRules: AbilityVarianceStateCriterion -using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition -using ComputerAdaptiveTesting.Aggregators: PriorAbilityEstimator, +using ComputerAdaptiveTesting.NextItemRules: AbilityVariance +using ComputerAdaptiveTesting.TerminationConditions: FixedLength +using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks using ComputerAdaptiveTesting.Responses: BooleanResponse @@ -46,18 +46,18 @@ using FittedItemBanks.DummyData: dummy_full, std_normal, SimpleItemBankSpec, Std # CatRecorder collects information which can be used to draw different types of plots. max_questions = 99 integrator = FixedGKIntegrator(-6, 6, 80) -dist_ability_est = PriorAbilityEstimator(std_normal) +dist_ability_est = PosteriorAbilityEstimator(std_normal) ability_estimator = MeanAbilityEstimator(dist_ability_est, integrator) rules = CatRules(ability_estimator, - AbilityVarianceStateCriterion(dist_ability_est, integrator), - FixedItemsTerminationCondition(max_questions)) + AbilityVariance(dist_ability_est, integrator), + FixedLength(max_questions)) points = 500 xs = range(-2.5, 2.5, length = points) raw_estimator = LikelihoodAbilityEstimator() recorder = CatRecorder(xs, responses, integrator, raw_estimator, ability_estimator) for testee_idx in axes(responses, 2) - tracked_responses, θ = run_cat(CatLoopConfig(rules = rules, + tracked_responses, θ = run_cat(CatLoop(rules = rules, get_response = auto_responder(@view responses[:, testee_idx]), new_response_callback = (tracked_responses, terminating) -> recorder(tracked_responses, testee_idx, diff --git a/docs/examples/examples/ability_convergence_mirt.jl b/docs/examples/examples/ability_convergence_mirt.jl index d71ac95..b8cb144 100644 --- a/docs/examples/examples/ability_convergence_mirt.jl +++ b/docs/examples/examples/ability_convergence_mirt.jl @@ -21,8 +21,8 @@ using AlgebraOfGraphics using ComputerAdaptiveTesting using ComputerAdaptiveTesting.Sim: auto_responder using ComputerAdaptiveTesting.NextItemRules: DRuleItemCriterion -using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition -using ComputerAdaptiveTesting.Aggregators: PriorAbilityEstimator, +using ComputerAdaptiveTesting.TerminationConditions: FixedLength +using ComputerAdaptiveTesting.Aggregators: PosteriorAbilityEstimator, MeanAbilityEstimator, LikelihoodAbilityEstimator using FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs @@ -49,11 +49,11 @@ using ComputerAdaptiveTesting.Responses: BooleanResponse # CatRecorder collects information which can be used to draw different types of plots. max_questions = 9 integrator = CubaIntegrator([-6.0, -6.0], [6.0, 6.0], CubaVegas(); rtol = 1e-2) -ability_estimator = MeanAbilityEstimator(PriorAbilityEstimator(std_mv_normal(dims)), +ability_estimator = MeanAbilityEstimator(PosteriorAbilityEstimator(std_mv_normal(dims)), integrator) rules = CatRules(ability_estimator, DRuleItemCriterion(ability_estimator), - FixedItemsTerminationCondition(max_questions)) + FixedLength(max_questions)) # XXX: We shouldn't need to specify xs here since the distributions are not used -- rework points = 3 @@ -67,7 +67,7 @@ recorder = CatRecorder(xs, abilities) for testee_idx in axes(responses, 2) @debug "Running for testee" testee_idx - tracked_responses, θ = run_cat(CatLoopConfig(rules = rules, + tracked_responses, θ = run_cat(CatLoop(rules = rules, get_response = auto_responder(@view responses[:, testee_idx]), new_response_callback = (tracked_responses, terminating) -> recorder(tracked_responses, testee_idx, diff --git a/docs/examples/examples/vocab_iq.jl b/docs/examples/examples/vocab_iq.jl index 01bad7e..009217d 100644 --- a/docs/examples/examples/vocab_iq.jl +++ b/docs/examples/examples/vocab_iq.jl @@ -5,9 +5,9 @@ # --- #md # Running a CAT based based on real response data -# +# # This example shows how to run a CAT end-to-end on real data. -# +# # First a 1-dimensional IRT model is fitted based on open response data to the # vocabulary IQ test using the IRTSupport package which internally, this uses # the `mirt` R package. Next, the model is used to administer the test @@ -37,13 +37,13 @@ function run_vocab_iq_cat() item_bank, labels = get_item_bank() integrator = FixedGKIntegrator(-6, 6, 61) ability_integrator = AbilityIntegrator(integrator) - dist_ability_est = PriorAbilityEstimator(std_normal) + dist_ability_est = PosteriorAbilityEstimator(std_normal) optimizer = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) ability_estimator = ModeAbilityEstimator(dist_ability_est, optimizer) @info "run_cat" ability_estimator rules = CatRules(ability_estimator, - AbilityVarianceStateCriterion(dist_ability_est, ability_integrator), - FixedItemsTerminationCondition(45)) + AbilityVariance(dist_ability_est, ability_integrator), + FixedLength(45)) function get_response(response_idx, response_name) params = item_params(item_bank, response_idx) println("Parameters for next question: $params") @@ -63,7 +63,7 @@ function run_vocab_iq_cat() println("Got ability estimate: $ability ± $var") println("") end - loop_config = CatLoopConfig(rules = rules, + loop_config = CatLoop(rules = rules, get_response = get_response, new_response_callback = new_response_callback) run_cat(loop_config, item_bank) diff --git a/docs/src/api.md b/docs/src/api.md index eb2a274..8755a67 100644 --- a/docs/src/api.md +++ b/docs/src/api.md @@ -8,5 +8,5 @@ CurrentModule = ComputerAdaptiveTesting ``` ```@autodocs -Modules = [ComputerAdaptiveTesting, ComputerAdaptiveTesting.Aggregators, ComputerAdaptiveTesting.Responses, ComputerAdaptiveTesting.Sim, ComputerAdaptiveTesting.TerminationConditions, ComputerAdaptiveTesting.NextItemRules, ComputerAdaptiveTesting.CatConfig] +Modules = [ComputerAdaptiveTesting, ComputerAdaptiveTesting.Aggregators, ComputerAdaptiveTesting.Responses, ComputerAdaptiveTesting.Sim, ComputerAdaptiveTesting.TerminationConditions, ComputerAdaptiveTesting.NextItemRules, ComputerAdaptiveTesting.Rules] ``` diff --git a/docs/src/creating_a_cat.md b/docs/src/creating_a_cat.md index 8b2664b..35927bd 100644 --- a/docs/src/creating_a_cat.md +++ b/docs/src/creating_a_cat.md @@ -13,7 +13,7 @@ The configuration of a CAT is built up as a tree of configuration structs. These structs are all subtypes of `CatConfigBase`. ```@docs; canonical=false -ComputerAdaptiveTesting.CatConfig.CatConfigBase +ComputerAdaptiveTesting.ConfigBase.CatConfigBase ``` The constructors for the configuration structs in this package tend to have @@ -59,7 +59,7 @@ next item selection rule, and the stopping rule. `CatRules` has explicit and implicit constructors. ```@docs; canonical=false -ComputerAdaptiveTesting.CatConfig.CatRules +ComputerAdaptiveTesting.CatRules ``` ### Next item selection with `NextItemRule` @@ -79,13 +79,13 @@ ComputerAdaptiveTesting.NextItemRules.RandomNextItemRule Other rules are created by combining a `ItemCriterion` -- which somehow rates items according to how good they are -- with a `NextItemStrategy` using an -`ItemStrategyNextItemRule`, which acts as an adapter. The default +`ItemCriterionRule`, which acts as an adapter. The default `NextItemStrategy` (and currently only) is `ExhaustiveSearch`. When using the implicit constructors, `ItemCriterion` can therefore be used directly without wrapping in any place an NextItemRule is expected. ```@docs; canonical=false -ComputerAdaptiveTesting.NextItemRules.ItemStrategyNextItemRule +ComputerAdaptiveTesting.NextItemRules.ItemCriterionRule ``` ```@docs; canonical=false @@ -114,17 +114,17 @@ takes a `ResponseExpectation`: either `PointResponseExpectation` or good a particular state is in terms getting a good estimate of the test takers ability. They look one ply ahead to get the expected value of the ``StateCriterion`` after selecting the given item. The -`AbilityVarianceStateCriterion` looks at the variance of the ability ``\theta`` +`AbilityVariance` looks at the variance of the ability ``\theta`` estimate at that state. ### Stopping rules with `TerminationCondition` -Currently the only `TerminationCondition` is `FixedItemsTerminationCondition`, which ends the test after a fixed number of items. +Currently the only `TerminationCondition` is `FixedLength`, which ends the test after a fixed number of items. ```@docs; canonical=false ComputerAdaptiveTesting.TerminationConditions.TerminationCondition ``` ```@docs; canonical=false -ComputerAdaptiveTesting.TerminationConditions.FixedItemsTerminationCondition +ComputerAdaptiveTesting.TerminationConditions.FixedLength ``` diff --git a/docs/src/stateful.md b/docs/src/stateful.md index 2c62260..0fea784 100644 --- a/docs/src/stateful.md +++ b/docs/src/stateful.md @@ -28,9 +28,9 @@ Stateful.get_ability There is an implementation in terms of [CatRules](@ref): ```@docs -Stateful.StatefulCatConfig +Stateful.StatefulCatRules ``` ## Usage -Just as [CatLoopConfig](@ref) can wrap [CatRules](@ref), you can also use it with any implementor of [Stateful.StatefulCat](@ref), and run using [Sim.run_cat](@ref). \ No newline at end of file +Just as [CatLoop](@ref) can wrap [CatRules](@ref), you can also use it with any implementor of [Stateful.StatefulCat](@ref), and run using [Sim.run_cat](@ref). \ No newline at end of file diff --git a/docs/src/using_your_cat.md b/docs/src/using_your_cat.md index dd33d4d..04711c1 100644 --- a/docs/src/using_your_cat.md +++ b/docs/src/using_your_cat.md @@ -9,10 +9,10 @@ a number of ways you can use it. This section covers a few. See also the [Examples](@ref demo-page). -When you've set up your CAT using [CatRules](@ref), you can wrap it in a [CatLoopConfig](@ref) and run it with [run_cat](@ref). +When you've set up your CAT using [CatRules](@ref), you can wrap it in a [CatLoop](@ref) and run it with [run_cat](@ref). ```@docs; canonical=false -CatLoopConfig +CatLoop run_cat ``` diff --git a/ext/TestExt.jl b/ext/TestExt.jl index c305dfc..996c573 100644 --- a/ext/TestExt.jl +++ b/ext/TestExt.jl @@ -2,7 +2,7 @@ module TestExt using Test using ComputerAdaptiveTesting: Stateful -using FittedItemBanks: AbstractItemBank, ItemResponse, resp +using FittedItemBanks: AbstractItemBank, ItemResponse, resp_vec export test_stateful_cat_1d_dich_ib, test_stateful_cat_item_bank_1d_dich_ib @@ -96,18 +96,38 @@ function test_stateful_cat_item_bank_1d_dich_ib( cat::Stateful.StatefulCat, item_bank::AbstractItemBank, points=[-.78, 0.0, .78], - margin=0.05, + margin=0.01, ) if length(item_bank) != Stateful.item_bank_size(cat) error("Item bank length does not match the cat's item bank size.") end for i in 1:length(item_bank) for point in points - cat_prob = Stateful.item_response_function(cat, i, true, point) - ib_prob = resp(ItemResponse(item_bank, i), true, point) + cat_prob = Stateful.item_response_functions(cat, i, point) + ib_prob = resp_vec(ItemResponse(item_bank, i), point) @test cat_prob ≈ ib_prob rtol=margin end end end +function test_ability( + cat1::Stateful.StatefulCat, + cat2::Stateful.StatefulCat, + item_bank_length; + margin=0.01 +) + if item_bank_length < 4 + error("Item bank length must be at least 4.") + end + for cat in (cat1, cat2) + Stateful.add_response!(cat, 1, false) + Stateful.add_response!(cat, 2, true) + Stateful.add_response!(cat, 3, false) + Stateful.add_response!(cat, 4, true) + end + ability1 = Stateful.get_ability(cat1) + ability2 = Stateful.get_ability(cat2) + @test ability1[1] ≈ ability2[1] rtol=margin +end + end \ No newline at end of file diff --git a/profile/next_items.jl b/profile/next_items.jl index e23ad07..ef8b5fc 100644 --- a/profile/next_items.jl +++ b/profile/next_items.jl @@ -20,7 +20,7 @@ function get_ability_estimator(multidim) integrator = FixedGKIntegrator(-6.0, 6.0) dist = Normal() end - return PriorAbilityEstimator(dist, integrator) + return PosteriorAbilityEstimator(dist, integrator) end function prepare_empty(item_bank, actual_responses, ability_tracker) diff --git a/src/aggregators/Aggregators.jl b/src/Aggregators/Aggregators.jl similarity index 89% rename from src/aggregators/Aggregators.jl rename to src/Aggregators/Aggregators.jl index b73e188..4e5b967 100644 --- a/src/aggregators/Aggregators.jl +++ b/src/Aggregators/Aggregators.jl @@ -10,6 +10,7 @@ using StaticArrays: SVector using Distributions: Distribution, Normal, Distributions using Base.Threads using ForwardDiff: ForwardDiff +using LogarithmicNumbers: Logarithmic, ULogarithmic using FittedItemBanks: AbstractItemBank, ContinuousDomain, DichotomousSmoothedItemBank, DiscreteIndexableDomain, @@ -24,12 +25,16 @@ using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_type_sloppy using PsychometricsBazaarBase.Integrators: Integrators, BareIntegrationResult, - FixedGridIntegrator, IntReturnType, + FixedGridIntegrator, + IntReturnType, IntValue, Integrator, PreallocatedFixedGridIntegrator, normdenom -using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer +using PsychometricsBazaarBase.Optimizers: OneDimOptimOptimizer, Optimizer, Optimizers using PsychometricsBazaarBase.ConstDistributions: std_normal, std_mv_normal +using PsychometricsBazaarBase.IndentWrappers: indent +import Distributions: pdf +import Base: show import FittedItemBanks import PsychometricsBazaarBase.IntegralCoeffs @@ -38,7 +43,8 @@ export AbilityEstimator, TrackedResponses export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker export ClosedFormNormalAbilityTracker, track! export response_expectation, expectation, distribution_estimator -export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator +export PointAbilityEstimator, PosteriorAbilityEstimator +export SafeLikelihoodAbilityEstimator, LikelihoodAbilityEstimator export ModeAbilityEstimator, MeanAbilityEstimator export Speculator, replace_speculation!, normdenom, maybe_tracked_ability_estimate export AbilityIntegrator, AbilityOptimizer @@ -67,9 +73,16 @@ function AbilityEstimator(::ContinuousDomain, bits...) integrator) end +# Mark as a scalar for broadcasting +Base.broadcastable(ir::AbilityEstimator) = Ref(ir) + abstract type DistributionAbilityEstimator <: AbilityEstimator end function DistributionAbilityEstimator(bits...) @returnsome find1_instance(DistributionAbilityEstimator, bits) + point_ability_estimator = find1_instance(PointAbilityEstimator, bits) + if point_ability_estimator !== nothing + return distribution_estimator(point_ability_estimator) + end end abstract type PointAbilityEstimator <: AbilityEstimator end @@ -163,6 +176,9 @@ function TrackedResponses(responses, item_bank) TrackedResponses(responses, item_bank, NullAbilityTracker()) end +# Mark as a scalar for broadcasting +Base.broadcastable(ir::TrackedResponses) = Ref(ir) + function Responses.AbilityLikelihood(tracked_responses::TrackedResponses{ BareResponsesT, ItemBankT, @@ -194,6 +210,10 @@ function (integrator::FunctionIntegrator{IntegratorT})(f::F, integrator.integrator(FunctionProduct(f, lh_function), ncomp) end +function show(io::IO, ::MIME"text/plain", responses::FunctionIntegrator) + show(io, MIME("text/plain"), responses.integrator) +end + # Defaults const optim_tol = 1e-12 const int_tol = 1e-8 diff --git a/src/aggregators/ability_estimator.jl b/src/Aggregators/ability_estimator.jl similarity index 68% rename from src/aggregators/ability_estimator.jl rename to src/Aggregators/ability_estimator.jl index 87a372e..ed2d204 100644 --- a/src/aggregators/ability_estimator.jl +++ b/src/Aggregators/ability_estimator.jl @@ -11,6 +11,8 @@ function Integrators.normdenom(rett::IntReturnType, rett(integrator(IntegralCoeffs.one, 0, est, tracked_responses)) end +# This is not type piracy, but maybe a slightly distasteful overload +# TODO: Fix this interface? function pdf(ability_est::DistributionAbilityEstimator, tracked_responses::TrackedResponses, x) @@ -24,24 +26,83 @@ function pdf(::LikelihoodAbilityEstimator, AbilityLikelihood(tracked_responses) end -struct PriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstimator +function show(io::IO, ::MIME"text/plain", ability_estimator::LikelihoodAbilityEstimator) + println(io, "Ability likelihood distribution") +end + +struct PosteriorAbilityEstimator{PriorT <: Distribution} <: DistributionAbilityEstimator prior::PriorT end -function PriorAbilityEstimator(; ncomp = 0) +function PosteriorAbilityEstimator(; ncomp = 0) if ncomp == 0 - return PriorAbilityEstimator(std_normal) + return PosteriorAbilityEstimator(std_normal) else - return PriorAbilityEstimator(std_mv_normal(ncomp)) + return PosteriorAbilityEstimator(std_mv_normal(ncomp)) end end -function pdf(est::PriorAbilityEstimator, +function pdf(est::PosteriorAbilityEstimator, tracked_responses::TrackedResponses) IntegralCoeffs.PriorApply(IntegralCoeffs.Prior(est.prior), AbilityLikelihood(tracked_responses)) end +function multiple_response_types_guard(tracked_responses) + if length(tracked_responses.responses.values) == 0 + return false + end + seen_value = tracked_responses.responses.values[1] + for value in tracked_responses.responses.values + if value !== seen_value + return true + end + end + return false +end + +function show(io::IO, ::MIME"text/plain", ability_estimator::PosteriorAbilityEstimator) + println(io, "Ability posterior distribution") + indent_io = indent(io, 2) + print(indent_io, "Prior: ") + show(indent_io, MIME("text/plain"), ability_estimator.prior) + println(io) +end + +struct GuardedAbilityEstimator{T <: DistributionAbilityEstimator, U <: DistributionAbilityEstimator, F} <: DistributionAbilityEstimator + est::T + fallback::U + guard::F +end + +function pdf(est::GuardedAbilityEstimator, + tracked_responses::TrackedResponses) + if est.guard(tracked_responses) + return pdf(est.est, tracked_responses) + else + return pdf(est.fallback, tracked_responses) + end +end + +function SafeLikelihoodAbilityEstimator(args...; kwargs...) + GuardedAbilityEstimator( + LikelihoodAbilityEstimator(), + PosteriorAbilityEstimator(args...), + multiple_response_types_guard + ) +end + +unlog(x) = x +unlog(x::Logarithmic{T}) where {T} = T(x) +unlog(x::ULogarithmic{T}) where {T} = T(x) +unlog(x::AbstractVector{Logarithmic{T}}) where {T} = T.(x) +unlog(x::AbstractVector{ULogarithmic{T}}) where {T} = T.(x) +#=unlog(x::ErrorIntegrationResult{Logarithmic{T}}) where {T} = T(x) +unlog(x::ErrorIntegrationResult{ULogarithmic{T}}) where {T} = T(x) +unlog(x::ErrorIntegrationResult{AbstractVector{Logarithmic{T}}}) where {T} = T.(x) +unlog(x::ErrorIntegrationResult{AbstractVector{ULogarithmic{T}}}) where {T} = T.(x) +=# + function expectation(rett::IntReturnType, f::F, ncomp, @@ -49,7 +110,7 @@ function expectation(rett::IntReturnType, est::DistributionAbilityEstimator, tracked_responses::TrackedResponses, denom = normdenom(rett, integrator, est, tracked_responses)) where {F} - rett(integrator(f, ncomp, est, tracked_responses)) / denom + unlog(rett(integrator(f, ncomp, est, tracked_responses)) / denom) end function expectation(f::F, @@ -163,6 +224,13 @@ function ModeAbilityEstimator(bits...) ModeAbilityEstimator(dist_est, optimizer) end +function show(io::IO, ::MIME"text/plain", ability_estimator::ModeAbilityEstimator) + println(io, "Estimate ability using its mode") + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), ability_estimator.dist_est) + show(indent_io, MIME("text/plain"), ability_estimator.optim) +end + struct MeanAbilityEstimator{ DistEst <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator @@ -178,6 +246,14 @@ function MeanAbilityEstimator(bits...) MeanAbilityEstimator(dist_est, integrator) end +function show(io::IO, ::MIME"text/plain", ability_estimator::MeanAbilityEstimator) + println(io, "Estimate ability using its mean") + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), ability_estimator.dist_est) + print(indent_io, "Integrator: ") + show(indent_io, MIME("text/plain"), ability_estimator.integrator) +end + function distribution_estimator(dist_est::DistributionAbilityEstimator)::DistributionAbilityEstimator dist_est end @@ -231,7 +307,7 @@ function (est::MeanAbilityEstimator{AbilityEstimatorT, RiemannEnumerationIntegra tracked_responses) end -function maybe_apply_prior(f::F, est::PriorAbilityEstimator) where {F} +function maybe_apply_prior(f::F, est::PosteriorAbilityEstimator) where {F} IntegralCoeffs.PriorApply(IntegralCoeffs.Prior(est.prior), f) end diff --git a/src/aggregators/ability_tracker.jl b/src/Aggregators/ability_tracker.jl similarity index 100% rename from src/aggregators/ability_tracker.jl rename to src/Aggregators/ability_tracker.jl diff --git a/src/aggregators/ability_trackers/closed_form_normal.jl b/src/Aggregators/ability_trackers/closed_form_normal.jl similarity index 98% rename from src/aggregators/ability_trackers/closed_form_normal.jl rename to src/Aggregators/ability_trackers/closed_form_normal.jl index 0c81ffc..83505ea 100644 --- a/src/aggregators/ability_trackers/closed_form_normal.jl +++ b/src/Aggregators/ability_trackers/closed_form_normal.jl @@ -2,7 +2,7 @@ mutable struct ClosedFormNormalAbilityTracker <: AbilityTracker cur_ability::VarNormal end -function ClosedFormNormalAbilityTracker(prior_ability_estimator::PriorAbilityEstimator) +function ClosedFormNormalAbilityTracker(prior_ability_estimator::PosteriorAbilityEstimator) @warn "ClosedFormNormalAbilityTracker is based on equations from Liden 1998 / Owen 1975, but these appear to give poor results" prior = prior_ability_estimator.prior if !(prior isa Normal) diff --git a/src/aggregators/ability_trackers/grid.jl b/src/Aggregators/ability_trackers/grid.jl similarity index 100% rename from src/aggregators/ability_trackers/grid.jl rename to src/Aggregators/ability_trackers/grid.jl diff --git a/src/aggregators/ability_trackers/laplace.jl b/src/Aggregators/ability_trackers/laplace.jl similarity index 100% rename from src/aggregators/ability_trackers/laplace.jl rename to src/Aggregators/ability_trackers/laplace.jl diff --git a/src/aggregators/ability_trackers/point.jl b/src/Aggregators/ability_trackers/point.jl similarity index 100% rename from src/aggregators/ability_trackers/point.jl rename to src/Aggregators/ability_trackers/point.jl diff --git a/src/aggregators/optimizers.jl b/src/Aggregators/optimizers.jl similarity index 60% rename from src/aggregators/optimizers.jl rename to src/Aggregators/optimizers.jl index 2c45006..314b6c3 100644 --- a/src/aggregators/optimizers.jl +++ b/src/Aggregators/optimizers.jl @@ -10,6 +10,22 @@ function (optim::FunctionOptimizer)(f::F, optim.optim(comp_f) end +function show(io::IO, ::MIME"text/plain", optim::FunctionOptimizer) + indent_io = indent(io, 2) + if optim.optim isa Optimizers.OneDimOptimOptimizer || optim.optim isa Optimizers.MultiDimOptimOptimizer || optim.optim isa Optimizers.NativeOneDimOptimOptimizer + inner = optim.optim + println(io, "Optimizer:") + if optim.optim isa Optimizers.NativeOneDimOptimOptimizer + name = typeof(inner.method).name.name + else + name = typeof(inner.optim).name.name + end + println(indent_io, "Method: ", name) + println(indent_io, "Lo: ", inner.lo) + println(indent_io, "Hi: ", inner.hi) + end +end + #= """ Argmax + max over the ability likihood given a set of responses with a given @@ -32,7 +48,7 @@ function (optim::EnumerationOptimizer)(f::F, ability_likelihood.item_bank; lo = lo, hi = hi) do (x, prob) - # @inline + # @inline fprob = f(x) * prob if fprob >= cur_max[] cur_argmax[] = x @@ -47,5 +63,6 @@ function (optim::AbilityOptimizer)(f::F, est, tracked_responses::TrackedResponses; kwargs...) where {F} - optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...) + #optim(maybe_apply_prior(f, est), AbilityLikelihood(tracked_responses); kwargs...) + optim(f, pdf(est, tracked_responses); kwargs...) end diff --git a/src/aggregators/riemann.jl b/src/Aggregators/riemann.jl similarity index 71% rename from src/aggregators/riemann.jl rename to src/Aggregators/riemann.jl index 243e0ae..cf5ea82 100644 --- a/src/aggregators/riemann.jl +++ b/src/Aggregators/riemann.jl @@ -26,13 +26,32 @@ function (integrator::RiemannEnumerationIntegrator)(f::F, return BareIntegrationResult(result) end -function (integrator::Union{RiemannEnumerationIntegrator, FunctionIntegrator})(f::F, - ncomp, - est, - tracked_responses::TrackedResponses; - kwargs...) where {F} - integrator(maybe_apply_prior(f, est), +function (integrator::RiemannEnumerationIntegrator)( + f::F, + ncomp, + est, + tracked_responses::TrackedResponses; + kwargs... +) where {F} + integrator( + maybe_apply_prior(f, est), ncomp, AbilityLikelihood(tracked_responses); - kwargs...) + kwargs... + ) +end + +function (integrator::FunctionIntegrator)( + f::F, + ncomp, + est, + tracked_responses::TrackedResponses; + kwargs... +) where {F} + integrator( + f, + ncomp, + pdf(est, tracked_responses); + kwargs... + ) end diff --git a/src/aggregators/slow.jl b/src/Aggregators/slow.jl similarity index 100% rename from src/aggregators/slow.jl rename to src/Aggregators/slow.jl diff --git a/src/aggregators/speculators.jl b/src/Aggregators/speculators.jl similarity index 100% rename from src/aggregators/speculators.jl rename to src/Aggregators/speculators.jl diff --git a/src/aggregators/tracked.jl b/src/Aggregators/tracked.jl similarity index 100% rename from src/aggregators/tracked.jl rename to src/Aggregators/tracked.jl diff --git a/src/Comparison.jl b/src/Comparison/Comparison.jl similarity index 69% rename from src/Comparison.jl rename to src/Comparison/Comparison.jl index 40aa9e2..046bbcb 100644 --- a/src/Comparison.jl +++ b/src/Comparison/Comparison.jl @@ -23,6 +23,8 @@ export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy export ReplayResponsesExecutionStrategy export CatComparisonConfig +include("./watchdog.jl") + struct RandomCatComparison true_abilities::Array{Float64} rand_abilities::Array{Float64, 3} @@ -82,8 +84,7 @@ end abstract type CatComparisonExecutionStrategy end -struct CatComparisonConfig{ - StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple} +struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple} """ A named tuple with the (named) CatRules (or compatable) to be compared """ @@ -102,6 +103,18 @@ struct CatComparisonConfig{ The phases to run, optionally paired with a callback """ phases::PhasesT + """ + Where to sample for likelihood + """ + sample_points::Union{Vector{Float64}, Nothing} + """ + Skips + """ + skip_callback + """ + Watchdog timeout + """ + timeout::Float64 end """ @@ -109,6 +122,7 @@ end rules::NamedTuple{Symbol, StatefulCat}, strategy::CatComparisonExecutionStrategy, phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}}, + skips::Set{Tuple{Symbol, Symbol}}, callback::Callable ) -> CatComparisonConfig @@ -123,18 +137,24 @@ no callback is provided. The exact phases depend on the strategy used. See their individual documentation for more. """ -function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing) +function CatComparisonConfig(; rules, strategy, phases = nothing, skip_callback = ((_, _, _) -> false), sample_points = nothing, callback = nothing, timeout = Inf) if callback === nothing callback = (info; kwargs...) -> nothing end if phases === nothing phases = (:before_next_item, :after_next_item) end - # TODO: normalize phases into named tuple if !(phases isa NamedTuple) phases = NamedTuple((phase => callback for phase in phases)) end - CatComparisonConfig(rules, strategy, phases) + CatComparisonConfig( + rules, + strategy, + phases, + sample_points, + skip_callback, + timeout + ) end # Comparison scenarios: @@ -158,7 +178,6 @@ end #phase_func=nothing; function measure_all(comparison, system, cat, phase; kwargs...) - @info "measure_all" phase system kwargs if !(phase in keys(comparison.phases)) return end @@ -273,7 +292,6 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec num_items=size, system_name=name ) - @info "next_item" name timed_next_item.time strategy.time_limit if timed_next_item.time < strategy.time_limit push!(next_current_cats, name => cat) end @@ -300,108 +318,165 @@ end struct ReplayResponsesExecutionStrategy <: CatComparisonExecutionStrategy responses::BareResponses + time_limit::Float64 +end + +ReplayResponsesExecutionStrategy(responses) = ReplayResponsesExecutionStrategy(responses, Inf) + +function should_run(comparison, name, cat, phase) + return phase in keys(comparison.phases) && + !comparison.skip_callback(name, cat, phase) end # Which questions to ask: Specified # Which answer to use: From response memory function run_comparison(comparison::CatComparisonConfig{ReplayResponsesExecutionStrategy}) strategy = comparison.strategy - for (items_answered, response) in zip( - Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing]))) - for (name, cat) in pairs(comparison.rules) - if :before_item_criteria in comparison.phases - timed_item_criteria = @timed Stateful.item_criteria(cat) - measure_all( - comparison, - name, - cat, - :before_item_criteria, - items_answered = items_answered, - item_criteria = timed_item_criteria.value, - timing = timed_item_criteria - ) - end - if :before_ranked_items in comparison.phases - timed_ranked_items = @timed Stateful.ranked_items(cat) - measure_all( - comparison, - name, - cat, - :before_ranked_items, - items_answered = items_answered, - ranked_items = timed_ranked_items.value, - timing = timed_ranked_items - ) - end - if :before_ability in comparison.phases - timed_get_ability = @timed Stateful.get_ability(cat) - measure_all( - comparison, - name, - cat, - :before_ability, - items_answered = items_answered, - ability = timed_get_ability.value, - timing = timed_get_ability - ) - end - measure_all( - comparison, - name, - cat, - :before_next_item, - items_answered = items_answered - ) - timed_next_item = @timed Stateful.next_item(cat) - next_item = timed_next_item.value - measure_all( - comparison, - name, - cat, - :after_next_item, - next_item = next_item, - timing = timed_next_item, - items_answered = items_answered - ) - if :after_item_criteria in comparison.phases - # TOOD: Combine with next_item if possible and requested? - timed_item_criteria = @timed Stateful.item_criteria(cat) - measure_all( - comparison, - name, - cat, - :after_item_criteria, - items_answered = items_answered, - item_criteria = timed_item_criteria.value, - timing = timed_item_criteria - ) + current_cats = Dict(pairs(comparison.rules)) + function check_time(name, timer) + if timer.time >= strategy.time_limit + if name in keys(current_cats) + @info "Time limit exceeded" name timer.time + delete!(current_cats, name) end - if :after_ranked_items in comparison.phases - timed_ranked_items = @timed Stateful.ranked_items(cat) + end + end + watchdog = WatchdogTask(comparison.timeout) + start!(watchdog) do + for (items_answered, response) in zip( + Iterators.countfrom(0), Iterators.flatten((strategy.responses, [nothing]))) + for (name, cat) in pairs(current_cats) + println("") + println("Starting $name for $items_answered") + flush(stdout) + if should_run(comparison, name, cat, :before_item_criteria) + reset!(watchdog, "$name item_criteria") + timed_item_criteria = @timed Stateful.item_criteria(cat) + check_time(name, timed_item_criteria) + measure_all( + comparison, + name, + cat, + :before_item_criteria, + items_answered = items_answered, + item_criteria = timed_item_criteria.value, + timing = timed_item_criteria + ) + end + if should_run(comparison, name, cat, :before_ranked_items) + reset!(watchdog, "$name ranked_items") + timed_ranked_items = @timed Stateful.ranked_items(cat) + check_time(name, timed_ranked_items) + measure_all( + comparison, + name, + cat, + :before_ranked_items, + items_answered = items_answered, + ranked_items = timed_ranked_items.value, + timing = timed_ranked_items + ) + end + if should_run(comparison, name, cat, :before_ability) + reset!(watchdog, "$name get_ability") + timed_get_ability = @timed Stateful.get_ability(cat) + check_time(name, timed_get_ability) + measure_all( + comparison, + name, + cat, + :before_ability, + items_answered = items_answered, + ability = timed_get_ability.value, + timing = timed_get_ability + ) + end measure_all( comparison, name, cat, - :after_ranked_items, - items_answered = items_answered, - ranked_items = timed_ranked_items.value, - timing = timed_ranked_items + :before_next_item, + items_answered = items_answered ) - end - if :after_ability in comparison.phases - timed_get_ability = @timed Stateful.get_ability(cat) + reset!(watchdog, "$name next_item") + timed_next_item = @timed Stateful.next_item(cat) + check_time(name, timed_next_item) + next_item = timed_next_item.value measure_all( comparison, name, cat, - :after_ability, - items_answered = items_answered, - ability = timed_get_ability.value, - timing = timed_get_ability + :after_next_item, + next_item = next_item, + timing = timed_next_item, + items_answered = items_answered ) - end - if response !== nothing - Stateful.add_response!(cat, response.index, response.value) + if should_run(comparison, name, cat, :after_item_criteria) + # TOOD: Combine with next_item if possible and requested? + reset!(watchdog, "$name item_criteria") + timed_item_criteria = @timed Stateful.item_criteria(cat) + check_time(name, timed_item_criteria) + if timed_item_criteria.value !== nothing + measure_all( + comparison, + name, + cat, + :after_item_criteria, + items_answered = items_answered, + item_criteria = timed_item_criteria.value, + timing = timed_item_criteria + ) + end + end + if should_run(comparison, name, cat, :after_ranked_items) + reset!(watchdog, "$name ranked_items") + timed_ranked_items = @timed Stateful.ranked_items(cat) + check_time(name, timed_ranked_items) + if timed_ranked_items.value !== nothing + measure_all( + comparison, + name, + cat, + :after_ranked_items, + items_answered = items_answered, + ranked_items = timed_ranked_items.value, + timing = timed_ranked_items + ) + end + end + if should_run(comparison, name, cat, :after_likelihood) + reset!(watchdog, "$name likelihood") + timed_likelihood = @timed Stateful.likelihood.(Ref(cat), comparison.sample_points) + check_time(name, timed_likelihood) + measure_all( + comparison, + name, + cat, + :after_likelihood, + items_answered = items_answered, + sample_points = comparison.sample_points, + likelihood = timed_likelihood.value, + timing = timed_likelihood + ) + + end + if should_run(comparison, name, cat, :after_ability) + reset!(watchdog, "$name get_ability") + timed_get_ability = @timed Stateful.get_ability(cat) + check_time(name, timed_get_ability) + measure_all( + comparison, + name, + cat, + :after_ability, + items_answered = items_answered, + ability = timed_get_ability.value, + timing = timed_get_ability + ) + end + if response !== nothing + Stateful.add_response!(cat, response.index, response.value) + end end end end diff --git a/src/Comparison/watchdog.jl b/src/Comparison/watchdog.jl new file mode 100644 index 0000000..7a4fda8 --- /dev/null +++ b/src/Comparison/watchdog.jl @@ -0,0 +1,154 @@ +using Base.Threads: nthreads + + +abstract type AbstractWatchdogTask end + +mutable struct WatchdogTask <: AbstractWatchdogTask + timeout::Float64 + channel::Channel + task::Union{Task, Nothing} +end + +function run_watchdog(timeout, channel, worker_task) + #Core.println("Starting watchdog") + #Base.flush(stdout) + reset_timestamp = time() + deadline = reset_timestamp + timeout + msg = nothing + active = false + die = false + #Core.println("X") + #Base.flush(stdout) + l = ReentrantLock() + #Core.println("Y") + #Base.flush(stdout) + activation = Threads.Condition(l) + #Core.println("Blam") + #Base.flush(stdout) + @async begin + #Core.println("Subloop") + while true + cmd = take!(channel) + #Core.println("Take") + if haskey(cmd, :kill) + die = true + lock(l) do + notify(activation) + end + break + end + lock(l) do + if haskey(cmd, :reset_timestamp) + reset_timestamp = cmd[:reset_timestamp] + deadline = reset_timestamp + timeout + end + if haskey(cmd, :msg) + msg = cmd[:msg] + end + if haskey(cmd, :active) + active = cmd[:active] + if active + #Core.println("Notify") + notify(activation) + end + end + end + end + end + loop = true + while loop + #Core.println("Aquiring lock") + loop = lock(l) do + while !active && !die + #Core.println("Waiting for activation") + wait(activation) + end + if die + return false + end + if active + unlock(l) + try + delay = deadline - time() + #Core.println("Sleeping for $delay") + sleep(max(delay, 0.0)) + finally + lock(l) + end + end + if die + return false + end + overrun = time() - deadline + if overrun > 0 && active + msg = "WATCHDOG TIMEOUT: $msg timed after after $(timeout)s (overran $(overrun)s)" + unlock(l) + put!(channel, (; kill=true)) + Core.println("") + Core.println(msg) + Core.println("") + Base.flush(Core.stdout) + sleep(0.1) + schedule(worker_task, InterruptException(), error=true) + # Wait a proper amount of time here since otherwise we will probably not get a stacktrace + sleep(5.0) + if istaskdone(worker_task) + return false + end + ccall(:uv_kill, Cint, (Cint, Cint), getpid(), Base.SIGTERM) + sleep(1.0) + ccall(:uv_kill, Cint, (Cint, Cint), getpid(), Base.SIGKILL) + sleep(1.0) + exit(1) # This is done last since it doesn't always take down the parent + end + return true + end + end +end + +function WatchdogTask(timeout::Float64) + if timeout !== Inf + channel = Channel{Any}(Inf) + WatchdogTask(timeout, channel, nothing) + else + NullWatchdog() + end + #WatchdogTask(task, timeout, channel, nothing) +end + +function start!(f, watchdog::WatchdogTask) + if nthreads(:interactive) < 1 || nthreads(:default) < 1 + error("WatchdogTask: Need an interactive and default thread") + end + worker_task = Threads.@spawn :default f() + watchdog.task = Threads.@spawn :interactive run_watchdog(watchdog.timeout, watchdog.channel, worker_task) + wait(worker_task) + put!(watchdog.channel, (; kill=true)) + wait(watchdog.task) +end + +function reset!(watchdog::WatchdogTask, msg=nothing) + if istaskdone(watchdog.task) + wait(watchdog.task) + end + payload = (; + active=true, + reset_timestamp=time(), + ) + if msg !== nothing + payload = (; payload..., msg=msg) + end + #@info "Put" payload + put!(watchdog.channel, payload) +end + +function deactivate!(watchdog::WatchdogTask) + if istaskdone(watchdog.task) + wait(watchdog.task) + end + put(watchdog.channel, (; active=false)) +end + +struct NullWatchdog <: AbstractWatchdogTask end +function reset!(::NullWatchdog, msg=nothing) end +function deactivate!(::NullWatchdog) end \ No newline at end of file diff --git a/src/Compat/CatR.jl b/src/Compat/CatR.jl new file mode 100644 index 0000000..c17cb6e --- /dev/null +++ b/src/Compat/CatR.jl @@ -0,0 +1,111 @@ +module CatR + +using ComputerAdaptiveTesting.Aggregators: AbilityIntegrator, + LikelihoodAbilityEstimator, + DistributionAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + PosteriorAbilityEstimator +using ComputerAdaptiveTesting.TerminationConditions: RunForever +using ComputerAdaptiveTesting.Rules: CatRules +using ComputerAdaptiveTesting.NextItemRules +using PsychometricsBazaarBase: Integrators, Optimizers + +public next_item_aliases, ability_estimator_aliases, assemble_rules + +function _next_item_aliases() + res = Dict{String, Any}() + for (nick, mk_item_criterion) in ( + "MFI" => InformationItemCriterion, + "bOpt" => UrryItemCriterion, + ) + res[nick] = (bits...; kwargs...) -> ItemCriterionRule( + ExhaustiveSearch(), + mk_item_criterion(bits...)) + end + res["MEPV"] = (bits...; posterior_ability_estimator, kwargs...) -> ItemCriterionRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(bits..., + AbilityVariance(posterior_ability_estimator, AbilityIntegrator(bits...)))) + res["MEI"] = (bits...; kwargs...) -> ItemCriterionRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(bits..., + InformationItemCriterion(bits...))) + #"MLWI", #"MPWI", + return res + #"thOpt", + #"progressive", + #"proportional", + #"KL", + #"KLP", + #"GDI", + #"GDIP", + #"random" +end + +""" +This mapping provides next item rules through the same names that they are +available through in the `catR` R package. TODO compability with `mirtcat` +""" +const next_item_aliases = _next_item_aliases() + +function _ability_estimator_aliases() + res = Dict{String, Any}() + res["BM"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(PosteriorAbilityEstimator(), optimizer) + res["ML"] = (; optimizer, kwargs...) -> ModeAbilityEstimator(LikelihoodAbilityEstimator(), optimizer) + res["EAP"] = (; integrator, kwargs...) -> MeanAbilityEstimator(PosteriorAbilityEstimator(), integrator) + #res["WL"] + #res["ROB"] + return res +end + +const ability_estimator_aliases = _ability_estimator_aliases() + +#= + for (resp_exp, resp_exp_nick) in resp_exp_nick_pairs + next_item_rule = NextItemRule( + ExpectationBasedItemCriterion(resp_exp, AbilityVariance(numtools.integrator, distribution_estimator(abil_est))) + ) + next_item_rule = preallocate(next_item_rule) + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mepv_$(resp_exp_nick)")] = (abil_est, next_item_rule) + next_item_rule = NextItemRule( + ExpectationBasedItemCriterion(resp_exp, InformationItemCriterion(abil_est)) + ) + next_item_rule = preallocate(next_item_rule) + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mei_$(resp_exp_nick)")] = (abil_est, next_item_rule) + end + est_next_item_rule_pairs[Symbol("$(abil_est_str)_mi")] = (abil_est, InformationItemCriterion(abil_est)) +=# + + +function setup_integrator(lo=-4.0, hi=4.0, pts=33) + Integrators.MidpointIntegrator(range(lo, hi, pts)) +end + +function setup_optimizer(lo=-4.0, hi=4.0) + Optimizers.NativeOneDimOptimOptimizer(; lo, hi) +end + +function assemble_rules(; + criterion, + method, + start_item = 1 + #prior_dist="norm", + #prior_par=@SVector[0.0, 1.0], + #info_type="observed" +) + integrator = setup_integrator() + optimizer = setup_optimizer() + ability_estimator = ability_estimator_aliases[method](; integrator, optimizer) + posterior_ability_estimator = PosteriorAbilityEstimator() + raw_next_item = next_item_aliases[criterion](ability_estimator, integrator, optimizer; posterior_ability_estimator=posterior_ability_estimator) + next_item = FixedFirstItem(start_item, raw_next_item) + CatRules(; + next_item, + termination_condition = RunForever(), + ability_estimator, + #ability_tracker::AbilityTrackerT = NullAbilityTracker() + ) +end + +end diff --git a/src/Compat/Compat.jl b/src/Compat/Compat.jl new file mode 100644 index 0000000..a8a29ab --- /dev/null +++ b/src/Compat/Compat.jl @@ -0,0 +1,6 @@ +module Compat + +include("./CatR.jl") +include("./MirtCAT.jl") + +end \ No newline at end of file diff --git a/src/Compat/MirtCAT.jl b/src/Compat/MirtCAT.jl new file mode 100644 index 0000000..491fb24 --- /dev/null +++ b/src/Compat/MirtCAT.jl @@ -0,0 +1,156 @@ +module MirtCAT + +using ComputerAdaptiveTesting.Aggregators: SafeLikelihoodAbilityEstimator, + LikelihoodAbilityEstimator, + DistributionAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + PosteriorAbilityEstimator, + AbilityEstimator, + distribution_estimator +using ComputerAdaptiveTesting.TerminationConditions: RunForever +using ComputerAdaptiveTesting.NextItemRules +using ComputerAdaptiveTesting.Rules: CatRules +using PsychometricsBazaarBase: Integrators, Optimizers + +public next_item_aliases, ability_estimator_aliases, assemble_rules + +function _next_item_helper(item_criterion_callback) + function _helper(ability_estimator, posterior_ability_estimator, integrator, optimizer) + bits = [ + ability_estimator, + integrator, + optimizer, + ] + item_criterion = item_criterion_callback(; bits, ability_estimator, posterior_ability_estimator, integrator, optimizer) + return ItemCriterionRule(ExhaustiveSearch(), item_criterion) + end + return _helper +end + +const next_item_aliases = Dict( + # "MI' for the maximum information + "MI" => _next_item_helper((; bits, ability_estimator, rest...) -> InformationItemCriterion(ability_estimator)), + # 'MEPV' for minimum expected posterior variance + "MEPV" => _next_item_helper((; bits, ability_estimator, posterior_ability_estimator, integrator, rest...) -> ExpectationBasedItemCriterion( + ability_estimator, + AbilityVariance(posterior_ability_estimator, integrator))), + "MEI" => _next_item_helper((; bits, ability_estimator, rest...) -> ExpectationBasedItemCriterion( + ability_estimator, + PointItemCategoryCriterion(EmpiricalInformationPointwiseItemCategoryCriterion(), ability_estimator) + )), + "MLWI" => _next_item_helper((; bits, ability_estimator, integrator, rest...) -> LikelihoodWeightedItemCriterion( + TotalItemInformation(RawEmpiricalInformationPointwiseItemCategoryCriterion()), + distribution_estimator(ability_estimator), + integrator + )), + "MPWI" => _next_item_helper((; bits, ability_estimator, posterior_ability_estimator, integrator, rest...) -> LikelihoodWeightedItemCriterion( + TotalItemInformation(RawEmpiricalInformationPointwiseItemCategoryCriterion()), + distribution_estimator(posterior_ability_estimator), + integrator + )), + "Drule" => _next_item_helper((; bits, ability_estimator, rest...) -> DRuleItemCriterion(ability_estimator)), + "Trule" => _next_item_helper((; bits, ability_estimator, rest...) -> TRuleItemCriterion(ability_estimator)) +) + +# 'IKLP' as well as 'IKL' for the integration based Kullback-Leibler criteria with and without the prior density weight, +# respectively, and their root-n items administered weighted counter-parts, 'IKLn' and 'IKLPn'. +#= +Possible inputs for multidimensional adaptive tests include: 'Drule' for the +maximum determinant of the information matrix, 'Trule' for the maximum +(potentially weighted) trace of the information matrix, 'Arule' for the minimum (potentially weighted) trace of the asymptotic covariance matrix, 'Erule' +for the minimum value of the information matrix, and 'Wrule' for the weighted +information criteria. For each of these rules, the posterior weight for the latent trait scores can also be included with the 'DPrule', 'TPrule', 'APrule', +'EPrule', and 'WPrule', respectively. +Applicable to both unidimensional and multidimensional tests are the 'KL' and +'KLn' for point-wise Kullback-Leibler divergence and point-wise KullbackLeibler with a decreasing delta value (delta*sqrt(n), where n is the number +of items previous answered), respectively. The delta criteria is defined in the +design object +Non-adaptive methods applicable even when no mo object is passed are: 'random' +to randomly select items, and 'seq' for selecting items sequentially +=# + +const ability_estimator_aliases = Dict( + "MAP" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(PosteriorAbilityEstimator(; ncomp=ncomp), optimizer), + "ML" => (; optimizer, ncomp, kwargs...) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(; ncomp=ncomp), optimizer), + "EAP" => (; integrator, ncomp, kwargs...) -> MeanAbilityEstimator(PosteriorAbilityEstimator(; ncomp=ncomp), integrator), +# "WLE" for weighted likelihood estimation +# "EAPsum" for the expected a-posteriori for each sum score +) + +#= +• "plausible" for a single plausible value imputation for each case. This is +equivalent to setting plausible.draws = 1 +• "classify" for the posteriori classification probabilities (only applicable +when the input model was of class MixtureClass) +=# + +function mirtcat_quadpts(nfact) + if nfact == 1 + return 121 + elseif nfact == 2 + return 61 + elseif nfact == 3 + return 31 + elseif nfact == 4 + return 19 + elseif nfact == 5 + return 11 + else + return 5 + end +end + +function setup_integrator(lo=-6.0, hi=6.0, pts=mirtcat_quadpts(1)) + Integrators.even_grid(lo, hi, pts) +end + +function setup_optimizer(lo=-6.0, hi=6.0) + # TODO: Is this correct? + # mirtcat uses the `nlm` function from the `stats` package + # Source: https://github.com/philchalmers/mirt/blob/46b5db3a0120d87b8f1b034e6111fc5fb352a698/R/fscores.internal.R#L957C25-L957C28 + # It looks like no gradient is passed, so the numerical gradient will be used + # Source: https://github.com/philchalmers/mirt/blob/46b5db3a0120d87b8f1b034e6111fc5fb352a698/R/fscores.internal.R#L623 + # This is what we get by default so do this + # Main difference is probably in the line search + # https://stats.stackexchange.com/questions/272880/algorithm-used-in-nlm-function-in-r + # So just use Newton() with defaults for now + # Except then we can't have box constraints so I suppose IPNewton + if lo isa AbstractVector && hi isa AbstractVector + Optimizers.MultiDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) + else + Optimizers.OneDimOptimOptimizer(lo, hi, Optimizers.IPNewton()) + end +end + +function assemble_rules(; + criteria = "MI", + method = "MAP", + start_item = 1, + ncomp = 0 +) + if ncomp == 0 + lo = -6.0 + hi = 6.0 + pts = mirtcat_quadpts(1) + theta_lim = 20.0 + else + lo = fill(-6.0, ncomp) + hi = fill(6.0, ncomp) + pts = mirtcat_quadpts(ncomp) + theta_lim = fill(20.0, ncomp) + end + integrator = setup_integrator(lo, hi, pts) + optimizer = setup_optimizer(-theta_lim, theta_lim) + ability_estimator = ability_estimator_aliases[method](; integrator, optimizer, ncomp) + posterior_ability_estimator = PosteriorAbilityEstimator(; ncomp) + raw_next_item = next_item_aliases[criteria](ability_estimator, posterior_ability_estimator, integrator, optimizer) + next_item = FixedFirstItem(start_item, raw_next_item) + CatRules(; + next_item, + ability_estimator, + termination_condition = RunForever(), + ) +end + +end diff --git a/src/ComputerAdaptiveTesting.jl b/src/ComputerAdaptiveTesting.jl index 328a71c..f1228e0 100644 --- a/src/ComputerAdaptiveTesting.jl +++ b/src/ComputerAdaptiveTesting.jl @@ -5,17 +5,14 @@ include("./hacks.jl") using Reexport: Reexport, @reexport # Modules -export ConfigBase, Responses, Aggregators +export Responses, Aggregators export NextItemRules, TerminationConditions -export CatConfig, Sim, DecisionTree +export Sim, DecisionTree export Stateful, Comparison # Extension modules public require_testext -# Vendored dependencies -include("./vendor/PushVectors.jl") - # Config base include("./ConfigBase.jl") @@ -23,26 +20,27 @@ include("./ConfigBase.jl") include("./Responses.jl") # Near base -include("./aggregators/Aggregators.jl") +include("./Aggregators/Aggregators.jl") # Extra item banks include("./logitembank.jl") # Stages -include("./next_item_rules/NextItemRules.jl") +include("./NextItemRules/NextItemRules.jl") include("./TerminationConditions.jl") # Combining / running -include("./CatConfig.jl") -include("./Sim.jl") -include("./decision_tree/DecisionTree.jl") +include("./Rules.jl") +include("./Sim/Sim.jl") +include("./DecisionTree/DecisionTree.jl") -# Stateful layer and comparison +# Stateful layer, compat, and comparison include("./Stateful.jl") -include("./Comparison.jl") +include("./Compat/Compat.jl") +include("./Comparison/Comparison.jl") -@reexport using .CatConfig: CatLoopConfig, CatRules -@reexport using .Sim: run_cat +@reexport using .Rules: CatRules +@reexport using .Sim: CatLoop, run_cat @reexport using .NextItemRules: preallocate include("./precompiles.jl") diff --git a/src/decision_tree/DecisionTree.jl b/src/DecisionTree/DecisionTree.jl similarity index 96% rename from src/decision_tree/DecisionTree.jl rename to src/DecisionTree/DecisionTree.jl index b42ad58..97e354c 100644 --- a/src/decision_tree/DecisionTree.jl +++ b/src/DecisionTree/DecisionTree.jl @@ -3,7 +3,6 @@ module DecisionTree using Mmap: mmap using ComputerAdaptiveTesting.ConfigBase: CatConfigBase -using ComputerAdaptiveTesting.PushVectors using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.Aggregators using ComputerAdaptiveTesting.Responses: BareResponses, Response, add_response!, pop_response! @@ -18,15 +17,19 @@ end Base.@kwdef mutable struct TreePosition max_depth::UInt cur_depth::UInt - todo::PushVector{AgendaItem, Vector{AgendaItem}} + todo::Vector{AgendaItem} parent_ability::Float64 end function TreePosition(max_depth) - TreePosition(max_depth = max_depth, + todo = Vector{AgendaItem}() + sizehint!(todo, max_depth) + TreePosition(; + max_depth, cur_depth = 0, - todo = PushVector{AgendaItem}(max_depth), - parent_ability = 0.0) + todo, + parent_ability = 0.0 + ) end function next!(state::TreePosition, responses, item_bank, question, ability) diff --git a/src/decision_tree/mmap.jl b/src/DecisionTree/mmap.jl similarity index 100% rename from src/decision_tree/mmap.jl rename to src/DecisionTree/mmap.jl diff --git a/src/decision_tree/sim.jl b/src/DecisionTree/sim.jl similarity index 90% rename from src/decision_tree/sim.jl rename to src/DecisionTree/sim.jl index 73b9db5..5d67f84 100644 --- a/src/decision_tree/sim.jl +++ b/src/DecisionTree/sim.jl @@ -1,9 +1,9 @@ import ComputerAdaptiveTesting: Sim """ -Run a given CatLoopConfig with a MaterializedDecisionTree +Run a given CatLoop with a MaterializedDecisionTree """ -function Sim.run_cat(cat_config::Sim.CatLoopConfig{RulesT}, +function Sim.run_cat(cat_config::Sim.CatLoop{RulesT}, item_bank::AbstractItemBank; ib_labels = nothing) where {RulesT <: MaterializedDecisionTree} (; rules, get_response, new_response_callback) = cat_config diff --git a/src/next_item_rules/NextItemRules.jl b/src/NextItemRules/NextItemRules.jl similarity index 65% rename from src/next_item_rules/NextItemRules.jl rename to src/NextItemRules/NextItemRules.jl index 43d9c4e..def1dfb 100644 --- a/src/next_item_rules/NextItemRules.jl +++ b/src/NextItemRules/NextItemRules.jl @@ -11,7 +11,7 @@ Springer, New York, NY. module NextItemRules using DocStringExtensions: FUNCTIONNAME, TYPEDEF, TYPEDFIELDS -using PsychometricsBazaarBase.Parameters: @with_kw +using PsychometricsBazaarBase.Parameters using LinearAlgebra: det, tr using Random: AbstractRNG, Xoshiro @@ -19,14 +19,15 @@ using ..Responses: BareResponses using ..ConfigBase using PsychometricsBazaarBase.ConfigTools: @requiresome, @returnsome, find1_instance, find1_type -using PsychometricsBazaarBase.Integrators: Integrator +using PsychometricsBazaarBase.Integrators: Integrator, intval using PsychometricsBazaarBase: Integrators +using PsychometricsBazaarBase.IndentWrappers: indent import PsychometricsBazaarBase.IntegralCoeffs using FittedItemBanks: AbstractItemBank, DiscreteDomain, DomainType, ItemResponse, OneDimContinuousDomain, domdims, item_params, - resp, resp_vec, responses + resp, resp_vec, responses, subset_view using ..Aggregators -using ..Aggregators: covariance_matrix +using ..Aggregators: covariance_matrix, FunctionProduct using Distributions: logccdf, logcdf, pdf using Base.Threads @@ -34,16 +35,22 @@ using Base.Order using StaticArrays: SVector using ConstructionBase: constructorof import ForwardDiff +import Base: show -export ExpectationBasedItemCriterion, AbilityVarianceStateCriterion, init_thread -export NextItemRule, ItemStrategyNextItemRule +export ExpectationBasedItemCriterion, AbilityVariance, init_thread +export NextItemRule, ItemCriterionRule export UrryItemCriterion, InformationItemCriterion +export LikelihoodWeightedItemCriterion, PointItemCriterion +export LikelihoodWeightedItemCategoryCriterion, PointItemCategoryCriterion +export ObservedInformationPointwiseItemCategoryCriterion +export RawEmpiricalInformationPointwiseItemCategoryCriterion +export EmpiricalInformationPointwiseItemCategoryCriterion +export TotalItemInformation export RandomNextItemRule -export ExhaustiveSearch -export catr_next_item_aliases +export FixedRuleSequencer, MemoryNextItemRule, FixedFirstItem +export ExhaustiveSearch, RandomesqueStrategy export preallocate -export compute_criteria, compute_criterion, compute_multi_criterion, - compute_pointwise_criterion +export compute_criteria, compute_criterion, compute_multi_criterion export best_item export PointResponseExpectation, DistributionResponseExpectation export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer @@ -60,7 +67,11 @@ include("./prelude/preallocate.jl") # Selection strategies include("./strategies/random.jl") +include("./strategies/randomesque.jl") +include("./strategies/sequential.jl") include("./strategies/exhaustive.jl") +include("./strategies/pointwise.jl") +include("./strategies/balance.jl") # Combinators include("./combinators/expectation.jl") @@ -68,15 +79,15 @@ include("./combinators/scalarizers.jl") include("./combinators/likelihood.jl") # Criteria -include("./criteria/item/information_special.jl") -include("./criteria/item/information_support.jl") include("./criteria/item/information.jl") include("./criteria/item/urry.jl") include("./criteria/state/ability_variance.jl") +include("./criteria/pointwise/information_special.jl") +include("./criteria/pointwise/information_support.jl") +include("./criteria/pointwise/information.jl") include("./criteria/pointwise/kl.jl") # Porcelain include("./porcelain/porcelain.jl") -include("./porcelain/aliases.jl") end diff --git a/src/next_item_rules/combinators/expectation.jl b/src/NextItemRules/combinators/expectation.jl similarity index 77% rename from src/next_item_rules/combinators/expectation.jl rename to src/NextItemRules/combinators/expectation.jl index 61ac76f..7e51841 100644 --- a/src/next_item_rules/combinators/expectation.jl +++ b/src/NextItemRules/combinators/expectation.jl @@ -39,6 +39,12 @@ function Aggregators.response_expectation( item_idx) end +function show(io::IO, ::MIME"text/plain", point_response_expectation::PointResponseExpectation) + println(io, "Expected response at point ability estimate") + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), point_response_expectation.ability_estimator) +end + struct DistributionResponseExpectation{ DistributionAbilityEstimatorT <: DistributionAbilityEstimator, AbilityIntegratorT <: AbilityIntegrator @@ -67,7 +73,7 @@ item 1-ply ahead. """ struct ExpectationBasedItemCriterion{ ResponseExpectationT <: ResponseExpectation, - CriterionT <: Union{StateCriterion, ItemCriterion} + CriterionT <: Union{StateCriterion, ItemCriterion, ItemCategoryCriterion}, } <: ItemCriterion response_expectation::ResponseExpectationT criterion::CriterionT @@ -75,7 +81,8 @@ end function _get_some_criterion(bits...; kwargs...) @returnsome StateCriterion(bits...; kwargs...) - @returnsome ItemCriterion(bits...; kwargs...) + @returnsome ItemCriterion(bits...; skip_expectation=true, kwargs...) + @returnsome ItemCategoryCriterion(bits...) end function ExpectationBasedItemCriterion(bits...; @@ -95,13 +102,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse Speculator(responses, 1) end -function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx) +function _generic_criterion(criterion::StateCriterion, tracked_responses, _item_idx, _response) compute_criterion(criterion, tracked_responses) end # TODO: Support init_thread for wrapped ItemCriterion -function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx) +function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx, _response) compute_criterion(criterion, tracked_responses, item_idx) end +function _generic_criterion(criterion::ItemCategoryCriterion, tracked_responses, item_idx, response) + compute_criterion(criterion, tracked_responses, item_idx, response) +end function compute_criterion( item_criterion::ExpectationBasedItemCriterion, @@ -116,7 +126,14 @@ function compute_criterion( for (prob, possible_response) in zip(exp_resp, possible_responses) replace_speculation!(speculator, SVector(item_idx), SVector(possible_response)) res += prob * - _generic_criterion(item_criterion.criterion, speculator.responses, item_idx) + _generic_criterion(item_criterion.criterion, speculator.responses, item_idx, possible_response) end res end + +function show(io::IO, ::MIME"text/plain", item_criterion::ExpectationBasedItemCriterion) + println(io, "Optimize an state/item/item-category criterion based on an expected response") + indent_io = indent(io, 2) + show(indent_io, MIME"text/plain"(), item_criterion.response_expectation) + show(indent_io, MIME"text/plain"(), item_criterion.criterion) +end diff --git a/src/NextItemRules/combinators/likelihood.jl b/src/NextItemRules/combinators/likelihood.jl new file mode 100644 index 0000000..6da661c --- /dev/null +++ b/src/NextItemRules/combinators/likelihood.jl @@ -0,0 +1,97 @@ +struct LikelihoodWeightedItemCriterion{ + PointwiseItemCriterionT <: PointwiseItemCriterion, + AbilityIntegratorT <: AbilityIntegrator, + AbilityEstimatorT <: DistributionAbilityEstimator +} <: ItemCriterion + criterion::PointwiseItemCriterionT + integrator::AbilityIntegratorT + estimator::AbilityEstimatorT +end + +function LikelihoodWeightedItemCriterion(bits...) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) + (dist_est, integrator) = dist_est_integrator_pair + criterion = PointwiseItemCriterion(bits...) + return LikelihoodWeightedItemCriterion(criterion, integrator, dist_est) +end + +function compute_criterion( + lwic::LikelihoodWeightedItemCriterion, + tracked_responses::TrackedResponses, + item_idx +) + func = FunctionProduct( + pdf(lwic.estimator, tracked_responses), ability -> compute_criterion(lwic.criterion, tracked_responses, item_idx, ability)) + intval(lwic.integrator(func, 0, lwic.estimator, tracked_responses)) +end + +struct PointItemCriterion{ + PointwiseItemCriterionT <: PointwiseItemCriterion, + AbilityEstimatorT <: PointAbilityEstimator +} <: ItemCriterion + criterion::PointwiseItemCriterionT + estimator::AbilityEstimatorT +end + +function compute_criterion( + pic::PointItemCriterion, + tracked_responses::TrackedResponses, + item_idx +) + ability = maybe_tracked_ability_estimate( + tracked_responses, + pic.estimator + ) + return compute_criterion(pic.criterion, tracked_responses, item_idx, ability) +end + +struct LikelihoodWeightedItemCategoryCriterion{ + PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion, + AbilityIntegratorT <: AbilityIntegrator, + AbilityEstimatorT <: DistributionAbilityEstimator +} <: ItemCategoryCriterion + criterion::PointwiseItemCategoryCriterionT + integrator::AbilityIntegratorT + estimator::AbilityEstimatorT +end + +function LikelihoodWeightedItemCategoryCriterion(bits...) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) + (dist_est, integrator) = dist_est_integrator_pair + criterion = PointwiseItemCategoryCriterion(bits...) + return LikelihoodWeightedItemCategoryCriterion(criterion, integrator, dist_est) +end + +function compute_criterion( + lwicc::LikelihoodWeightedItemCategoryCriterion, + tracked_responses::TrackedResponses, + item_idx, + category +) + func = FunctionProduct( + pdf(lwicc.estimator, tracked_responses), + ability -> compute_criterion(lwicc.criterion, tracked_responses, item_idx, ability, category) + ) + intval(lwicc.integrator(func, 0, lwicc.estimator, tracked_responses)) +end + +struct PointItemCategoryCriterion{ + PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion, + AbilityEstimatorT <: PointAbilityEstimator +} <: ItemCategoryCriterion + criterion::PointwiseItemCategoryCriterionT + estimator::AbilityEstimatorT +end + +function compute_criterion( + pic::PointItemCategoryCriterion, + tracked_responses::TrackedResponses, + item_idx, + category +) + ability = maybe_tracked_ability_estimate( + tracked_responses, + pic.estimator + ) + return compute_criterion(pic.criterion, tracked_responses, item_idx, ability, category) +end \ No newline at end of file diff --git a/src/next_item_rules/combinators/scalarizers.jl b/src/NextItemRules/combinators/scalarizers.jl similarity index 100% rename from src/next_item_rules/combinators/scalarizers.jl rename to src/NextItemRules/combinators/scalarizers.jl diff --git a/src/next_item_rules/criteria/item/information.jl b/src/NextItemRules/criteria/item/information.jl similarity index 70% rename from src/next_item_rules/criteria/item/information.jl rename to src/NextItemRules/criteria/item/information.jl index 04987e4..bcdbab6 100644 --- a/src/next_item_rules/criteria/item/information.jl +++ b/src/NextItemRules/criteria/item/information.jl @@ -1,12 +1,20 @@ # TODO: Should have Variants for point ability versus distribution ability -struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: +@kw_only struct InformationItemCriterion{AbilityEstimatorT <: PointAbilityEstimator, F} <: ItemCriterion ability_estimator::AbilityEstimatorT expected_item_information::F end -function InformationItemCriterion(ability_estimator) - InformationItemCriterion(ability_estimator, expected_item_information) +function InformationItemCriterion(ability_estimator::PointAbilityEstimator) + InformationItemCriterion(; + ability_estimator, + expected_item_information + ) +end + +function InformationItemCriterion(bits...) + @requiresome ability_estimator = PointAbilityEstimator(bits...) + InformationItemCriterion(ability_estimator) end function compute_criterion( @@ -18,14 +26,15 @@ function compute_criterion( return -item_criterion.expected_item_information(ir, ability) end -struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: +struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F, G} <: ItemMultiCriterion ability_estimator::AbilityEstimatorT - expected_item_information::F + known_item_information::F + expected_item_information::G end function InformationMatrixCriteria(ability_estimator) - InformationMatrixCriteria(ability_estimator, expected_item_information) + InformationMatrixCriteria(ability_estimator, expected_item_information, expected_item_information) end function init_thread(item_criterion::InformationMatrixCriteria, @@ -34,7 +43,8 @@ function init_thread(item_criterion::InformationMatrixCriteria, # θ update. # TODO: Update this to use track!(...) mechanism ability = maybe_tracked_ability_estimate(responses, item_criterion.ability_estimator) - responses_information(responses.item_bank, responses.responses, ability) + responses_information(responses.item_bank, responses.responses, ability; + information_func=item_criterion.known_item_information) end function compute_multi_criterion( @@ -44,9 +54,9 @@ function compute_multi_criterion( # TODO: Add in information from the prior ability = maybe_tracked_ability_estimate( tracked_responses, item_criterion.ability_estimator) - return acc_info .+ - item_criterion.expected_item_information( + exp_info = item_criterion.expected_item_information( ItemResponse(tracked_responses.item_bank, item_idx), ability) + return acc_info .+ exp_info end should_minimize(::InformationMatrixCriteria) = false diff --git a/src/next_item_rules/criteria/item/urry.jl b/src/NextItemRules/criteria/item/urry.jl similarity index 82% rename from src/next_item_rules/criteria/item/urry.jl rename to src/NextItemRules/criteria/item/urry.jl index e71a82b..177c36f 100644 --- a/src/next_item_rules/criteria/item/urry.jl +++ b/src/NextItemRules/criteria/item/urry.jl @@ -9,6 +9,11 @@ struct UrryItemCriterion{AbilityEstimatorT <: PointAbilityEstimator} <: ItemCrit ability_estimator::AbilityEstimatorT end +function UrryItemCriterion(bits...) + @requiresome ability_estimator = PointAbilityEstimator(bits...) + UrryItemCriterion(ability_estimator) +end + # TODO: Slow + poor error handling function raw_difficulty(item_bank, item_idx) item_params(item_bank, item_idx).difficulty diff --git a/src/NextItemRules/criteria/pointwise/information.jl b/src/NextItemRules/criteria/pointwise/information.jl new file mode 100644 index 0000000..1104bef --- /dev/null +++ b/src/NextItemRules/criteria/pointwise/information.jl @@ -0,0 +1,145 @@ +""" +This calculates the pointwise information criterion for an item response model. +""" +struct ObservedInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::ObservedInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = -double_derivative((ability -> log_resp(ir, category, ability)), ability) .* resp(ir, category, ability) + -actual +end + +function compute_criterion_vec( + ::ObservedInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = -double_derivative((ability -> log_resp_vec(ir, ability)), ability) .* resp_vec(ir, ability) + -actual +end + +function show(io::IO, ::MIME"text/plain", ::ObservedInformationPointwiseItemCategoryCriterion) + println(io, "Observed pointwise item-category information") +end + +""" +See EmpiricalInformationPointwiseItemCategoryCriterion for more details. +""" +struct RawEmpiricalInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::RawEmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = ForwardDiff.derivative(ability -> resp(ir, category, ability), ability) ^ 2 / resp(ir, category, ability) + -actual +end + +function compute_criterion_vec( + ::RawEmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = ForwardDiff.derivative(ability -> resp_vec(ir, ability), ability) .^ 2 ./ resp_vec(ir, ability) + -actual +end + + +function show(io::IO, ::MIME"text/plain", ::RawEmpiricalInformationPointwiseItemCategoryCriterion) + println(io, "Raw empirical pointwise item-category information") +end + +""" +In equation 10 of [1] we see that we can compute information using 2nd derivatives of log likelihood or 1st derivative squared. +For single categories, we need to an extra term which disappears when we calculate the total see [2]. +For this reason +`RawEmpiricalInformationPointwiseItemCategoryCriterion` +computes without this factor, while +`EmpiricalInformationPointwiseItemCategoryCriterion` +computes with it. + +So in general, only use the former with `TotalItemInformation` + +[1] +``Information Functions of the Generalized Partial Credit Model'' +Eiji Muraki +https://doi.org/10.1177/014662169301700403 + +[2] +https://mark.reid.name/blog/fisher-information-and-log-likelihood.html +""" +struct EmpiricalInformationPointwiseItemCategoryCriterion <: PointwiseItemCategoryCriterion end + +function compute_criterion( + ::EmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability, + category +) + actual = -compute_criterion( + RawEmpiricalInformationPointwiseItemCategoryCriterion(), + ir, + ability, + category + ) .- double_derivative((ability -> resp(ir, category, ability)), ability) + -actual +end + +function compute_criterion_vec( + ::EmpiricalInformationPointwiseItemCategoryCriterion, + ir::ItemResponse, + ability +) + actual = -compute_criterion_vec( + RawEmpiricalInformationPointwiseItemCategoryCriterion(), + ir, + ability + ) .- double_derivative((ability -> resp_vec(ir, ability)), ability) + -actual +end + +function show(io::IO, ::MIME"text/plain", ::EmpiricalInformationPointwiseItemCategoryCriterion) + println(io, "Empirical pointwise item-category information") +end + +#= +""" +This implements Fisher information as a pointwise item criterion. +It uses ForwardDiff to find the second derivative of the log prob for the current item and ability estimate. +It then uses the expected outcome at the given ability estimate to weight the outcomes. + +\[ +E_{\thetaHAT}(log(\frac{d^2 log\thetaHAT}{d\theta)) +\] +""" +=# +struct TotalItemInformation{PointwiseItemCategoryCriterionT <: PointwiseItemCategoryCriterion} <: PointwiseItemCriterion + pcic::PointwiseItemCategoryCriterionT +end + +function compute_criterion( + tii::TotalItemInformation, + ir::ItemResponse, + ability +) + sum(compute_criterion_vec(tii.pcic, ir, ability)) +end + +function show(io::IO, ::MIME"text/plain", rule::TotalItemInformation) + if rule.pcic isa ObservedInformationPointwiseItemCategoryCriterion + println(io, "Observed pointwise item information") + elseif rule.pcic isa RawEmpiricalInformationPointwiseItemCategoryCriterion + println(io, "Raw empirical pointwise item information") + elseif rule.pcic isa EmpiricalInformationPointwiseItemCategoryCriterion + println(io, "Empirical pointwise item information") + else + print(io, "Total ") + show(io, MIME("text/plain"), rule.pcic) + end +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/item/information_special.jl b/src/NextItemRules/criteria/pointwise/information_special.jl similarity index 100% rename from src/next_item_rules/criteria/item/information_special.jl rename to src/NextItemRules/criteria/pointwise/information_special.jl diff --git a/src/next_item_rules/criteria/item/information_support.jl b/src/NextItemRules/criteria/pointwise/information_support.jl similarity index 70% rename from src/next_item_rules/criteria/item/information_support.jl rename to src/NextItemRules/criteria/pointwise/information_support.jl index c63e6af..9f987c6 100644 --- a/src/next_item_rules/criteria/item/information_support.jl +++ b/src/NextItemRules/criteria/pointwise/information_support.jl @@ -1,6 +1,6 @@ using FittedItemBanks: CdfMirtItemBank, - GuessItemBank, SlipItemBank, TransferItemBank, AnySlipOrGuessItemBank -using FittedItemBanks: inner_item_response, norm_abil, y_offset, irf_size + TransferItemBank, GuessAndSlipItemBank +using FittedItemBanks: inner_item_response, norm_abil, irf_size using StatsFuns: logaddexp function log_resp_vec(ir::ItemResponse{<:TransferItemBank}, θ) @@ -30,9 +30,10 @@ function log_resp(ir::ItemResponse{<:CdfMirtItemBank}, val, θ) end end +#= # XXX: Not sure if this is optimal numerically or speed wise -- possibly it # would be better to just transform to linear space in this case? -@inline function log_transform_irf_y(guess::Float64, slip::Float64, y) +@inline function log_transform_irf_y(guess, slip, y) # log space version of guess + irf_size(guess, slip) * y logaddexp(log(guess), log(irf_size(guess, slip)) + y) end @@ -63,6 +64,11 @@ end function log_resp(ir::ItemResponse{<:AnySlipOrGuessItemBank}, val, θ) log_transform_irf_y(ir, val, log_resp(inner_item_response(ir), val, θ)) end +=# + +log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, response, θ) = log(resp(ir, response, θ)) +log_resp(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log(resp(ir, θ)) +log_resp_vec(ir::ItemResponse{<:GuessAndSlipItemBank}, θ) = log.(resp_vec(ir, θ)) function vector_hessian(f, x, n) out = ForwardDiff.jacobian(x -> ForwardDiff.jacobian(f, x), x) @@ -73,7 +79,7 @@ function double_derivative(f, x) ForwardDiff.derivative(x -> ForwardDiff.derivative(f, x), x) end -function expected_item_information(ir::ItemResponse, θ::Float64) +function expected_item_information(ir::ItemResponse, θ::Number) exp_resp = resp_vec(ir, θ) d² = double_derivative((θ -> log_resp_vec(ir, θ)), θ) -sum(exp_resp .* d²) @@ -81,21 +87,35 @@ end # TODO: Unclear whether this should be implemented with ExpectationBasedItemCriterion # TODO: This is not implementing DRule but postposterior DRule -function expected_item_information(ir::ItemResponse, θ::Vector{Float64}) +function expected_item_information(ir::ItemResponse, θ::Vector) exp_resp = resp_vec(ir, θ) n = domdims(ir.item_bank) hess = vector_hessian(θ -> log_resp_vec(ir, θ), θ, n) - -dropdims(sum((exp_resp .* (@view hess[2, :, :])), dims = 1), dims = 1) + return -sum(eachslice(hess, dims=1) .* exp_resp) end +expected_item_information(ir::ItemResponse, _, θ::Vector) = expected_item_information(ir, θ) + function known_item_information(ir::ItemResponse, resp_value, θ) -ForwardDiff.hessian(θ -> log_resp(ir, resp_value, θ), θ) end -function responses_information(item_bank::AbstractItemBank, responses::BareResponses, θ) +function responses_information(item_bank::AbstractItemBank, responses::BareResponses, θ; information_func=known_item_information) d = domdims(item_bank) reduce(.+, - (known_item_information(ItemResponse(item_bank, resp_idx), resp_value > 0, θ) + (information_func(ItemResponse(item_bank, resp_idx), resp_value > 0, θ) for (resp_idx, resp_value) in zip(responses.indices, responses.values)); init = zeros(d, d)) end + +using ComputerAdaptiveTesting: ItemBanks + +function log_resp_vec(ir::ItemResponse{<:ItemBanks.LogItemBank}, θ) + # XXX: Should not destruct the logarithmic number here + # Works for now + log.(resp_vec(ItemBanks.inner_ir(ir), θ)) +end + +function log_resp(ir::ItemResponse{<:ItemBanks.LogItemBank}, resp, θ) + log(resp(ItemBanks.inner_ir(ir), resp, θ)) +end \ No newline at end of file diff --git a/src/next_item_rules/criteria/pointwise/kl.jl b/src/NextItemRules/criteria/pointwise/kl.jl similarity index 95% rename from src/next_item_rules/criteria/pointwise/kl.jl rename to src/NextItemRules/criteria/pointwise/kl.jl index 630680c..efaf115 100644 --- a/src/next_item_rules/criteria/pointwise/kl.jl +++ b/src/NextItemRules/criteria/pointwise/kl.jl @@ -22,10 +22,11 @@ function PosteriorExpectedKLInformationItemCriterion(bits...) point_estimator, distribution_estimator, integrator) end -function compute_pointwise_criterion( +function compute_criterion( item_criterion::PosteriorExpectedKLInformationItemCriterion, tracked_responses::TrackedResponses, - item_idx) + item_idx, + theta) theta_0 = maybe_tracked_ability_estimate(tracked_responses, item_criterion.point_estimator) item_response = ItemResponse(tracked_responses.item_bank, item_idx) diff --git a/src/next_item_rules/criteria/state/ability_variance.jl b/src/NextItemRules/criteria/state/ability_variance.jl similarity index 66% rename from src/next_item_rules/criteria/state/ability_variance.jl rename to src/NextItemRules/criteria/state/ability_variance.jl index 343a873..47d7232 100644 --- a/src/next_item_rules/criteria/state/ability_variance.jl +++ b/src/NextItemRules/criteria/state/ability_variance.jl @@ -5,7 +5,7 @@ $(TYPEDFIELDS) This `StateCriterion` returns the variance of the ability estimate given a set of responses. """ -struct AbilityVarianceStateCriterion{ +struct AbilityVariance{ DistEst <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator } <: StateCriterion @@ -14,30 +14,15 @@ struct AbilityVarianceStateCriterion{ skip_zero::Bool end -function _get_dist_est_and_integrator(bits...) - # XXX: Weakness in this initialisation system is showing now - # This needs ot be explicitly passed dist_est and integrator, but this may - # be burried within a MeanAbilityEstimator - dist_est = DistributionAbilityEstimator(bits...) - integrator = AbilityIntegrator(bits...) - if dist_est !== nothing && integrator !== nothing - return (dist_est, integrator) - end - # So let's just handle this case individually for now - # (Is this going to cause a problem with this being picked over something more appropriate?) - @requiresome mean_ability_est = MeanAbilityEstimator(bits...) - return (mean_ability_est.dist_est, mean_ability_est.integrator) -end - -function AbilityVarianceStateCriterion(bits...) +function AbilityVariance(bits...) skip_zero = false - @returnsome find1_instance(AbilityVarianceStateCriterion, bits) - @requiresome dist_est_integrator_pair = _get_dist_est_and_integrator(bits...) + @returnsome find1_instance(AbilityVariance, bits) + @requiresome dist_est_integrator_pair = get_dist_est_and_integrator(bits...) (dist_est, integrator) = dist_est_integrator_pair - return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero) + return AbilityVariance(dist_est, integrator, skip_zero) end -function compute_criterion(criterion::AbilityVarianceStateCriterion, +function compute_criterion(criterion::AbilityVariance, tracked_responses::TrackedResponses)::Float64 # XXX: Not sure if the estimator should come from somewhere else here denom = normdenom(criterion.integrator, @@ -50,7 +35,7 @@ function compute_criterion(criterion::AbilityVarianceStateCriterion, criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom) end -function compute_criterion(criterion::AbilityVarianceStateCriterion, +function compute_criterion(criterion::AbilityVariance, ::Union{OneDimContinuousDomain, DiscreteDomain}, tracked_responses::TrackedResponses, denom)::Float64 @@ -63,7 +48,7 @@ function compute_criterion(criterion::AbilityVarianceStateCriterion, end function compute_criterion( - criterion::AbilityVarianceStateCriterion, + criterion::AbilityVariance, ::Vector, tracked_responses::TrackedResponses, denom @@ -83,6 +68,13 @@ function compute_criterion( denom) end +function show(io::IO, ::MIME"text/plain", criterion::AbilityVariance) + println(io, "Minimise variance of ability estimate") + indent_io = indent(io, 2) + show(indent_io, MIME("text/plain"), criterion.dist_est) + show(indent_io, MIME("text/plain"), criterion.integrator) +end + struct AbilityCovarianceStateMultiCriterion{ DistEstT <: DistributionAbilityEstimator, IntegratorT <: AbilityIntegrator @@ -94,7 +86,7 @@ end function AbilityCovarianceStateMultiCriterion(bits...) skip_zero = false - @requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...) + @requiresome (dist_est, integrator) = get_dist_est_and_integrator(bits...) return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero) end diff --git a/src/NextItemRules/porcelain/aliases.jl b/src/NextItemRules/porcelain/aliases.jl new file mode 100644 index 0000000..e69de29 diff --git a/src/next_item_rules/porcelain/porcelain.jl b/src/NextItemRules/porcelain/porcelain.jl similarity index 100% rename from src/next_item_rules/porcelain/porcelain.jl rename to src/NextItemRules/porcelain/porcelain.jl diff --git a/src/next_item_rules/prelude/abstract.jl b/src/NextItemRules/prelude/abstract.jl similarity index 58% rename from src/next_item_rules/prelude/abstract.jl rename to src/NextItemRules/prelude/abstract.jl index fdc68c9..52ee8ae 100644 --- a/src/next_item_rules/prelude/abstract.jl +++ b/src/NextItemRules/prelude/abstract.jl @@ -6,21 +6,21 @@ Abstract base type for all item selection rules. All descendants of this type are expected to implement the interface `(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`. -In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`. +In practice, all adaptive rules in this package use `ItemCriterionRule`. $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or -delegates to `ItemStrategyNextItemRule` the default instance. +delegates to `ItemCriterionRule` the default instance. """ abstract type NextItemRule <: CatConfigBase end """ $(TYPEDEF) -Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`. +Abstract type for next item strategies, tightly coupled with `ItemCriterionRule`. All descendants of this type are expected to implement the interface -`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses, +`(rule::ItemCriterionRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses, items) where {ItemCriterionT <: } `(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy` """ @@ -29,21 +29,32 @@ abstract type NextItemStrategy <: CatConfigBase end """ $(TYPEDEF) -Abstract type for next item criteria +Abstract base type all criteria should inherit from """ +abstract type CriterionBase <: CatConfigBase end +abstract type SubItemCriterionBase <: CatConfigBase end + abstract type ItemCriterion <: CatConfigBase end """ $(TYPEDEF) """ -abstract type StateCriterion <: CatConfigBase end +abstract type StateCriterion <: CriterionBase end + +""" +$(TYPEDEF) +""" +abstract type PointwiseItemCriterion <: SubItemCriterionBase end """ $(TYPEDEF) """ -abstract type PointwiseItemCriterion <: CatConfigBase end +abstract type ItemCategoryCriterion <: SubItemCriterionBase end -abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end +""" +$(TYPEDEF) +""" +abstract type PointwiseItemCategoryCriterion <: SubItemCriterionBase end abstract type MatrixScalarizer end abstract type StateMultiCriterion end diff --git a/src/next_item_rules/prelude/criteria.jl b/src/NextItemRules/prelude/criteria.jl similarity index 60% rename from src/next_item_rules/prelude/criteria.jl rename to src/NextItemRules/prelude/criteria.jl index 277c65d..f64c8ea 100644 --- a/src/next_item_rules/prelude/criteria.jl +++ b/src/NextItemRules/prelude/criteria.jl @@ -1,13 +1,15 @@ #= Single dimensional =# -function ItemCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) +function ItemCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing, skip_expectation = false) @returnsome find1_instance(ItemCriterion, bits) @returnsome find1_type(ItemCriterion, bits) typ->typ( ability_estimator = ability_estimator, ability_tracker = ability_tracker) - @returnsome ExpectationBasedItemCriterion(bits...; - ability_estimator = ability_estimator, - ability_tracker = ability_tracker) + if !skip_expectation + @returnsome ExpectationBasedItemCriterion(bits...; + ability_estimator = ability_estimator, + ability_tracker = ability_tracker) + end end function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = nothing) @@ -15,6 +17,21 @@ function StateCriterion(bits...; ability_estimator = nothing, ability_tracker = @returnsome find1_type(StateCriterion, bits) typ->typ() end +function ItemCategoryCriterion(bits...) + @returnsome find1_instance(ItemCategoryCriterion, bits) + @returnsome find1_type(ItemCategoryCriterion, bits) typ->typ() +end + +function PointwiseItemCriterion(bits...) + @returnsome find1_instance(PointwiseItemCriterion, bits) + @returnsome find1_type(PointwiseItemCriterion, bits) typ->typ() +end + +function PointwiseItemCategoryCriterion(bits...) + @returnsome find1_instance(PointwiseItemCategoryCriterion, bits) + @returnsome find1_type(PointwiseItemCategoryCriterion, bits) typ->typ() +end + function init_thread(::ItemCriterion, ::TrackedResponses) nothing end @@ -58,7 +75,7 @@ function compute_criteria( end function compute_criteria( - rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, + rule::ItemCriterionRule{StrategyT, ItemCriterionT}, responses, items ) where {StrategyT, ItemCriterionT <: ItemCriterion} @@ -66,19 +83,15 @@ function compute_criteria( end function compute_criteria( - rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT}, + rule::ItemCriterionRule{StrategyT, ItemCriterionT}, responses::TrackedResponses ) where {StrategyT, ItemCriterionT <: ItemCriterion} compute_criteria(rule.criterion, responses) end -function compute_pointwise_criterion( - ppic::PurePointwiseItemCriterion, tracked_responses, item_idx) - compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx)) -end - -struct PurePointwiseItemCriterionFunction{PointwiseItemCriterionT <: PointwiseItemCriterion} - item_response::ItemResponse +function compute_criterion( + ppic::SubItemCriterionBase, tracked_responses::TrackedResponses, item_idx, args...) + compute_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx), args...) end function init_thread(::ItemMultiCriterion, ::TrackedResponses) @@ -98,3 +111,18 @@ function compute_multi_criterion( state_criterion::StateMultiCriterion, ::Nothing, tracked_responses) compute_multi_criterion(state_criterion, tracked_responses) end + +function get_dist_est_and_integrator(bits...) + # XXX: Weakness in this initialisation system is showing now + # This needs ot be explicitly passed dist_est and integrator, but this may + # be burried within a MeanAbilityEstimator + dist_est = DistributionAbilityEstimator(bits...) + integrator = AbilityIntegrator(bits...) + if dist_est !== nothing && integrator !== nothing + return (dist_est, integrator) + end + # So let's just handle this case individually for now + # (Is this going to cause a problem with this being picked over something more appropriate?) + @requiresome mean_ability_est = MeanAbilityEstimator(bits...) + return (mean_ability_est.dist_est, mean_ability_est.integrator) +end diff --git a/src/next_item_rules/prelude/next_item_rule.jl b/src/NextItemRules/prelude/next_item_rule.jl similarity index 51% rename from src/next_item_rules/prelude/next_item_rule.jl rename to src/NextItemRules/prelude/next_item_rule.jl index bd708e8..c61836c 100644 --- a/src/next_item_rules/prelude/next_item_rule.jl +++ b/src/NextItemRules/prelude/next_item_rule.jl @@ -1,37 +1,35 @@ function NextItemRule(bits...; ability_estimator = nothing, - ability_tracker = nothing, - parallel = true) + ability_tracker = nothing) @returnsome find1_instance(NextItemRule, bits) - @returnsome ItemStrategyNextItemRule(bits..., + @returnsome ItemCriterionRule(bits..., ability_estimator = ability_estimator, - ability_tracker = ability_tracker, - parallel = parallel) + ability_tracker = ability_tracker) end -function NextItemStrategy(; parallel = true) - ExhaustiveSearch(parallel) +function NextItemStrategy() + ExhaustiveSearch() end -function NextItemStrategy(bits...; parallel = true) +function NextItemStrategy(bits...) @returnsome find1_instance(NextItemStrategy, bits) - @returnsome find1_type(NextItemStrategy, bits) typ->typ(; parallel = parallel) - @returnsome NextItemStrategy(; parallel = parallel) + @returnsome find1_type(NextItemStrategy, bits) typ->typ() + @returnsome NextItemStrategy() end """ $(TYPEDEF) $(TYPEDFIELDS) -`ItemStrategyNextItemRule` which together with a `NextItemStrategy` acts as an +`ItemCriterionRule` which together with a `NextItemStrategy` acts as an adapter by which an `ItemCriterion` can serve as a `NextItemRule`. - $(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true) + $(FUNCTIONNAME)(bits...; ability_estimator=nothing Implicit constructor for $(FUNCTIONNAME). Will default to `ExhaustiveSearch` when no `NextItemStrategy` is given. """ -struct ItemStrategyNextItemRule{ +struct ItemCriterionRule{ NextItemStrategyT <: NextItemStrategy, ItemCriterionT <: ItemCriterion } <: NextItemRule @@ -39,19 +37,30 @@ struct ItemStrategyNextItemRule{ criterion::ItemCriterionT end -function ItemStrategyNextItemRule(bits...; - parallel = true, +function ItemCriterionRule(bits...; ability_estimator = nothing, ability_tracker = nothing) - strategy = NextItemStrategy(bits...; parallel = parallel) + strategy = NextItemStrategy(bits...) criterion = ItemCriterion(bits...; ability_estimator = ability_estimator, ability_tracker = ability_tracker) if strategy !== nothing && criterion !== nothing - return ItemStrategyNextItemRule(strategy, criterion) + return ItemCriterionRule(strategy, criterion) end end function best_item(rule::NextItemRule, tracked_responses::TrackedResponses) best_item(rule, tracked_responses, tracked_responses.item_bank) -end \ No newline at end of file +end + +function Base.show(io::IO, ::MIME"text/plain", rule::ItemCriterionRule) + println(io, "Pick optimal item criterion according to strategy") + indent_io = indent(io, 2) + show(indent_io, MIME"text/plain"(), rule.strategy) + show(indent_io, MIME"text/plain"(), rule.criterion) +end + +# Default implementation +function compute_criteria(::NextItemRule, ::TrackedResponses) + nothing +end diff --git a/src/next_item_rules/prelude/preallocate.jl b/src/NextItemRules/prelude/preallocate.jl similarity index 100% rename from src/next_item_rules/prelude/preallocate.jl rename to src/NextItemRules/prelude/preallocate.jl diff --git a/src/NextItemRules/strategies/balance.jl b/src/NextItemRules/strategies/balance.jl new file mode 100644 index 0000000..5adcd20 --- /dev/null +++ b/src/NextItemRules/strategies/balance.jl @@ -0,0 +1,89 @@ +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +This content balancing procedure takes target proportions for each group of items. +At each step the group with the lowest ratio of seen items to target is selected. + +http://dx.doi.org/10.1207/s15324818ame0403_4 +""" +struct GreedyForcedContentBalancer{InnerRuleT <: NextItemRule} <: NextItemRule + targets::Vector{Float64} + groups::Vector{Int} + inner_rule::InnerRuleT +end + +function GreedyForcedContentBalancer(targets::Dict, groups, bits...) + targets_vec = zeros(Float64, length(targets)) + groups_idxs = zeros(Int, length(groups)) + group_lookup = Dict{Any, Int}() + for (idx, group) in enumerate(groups) + if haskey(group_lookup, group) + group_idx = group_lookup[group] + else + group_idx = length(group_lookup) + 1 + group_lookup[group] = group_idx + end + groups_idxs[idx] = group_idx + end + if length(group_lookup) != length(targets) + error("Number of groups $(length(group_lookup)) does not match number of targets $(length(targets))") + end + for (group, group_idx) in pairs(group_lookup) + targets_vec[group_idx] = get(targets, group, 0.0) + end + GreedyForcedContentBalancer(targets_vec, groups_idxs, bits...) +end + +function GreedyForcedContentBalancer(targets::AbstractVector, groups, bits...) + GreedyForcedContentBalancer(targets, groups, NextItemRule(bits...)) +end + +function show(io::IO, ::MIME"text/plain", rule::GreedyForcedContentBalancer) + indent_io = indent(io, 2) + println(io, "Greedy + forced content balancing") + println(indent_io, "Target ratio: " * join(rule.targets, ", ")) + show(indent_io, MIME("text/plain"), rule.inner_rule) +end + +function next_item_bank(targets, groups, responses, items) + seen = zeros(UInt, size(targets)) + indices = responses.responses.indices + for group_idx in groups[indices] + seen[group_idx] += 1 + end + next_group_idx = argmin(seen ./ targets) + matching_indicator = groups .== next_group_idx + next_items = subset_view(items, matching_indicator) + return (next_items, matching_indicator) +end + +function best_item( + rule::GreedyForcedContentBalancer, + responses::TrackedResponses, + items +) + next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items) + inner_idx = best_item(rule.inner_rule, responses, next_items) + for (outer_idx, in_group) in enumerate(matching_indicator) + if in_group + inner_idx -= 1 + if inner_idx <= 0 + return outer_idx + end + end + end + error("No item found in group length $(length(next_items)) with inner index $inner_idx") +end + +function compute_criteria( + rule::GreedyForcedContentBalancer, + responses::TrackedResponses, + items +) + next_items, matching_indicator = next_item_bank(rule.targets, rule.groups, responses, items) + criteria = compute_criteria(rule.inner_rule, responses, next_items) + expanded = fill(Inf, length(items)) + expanded[matching_indicator] .= criteria + return expanded +end \ No newline at end of file diff --git a/src/next_item_rules/strategies/exhaustive.jl b/src/NextItemRules/strategies/exhaustive.jl similarity index 62% rename from src/next_item_rules/strategies/exhaustive.jl rename to src/NextItemRules/strategies/exhaustive.jl index 7b47429..c550b8c 100644 --- a/src/next_item_rules/strategies/exhaustive.jl +++ b/src/NextItemRules/strategies/exhaustive.jl @@ -1,21 +1,18 @@ -function exhaustive_search(objective::ItemCriterionT, - responses::TrackedResponseT, - items::AbstractItemBank)::Tuple{ - Int, - Float64 -} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses} - #pre_next_item(expectation_tracker, items) - objective_state = init_thread(objective, responses) +function exhaustive_search( + callback, + answered_items::AbstractVector{Int}, + items::AbstractItemBank +)::Tuple{Int, Float64} min_obj_idx::Int = -1 min_obj_val::Float64 = Inf for item_idx in eachindex(items) # TODO: Add these back in #@init irf_states_storage = zeros(Int, length(responses) + 1) - if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing) + if (findfirst(idx -> idx == item_idx, answered_items) !== nothing) continue end - obj_val = compute_criterion(objective, objective_state, responses, item_idx) + obj_val = callback(item_idx) if obj_val <= min_obj_val min_obj_val = obj_val @@ -25,17 +22,27 @@ function exhaustive_search(objective::ItemCriterionT, return (min_obj_idx, min_obj_val) end +function exhaustive_search(objective::ItemCriterionT, + responses::TrackedResponseT, + items::AbstractItemBank)::Tuple{ + Int, + Float64 +} where {ItemCriterionT <: ItemCriterion, TrackedResponseT <: TrackedResponses} + objective_state = init_thread(objective, responses) + return exhaustive_search(responses.responses.indices, items) do item_idx + return compute_criterion(objective, objective_state, responses, item_idx) + end +end + """ $(TYPEDEF) $(TYPEDFIELDS) """ -@with_kw struct ExhaustiveSearch <: NextItemStrategy - parallel::Bool = false -end +struct ExhaustiveSearch <: NextItemStrategy end function best_item( - rule::ItemStrategyNextItemRule{ExhaustiveSearch, ItemCriterionT}, + rule::ItemCriterionRule{ExhaustiveSearch, ItemCriterionT}, responses::TrackedResponses, items ) where {ItemCriterionT <: ItemCriterion} diff --git a/src/NextItemRules/strategies/pointwise.jl b/src/NextItemRules/strategies/pointwise.jl new file mode 100644 index 0000000..e0a5616 --- /dev/null +++ b/src/NextItemRules/strategies/pointwise.jl @@ -0,0 +1,30 @@ +struct PointwiseNextItemRule{CriterionT <: PointwiseItemCriterion, PointsT <: AbstractArray{<:Number}} <: NextItemRule + criterion::CriterionT + points::PointsT +end + +function best_item(rule::PointwiseNextItemRule, responses::TrackedResponses, items) + num_responses = length(responses.responses.indices) + next_index = num_responses + 1 + if next_index > length(rule.points) + error("Number of responses exceeds the number of points defined in the rule.") + end + current_point = rule.points[next_index] + idx, _ = exhaustive_search(responses.responses.indices, items) do item_idx + return compute_criterion(rule.criterion, ItemResponse(items, item_idx), current_point) + end + return idx +end + +function show(io::IO, ::MIME"text/plain", rule::PointwiseNextItemRule) + println(io, "Optimize a pointwise criterion at specified points") + indent_io = indent(io, 2) + points_desc = join(rule.points, ", ") + println(indent_io, "Points: $points_desc") + show(indent_io, MIME("text/plain"), rule.criterion) +end + + +function PointwiseFirstNextItemRule(criterion, points, rule) + FixedRuleSequencer((length(points),), (PointwiseNextItemRule(criterion, points), rule)) +end diff --git a/src/next_item_rules/strategies/random.jl b/src/NextItemRules/strategies/random.jl similarity index 100% rename from src/next_item_rules/strategies/random.jl rename to src/NextItemRules/strategies/random.jl diff --git a/src/NextItemRules/strategies/randomesque.jl b/src/NextItemRules/strategies/randomesque.jl new file mode 100644 index 0000000..b3e0ac5 --- /dev/null +++ b/src/NextItemRules/strategies/randomesque.jl @@ -0,0 +1,58 @@ +using QuickHeaps: BinaryHeap, FastMax, Node, get_val +using StatsBase: sample + + +function randomesque( + rng::AbstractRNG, + objective::ItemCriterion, + responses::TrackedResponses, + items::AbstractItemBank, + k::Int +) + objective_state = init_thread(objective, responses) + heap = BinaryHeap{Node{Int, Float64}}(FastMax) + sizehint!(heap, k) + for item_idx in eachindex(items) + if (findfirst(idx -> idx == item_idx, responses.responses.indices) !== nothing) + continue + end + + obj_val = compute_criterion(objective, objective_state, responses, item_idx) + + if length(heap) < k + push!(heap, Node(item_idx, obj_val)) + elseif obj_val < get_val(peek(heap)) + heap[1] = Node(item_idx, obj_val) + end + end + if length(heap) >= 1 + Tuple(sample(rng, heap)) + else + return (-1, Inf) + end +end + +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +http://dx.doi.org/10.1207/s15324818ame0204_6 +""" +struct RandomesqueStrategy <: NextItemStrategy + rng::AbstractRNG + k::Int +end + +RandomesqueStrategy(k::Int) = RandomesqueStrategy(Xoshiro(), k) + +function best_item( + rule::ItemCriterionRule{RandomesqueStrategy, ItemCriterionT}, + responses::TrackedResponses, + items +) where {ItemCriterionT <: ItemCriterion} + randomesque(rule.strategy.rng, rule.criterion, responses, items, rule.strategy.k)[1] +end + +function show(io::IO, ::MIME"text/plain", rule::RandomesqueStrategy) + println(io, "Randomesque strategy with k = $(rule.k)") +end \ No newline at end of file diff --git a/src/NextItemRules/strategies/sequential.jl b/src/NextItemRules/strategies/sequential.jl new file mode 100644 index 0000000..7f98653 --- /dev/null +++ b/src/NextItemRules/strategies/sequential.jl @@ -0,0 +1,68 @@ +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +""" +@kwdef struct FixedRuleSequencer{RulesT} <: NextItemRule + # Tuple of Ints + breaks::Tuple{Int} + # Tuple of NextItemRules + rules::RulesT +end + +#tuple_len(::NTuple{N, Any}) where {N} = Val{N}() + +function current_rule(rule::FixedRuleSequencer, responses::TrackedResponses) + for brk in 1:length(rule.breaks) + if length(responses) < rule.breaks[brk] + return rule.rules[brk] + end + end + return rule.rules[end] +end + +function best_item(rule::FixedRuleSequencer, responses::TrackedResponses, items) + return best_item(current_rule(rule, responses), responses, items) +end + +function compute_criteria(rule::FixedRuleSequencer, responses::TrackedResponses) + return compute_criteria(current_rule(rule, responses), responses) +end + +function show(io::IO, ::MIME"text/plain", rule::FixedRuleSequencer) + indent_io = indent(io, 2) + println(io, "Fixed rule sequencing:") + print(indent_io, "Firstly: ") + show(indent_io, MIME("text/plain"), rule.rules[1]) + for (responses, rule) in zip(rule.breaks, rule.rules[2:end]) + print(indent_io, "After $responses responses: ") + show(indent_io, MIME("text/plain"), rule) + end +end + +""" +$(TYPEDEF) +$(TYPEDFIELDS) + +""" +@kwdef struct MemoryNextItemRule{MemoryT} <: NextItemRule + item_idxs::MemoryT +end + +function best_item(rule::MemoryNextItemRule, responses::TrackedResponses, _items) + return rule.item_idxs[length(responses) + 1] + # XXX: A few problems with this: + # 1. Could run out of `item_idxs` + # 2. Could return an item not in `items` + # 3: Will not work if this is sequenced after items have already been administered + # TODO: Add some basic error checking -- can only panic +end + +function show(io::IO, ::MIME"text/plain", rule::MemoryNextItemRule) + item_list = join(rule.item_idxs, ", ") + println(io, "Present the items indexed: $item_list") +end + +function FixedFirstItem(item_idx::Int, rule::NextItemRule) + FixedRuleSequencer((1,), (MemoryNextItemRule((item_idx,)), rule)) +end \ No newline at end of file diff --git a/src/CatConfig.jl b/src/Rules.jl similarity index 79% rename from src/CatConfig.jl rename to src/Rules.jl index e809f39..3ec8c02 100644 --- a/src/CatConfig.jl +++ b/src/Rules.jl @@ -1,15 +1,17 @@ -module CatConfig +module Rules -export CatRules, CatLoopConfig +export CatRules using DocStringExtensions using PsychometricsBazaarBase.Parameters +using PsychometricsBazaarBase.IndentWrappers: indent using ..Aggregators: AbilityEstimator, AbilityTracker, ConsAbilityTracker, NullAbilityTracker using ..NextItemRules: NextItemRule using ..TerminationConditions: TerminationCondition using ..ConfigBase +import Base: show """ $(TYPEDEF) @@ -19,7 +21,7 @@ Configuration of the rules for a CAT. This all includes all the basic rules for the CAT's operation, but not the item bank, nor any of the interactivity hooks needed to actually run the CAT. -This may be more a more convenient layer to integrate than CatLoopConfig if you +This may be more a more convenient layer to integrate than CatLoop if you want to write your own CAT loop rather than using hooks. $(FUNCTIONNAME)(; next_item=..., termination_condition=..., ability_estimator=..., ability_tracker=...) @@ -79,6 +81,17 @@ function CatRules(bits...) ability_tracker = collect_trackers(next_item, ability_tracker)) end +function show(io::IO, ::MIME"text/plain", rules::CatRules) + print(io, "Next item rule: ") + show(io, MIME("text/plain"), rules.next_item) + println(io) + print(io, "Termination condition: ") + show(io, MIME("text/plain"), rules.termination_condition) + println(io) + print(io, "Ability estimator: ") + show(io, MIME("text/plain"), rules.ability_estimator) +end + function _find_ability_estimator_and_tracker(bits...) ability_estimator = AbilityEstimator(bits...) ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator) @@ -113,33 +126,4 @@ function collect_trackers(next_item_rule::NextItemRule, ability_tracker::Ability end end -""" -```julia -struct $(FUNCTIONNAME) -$(FUNCTIONNAME)(; rules=..., get_response=..., new_response_callback=...) -``` -$(TYPEDFIELDS) - -Configuration for a simulatable CAT. -""" -@with_kw struct CatLoopConfig{CatEngineT} <: CatConfigBase - """ - An object which implements the CAT engine. - Implementations exist for: - * [CatRules](@ref) - * [Stateful.StatefulCat](@ref ComputerAdaptiveTesting.Stateful.StatefulCat) - """ - rules::CatEngineT # e.g. CatRules - """ - The function `(index, label) -> Int8`` which obtains the testee's response for - a given question, e.g. by prompting or simulation from data. - """ - get_response::Any - """ - A callback called each time there is a new responses. - If provided, it is passed `(responses::TrackedResponses, terminating)`. - """ - new_response_callback = nothing -end - end diff --git a/src/Sim/Sim.jl b/src/Sim/Sim.jl new file mode 100644 index 0000000..75511c9 --- /dev/null +++ b/src/Sim/Sim.jl @@ -0,0 +1,36 @@ +module Sim + +using DataFrames: DataFrame +using ElasticArrays +using ElasticArrays: sizehint_lastdim! +using DocStringExtensions +using StatsBase +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse +using PsychometricsBazaarBase.Integrators +using PsychometricsBazaarBase.IndentWrappers: indent +using ..ConfigBase +using ..Responses +using ..Rules: CatRules +using ..Aggregators: TrackedResponses, + add_response!, + Aggregators, + AbilityIntegrator, + AbilityEstimator, + LikelihoodAbilityEstimator, + PosteriorAbilityEstimator, + ModeAbilityEstimator, + MeanAbilityEstimator, + LikelihoodAbilityEstimator, + RiemannEnumerationIntegrator +using ..NextItemRules: compute_criteria, best_item +import Base: show + +export CatRecorder, CatRecording +export CatLoop, record! +export run_cat, prompt_response, auto_responder + +include("./recorder.jl") +include("./loop.jl") +include("./run.jl") + +end diff --git a/src/Sim/loop.jl b/src/Sim/loop.jl new file mode 100644 index 0000000..629968c --- /dev/null +++ b/src/Sim/loop.jl @@ -0,0 +1,54 @@ +""" +```julia +struct $(FUNCTIONNAME) +$(FUNCTIONNAME)(; rules=..., get_response=..., new_response_callback=...) +``` +$(TYPEDFIELDS) + +Configuration for a simulatable CAT. +""" +struct CatLoop{CatEngineT} <: CatConfigBase + """ + An object which implements the CAT engine. + Implementations exist for: + * [CatRules](@ref) + * [Stateful.StatefulCat](@ref ComputerAdaptiveTesting.Stateful.StatefulCat) + """ + rules::CatEngineT # e.g. CatRules + """ + The function `(index, label) -> Int8`` which obtains the testee's response for + a given question, e.g. by prompting or simulation from data. + """ + get_response::Any + """ + A callback called each time there is a new responses. + If provided, it is passed `(responses::TrackedResponses, terminating)`. + """ + new_response_callback +end + +function CatLoop(; + rules, + get_response, + new_response_callback = nothing, + new_response_callbacks = Any[], + recorder = nothing +) + new_response_callbacks = collect(new_response_callbacks) + if new_response_callback !== nothing + push!(new_response_callbacks, new_response_callback) + end + if recorder !== nothing && showable(MIME("text/plain"), rules) + buf = IOBuffer() + show(buf, MIME("text/plain"), rules) + recorder.recording.rules_description = String(take!(buf)) + push!(new_response_callbacks, catrecorder_callback(recorder)) + end + function all_callbacks(responses, terminating) + for callback in new_response_callbacks + callback(responses, terminating) + end + nothing + end + CatLoop{typeof(rules)}(rules, get_response, all_callbacks) +end \ No newline at end of file diff --git a/src/Sim/recorder.jl b/src/Sim/recorder.jl new file mode 100644 index 0000000..b897a88 --- /dev/null +++ b/src/Sim/recorder.jl @@ -0,0 +1,397 @@ +function empty_capacity(typ, size) + ret = typ[] + sizehint!(ret, size) + return ret +end + +function empty_capacity(typ, dims...) + ret = ElasticArray{typ}(undef, dims[1:end - 1]..., 0) + sizehint_lastdim!(ret, dims[end]) + return ret +end + +# Elastic arrays do not support `push!` directly, so we define our own +elastic_push!(xs::AbstractVector, value) = push!(xs, value) +elastic_push!(xs::ElasticArray, value) = append!(xs, value) + +Base.@kwdef mutable struct CatRecording{LikelihoodsT <: NamedTuple} + #ability_ests::AbilityVecT + #xs::Union{Nothing, AbilityVecT} + #likelihoods::Matrix{Float64} + #raw_likelihoods::Matrix{Float64} + data::LikelihoodsT + item_index::Vector{Int} + item_correctness::Vector{Bool} + rules_description::Union{Nothing, String} = nothing +end + +function Base.getproperty(obj::CatRecording, sym::Symbol) + if hasfield(CatRecording, sym) + return getfield(obj, sym) + else + return getproperty(obj.data, sym) + end +end + +Base.@kwdef struct CatRecorder{RequestsT <: NamedTuple, LikelihoodsT <: NamedTuple} + recording::CatRecording{LikelihoodsT} + requests::RequestsT + #integrator::AbilityIntegrator + #raw_estimator::LikelihoodAbilityEstimator + #ability_estimator::AbilityEstimator +end + +function CatRecording( + data, + expected_responses=0 +) + CatRecording(; + data=data, + item_index=empty_capacity(Int, expected_responses), + item_correctness=empty_capacity(Bool, expected_responses) + ) +end + +function prepare_dataframe(recording::CatRecording) + cols::Vector{Pair{String, Vector{Any}}} = [ + "Item" => recording.item_index, + "Response" => recording.item_correctness, + ] + for (name, value) in pairs(recording.data) + #@show name value.type keys(value) size(value.data) + if value.data isa AbstractVector + push!(cols, String(name) => value.data) + end + end + return DataFrame(cols) +end + +function show(io::IO, ::MIME"text/plain", recording::CatRecording) + println(io, "Recording of a Computer-Adaptive Test") + if recording.rules_description === nothing + println(io, " Unknown CAT configuration") + else + println(io, " CAT configuration:") + for line in split(strip(recording.rules_description, '\n'), "\n") + println(io, " ", line) + end + end + println(io) + println(io, " Recorded information:") + df = prepare_dataframe(recording) + buf = IOBuffer() + show(buf, MIME("text/plain"), df; summary=false, eltypes=false, rowlabel=:Number) + seekstart(buf) + for line in eachline(buf) + println(io, " ", line) + end + #println(io) + #println(io, " Final information:") +end + +#= +function CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities = nothing) + num_values = num_questions * num_respondents + if xs === nothing + xs_vec = nothing + else + xs_vec = collect(xs) + end + + CatRecorder(1, + 1, + points, + zeros(Int, num_values), + ability_ests, + zeros(Float64, num_values), + zeros(Int, num_values), + xs_vec, + zeros(points, num_values), + zeros(points, num_values), + zeros(points, num_values), + zeros(num_questions, num_respondents), + zeros(Int, num_questions, num_respondents), + zeros(Bool, num_questions, num_respondents), + Dict{Tuple{Int, Int}, Int}(), + actual_abilities) +end +=# + +function record!(recording::CatRecording, responses; data...) + #push_ability_est!(recording.ability_ests, recording.col_idx, ability_est) + + item_index = responses.indices[end] + item_correct = responses.values[end] > 0 + push!(recording.item_index, item_index) + push!(recording.item_correctness, item_correct) +end + +#= +""" +$(TYPEDSIGNATURES) +""" +function CatRecorder( + xs, + points, + ability_ests, + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing) + CatRecorder( + CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities + ), + AbilityIntegrator(integrator), + raw_estimator, + ability_estimator, + ) +end + +function CatRecorder( + xs::AbstractVector{Float64}, + responses, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + points = size(xs, 1) + num_questions = size(responses, 1) + num_respondents = size(responses, 2) + num_values = num_questions * num_respondents + CatRecorder( + xs, + points, + zeros(num_values), + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractMatrix{Float64}, + responses, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + dims = size(xs, 1) + points = size(xs, 2) + num_questions = size(responses, 1) + num_respondents = size(responses, 2) + num_values = num_questions * num_respondents + CatRecorder(xs, + points, + zeros(dims, num_values), + num_questions, + num_respondents, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractVector{Float64}, + max_responses::Int, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + points = size(xs, 1) + CatRecorder(xs, + points, + zeros(max_responses), + max_responses, + 1, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end + +function CatRecorder( + xs::AbstractMatrix{Float64}, + max_responses::Int, + integrator, + raw_estimator, + ability_estimator, + actual_abilities = nothing + ) + dims = size(xs, 1) + points = size(xs, 2) + CatRecorder(xs, + points, + zeros(dims, max_responses), + max_responses, + 1, + integrator, + raw_estimator, + ability_estimator, + actual_abilities) +end +=# + +function name_to_label(name) + titlecase(join(split(String(name), "_"), " ")) +end + +function CatRecorder(dims::Int, expected_responses::Int; requests...) + out = [] + sizehint!(out, length(requests)) + for (name, request) in pairs(requests) + extra = (;) + if request.type in (:ability, :ability_stddev) + data = empty_capacity(Float64, expected_responses) + elseif request.type == :ability_distribution + if dims == 0 + data = empty_capacity(Float64, length(request.points), expected_responses) + else + data = empty_capacity(Float64, dims, length(request.points), expected_responses) + end + extra = (; points = request.points) + end + push!(out, (name => (; + label=haskey(request, :label) ? request.label : name_to_label(name), + type=request.type, + data, + extra... + ))) + end + return CatRecorder(; + recording=CatRecording(NamedTuple(out), expected_responses), + requests=NamedTuple(requests), + ) + #= + CatRecording( + xs, + points, + ability_ests, + num_questions, + num_respondents, + actual_abilities + ), + AbilityIntegrator(integrator), + raw_estimator, + ability_estimator + =# +end + + +function push_ability_est!(ability_ests::AbstractMatrix{Float64}, col_idx, ability_est) + ability_ests[:, col_idx] = ability_est +end + +function push_ability_est!(ability_ests::AbstractVector{Float64}, col_idx, ability_est) + ability_ests[col_idx] = ability_est +end + +function eachmatcol(xs::AbstractMatrix) + eachcol(xs) +end + +function eachmatcol(xs::AbstractVector) + xs +end + +#= +function save_sampled(xs::Nothing, integrator::RiemannEnumerationIntegrator, + recorder::CatRecorder, tracked_responses, ir, item_correct) + # In this case, the item bank is probably sampled so we can use that + + # Save likelihoods + dist_est = distribution_estimator(recorder.ability_estimator) + denom = normdenom(integrator, dist_est, tracked_responses) + recorder.likelihoods[:, recorder.col_idx] = function_ys( + Aggregators.pdf( + dist_est, + tracked_responses + ) + ) ./ denom + raw_denom = normdenom(integrator, recorder.raw_estimator, tracked_responses) + recorder.raw_likelihoods[:, recorder.col_idx] = function_ys( + Aggregators.pdf( + recorder.raw_estimator, + tracked_responses + ) + ) ./ raw_denom + + # Save item responses + recorder.item_responses[:, recorder.col_idx] = item_ys(ir, item_correct) +end +=# + +function sample_likelihood(tracked_responses, xs, dist_est, integrator) + # Save likelihoods + num = Aggregators.pdf.( + dist_est, + tracked_responses, + eachmatcol(xs) + ) + denom = normdenom(integrator, dist_est, tracked_responses) + return num ./ denom +end + +#= + raw_denom = normdenom(integrator, recorder.raw_estimator, tracked_responses) + recorder.raw_likelihoods[:, recorder.col_idx] = Aggregators.pdf.( + Ref(recorder.raw_estimator), + Ref(tracked_responses), + eachmatcol(xs)) ./ raw_denom +=# + +function service_requests!( + #xs, integrator, recorder::CatRecorder, tracked_responses, ir, item_correct) + recorder::CatRecorder, tracked_responses, ir, item_correct +) + out = recorder.recording.data + for (name, request) in pairs(recorder.requests) + if request.type in (:ability, :ability_stddev) + push!(out[name].data, request.estimator(tracked_responses)) + elseif request.type == :ability_distribution + likelihood_sample = sample_likelihood(tracked_responses, request.points, request.estimator, request.integrator) + elastic_push!(out[name].data, likelihood_sample) + end + end + + #= + # Save item responses + recorder.item_responses[:, recorder.col_idx] = resp.(Ref(ir), + item_correct, + eachmatcol(xs)) + =# +end + +""" +$(TYPEDSIGNATURES) +""" +function record!(recorder::CatRecorder, tracked_responses) + item_index = tracked_responses.responses.indices[end] + item_correct = tracked_responses.responses.values[end] > 0 + ir = ItemResponse(tracked_responses.item_bank, item_index) + service_requests!(recorder, tracked_responses, ir, item_correct) + record!(recorder.recording, tracked_responses.responses) +end + +function catrecorder_callback(recoder::CatRecorder) + return (tracked_responses, _) -> record!(recoder, tracked_responses) +end diff --git a/src/Sim.jl b/src/Sim/run.jl similarity index 79% rename from src/Sim.jl rename to src/Sim/run.jl index 39bc9b8..b6327f7 100644 --- a/src/Sim.jl +++ b/src/Sim/run.jl @@ -1,15 +1,3 @@ -module Sim - -using DocStringExtensions -using StatsBase -using FittedItemBanks: AbstractItemBank, ResponseType -using ..Responses -using ..CatConfig: CatLoopConfig, CatRules -using ..Aggregators: TrackedResponses, add_response!, Aggregators -using ..NextItemRules: compute_criteria, best_item - -export run_cat, prompt_response, auto_responder - """ $(TYPEDSIGNATURES) @@ -45,17 +33,17 @@ end """ ```julia -$(FUNCTIONNAME)(cat_config::CatLoopConfig, item_bank::AbstractItemBank; ib_labels=nothing) +$(FUNCTIONNAME)(cat_config::CatLoop, item_bank::AbstractItemBank; ib_labels=nothing) ``` -Run a given [CatLoopConfig](@ref) `cat_config` on the given `item_bank`. +Run a given [CatLoop](@ref) `cat_config` on the given `item_bank`. If `ib_labels` is not given, default labels of the form `<>` are passed to the callback. """ -function run_cat(cat_config::CatLoopConfig{RulesT}, +function run_cat(loop::CatLoop{RulesT}, item_bank::AbstractItemBank; ib_labels = nothing) where {RulesT <: CatRules} - (; rules, get_response, new_response_callback) = cat_config + (; rules, get_response, new_response_callback) = loop (; next_item, termination_condition, ability_estimator, ability_tracker) = rules responses = TrackedResponses(BareResponses(ResponseType(item_bank)), item_bank, @@ -93,6 +81,4 @@ function run_cat(cat_config::CatLoopConfig{RulesT}, end end (responses.responses, ability_estimator(responses)) -end - -end +end \ No newline at end of file diff --git a/src/Stateful.jl b/src/Stateful.jl index eb7d535..f8482e6 100644 --- a/src/Stateful.jl +++ b/src/Stateful.jl @@ -7,14 +7,14 @@ module Stateful using DocStringExtensions -using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp -using ..Aggregators: TrackedResponses, Aggregators -using ..CatConfig: CatLoopConfig, CatRules +using FittedItemBanks: AbstractItemBank, ResponseType, ItemResponse, resp_vec +using ..Aggregators: TrackedResponses, Aggregators, pdf, distribution_estimator +using ..Rules: CatRules using ..Responses: BareResponses, Response, Responses using ..NextItemRules: compute_criteria, best_item -using ..Sim: Sim, item_label +using ..Sim: CatLoop, Sim, item_label -export StatefulCat, StatefulCatConfig +export StatefulCat, StatefulCatRules public next_item, ranked_items, item_criteria public add_response!, rollback!, reset!, get_responses, get_ability @@ -45,6 +45,7 @@ $(FUNCTIONNAME)(config::StatefulCat) -> AbstractVector{IndexT} Return a vector of indices of the sorted from best to worst item according to the CAT. """ function ranked_items end +function ranked_items(::StatefulCat) nothing end """ ```julia @@ -56,6 +57,7 @@ Returns a vector of criteria values for each item in the item bank. The criteria can vary, but should attempt to interoperate with ComputerAdaptiveTesting.jl. """ function item_criteria end +function item_criteria(::StatefulCat) nothing end """ ```julia @@ -124,6 +126,15 @@ but should attempt to interoperate with ComputerAdaptiveTesting.jl. """ function get_ability end +""" +```julia +$(FUNCTIONNAME)(config::StatefulCat, ability::AbilityT) -> Float64 +``` + +TODO +""" +function likelihood end + """ ```julia $(FUNCTIONNAME)(config::StatefulCat) @@ -135,16 +146,17 @@ function item_bank_size end """ ```julia -$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, response::ResponseT, ability::AbilityT) -> Float +$(FUNCTIONNAME)(config::StatefulCat, index::IndexT, ability::AbilityT) -> AbstractVector{Float} ```` -Return the probability of a `response` to item at `index` for someone with -a certain `ability` according to the IRT model backing the CAT. +Return the vector of probability of different responses to item at +`index` for someone with a certain `ability` according to the IRT +model backing the CAT. """ -function item_response_function end +function item_response_functions end ## Running the CAT -function Sim.run_cat(cat_config::CatLoopConfig{RulesT}, +function Sim.run_cat(cat_config::CatLoop{RulesT}, ib_labels = nothing) where {RulesT <: StatefulCat} (; stateful_cat, get_response, new_response_callback) = cat_config while true @@ -178,51 +190,57 @@ $(TYPEDSIGNATURES) This is a the `StatefulCat` implementation in terms of `CatRules`. It is also the de-facto standard for the behavior of the interface. """ -struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat +struct StatefulCatRules{TrackedResponsesT <: TrackedResponses} <: StatefulCat rules::CatRules tracked_responses::Ref{TrackedResponsesT} end -function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank) +function StatefulCatRules(rules::CatRules, item_bank::AbstractItemBank) bare_responses = BareResponses(ResponseType(item_bank)) tracked_responses = TrackedResponses( bare_responses, item_bank, rules.ability_tracker ) - return StatefulCatConfig(rules, Ref(tracked_responses)) + return StatefulCatRules(rules, Ref(tracked_responses)) end -function next_item(config::StatefulCatConfig) +StatefulCat(rules::CatRules, item_bank::AbstractItemBank) = StatefulCatRules(rules, item_bank) + +function next_item(config::StatefulCatRules) return best_item(config.rules.next_item, config.tracked_responses[]) end -function ranked_items(config::StatefulCatConfig) - return sortperm(compute_criteria( - config.rules.next_item, config.tracked_responses[])) +function ranked_items(config::StatefulCatRules) + criteria = compute_criteria( + config.rules.next_item, config.tracked_responses[]) + if criteria === nothing + return nothing + end + return sortperm(criteria) end -function item_criteria(config::StatefulCatConfig) +function item_criteria(config::StatefulCatRules) return compute_criteria( config.rules.next_item, config.tracked_responses[]) end -function add_response!(config::StatefulCatConfig, index, response) +function add_response!(config::StatefulCatRules, index, response) tracked_responses = config.tracked_responses[] Responses.add_response!( tracked_responses, Response( ResponseType(tracked_responses.item_bank), index, response)) end -function rollback!(config::StatefulCatConfig) +function rollback!(config::StatefulCatRules) Responses.pop_response!(config.tracked_responses[]) end -function reset!(config::StatefulCatConfig) +function reset!(config::StatefulCatRules) empty!(config.tracked_responses[]) end -function set_item_bank!(config::StatefulCatConfig, item_bank) +function set_item_bank!(config::StatefulCatRules, item_bank) bare_responses = BareResponses(ResponseType(item_bank)) config.tracked_responses[] = TrackedResponses( bare_responses, @@ -231,22 +249,26 @@ function set_item_bank!(config::StatefulCatConfig, item_bank) ) end -function get_responses(config::StatefulCatConfig) +function get_responses(config::StatefulCatRules) return config.tracked_responses[].responses end -function get_ability(config::StatefulCatConfig) +function get_ability(config::StatefulCatRules) return (config.rules.ability_estimator(config.tracked_responses[]), nothing) end -function item_bank_size(config::StatefulCatConfig) +function likelihood(config::StatefulCatRules, ability) + pdf(distribution_estimator(config.rules.ability_estimator), config.tracked_responses[], ability) +end + +function item_bank_size(config::StatefulCatRules) return length(config.tracked_responses[].item_bank) end -function item_response_function(config::StatefulCatConfig, index, response, ability) +function item_response_functions(config::StatefulCatRules, index, ability) item_bank = config.tracked_responses[].item_bank item_response = ItemResponse(item_bank, index) - return resp(item_response, response, ability) + return resp_vec(item_response, ability) end ## TODO: Implementation for MaterializedDecisionTree diff --git a/src/TerminationConditions.jl b/src/TerminationConditions.jl index f0b5261..45867c7 100644 --- a/src/TerminationConditions.jl +++ b/src/TerminationConditions.jl @@ -6,10 +6,10 @@ using ..Aggregators: TrackedResponses using ..ConfigBase using PsychometricsBazaarBase.ConfigTools: @returnsome, find1_instance using FittedItemBanks +import Base: show -export TerminationCondition, - FixedItemsTerminationCondition, SimpleFunctionTerminationCondition -export RunForeverTerminationCondition +export TerminationCondition, FixedLength, TerminationTest +export RunForever """ $(TYPEDEF) @@ -24,24 +24,28 @@ end $(TYPEDEF) $(TYPEDFIELDS) """ -struct FixedItemsTerminationCondition{} <: TerminationCondition +struct FixedLength{} <: TerminationCondition num_items::Int64 end -function (condition::FixedItemsTerminationCondition)(responses::TrackedResponses, +function (condition::FixedLength)(responses::TrackedResponses, items::AbstractItemBank) length(responses) >= condition.num_items end -struct SimpleFunctionTerminationCondition{F} <: TerminationCondition - func::F +function show(io::IO, ::MIME"text/plain", condition::FixedLength) + println(io, "Terminate test after administering $(condition.num_items) items") end -function (condition::SimpleFunctionTerminationCondition)(responses::TrackedResponses, + +struct TerminationTest{F} <: TerminationCondition + condition::F +end +function (condition::TerminationTest)(responses::TrackedResponses, items::AbstractItemBank) - condition.func(responses, items) + condition.condition(responses, items) end -struct RunForeverTerminationCondition <: TerminationCondition end -function (condition::RunForeverTerminationCondition)(::TrackedResponses, ::AbstractItemBank) +struct RunForever <: TerminationCondition end +function (condition::RunForever)(::TrackedResponses, ::AbstractItemBank) return false end diff --git a/src/logitembank.jl b/src/logitembank.jl index 23312e5..bfb19c1 100644 --- a/src/logitembank.jl +++ b/src/logitembank.jl @@ -21,18 +21,18 @@ inner_ir(ir::ItemResponse{<:LogItemBank}) = ItemResponse(ir.item_bank.inner, ir. ## TODO: Support item banks with other response types e.g. Float32 function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, θ) - exp(ULogarithmic{Float64}, FittedItemBanks.log_resp(inner_ir(ir), θ)) + exp(ULogarithmic, FittedItemBanks.log_resp(inner_ir(ir), θ)) end function FittedItemBanks.resp(ir::ItemResponse{<:LogItemBank}, response, θ) exp( - ULogarithmic{Float64}, + ULogarithmic, FittedItemBanks.log_resp(inner_ir(ir), response, θ) ) end function FittedItemBanks.resp_vec(ir::ItemResponse{<:LogItemBank}, θ) - exp.(ULogarithmic{Float64}, FittedItemBanks.log_resp_vec(inner_ir(ir), θ)) + exp.(ULogarithmic, FittedItemBanks.log_resp_vec(inner_ir(ir), θ)) end @forward LogItemBank.inner Base.length, diff --git a/src/next_item_rules/combinators/likelihood.jl b/src/next_item_rules/combinators/likelihood.jl deleted file mode 100644 index 03da6a6..0000000 --- a/src/next_item_rules/combinators/likelihood.jl +++ /dev/null @@ -1,19 +0,0 @@ -struct LikelihoodWeightedItemCriterion{ - PointwiseItemCriterionT <: PointwiseItemCriterion, - AbilityIntegratorT <: AbilityIntegrator, - AbilityEstimatorT <: DistributionAbilityEstimator -} <: ItemCriterion - criterion::PointwiseItemCriterionT - integrator::AbilityIntegratorT - estimator::AbilityEstimatorT -end - -function compute_criterion( - lwic::LikelihoodWeightedItemCriterion, - tracked_responses::TrackedResponses, - item_idx -) - func = FunctionProduct( - pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx)) - lwic.integrator(func, 0, lwic.estimator, tracked_responses) -end diff --git a/src/next_item_rules/porcelain/aliases.jl b/src/next_item_rules/porcelain/aliases.jl deleted file mode 100644 index 392b9a5..0000000 --- a/src/next_item_rules/porcelain/aliases.jl +++ /dev/null @@ -1,98 +0,0 @@ -""" -This mapping provides next item rules through the same names that they are -available through in the `catR` R package. TODO compability with `mirtcat` -""" -const catr_next_item_aliases = Dict( - "MFI" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - InformationItemCriterion(ability_estimator)), - "bOpt" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - UrryItemCriterion(ability_estimator)), - "MEPV" => (ability_estimator; parallel = true) -> ItemStrategyNextItemRule( - ExhaustiveSearch(parallel), - ExpectationBasedItemCriterion(ability_estimator, - AbilityVarianceStateCriterion(ability_estimator))) #"MLWI", #"MPWI", #"MEI", -) - -#"thOpt", -#"progressive", -#"proportional", -#"KL", -#"KLP", -#"GDI", -#"GDIP", -#"random" - -function _mirtcat_helper(item_criterion_callback) - function _helper(bits...; ability_estimator = nothing) - ability_estimator = AbilityEstimator(bits...; ability_estimator = ability_estimator) - item_criterion = item_criterion_callback( - [bits..., ability_estimator], ability_estimator) - return ItemStrategyNextItemRule(ExhaustiveSearch(), item_criterion) - end - return _helper -end - -const mirtcat_next_item_aliases = Dict( - # "MI' for the maximum information - "MI" => _mirtcat_helper((bits, ability_estimator) -> InformationItemCriterion(ability_estimator)), - # 'MEPV' for minimum expected posterior variance - "MEPV" => _mirtcat_helper((bits, ability_estimator) -> ExpectationBasedItemCriterion( - ability_estimator, - AbilityVarianceStateCriterion(bits...))), - "Drule" => _mirtcat_helper((bits, ability_estimator) -> DRuleItemCriteron(ability_estimator)), - "Trule" => _mirtcat_helper((bits, ability_estimator) -> TRuleItemCriteron(ability_estimator)) -) - -# 'MLWI' for maximum likelihood weighted information -#"MLWI" => _mirtcat_helper((bits, ability_estimator) -> InformationItemCriterion(ability_estimator)) -# 'MPWI' for maximum posterior weighted information -# 'MEI' for maximum expected information -# 'IKLP' as well as 'IKL' for the integration based Kullback-Leibler criteria with and without the prior density weight, -# respectively, and their root-n items administered weighted counter-parts, 'IKLn' and 'IKLPn'. -#= -Possible inputs for multidimensional adaptive tests include: 'Drule' for the -maximum determinant of the information matrix, 'Trule' for the maximum -(potentially weighted) trace of the information matrix, 'Arule' for the minimum (potentially weighted) trace of the asymptotic covariance matrix, 'Erule' -for the minimum value of the information matrix, and 'Wrule' for the weighted -information criteria. For each of these rules, the posterior weight for the latent trait scores can also be included with the 'DPrule', 'TPrule', 'APrule', -'EPrule', and 'WPrule', respectively. -Applicable to both unidimensional and multidimensional tests are the 'KL' and -'KLn' for point-wise Kullback-Leibler divergence and point-wise KullbackLeibler with a decreasing delta value (delta*sqrt(n), where n is the number -of items previous answered), respectively. The delta criteria is defined in the -design object -Non-adaptive methods applicable even when no mo object is passed are: 'random' -to randomly select items, and 'seq' for selecting items sequentially -=# - -const mirtcat_ability_estimator_aliases = Dict( -# "MAP" for the maximum a-posteriori (i.e, Bayes modal) -# "ML" for maximum likelihood -# "WLE" for weighted likelihood estimation -# "EAPsum" for the expected a-posteriori for each sum score -# "EAP" for the expected a-posteriori (default). -) - -#= -• "plausible" for a single plausible value imputation for each case. This is -equivalent to setting plausible.draws = 1 -• "classify" for the posteriori classification probabilities (only applicable -when the input model was of class MixtureClass) -=# - -function mirtcat_quadpts(nfact) - if nfact == 1 - return 121 - elseif nfact == 2 - return 61 - elseif nfact == 3 - return 31 - elseif nfact == 4 - return 19 - elseif nfact == 5 - return 11 - else - return 5 - end -end diff --git a/src/precompiles.jl b/src/precompiles.jl index c7b44f5..c05ef84 100644 --- a/src/precompiles.jl +++ b/src/precompiles.jl @@ -7,8 +7,10 @@ using PrecompileTools: @compile_workload, @setup_workload using Random: default_rng using .Aggregators: LikelihoodAbilityEstimator, MeanAbilityEstimator, GriddedAbilityTracker, AbilityIntegrator - using .NextItemRules: catr_next_item_aliases, preallocate + using .NextItemRules: preallocate, ExhaustiveSearch, ItemCriterionRule, + ExpectationBasedItemCriterion, AbilityVariance using .Stateful: Stateful + using .ComputerAdaptiveTesting: CatRules rng = default_rng(42) spec = SimpleItemBankSpec(StdModel2PL(), OneDimContinuousDomain(), BooleanResponse()) @@ -19,10 +21,13 @@ using PrecompileTools: @compile_workload, @setup_workload lh_grid_tracker = GriddedAbilityTracker(lh_ability_est, integrator) ability_integrator = AbilityIntegrator(integrator, lh_grid_tracker) ability_estimator = MeanAbilityEstimator(lh_ability_est, ability_integrator) - next_item_rule = catr_next_item_aliases["MEPV"](ability_estimator) - cat = Stateful.StatefulCatConfig(CatConfig.CatRules(; + next_item_rule = ItemCriterionRule( + ExhaustiveSearch(), + ExpectationBasedItemCriterion(ability_estimator, + AbilityVariance(ability_estimator))) + cat = Stateful.StatefulCatRules(CatRules(; next_item=next_item_rule, - termination_condition=TerminationConditions.RunForeverTerminationCondition(), + termination_condition=TerminationConditions.RunForever(), ability_estimator=ability_estimator ), item_bank) Stateful.add_response!(cat, 1, 0) diff --git a/src/vendor/PushVectors.jl b/src/vendor/PushVectors.jl deleted file mode 100644 index 9949a43..0000000 --- a/src/vendor/PushVectors.jl +++ /dev/null @@ -1,102 +0,0 @@ -module PushVectors - -export PushVector, finish! - -mutable struct PushVector{T, V <: AbstractVector{T}} <: AbstractVector{T} - "Vector used for storage." - parent::V - "Number of elements held by `parent`." - len::Int -end - -""" - PushVector{T}(sizehint = 4) - -Create a `PushVector` for elements typed `T`, with no initial elements. `sizehint` -determines the initial allocated size. -""" -function PushVector{T}(sizehint::Integer = 4) where {T} - sizehint ≥ 0 || throw(DomainError(sizehint, "Invalid initial size.")) - PushVector(Vector{T}(undef, sizehint), 0) -end - -@inline Base.length(v::PushVector) = v.len - -@inline Base.size(v::PushVector) = (v.len,) - -function Base.sizehint!(v::PushVector, n) - if length(v.parent) < n || n ≥ v.len - resize!(v.parent, n) - end - nothing -end - -@inline function Base.getindex(v::PushVector, i) - @boundscheck checkbounds(v, i) - @inbounds v.parent[i] -end - -@inline function Base.setindex!(v::PushVector, x, i) - @boundscheck checkbounds(v, i) - @inbounds v.parent[i] = x -end - -function Base.push!(v::PushVector, x) - v.len += 1 - if v.len > length(v.parent) - resize!(v.parent, v.len * 2) - end - v.parent[v.len] = x - v -end - -function Base.pop!(v::PushVector) - isempty(v) && throw(ArgumentError("vector must be non-empty")) - x = v.parent[v.len] - v.len -= 1 - x -end - -function Base.resize!(v::PushVector, n) - if n < v.len - v.len = n - elseif n > v.len - if n > length(v.parent) - resize!(v.parent, n) - end - v.len = n - end - v -end - -Base.empty!(v::PushVector) = (v.len = 0; v) - -function Base.append!(v::PushVector, xs) - ι_xs = eachindex(xs) # allow generalized indexing - l = length(ι_xs) - if l ≤ 4 - for x in xs - push!(v, x) - end - else - L = l + v.len - if length(v.parent) < L - resize!(v.parent, nextpow(2, nextpow(2, L))) - end - @inbounds copyto!(v.parent, v.len + 1, xs, first(ι_xs), l) - v.len += l - end - v -end - -""" - finish!(v) - -Shrink the buffer `v` to its current content and return that vector. - -!!! NOTE - Consequences are undefined if you modify `v` after this. -""" -finish!(v::PushVector) = resize!(v.parent, v.len) - -end # module diff --git a/test/ability_estimator_1d.jl b/test/ability_estimator_1d.jl index 7c80037..c76bea0 100644 --- a/test/ability_estimator_1d.jl +++ b/test/ability_estimator_1d.jl @@ -32,7 +32,7 @@ tracked_responses_1d = TrackedResponses(responses_1d, item_bank_1d, NullAbilityT integrator_1d = AbilityIntegrator(FixedGKIntegrator(-6.0, 6.0, 61)) optimizer_1d = AbilityOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) lh_est_1d = LikelihoodAbilityEstimator() -pa_est_1d = PriorAbilityEstimator(Normal(1.0, 0.2)) +pa_est_1d = PosteriorAbilityEstimator(Normal(1.0, 0.2)) eap_1d = MeanAbilityEstimator(pa_est_1d, integrator_1d) map_1d = ModeAbilityEstimator(pa_est_1d, optimizer_1d) mle_mean_1d = MeanAbilityEstimator(lh_est_1d, integrator_1d) @@ -72,7 +72,7 @@ mle_mode_1d = ModeAbilityEstimator(lh_est_1d, optimizer_1d) ) end - ability_variance_state_criterion = AbilityVarianceStateCriterion( + ability_variance_state_criterion = AbilityVariance( lh_est_1d, integrator_1d) ability_variance_item_criterion = ExpectationBasedItemCriterion( mle_mean_1d, diff --git a/test/ability_estimator_2d.jl b/test/ability_estimator_2d.jl index 6929bea..7e820d7 100644 --- a/test/ability_estimator_2d.jl +++ b/test/ability_estimator_2d.jl @@ -34,7 +34,7 @@ integrator_2d = AbilityIntegrator(MultiDimFixedGKIntegrator([-6.0, -6.0], [6.0, optimizer_2d = AbilityOptimizer(MultiDimOptimOptimizer( [-6.0, -6.0], [6.0, 6.0], NelderMead())) lh_est_2d = LikelihoodAbilityEstimator() -pa_est_2d = PriorAbilityEstimator(MvNormal([1.0, 1.0], ScalMat(2, 0.2))) +pa_est_2d = PosteriorAbilityEstimator(MvNormal([1.0, 1.0], ScalMat(2, 0.2))) eap_2d = MeanAbilityEstimator(pa_est_2d, integrator_2d) map_2d = ModeAbilityEstimator(pa_est_2d, optimizer_2d) mle_mean_2d = MeanAbilityEstimator(lh_est_2d, integrator_2d) @@ -67,7 +67,7 @@ mle_mode_2d = ModeAbilityEstimator(lh_est_2d, optimizer_2d) # Item closer to the current estimate (1, 1) close_item = 5 # Item further from the current estimate - far_item = 6 + far_item = 7 close_info = compute_criterion( information_criterion, tracked_responses_2d, close_item) diff --git a/test/compat.jl b/test/compat.jl new file mode 100644 index 0000000..984e7ba --- /dev/null +++ b/test/compat.jl @@ -0,0 +1,76 @@ +@testset "Compat" begin + using FittedItemBanks.DummyData: dummy_full + using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, + BooleanResponse + using ComputerAdaptiveTesting.Aggregators: TrackedResponses, NullAbilityTracker + using ComputerAdaptiveTesting.TerminationConditions: FixedLength + using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule + using ComputerAdaptiveTesting.Responses: BareResponses, ResponseType + using ComputerAdaptiveTesting: Stateful + using ComputerAdaptiveTesting: require_testext + using ComputerAdaptiveTesting.ItemBanks: LogItemBank + using ComputerAdaptiveTesting.NextItemRules: best_item + using ComputerAdaptiveTesting: Compat + using Test: @test, @testset + + #include("./dummy.jl") + #using .Dummy + using Random + + rng = Random.default_rng(42) + (item_bank, abilities, true_responses) = dummy_full( + Random.default_rng(42), + SimpleItemBankSpec(StdModel3PL(), OneDimContinuousDomain(), BooleanResponse()); + num_questions = 4, + num_testees = 1 + ) + half_responses = BareResponses( + ResponseType(item_bank), + [1, 2], + Vector{Bool}(true_responses[1:2, 1]) + ) + + @testset "CatJL" begin + tracked_responses = TrackedResponses(half_responses, item_bank, NullAbilityTracker()) + for method in ("EAP", "MAP", "ML") + @testset "Ability estimation $method" begin + rules = Compat.MirtCAT.assemble_rules(; + criteria="MI", + method + ) + @test -6.0 <= rules.ability_estimator(tracked_responses) <= 6.0 + end + end + for criteria in ("MI", "MEPV") + @testset "Next item $criteria" begin + rules = Compat.MirtCAT.assemble_rules(; + criteria, + method="EAP" + ) + @test best_item(rules.next_item, tracked_responses) in 3:4 + end + end + end + + @testset "CatR" begin + tracked_responses = TrackedResponses(half_responses, item_bank, NullAbilityTracker()) + for method in ("EAP", "BM", "ML") + @testset "Ability estimation $method" begin + rules = Compat.CatR.assemble_rules(; + criterion="MFI", + method + ) + @test -6.0 <= rules.ability_estimator(tracked_responses) <= 6.0 + end + end + for criterion in ("MFI", "bOpt", "MEPV", "MEI") + @testset "Next item $criterion" begin + rules = Compat.CatR.assemble_rules(; + criterion, + method="EAP" + ) + @test best_item(rules.next_item, tracked_responses) in 3:4 + end + end + end +end \ No newline at end of file diff --git a/test/dt.jl b/test/dt.jl index 8c89b07..65d0660 100644 --- a/test/dt.jl +++ b/test/dt.jl @@ -9,19 +9,19 @@ ability_estimator = MeanAbilityEstimator(LikelihoodAbilityEstimator(), integrato get_response = auto_responder(@view true_responses[:, 1]) @testset "decision tree round trip" begin - next_item_rule = ItemStrategyNextItemRule( - AbilityVarianceStateCriterion( + next_item_rule = ItemCriterionRule( + AbilityVariance( distribution_estimator(ability_estimator), integrator), ability_estimator = ability_estimator ) - termination_condition = FixedItemsTerminationCondition(4) + termination_condition = FixedLength(4) cat_rules = CatRules( next_item = next_item_rule, termination_condition = termination_condition, ability_estimator = ability_estimator ) - cat_loop_config = CatLoopConfig( + cat_loop_config = CatLoop( rules = cat_rules, get_response = get_response ) @@ -33,7 +33,7 @@ get_response = auto_responder(@view true_responses[:, 1]) ability_estimator = ability_estimator ) dt_materialized = generate_dt_cat(dt_generation_config, item_bank) - dt_loop_config = CatLoopConfig( + dt_loop_config = CatLoop( rules = dt_materialized, get_response = get_response ) @@ -45,7 +45,7 @@ get_response = auto_responder(@view true_responses[:, 1]) tempdir = mktempdir() save_mmap(tempdir, dt_materialized) dt_rt = load_mmap(tempdir) - dt_rt_loop_config = CatLoopConfig( + dt_rt_loop_config = CatLoop( rules = dt_rt, get_response = get_response ) diff --git a/test/dummy.jl b/test/dummy.jl index 5f71ddc..6d52a53 100644 --- a/test/dummy.jl +++ b/test/dummy.jl @@ -10,7 +10,6 @@ using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.Optimizers using Optim using Random -using ResumableFunctions struct DummyAbilityEstimator <: AbilityEstimator val::Any @@ -24,69 +23,61 @@ const optimizers_1d = [ FunctionOptimizer(OneDimOptimOptimizer(-6.0, 6.0, NelderMead())) ] const integrators_1d = [ - FunctionIntegrator(QuadGKIntegrator(-6, 6, 5)), + FunctionIntegrator(QuadGKIntegrator(lo=-6.0, hi=6.0, order=5)), FunctionIntegrator(FixedGKIntegrator(-6, 6, 80)) ] const ability_estimators_1d = [ ((:integrator,), - (stuff) -> MeanAbilityEstimator(PriorAbilityEstimator(std_normal), stuff.integrator)), + (stuff) -> MeanAbilityEstimator(PosteriorAbilityEstimator(std_normal), stuff.integrator)), ((:optimizer,), - (stuff) -> ModeAbilityEstimator(PriorAbilityEstimator(std_normal), stuff.optimizer)), + (stuff) -> ModeAbilityEstimator(PosteriorAbilityEstimator(std_normal), stuff.optimizer)), ((:integrator,), (stuff) -> MeanAbilityEstimator(LikelihoodAbilityEstimator(), stuff.integrator)), ((:optimizer,), - (stuff) -> ModeAbilityEstimator(LikelihoodAbilityEstimator(), stuff.optimizer)) + (stuff) -> ModeAbilityEstimator(SafeLikelihoodAbilityEstimator(), stuff.optimizer)) ] const criteria_1d = [ ((:integrator, :est), - (stuff) -> AbilityVarianceStateCriterion( + (stuff) -> AbilityVariance( distribution_estimator(stuff.est), stuff.integrator)), ((:est,), (stuff) -> InformationItemCriterion(stuff.est)), ((:est,), (stuff) -> UrryItemCriterion(stuff.est)), ((), (stuff) -> RandomNextItemRule()) ] -@resumable function _get_stuffs(needed) +function _get_stuffs(needed) if :est in needed - for (extra_needed, mk_est) in ability_estimators_1d + return ( + (; stuff..., est = mk_est(stuff)) + for (extra_needed, mk_est) in ability_estimators_1d for stuff in _get_stuffs(setdiff(needed, Set((:est,))) ∪ extra_needed) - x = (; stuff..., est = mk_est(stuff)) - @yield x - end - end - return + ) end if :integrator in needed - for new_integrator in integrators_1d + return ( + (; stuff..., integrator = new_integrator) + for new_integrator in integrators_1d for stuff in _get_stuffs(setdiff(needed, Set((:integrator,)))) - x = (; stuff..., integrator = new_integrator) - @yield x - end - end - return + ) end if :optimizer in needed - pop!(needed, :optimizer) - for new_optimizer in optimizers_1d + return ( + (; stuff..., optimizer = new_optimizer) + for new_optimizer in optimizers_1d for stuff in _get_stuffs(setdiff(needed, Set((:optimizer,)))) - x = (; stuff..., optimizer = new_optimizer) - @yield x - end - end - return + ) end - x = NamedTuple() - @yield x - return + return [NamedTuple()] end -@resumable function get_stuffs(needed) - add_dummy_est = !(:est in needed) - for stuff in _get_stuffs(needed) - if add_dummy_est - stuff = (; stuff..., est = DummyAbilityEstimator(0.0)) - end - @yield stuff +function get_stuffs(needed) + if !(:est in needed) + return ( + (; stuff..., est = DummyAbilityEstimator(0.0)) + for stuff in _get_stuffs(needed) + ) + else + return _get_stuffs(needed) end end diff --git a/test/runtests.jl b/test/runtests.jl index 0610f2a..beb5390 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,6 @@ using ComputerAdaptiveTesting.Aggregators using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, VectorContinuousDomain, BooleanResponse, std_normal using FittedItemBanks -using ComputerAdaptiveTesting.CatConfig using ComputerAdaptiveTesting.Responses using ComputerAdaptiveTesting.NextItemRules using ComputerAdaptiveTesting.TerminationConditions @@ -12,12 +11,11 @@ using ComputerAdaptiveTesting.Sim using PsychometricsBazaarBase.Integrators using PsychometricsBazaarBase.Optimizers using ComputerAdaptiveTesting.DecisionTree -using ComputerAdaptiveTesting: Stateful +using ComputerAdaptiveTesting: Stateful, CatRules using Distributions using Distributions: ZeroMeanIsoNormal, Zeros, ScalMat using Optim using Random -using ResumableFunctions using Test @@ -32,4 +30,5 @@ using .Dummy include("./smoke.jl") include("./dt.jl") include("./stateful.jl") + include("./compat.jl") end diff --git a/test/smoke.jl b/test/smoke.jl index 4a4c176..f5b1891 100644 --- a/test/smoke.jl +++ b/test/smoke.jl @@ -1,5 +1,17 @@ #(item_bank, abilities, responses) = dummy_full(Random.default_rng(42), SimpleItemBankSpec(StdModel4PL(), VectorContinuousDomain(), BooleanResponse()), 2; num_questions=100, num_testees=3) +using Random +using ComputerAdaptiveTesting +using ComputerAdaptiveTesting.Aggregators +using ComputerAdaptiveTesting.TerminationConditions +using ComputerAdaptiveTesting.Sim +using FittedItemBanks +using FittedItemBanks.DummyData: dummy_full, SimpleItemBankSpec, StdModel3PL, + VectorContinuousDomain, BooleanResponse, std_normal + +include("./dummy.jl") +using .Dummy + @testset "Smoke test 1d" begin (item_bank, abilities, true_responses) = dummy_full( Random.default_rng(42), @@ -10,13 +22,13 @@ function test1d(ability_estimator, bits...) rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), ability_estimator, bits... ) for testee_idx in axes(true_responses, 2) responses, ability = run_cat( - CatLoopConfig( + CatLoop( rules = rules, get_response = auto_responder(@view true_responses[:, testee_idx]) ), diff --git a/test/stateful.jl b/test/stateful.jl index 2bb84c8..19f0b54 100644 --- a/test/stateful.jl +++ b/test/stateful.jl @@ -3,11 +3,10 @@ using FittedItemBanks.DummyData: dummy_full using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL, BooleanResponse - using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition + using ComputerAdaptiveTesting.TerminationConditions: FixedLength using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule using ComputerAdaptiveTesting: Stateful using ComputerAdaptiveTesting: require_testext - using ResumableFunctions using Test: @test, @testset include("./dummy.jl") @@ -24,15 +23,15 @@ num_testees = 2 ) - @testset "StatefulCatConfig basic usage" begin + @testset "StatefulCatRules basic usage" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) # Initialize config - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Test initialization state @test isempty(Stateful.get_responses(cat_config)) @@ -54,11 +53,11 @@ @testset "Stateful next item selection" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Test first item selection first_item = Stateful.next_item(cat_config) @@ -73,13 +72,13 @@ @testset "Standard interface tests" begin rules = CatRules( - FixedItemsTerminationCondition(2), + FixedLength(2), Dummy.DummyAbilityEstimator(0.0), RandomNextItemRule() ) # Initialize config - cat_config = Stateful.StatefulCatConfig(rules, item_bank) + cat_config = Stateful.StatefulCatRules(rules, item_bank) # Run the standard interface tests TestExt = require_testext()