Skip to content

Commit 2d4a674

Browse files
Merge pull request #3795 from AayushSabharwal/as/new-io
feat: more robust inputs/outputs handling
2 parents f51778d + 7a8201f commit 2d4a674

File tree

6 files changed

+86
-23
lines changed

6 files changed

+86
-23
lines changed

src/inputoutput.jl

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using Symbolics: get_variables
55
Return all variables that mare marked as inputs. See also [`unbound_inputs`](@ref)
66
See also [`bound_inputs`](@ref), [`unbound_inputs`](@ref)
77
"""
8-
inputs(sys) = [filter(isinput, unknowns(sys)); filter(isinput, parameters(sys))]
8+
inputs(sys) = collect(get_inputs(sys))
99

1010
"""
1111
outputs(sys)
@@ -14,13 +14,7 @@ Return all variables that mare marked as outputs. See also [`unbound_outputs`](@
1414
See also [`bound_outputs`](@ref), [`unbound_outputs`](@ref)
1515
"""
1616
function outputs(sys)
17-
o = observed(sys)
18-
rhss = [eq.rhs for eq in o]
19-
lhss = [eq.lhs for eq in o]
20-
unique([filter(isoutput, unknowns(sys))
21-
filter(isoutput, parameters(sys))
22-
filter(x -> iscall(x) && isoutput(x), rhss) # observed can return equations with complicated expressions, we are only looking for single Terms
23-
filter(x -> iscall(x) && isoutput(x), lhss)])
17+
return collect(get_outputs(sys))
2418
end
2519

2620
"""
@@ -288,7 +282,12 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
288282
push!(new_fullvars, v)
289283
end
290284
end
291-
ninputs == 0 && return state
285+
if ninputs == 0
286+
@set! sys.inputs = OrderedSet{BasicSymbolic}()
287+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
288+
state.sys = sys
289+
return state
290+
end
292291

293292
nvars = ndsts(graph) - ninputs
294293
new_graph = BipartiteGraph(nsrcs(graph), nvars, Val(false))
@@ -318,6 +317,8 @@ function inputs_to_parameters!(state::TransformationState, inputsyms)
318317
ps = parameters(sys)
319318

320319
@set! sys.ps = [ps; new_parameters]
320+
@set! sys.inputs = OrderedSet{BasicSymbolic}(new_parameters)
321+
@set! sys.outputs = OrderedSet{BasicSymbolic}(filter(isoutput, fullvars))
321322
@set! state.sys = sys
322323
@set! state.fullvars = Vector{BasicSymbolic}(new_fullvars)
323324
@set! state.structure = structure

src/linearization.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -572,8 +572,7 @@ struct IONotFoundError <: Exception
572572
end
573573

574574
function Base.showerror(io::IO, err::IONotFoundError)
575-
println(io,
576-
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
575+
println(io, "The following $(err.variant) provided to `mtkcompile` were not found in the system:")
577576
maybe_namespace_issue = false
578577
for var in err.not_found
579578
println(io, " ", var)

src/systems/abstractsystem.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ const SYS_PROPS = [:eqs
784784
:parent
785785
:is_dde
786786
:tstops
787+
:inputs
788+
:outputs
787789
:index_cache
788790
:isscheduled
789791
:costs
@@ -1820,6 +1822,17 @@ function push_vars!(stmt, name, typ, vars)
18201822
ex = nameof(s)
18211823
end
18221824
push!(vars_expr.args, ex)
1825+
1826+
meta_kvps = Expr[]
1827+
if isinput(s)
1828+
push!(meta_kvps, :(input = true))
1829+
end
1830+
if isoutput(s)
1831+
push!(meta_kvps, :(output = true))
1832+
end
1833+
if !isempty(meta_kvps)
1834+
push!(vars_expr.args, Expr(:vect, meta_kvps...))
1835+
end
18231836
end
18241837
push!(stmt, :($name = $collect($vars_expr)))
18251838
return

src/systems/system.jl

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,16 @@ struct System <: IntermediateDeprecationSystem
190190
"""
191191
tstops::Vector{Any}
192192
"""
193+
$INTERNAL_FIELD_WARNING
194+
The list of input variables of the system.
195+
"""
196+
inputs::OrderedSet{BasicSymbolic}
197+
"""
198+
$INTERNAL_FIELD_WARNING
199+
The list of output variables of the system.
200+
"""
201+
outputs::OrderedSet{BasicSymbolic}
202+
"""
193203
The `TearingState` of the system post-simplification with `mtkcompile`.
194204
"""
195205
tearing_state::Any
@@ -255,8 +265,9 @@ struct System <: IntermediateDeprecationSystem
255265
brownians, iv, observed, parameter_dependencies, var_to_name, name, description,
256266
defaults, guesses, systems, initialization_eqs, continuous_events, discrete_events,
257267
connector_type, assertions = Dict{BasicSymbolic, String}(),
258-
metadata = MetadataT(), gui_metadata = nothing,
259-
is_dde = false, tstops = [], tearing_state = nothing, namespacing = true,
268+
metadata = MetadataT(), gui_metadata = nothing, is_dde = false, tstops = [],
269+
inputs = Set{BasicSymbolic}(), outputs = Set{BasicSymbolic}(),
270+
tearing_state = nothing, namespacing = true,
260271
complete = false, index_cache = nothing, ignored_connections = nothing,
261272
preface = nothing, parent = nothing, initializesystem = nothing,
262273
is_initializesystem = false, is_discrete = false, isscheduled = false,
@@ -296,7 +307,8 @@ struct System <: IntermediateDeprecationSystem
296307
observed, parameter_dependencies, var_to_name, name, description, defaults,
297308
guesses, systems, initialization_eqs, continuous_events, discrete_events,
298309
connector_type, assertions, metadata, gui_metadata, is_dde,
299-
tstops, tearing_state, namespacing, complete, index_cache, ignored_connections,
310+
tstops, inputs, outputs, tearing_state, namespacing,
311+
complete, index_cache, ignored_connections,
300312
preface, parent, initializesystem, is_initializesystem, is_discrete,
301313
isscheduled, schedule)
302314
end
@@ -332,7 +344,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
332344
continuous_events = SymbolicContinuousCallback[], discrete_events = SymbolicDiscreteCallback[],
333345
connector_type = nothing, assertions = Dict{BasicSymbolic, String}(),
334346
metadata = MetadataT(), gui_metadata = nothing,
335-
is_dde = nothing, tstops = [], tearing_state = nothing,
347+
is_dde = nothing, tstops = [], inputs = OrderedSet{BasicSymbolic}(),
348+
outputs = OrderedSet{BasicSymbolic}(), tearing_state = nothing,
336349
ignored_connections = nothing, parent = nothing,
337350
description = "", name = nothing, discover_from_metadata = true,
338351
initializesystem = nothing, is_initializesystem = false, is_discrete = false,
@@ -367,15 +380,35 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
367380

368381
defaults = anydict(defaults)
369382
guesses = anydict(guesses)
383+
inputs = OrderedSet{BasicSymbolic}(inputs)
384+
outputs = OrderedSet{BasicSymbolic}(outputs)
385+
for subsys in systems
386+
for var in ModelingToolkit.inputs(subsys)
387+
push!(inputs, renamespace(subsys, var))
388+
end
389+
for var in ModelingToolkit.outputs(subsys)
390+
push!(outputs, renamespace(subsys, var))
391+
end
392+
end
370393
var_to_name = anydict()
371394

372395
let defaults = discover_from_metadata ? defaults : Dict(),
373-
guesses = discover_from_metadata ? guesses : Dict()
396+
guesses = discover_from_metadata ? guesses : Dict(),
397+
inputs = discover_from_metadata ? inputs : Set(),
398+
outputs = discover_from_metadata ? outputs : Set()
374399

375400
process_variables!(var_to_name, defaults, guesses, dvs)
376401
process_variables!(var_to_name, defaults, guesses, ps)
377402
process_variables!(var_to_name, defaults, guesses, [eq.lhs for eq in observed])
378403
process_variables!(var_to_name, defaults, guesses, [eq.rhs for eq in observed])
404+
405+
for var in dvs
406+
if isinput(var)
407+
push!(inputs, var)
408+
elseif isoutput(var)
409+
push!(outputs, var)
410+
end
411+
end
379412
end
380413
filter!(!(isnothing last), defaults)
381414
filter!(!(isnothing last), guesses)
@@ -417,7 +450,8 @@ function System(eqs::Vector{Equation}, iv, dvs, ps, brownians = [];
417450
costs, consolidate, dvs, ps, brownians, iv, observed, Equation[],
418451
var_to_name, name, description, defaults, guesses, systems, initialization_eqs,
419452
continuous_events, discrete_events, connector_type, assertions, metadata, gui_metadata, is_dde,
420-
tstops, tearing_state, true, false, nothing, ignored_connections, preface, parent,
453+
tstops, inputs, outputs, tearing_state, true, false,
454+
nothing, ignored_connections, preface, parent,
421455
initializesystem, is_initializesystem, is_discrete; checks)
422456
end
423457

@@ -731,6 +765,7 @@ function flatten(sys::System, noeqs = false)
731765
discrete_events = discrete_events(sys), assertions = assertions(sys),
732766
is_dde = is_dde(sys), tstops = symbolic_tstops(sys),
733767
initialization_eqs = initialization_equations(sys),
768+
inputs = inputs(sys), outputs = outputs(sys),
734769
# without this, any defaults/guesses obtained from metadata that were
735770
# later removed by the user will be re-added. Right now, we just want to
736771
# retain `defaults(sys)` as-is.
@@ -1143,6 +1178,8 @@ function Base.isapprox(sysa::System, sysb::System)
11431178
isequal(get_metadata(sysa), get_metadata(sysb)) &&
11441179
isequal(get_is_dde(sysa), get_is_dde(sysb)) &&
11451180
issetequal(get_tstops(sysa), get_tstops(sysb)) &&
1181+
issetequal(get_inputs(sysa), get_inputs(sysb)) &&
1182+
issetequal(get_outputs(sysa), get_outputs(sysb)) &&
11461183
safe_issetequal(get_ignored_connections(sysa), get_ignored_connections(sysb)) &&
11471184
isequal(get_is_initializesystem(sysa), get_is_initializesystem(sysb)) &&
11481185
isequal(get_is_discrete(sysa), get_is_discrete(sysb)) &&

src/systems/systemstructure.jl

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -993,13 +993,9 @@ function _mtkcompile!(state::TearingState; simplify = false,
993993
else
994994
check_consistency = true
995995
end
996-
has_io = !isempty(inputs) || !isempty(outputs) !== nothing ||
997-
!isempty(disturbance_inputs)
998996
orig_inputs = Set()
999-
if has_io
1000-
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
1001-
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
1002-
end
997+
ModelingToolkit.markio!(state, orig_inputs, inputs, outputs, disturbance_inputs)
998+
state = ModelingToolkit.inputs_to_parameters!(state, [inputs; disturbance_inputs])
1003999
trivial_tearing!(state)
10041000
sys, mm = ModelingToolkit.alias_elimination!(state; fully_determined, kwargs...)
10051001
if check_consistency

test/input_output_handling.jl

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -468,3 +468,20 @@ end
468468
x = [1.0]
469469
@test_nowarn f[1](x, u, p, 0.0)
470470
end
471+
472+
@testset "Observed inputs and outputs" begin
473+
@variables x(t) y(t) [input = true] z(t) [output = true]
474+
eqs = [D(x) ~ x + y + z
475+
y ~ z]
476+
@named sys = System(eqs, t)
477+
@test issetequal(ModelingToolkit.inputs(sys), [y])
478+
@test issetequal(ModelingToolkit.outputs(sys), [z])
479+
480+
ss1 = mtkcompile(sys, inputs = [y], outputs = [z])
481+
@test issetequal(ModelingToolkit.inputs(ss1), [y])
482+
@test issetequal(ModelingToolkit.outputs(ss1), [z])
483+
484+
ss2 = mtkcompile(sys, inputs = [z], outputs = [y])
485+
@test issetequal(ModelingToolkit.inputs(ss2), [z])
486+
@test issetequal(ModelingToolkit.outputs(ss2), [y])
487+
end

0 commit comments

Comments
 (0)