1+ export register_parameters!
2+
13include (" trace.jl" )
24
35"""
@@ -8,13 +10,14 @@ A generative function based on a shallowly embedding modeling language based on
810Constructed using the `@gen` keyword.
911Most methods in the generative function interface involve a end-to-end execution of the function.
1012"""
11- struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
13+ mutable struct DynamicDSLFunction{T} <: GenerativeFunction{T,DynamicDSLTrace}
1214 arg_types:: Vector{Type}
1315 has_defaults:: Bool
1416 arg_defaults:: Vector{Union{Some{Any},Nothing}}
1517 julia_function:: Function
1618 has_argument_grads:: Vector{Bool}
1719 accepts_output_grad:: Bool
20+ parameters:: Union{Vector,Function}
1821end
1922
2023function DynamicDSLFunction (arg_types:: Vector{Type} ,
@@ -26,29 +29,65 @@ function DynamicDSLFunction(arg_types::Vector{Type},
2629 return DynamicDSLFunction {T} (arg_types,
2730 has_defaults, arg_defaults,
2831 julia_function,
29- has_argument_grads, accepts_output_grad)
32+ has_argument_grads, accepts_output_grad, [] )
3033end
3134
32- function Base. show (io:: IO , gen_fn:: DynamicDSLFunction )
33- return " Gen DML generative function: $(gen_fn. julia_function) "
35+ function get_parameters (gen_fn:: DynamicDSLFunction , parameter_context)
36+ if isa (gen_fn. parameters, Vector)
37+ julia_store = get_julia_store (parameter_context)
38+ parameter_stores_to_ids = Dict {Any,Vector} ()
39+ parameter_ids = Tuple{GenerativeFunction,Symbol}[]
40+ for param in gen_fn. parameters
41+ if isa (param, Tuple{GenerativeFunction,Symbol})
42+ push! (parameter_ids, param)
43+ elseif isa (param, Symbol)
44+ push! (parameter_ids, (gen_fn, param))
45+ else
46+ throw (ArgumentError (" Invalid parameter declaration for DML generative function $gen_fn : $param " ))
47+ end
48+ end
49+ parameter_stores_to_ids[julia_store] = parameter_ids
50+ return parameter_stores_to_ids
51+ elseif isa (gen_fn. parameters, Function)
52+ return gen_fn. parameters (parameter_context)
53+ end
3454end
3555
36- function Base. show (io:: IO , :: MIME"text/plain" , gen_fn:: DynamicDSLFunction )
37- return " Gen DML generative function: $(gen_fn. julia_function) "
56+ """
57+ register_parameters!(gen_fn::DynamicDSLFunction, parameters)
58+
59+ Register the altrainable parameters that are used by a DML generative function.
60+
61+ This includes all parameters used within any calls made by the generative function.
62+
63+ There are two variants:
64+
65+ # TODO document the variants
66+ """
67+ function register_parameters! (gen_fn:: DynamicDSLFunction , parameters)
68+ gen_fn. parameters = parameters
69+ return nothing
3870end
3971
40- function get_parameters (gen_fn:: DynamicDSLFunction , parameter_context)
41- # TODO for this, we need to walk the code... (and throw errors when the
72+ function Base. show (io:: IO , gen_fn:: DynamicDSLFunction )
73+ print (io, " Gen DML generative function: $(gen_fn. julia_function) " )
74+ end
75+
76+ function Base. show (io:: IO , :: MIME"text/plain" , gen_fn:: DynamicDSLFunction )
77+ print (io, " Gen DML generative function: $(gen_fn. julia_function) " )
4278end
4379
44- function DynamicDSLTrace (gen_fn:: T , args, parameter_store:: JuliaParameterStore ) where {T<: DynamicDSLFunction }
80+ function DynamicDSLTrace (
81+ gen_fn:: T , args, parameter_store:: JuliaParameterStore ,
82+ parameter_context, registered_julia_parameters) where {T<: DynamicDSLFunction }
4583 # pad args with default values, if available
4684 if gen_fn. has_defaults && length (args) < length (gen_fn. arg_defaults)
4785 defaults = gen_fn. arg_defaults[length (args)+ 1 : end ]
4886 defaults = map (x -> something (x), defaults)
4987 args = Tuple (vcat (collect (args), defaults))
5088 end
51- return DynamicDSLTrace {T} (gen_fn, args, parameter_store)
89+ return DynamicDSLTrace {T} (
90+ gen_fn, args, parameter_store, parameter_context, registered_julia_parameters)
5291end
5392
5493accepts_output_grad (gen_fn:: DynamicDSLFunction ) = gen_fn. accepts_output_grad
@@ -118,10 +157,10 @@ end
118157function dynamic_param_impl (expr:: Expr )
119158 @assert expr. head == :genparam " Not a Gen param expression."
120159 name = expr. args[1 ]
121- Expr (:(= ), name, Expr (:call , GlobalRef (@__MODULE__ , :read_param ), state, QuoteNode (name)))
160+ Expr (:(= ), name, Expr (:call , GlobalRef (@__MODULE__ , :read_param! ), state, QuoteNode (name)))
122161end
123162
124- function read_param (state, name:: Symbol )
163+ function read_param! (state, name:: Symbol )
125164 parameter_id = get_parameter_id (state, name)
126165 store = get_parameter_store (state)
127166 return get_parameter_value (parameter_id, store)
0 commit comments