Skip to content

Commit 3fc5a0e

Browse files
committed
tests passing
1 parent d6f636c commit 3fc5a0e

39 files changed

+404
-230
lines changed

src/dynamic/assess.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ function traceat(state::GFAssessState, gen_fn::GenerativeFunction{T,U},
5656
end
5757

5858
function assess(
59-
gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap;
60-
parameter_context=default_parameter_context)
59+
gen_fn::DynamicDSLFunction, args::Tuple, choices::ChoiceMap,
60+
parameter_context::Dict)
6161
state = GFAssessState(gen_fn, choices, parameter_context)
6262
retval = exec(gen_fn, state, args)
6363

src/dynamic/backprop.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,8 +58,11 @@ mutable struct GFBackpropParamsState
5858
end
5959
end
6060

61-
function read_param(state::GFBackpropParamsState, name::Symbol)
61+
function read_param!(state::GFBackpropParamsState, name::Symbol)
6262
parameter_id = (state.active_gen_fn, name)
63+
if !(parameter_id in state.trace.registered_julia_parameters)
64+
throw(ArgumentError("parameter $parameter_id was not registered using register_parameters!"))
65+
end
6366
return state.tracked_params[parameter_id]
6467
end
6568

src/dynamic/dynamic.jl

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
export register_parameters!
2+
13
include("trace.jl")
24

35
"""
@@ -8,13 +10,14 @@ A generative function based on a shallowly embedding modeling language based on
810
Constructed using the `@gen` keyword.
911
Most methods in the generative function interface involve a end-to-end execution of the function.
1012
"""
11-
struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
13+
mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
1214
arg_types::Vector{Type}
1315
has_defaults::Bool
1416
arg_defaults::Vector{Union{Some{Any},Nothing}}
1517
julia_function::Function
1618
has_argument_grads::Vector{Bool}
1719
accepts_output_grad::Bool
20+
parameters::Union{Vector,Function}
1821
end
1922

2023
function DynamicDSLFunction(arg_types::Vector{Type},
@@ -26,29 +29,65 @@ function DynamicDSLFunction(arg_types::Vector{Type},
2629
return DynamicDSLFunction{T}(arg_types,
2730
has_defaults, arg_defaults,
2831
julia_function,
29-
has_argument_grads, accepts_output_grad)
32+
has_argument_grads, accepts_output_grad, [])
3033
end
3134

32-
function Base.show(io::IO, gen_fn::DynamicDSLFunction)
33-
return "Gen DML generative function: $(gen_fn.julia_function)"
35+
function get_parameters(gen_fn::DynamicDSLFunction, parameter_context)
36+
if isa(gen_fn.parameters, Vector)
37+
julia_store = get_julia_store(parameter_context)
38+
parameter_stores_to_ids = Dict{Any,Vector}()
39+
parameter_ids = Tuple{GenerativeFunction,Symbol}[]
40+
for param in gen_fn.parameters
41+
if isa(param, Tuple{GenerativeFunction,Symbol})
42+
push!(parameter_ids, param)
43+
elseif isa(param, Symbol)
44+
push!(parameter_ids, (gen_fn, param))
45+
else
46+
throw(ArgumentError("Invalid parameter declaration for DML generative function $gen_fn: $param"))
47+
end
48+
end
49+
parameter_stores_to_ids[julia_store] = parameter_ids
50+
return parameter_stores_to_ids
51+
elseif isa(gen_fn.parameters, Function)
52+
return gen_fn.parameters(parameter_context)
53+
end
3454
end
3555

36-
function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
37-
return "Gen DML generative function: $(gen_fn.julia_function)"
56+
"""
57+
register_parameters!(gen_fn::DynamicDSLFunction, parameters)
58+
59+
Register the altrainable parameters that are used by a DML generative function.
60+
61+
This includes all parameters used within any calls made by the generative function.
62+
63+
There are two variants:
64+
65+
# TODO document the variants
66+
"""
67+
function register_parameters!(gen_fn::DynamicDSLFunction, parameters)
68+
gen_fn.parameters = parameters
69+
return nothing
3870
end
3971

40-
function get_parameters(gen_fn::DynamicDSLFunction, parameter_context)
41-
# TODO for this, we need to walk the code... (and throw errors when the
72+
function Base.show(io::IO, gen_fn::DynamicDSLFunction)
73+
print(io, "Gen DML generative function: $(gen_fn.julia_function)")
74+
end
75+
76+
function Base.show(io::IO, ::MIME"text/plain", gen_fn::DynamicDSLFunction)
77+
print(io, "Gen DML generative function: $(gen_fn.julia_function)")
4278
end
4379

44-
function DynamicDSLTrace(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T<:DynamicDSLFunction}
80+
function DynamicDSLTrace(
81+
gen_fn::T, args, parameter_store::JuliaParameterStore,
82+
parameter_context, registered_julia_parameters) where {T<:DynamicDSLFunction}
4583
# pad args with default values, if available
4684
if gen_fn.has_defaults && length(args) < length(gen_fn.arg_defaults)
4785
defaults = gen_fn.arg_defaults[length(args)+1:end]
4886
defaults = map(x -> something(x), defaults)
4987
args = Tuple(vcat(collect(args), defaults))
5088
end
51-
return DynamicDSLTrace{T}(gen_fn, args, parameter_store)
89+
return DynamicDSLTrace{T}(
90+
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
5291
end
5392

5493
accepts_output_grad(gen_fn::DynamicDSLFunction) = gen_fn.accepts_output_grad
@@ -118,10 +157,10 @@ end
118157
function dynamic_param_impl(expr::Expr)
119158
@assert expr.head == :genparam "Not a Gen param expression."
120159
name = expr.args[1]
121-
Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param), state, QuoteNode(name)))
160+
Expr(:(=), name, Expr(:call, GlobalRef(@__MODULE__, :read_param!), state, QuoteNode(name)))
122161
end
123162

124-
function read_param(state, name::Symbol)
163+
function read_param!(state, name::Symbol)
125164
parameter_id = get_parameter_id(state, name)
126165
store = get_parameter_store(state)
127166
return get_parameter_value(parameter_id, store)

src/dynamic/generate.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ mutable struct GFGenerateState
88

99
function GFGenerateState(gen_fn, args, constraints, parameter_context)
1010
parameter_store = get_julia_store(parameter_context)
11-
trace = DynamicDSLTrace(gen_fn, args, parameter_store)
11+
registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}(
12+
get_parameters(gen_fn, parameter_context)[parameter_store])
13+
trace = DynamicDSLTrace(
14+
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
1215
return new(trace, constraints, 0., AddressVisitor(), gen_fn, parameter_context)
1316
end
1417
end
@@ -23,7 +26,6 @@ function set_active_gen_fn!(state::GFGenerateState, gen_fn::GenerativeFunction)
2326
state.active_gen_fn = gen_fn
2427
end
2528

26-
2729
function traceat(state::GFGenerateState, dist::Distribution{T},
2830
args, key) where {T}
2931
local retval::T
@@ -53,7 +55,7 @@ function traceat(state::GFGenerateState, dist::Distribution{T},
5355
state.weight += score
5456
end
5557

56-
retval
58+
return retval
5759
end
5860

5961
function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
@@ -69,8 +71,7 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
6971

7072
# get subtrace
7173
(subtrace, weight) = generate(
72-
gen_fn, args, constraints;
73-
parameter_context=state.parameter_context)
74+
gen_fn, args, constraints, state.parameter_context)
7475

7576
# add to the trace
7677
add_call!(state.trace, key, subtrace)
@@ -81,14 +82,14 @@ function traceat(state::GFGenerateState, gen_fn::GenerativeFunction{T,U},
8182
# get return value
8283
retval = get_retval(subtrace)
8384

84-
retval
85+
return retval
8586
end
8687

8788
function generate(
88-
gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap;
89-
parameter_context=default_parameter_context)
89+
gen_fn::DynamicDSLFunction, args::Tuple, constraints::ChoiceMap,
90+
parameter_context::Dict)
9091
state = GFGenerateState(gen_fn, args, constraints, parameter_context)
9192
retval = exec(gen_fn, state, args)
9293
set_retval!(state.trace, retval)
93-
(state.trace, state.weight)
94+
return (state.trace, state.weight)
9495
end

src/dynamic/propose.jl

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ function traceat(state::GFProposeState, dist::Distribution{T},
3737
# update weight
3838
state.weight += logpdf(dist, retval, args...)
3939

40-
retval
40+
return retval
4141
end
4242

4343
function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
@@ -56,12 +56,10 @@ function traceat(state::GFProposeState, gen_fn::GenerativeFunction{T,U},
5656
# update weight
5757
state.weight += weight
5858

59-
retval
59+
return retval
6060
end
6161

62-
function propose(
63-
gen_fn::DynamicDSLFunction, args::Tuple;
64-
parameter_context=default_parameter_context)
62+
function propose(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict)
6563
state = GFProposeState(gen_fn, parameter_context)
6664
retval = exec(gen_fn, state, args)
6765
return (state.choices, state.weight, retval)

src/dynamic/regenerate.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ mutable struct GFRegenerateState
88

99
function GFRegenerateState(gen_fn, args, prev_trace, selection)
1010
visitor = AddressVisitor()
11-
trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace))
12-
return new(prev_trace, trace, selection,
13-
0., visitor, gen_fn)
11+
trace = initialize_from(prev_trace, args)
12+
return new(prev_trace, trace, selection, 0.0, visitor, gen_fn)
1413
end
1514
end
1615

@@ -24,7 +23,6 @@ function set_active_gen_fn!(state::GFRegenerateState, gen_fn::GenerativeFunction
2423
state.active_gen_fn = gen_fn
2524
end
2625

27-
2826
function traceat(state::GFRegenerateState, dist::Distribution{T},
2927
args, key) where {T}
3028
local prev_retval::T
@@ -88,7 +86,8 @@ function traceat(state::GFRegenerateState, gen_fn::GenerativeFunction{T,U},
8886
(subtrace, weight, _) = regenerate(
8987
prev_subtrace, args, map((_) -> UnknownChange(), args), subselection)
9088
else
91-
(subtrace, weight) = generate(gen_fn, args, EmptyChoiceMap())
89+
(subtrace, weight) = generate(
90+
gen_fn, args, EmptyChoiceMap(), state.trace.parameter_context)
9291
end
9392

9493
# update weight

src/dynamic/simulate.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ mutable struct GFSimulateState
77
function GFSimulateState(
88
gen_fn::GenerativeFunction, args::Tuple, parameter_context)
99
parameter_store = get_julia_store(parameter_context)
10-
trace = DynamicDSLTrace(gen_fn, args, parameter_store)
10+
registered_julia_parameters = Set{Tuple{GenerativeFunction,Symbol}}(
11+
get_parameters(gen_fn, parameter_context)[parameter_store])
12+
trace = DynamicDSLTrace(
13+
gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
1114
return new(trace, AddressVisitor(), gen_fn, parameter_context)
1215
end
1316
end
@@ -49,7 +52,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
4952
visit!(state.visitor, key)
5053

5154
# get subtrace
52-
subtrace = simulate(gen_fn, args; parameter_context=state.parameter_context)
55+
subtrace = simulate(gen_fn, args, state.parameter_context)
5356

5457
# add to the trace
5558
add_call!(state.trace, key, subtrace)
@@ -60,9 +63,7 @@ function traceat(state::GFSimulateState, gen_fn::GenerativeFunction{T,U},
6063
retval
6164
end
6265

63-
function simulate(
64-
gen_fn::DynamicDSLFunction, args::Tuple;
65-
parameter_context=default_parameter_context)
66+
function simulate(gen_fn::DynamicDSLFunction, args::Tuple, parameter_context::Dict)
6667
state = GFSimulateState(gen_fn, args, parameter_context)
6768
retval = exec(gen_fn, state, args)
6869
set_retval!(state.trace, retval)

src/dynamic/trace.jl

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,27 @@ mutable struct DynamicDSLTrace{T} <: Trace
3838
noise::Float64
3939
args::Tuple
4040
parameter_store::JuliaParameterStore
41+
parameter_context::Dict
42+
registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}} # for runtime cross-check
4143
retval::Any
42-
function DynamicDSLTrace{T}(gen_fn::T, args, parameter_store::JuliaParameterStore) where {T}
44+
function DynamicDSLTrace{T}(
45+
gen_fn::T, args, parameter_store::JuliaParameterStore, parameter_context,
46+
registered_julia_parameters::Set{Tuple{GenerativeFunction,Symbol}}) where {T}
4347
trie = Trie{Any,ChoiceOrCallRecord}()
4448
# retval is not known yet
45-
new(gen_fn, trie, true, 0, 0, args, parameter_store)
49+
new(
50+
gen_fn, trie, true, 0, 0, args, parameter_store,
51+
parameter_context, registered_julia_parameters)
4652
end
4753
end
4854

55+
function initialize_from(other::DynamicDSLTrace, args)
56+
gen_fn = get_gen_fn(other)
57+
return DynamicDSLTrace(
58+
gen_fn, args, other.parameter_store, other.parameter_context,
59+
other.registered_julia_parameters)
60+
end
61+
4962
get_parameter_store(trace::DynamicDSLTrace) = trace.parameter_store
5063

5164
set_retval!(trace::DynamicDSLTrace, retval) = (trace.retval = retval)

src/dynamic/update.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ mutable struct GFUpdateState
1010
function GFUpdateState(gen_fn, args, prev_trace, constraints)
1111
visitor = AddressVisitor()
1212
discard = choicemap()
13-
trace = DynamicDSLTrace(gen_fn, args, get_parameter_store(prev_trace))
14-
return new(prev_trace, trace, constraints,
15-
0., visitor, discard, gen_fn)
13+
parameter_store = get_parameter_store(prev_trace)
14+
trace = initialize_from(prev_trace, args)
15+
return new(prev_trace, trace, constraints, 0.0, visitor, discard, gen_fn)
1616
end
1717
end
1818

@@ -26,7 +26,6 @@ function set_active_gen_fn!(state::GFUpdateState, gen_fn::GenerativeFunction)
2626
state.active_gen_fn = gen_fn
2727
end
2828

29-
3029
function traceat(state::GFUpdateState, dist::Distribution{T},
3130
args::Tuple, key) where {T}
3231

@@ -101,7 +100,8 @@ function traceat(state::GFUpdateState, gen_fn::GenerativeFunction{T,U},
101100
(subtrace, weight, _, discard) = update(prev_subtrace,
102101
args, map((_) -> UnknownChange(), args), constraints)
103102
else
104-
(subtrace, weight) = generate(gen_fn, args, constraints)
103+
(subtrace, weight) = generate(
104+
gen_fn, args, constraints, state.trace.parameter_context)
105105
end
106106

107107
# update the weight

0 commit comments

Comments
 (0)